在 PyTorch 中,nn.Sequential
和 nn.ModuleList
都是用来容纳多个子模块的容器,但它们的用途和行为有所不同。以下是它们的用法和区别的详细解释,以及样例代码。
nn.Sequential
nn.Sequential
是一个顺序容器,模块将按它们在传递给 Sequential
的顺序依次执行。它适用于那些输入经过多个模块顺序处理的简单模型。
用法:
import torch.nn as nnmodel = nn.Sequential(nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()
)
nn.ModuleList
nn.ModuleList
是一个持有 nn.Module
的列表。与 nn.Sequential
不同的是,nn.ModuleList
并不会定义模块的前向传播顺序,您需要在 forward
方法中手动定义它们的顺序。这给了用户更大的灵活性,可以在前向传播时对模块进行不同的操作。
用法:
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layers = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])def forward(self, x):for layer in self.layers:x = layer(x)return xmodel = MyModel()
区别总结
nn.Sequential
:模块按顺序执行,适用于顺序的简单模型。
nn.ModuleList
:只是一个持有模块的列表,您需要手动定义前向传播的顺序,适用于需要更复杂控制的模型。
样例代码
以下是一个简单的例子,展示了如何使用这两个容器来构建模型:
import torch
import torch.nn as nn# 定义一个顺序执行的模型
model_seq = nn.Sequential(nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()
)# 创建一个示例输入
input_seq = torch.randn(1, 1, 28, 28)
# 前向传播
output_seq = model_seq(input_seq)
print(output_seq.shape)import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layers = nn.ModuleList([nn.Conv2d(1, 20, 5),nn.ReLU(),nn.Conv2d(20, 64, 5),nn.ReLU()])def forward(self, x):for layer in self.layers:x = layer(x)return x# 定义一个模型
model_list = MyModel()# 创建一个示例输入
input_list = torch.randn(1, 1, 28, 28)
# 前向传播
output_list = model_list(input_list)
print(output_list.shape)