PyTorch 模型保存与加载的三种常用方式

在深度学习的训练过程中,我们不可避免地要保存模型,这是一个非常好的习惯。接下来,文章将通过一个简单的神经网络模型,带你了解 PyTorch 中主要的模型保存与加载方式。

文章目录

  • 为什么保存和加载模型很重要?
  • 代码示例
    • 模型准备
    • 方法一:保存和加载整个模型
    • 方法二:只保存模型的状态字典(state_dict)
      • 使用 `strict=False` 加载模型
    • 方法三:保存完整的训练状态(checkpoint)
    • 定义 checkpont 保存和加载的函数

为什么保存和加载模型很重要?

训练一个神经网络可能需要数小时甚至数天的时间,你需要认知到一点:时间是非常宝贵的,目前3090云服务器租赁一天的价格为 37.92 元。如果你的代码没有保存模型的模块,那就先不要开始,因为不保存基本等于没跑,你的效果再好也没有办法直接呈现给别人。如果你保存了模型,你就可以做到以下的事情:

  • 继续训练:通过保存检查点(checkpoint),你可以在意外中断后继续训练你的模型,这一点可能会节省你大量的时间。
  • 模型部署:训练好的模型可以被部署到生产环境中进行推理,比如 LLM,LoRA 等。
  • 分享模型:将训练好的模型分享给实验室其他成员或开源社区,以便进一步研究或复现结果。

代码示例

模型准备

为了演示,我们先定义一个简单的神经网络模型:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(784, 128)  # 输入层到隐藏层self.fc2 = nn.Linear(128, 64)   # 隐藏层到隐藏层self.fc3 = nn.Linear(64, 10)    # 隐藏层到输出层def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化模型和优化器
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

方法一:保存和加载整个模型

保存模型

torch.save(model, 'model.pth')

加载模型

model = torch.load('model.pth')
print(model)

输出

Net((fc1): Linear(in_features=784, out_features=128, bias=True)(fc2): Linear(in_features=128, out_features=64, bias=True)(fc3): Linear(in_features=64, out_features=10, bias=True)
)

这种方法非常简单直观,因为它保存了模型的整个结构和参数。

方法二:只保存模型的状态字典(state_dict)

保存模型状态字典

torch.save(model.state_dict(), 'model_state_dict.pth')

加载模型状态字典
需要注意的是,加载state_dict时你需要手动重新实例化模型。

model = Net()  # 你需要先定义好模型架构
model.load_state_dict(torch.load('model_state_dict.pth'))
print(model)

输出

Net((fc1): Linear(in_features=784, out_features=128, bias=True)(fc2): Linear(in_features=128, out_features=64, bias=True)(fc3): Linear(in_features=64, out_features=10, bias=True)
)

与保存整个模型相比,保存 state_dict 更加灵活,它只包含模型的参数,而不依赖于完整的模型定义,这意味着你可以在不同的项目中加载模型参数,甚至只加载部分模型的权重。举个例子,对于分类模型,即便你保存的是完整的网络参数,也可以仅导入特征提取层部分,当然,直接导入完整模型再拆分实际上是一样的。对于不完全匹配的模型,加载时可以通过设置 strict=False 来忽略某些不匹配的键:

model.load_state_dict(torch.load('model_state_dict.pth'), strict=False)

这样,你可以灵活地只加载模型的某些部分。

使用 strict=False 加载模型

假设我们在原来的 Net 模型中新增了一个全连接层(fc4),此时如果我们直接加载之前保存的 state_dict,会因为 state_dict 中没有 fc4 的权重信息而导致报错。

import torch
import torch.nn as nn
import torch.nn.functional as F# 修改后的模型,新增了一层 fc4
class ModifiedNet(nn.Module):def __init__(self):super(ModifiedNet, self).__init__()self.fc1 = nn.Linear(784, 128)self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 10)self.fc4 = nn.Linear(10, 5)  # 新增的全连接层def forward(self, x):x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = F.relu(self.fc3(x))x = self.fc4(x)return x# 实例化模型
modified_model = ModifiedNet()# 尝试加载之前保存的 state_dict,但忽略不匹配的层
modified_model.load_state_dict(torch.load('model_state_dict.pth'), strict=False)# 输出模型结构
print(modified_model)

输出

