pytorch小记(七):pytorch中的保存/加载模型操作
- 1. 加载模型参数 (`state_dict`)
- 1.1 保存模型参数
- 1.2 加载模型参数
- 1.3 常见变种
- 1.3.1 指定加载设备
- 1.3.2 非严格加载(跳过部分层)
- 1.3.3 打印加载的参数
- 2. 加载整个模型
- 2.1 保存整个模型
- 2.2 加载整个模型
- 2.3 注意事项
- 3. 总结
- 4. 加载模型的完整代码示例
- 4.1 保存和加载参数
- 4.2 保存和加载整个模型
- 4.3 加载到不同设备
- 4.4 忽略部分参数(非严格加载)
- 5. 检查模型是否加载成功
在 PyTorch 中,加载模型通常分为两种情况:加载模型参数(state_dict) 和 加载整个模型。以下是加载模型的所有相关操作及其详细步骤:
1. 加载模型参数 (state_dict
)
当仅保存了模型的参数时(使用 model.state_dict()
保存),加载模型的步骤如下:
1.1 保存模型参数
torch.save(model.state_dict(), 'model.pth')
- 文件内容:只保存模型的参数(权重和偏置)。
- 优点:
- 节省存储空间。
- 灵活性更高,可以与不同的模型架构配合使用。
- 缺点:
- 需要手动重新定义模型结构。
1.2 加载模型参数
-
重新定义模型架构:
model = MyModel() # 替换为你的模型类
-
加载参数:
state_dict = torch.load('model.pth') # 加载参数字典 model.load_state_dict(state_dict) # 加载参数到模型
-
选择运行设备:
model.to('cuda') # 如果需要运行在 GPU 上
1.3 常见变种
1.3.1 指定加载设备
- 如果保存时模型在 GPU 上,而加载时在 CPU 环境中,可以使用
map_location
:state_dict = torch.load('model.pth', map_location='cpu')
1.3.2 非严格加载(跳过部分层)
- 如果保存的参数与模型结构不完全匹配(例如额外的层或不同的顺序),可以使用
strict=False
:model.load_state_dict(state_dict, strict=False)
1.3.3 打印加载的参数
- 可以检查参数字典的内容:
print(state_dict.keys())
2. 加载整个模型
当模型是通过 torch.save(model)
保存时,文件包含了模型的结构和参数,加载更为简单。
2.1 保存整个模型
torch.save(model, 'model_full.pth')
- 文件内容:包含模型的架构和参数。
- 优点:
- 无需重新定义模型结构。
- 直接加载并使用。
- 缺点:
- 文件依赖于保存时的代码版本(如模型定义)。
- 文件体积较大。
2.2 加载整个模型
model = torch.load('model_full.pth')
model.to('cuda') # 如果需要在 GPU 上运行
2.3 注意事项
- 动态定义的模型:
- 如果模型结构是动态定义的(如包含条件逻辑),保存和加载整个模型可能会依赖于代码的一致性。
- 确保在加载时导入了与保存时相同的模型类。
3. 总结
操作 | 使用场景 | 优点 | 缺点 |
---|---|---|---|
保存参数 (state_dict ) | 推荐大多数情况 | 文件小、灵活性高 | 需要手动定义模型架构 |
保存整个模型 | 模型复杂且固定时 | 不需要重新定义模型,直接加载 | 文件大、依赖保存时的代码版本 |
4. 加载模型的完整代码示例
4.1 保存和加载参数
import torch
import torch.nn as nn# 定义模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 保存参数
model = MyModel()
torch.save(model.state_dict(), 'model.pth')# 加载参数
model = MyModel() # 重新定义模型
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
model.to('cuda') # 运行在 GPU
4.2 保存和加载整个模型
# 保存整个模型
torch.save(model, 'model_full.pth')# 加载整个模型
model = torch.load('model_full.pth')
model.to('cuda') # 运行在 GPU
4.3 加载到不同设备
# 保存参数
torch.save(model.state_dict(), 'model.pth')# 加载到 CPU
state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)# 加载到 GPU
model.to('cuda')
4.4 忽略部分参数(非严格加载)
# 保存参数
torch.save(model.state_dict(), 'model.pth')# 加载参数(非严格模式)
model = MyModel()
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict, strict=False)
5. 检查模型是否加载成功
-
验证权重是否加载
for name, param in model.named_parameters():print(f"{name}: {param.data}")
-
进行推理验证
x = torch.randn(1, 10).to('cuda') # 假设输入维度为 10 output = model(x) print(output)
通过以上操作,你可以灵活加载 PyTorch 模型,无论是仅加载参数还是加载整个模型结构和权重。