pytorch模型定义的几种方式
主要包括 标准继承模式 和 使用常用容器 两种方式
1. 继承nn.Module实现模型
通过继承nn模块的Module基类,并实现初始化init()以及forward()方法,实现模型定义
1 | import torch.nn as nn |
模型的使用
1 | test_input = torch.rand((256,128)) |
2. 快速定义模型
pytorch提供了一系列继承自nn.Module的实现类,类如Sequential等用来快速定义模型
1. nn.Sequential
A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.
向Sequential中传入一系列的层/其他模型(Module),按照传入的顺序或者OrderedDict中的顺序进行前向传播计算,实现模型定义
1 | # Example of using Sequential |
2.nn.ModuleList
ModuleList
can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by allModule
methods.
存储module的列表,支持列表的append,insert操作,并可通过坐标形式访问
- 与sequence不同点,在于ModuList就是个List,不支持forward前向传播运算
- 与List不同点,在于ModuList中所有模型的参数均加入到了反向传播梯度计算监控中
1 | class MyModule(nn.Module): |
与列表存储模型不同点比较
1 | class MyModule1(nn.Module): |