ModifiedNet((fc1): Linear(in_features=784, out_features=128, bias=True)(fc2): Linear(in_features=128, out_features=64, bias=True)(fc3): Linear(in_features=64, out_features=10, bias=True)(fc4): Linear(in_features=10, out_features=5, bias=True)
)

如果不设置 strict=False,将会报错,提示缺少 fc4 的权重:

RuntimeError: Error(s) in loading state_dict for ModifiedNet: Missing key(s) in state_dict: "fc4.weight", "fc4.bias". 

注意,减少层也可以使用 strict=False。例如,如果修改后的网络只保留前两层,仍然可以成功加载原始的 state_dict,并跳过缺失的部分。

方法三:保存完整的训练状态(checkpoint)

有时候,你可能不仅仅需要保存模型参数,还需要保存训练进度,比如当前的轮数、优化器状态等。此时可以使用检查点保存更多信息。

保存检查点

torch.save({'epoch': 100,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': 0.01,
}, 'checkpoint.pth')

加载检查点

checkpoint = torch.load('checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"Epoch: {epoch}, Loss: {loss}")

输出:

Epoch: 100, Loss: 0.01

这种方式适合长时间训练时,可以从中断的地方继续训练。但文件体积相比前面会更大,具体原因见《7. 探究模型参数与显存的关系以及不同精度造成的影响》,加载过程也稍微复杂一些,我们可以写一个函数来打包这个过程。

定义 checkpont 保存和加载的函数

def save_checkpoint(model, optimizer, epoch, loss, filepath='checkpoint.pth'):torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, filepath)def load_checkpoint(filepath, model, optimizer):checkpoint = torch.load(filepath)model.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])return checkpoint['epoch'], checkpoint['loss']# 保存
save_checkpoint(model, optimizer, 100, 0.01)# 加载
epoch, loss = load_checkpoint('checkpoint.pth', model, optimizer)
print(f"Loaded checkpoint at epoch {epoch} with loss {loss}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/55220.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新品:新一代全双工音频对讲模块SA618F22-C1

SA618F22-C1是我司一款升级版的无线数字和音频二合一全双工传输模块,支持8路并发高音质通话。用户不仅可以通过串口实现数据的无线传输,还可以通过I2S数字音频或模拟音频接口来传输语音信号。该模块内置高速微控制器、回声消除电路、ESD静电防护、高性能…

计算机网络各层有哪些协议?计算机网络协议解析:从拟定到实现,全面了解各层协议的作用与区别

在数字化时代,计算机网络无处不在,已经成为不可或缺的一部分。为了让不同设备能够有效地进行通信,网络协议作为一种约定和规则,确保了数据在网络中的可靠传输。今天,我们将深入探讨计算机网络的各层协议,详…

c#代码介绍23种设计模式_10组合模式

目录 1. 组合模式的定义 2. 组合模式的实现 3. 组合模式中涉及到三个角色 4. 组合模式的优缺点 5、实现思路 在软件开发过程中,我们经常会遇到处理简单对象和复合对象的情况,例如对操作系统中目录的处理就是这样的一个例子,因为目录可以…

四、Drf认证组件

四、Drf认证组件 4.1 快速使用 from django.shortcuts import render,HttpResponse from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.authentication import BaseAuthentication from rest_framework.exception…

【Linux】修改用户名用户家目录

0、锁定旧用户登录 如果旧用户olduser正在运行中是无法操作的,需要先禁用用户登录,然后杀掉所有此用户的进程。 1. 使用 usermod 命令禁用用户 这将锁定用户账户,使其无法登录: sudo usermod -L olduser2. 停止用户的进程 如…

【Python】FeinCMS:轻量级且可扩展的Django内容管理系统

在互联网飞速发展的今天,内容管理系统(CMS)成为了网站开发中的核心工具,尤其对于需要频繁更新内容的企业和个人站点而言,CMS 提供了极大的便利。市场上有许多不同的 CMS 工具可供选择,其中基于 Django 框架…

CentOS 6文件系统

由冯诺依曼在 1945 年提出的计算机五大组成部分:运算器,控制器,存储器,输入设 备,输出设备。 1. 硬盘结构: (1)机械硬盘结构: 磁盘拆解图: 扇区,…

前端BOM常用操作

BOM操作常用命令详解及代码案例 BOM(Browser Object Model)是浏览器对象模型,是浏览器提供的JavaScript操作浏览器的API。BOM提供了与网页无关的浏览器的功能对象,虽然没有正式的标准,但现代浏览器已经几乎实现了Java…

前端动态创建svg不起效果?

document.createElement(path);诸如此类的创建一般都是不太行的 我在创建这个之后,虽然在网页上是有相应的结构,但是完全不显示 一般正确的创建方式为 document.createElementNS(http://www.w3.org/2000/svg,path);在使用document.createElementNS(“ht…

【重学 MySQL】四十五、数据库的创建、修改与删除

【重学 MySQL】四十五、数据库的创建、修改与删除 一条数据存储的过程数据输入数据验证数据处理数据存储数据持久化反馈与日志注意事项 标识符命名规则基本规则长度限制保留字与特殊字符命名建议示例 MySQL 中的数据类型创建数据库创建数据库时指定字符集和排序规则 查看数据库…

影刀---实现我的第一个抓取数据的机器人

你们要的csdn自动回复机器人在这里文末哦! 这个上传的资源要vip下载,如果想了解影刀这个软件的话可以私聊我,我发你 目录 1.网页对象2.网页元素3.相似元素组4.元素操作设置下拉框复选框滚动条获取元素的信息 5.变量6.数据的表达字符串变量列…

CNN+Transformer解说

CNN(卷积神经网络)和Transformer是两种在深度学习领域广泛使用的模型架构,它们在处理不同类型的数据和任务时各有优势。 CNN擅长捕捉局部特征和空间层次结构,而Transformer擅长处理序列数据和长距离依赖关系。 将CNN与Transform…

解开 Golang‘for range’的神秘面纱:易错点剖析与解读

前言 在 Go 语言的编程世界中,充满了各种有趣的特性和挑战。其中,一些看似简单的代码结构可能会隐藏着意想不到的结果。今天,我们就来探讨一下在 Golang 中一个容易让人产生疑惑的地方——for range循环。相信很多 Go 开发者在日常编程中都会…

github项目--crawl4ai

github项目--crawl4ai 输出html输出markdown格式输出结构化数据与BeautifulSoup的对比 crawl4ai github上这个项目,没记错的话,昨天涨了3000多的star,今天又新增2000star。一款抓取和解析工具,简单写个demo感受下 这里我们使用cra…

另外知识与网络总结

一、重谈NAT(工作在网络层) 为什么会有NAT 为了解决ipv4地址太少问题,到了公网的末端就会有运营商路由器来构建私网,在不同私网中私有IP可以重复,这就可以缓解IP地址太少问题,但是这就导致私有IP是重复的…

车辆重识别(2021ICML改进的去噪扩散概率模型)论文阅读2024/9/29

所谓改进的去噪扩散概率模型主要改进在哪些方面: ①对数似然值的改进 通过对噪声的那个方差和T进行调参,来实现改进。 ②学习 这个参数也就是后验概率的方差。通过数据分析,发现在T非常大的情况下对样本质量几乎没有影响,也就是说…

酒店新科技,飞睿智能毫米波雷达人体存在感应器,智能照明创新节能新风尚

在这个日新月异的时代,科技正以未有的速度改变着我们的生活。从智能手机到智能家居,每一个细微之处都渗透着科技的魅力。而今,这股科技浪潮已经席卷到了酒店行业,为传统的住宿体验带来了翻天覆地的变化。其中,引人注目…

什么是托管安全信息和事件管理 SIEM?

什么是 SIEM? 安全信息和事件管理 ( SIEM ) 解决方案最初是一种集中式日志聚合解决方案。SIEM 解决方案会从整个组织网络中的系统收集日志数据,使组织能够从单一集中位置监控其网络。 随着时间的推移,SIEM解决方案已发展成为一个完整的威胁…

曲线图异常波形检测系统源码分享

曲线图异常波形检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comput…

(最新已验证)stm32 + 新版 onenet +dht11+esp8266/01s + mqtt物联网(含微信小程序)上报温湿度和控制单片机(保姆级教程)

物联网实践教程:微信小程序结合OneNET平台MQTT实现STM32单片机远程智能控制 远程上报和接收数据——汇总 前言 之前在学校获得了一个新玩意:ESP-01sWIFI模块,去搜了一下这个小东西很有玩点,远程控制LED啥的,然后我就想…