pytorch之torch.save()和torch.load()方法详细说明

        torch.save()和torch.load()是PyTorch中用于模型保存和加载的函数。它们提供了一种方便的方式来保存和恢复模型的状态、结构和参数。可以使用它们来保存和加载整个模型或其他任意的Python对象,并且可以在加载模型时指定目标设备。

1.语法介绍

1.1 torch.save()语法

        torch.save()函数用于将PyTorch模型保存到磁盘上的文件中,以便以后可以重新加载和使用。它的基本语法如下:

torch.save(obj, f, pickle_module=<module 'pickle' from '...'>, pickle_protocol=2)

        参数说明:

                obj是要保存的对象,通常是一个模型的状态字典(state_dict())。

                f是文件的路径或文件对象,用于存储模型。

                pickle_module是用于序列化的Python模块,默认为pickle。

                pickle_protocol是序列化时使用的协议版本,默认为2。

1.2 torch.load()语法

        torch.load()函数用于从磁盘上的文件加载保存的模型。它的基本语法如下:

torch.load(f, map_location=None, pickle_module=<module 'pickle' from '...'>)

        参数说明:

                f是要加载的文件的路径或文件对象。

                map_location用于指定加载模型的设备(CPU或特定的GPU设备)。默认情况下,加载的模型将被存储在与保存模型时相同的设备上。

                pickle_module是用于反序列化的Python模块,默认为pickle。

2. 基本使用示例介绍

2.1 保存和加载整个模型

        除了保存和加载模型的状态字典外,torch.save()和torch.load()还可以用于保存和加载整个模型,包括模型的结构、参数和其他相关信息。

        要保存整个模型,使用以下代码:

torch.save(model, 'model.pth')

        要加载整个模型,使用以下代码: 

model = torch.load('model.pth')

        注意,加载整个模型时,需要确保模型的定义代码可用,因为它将用于重新创建模型的结构。

2.2 保存和加载其他对象

        torch.save()和torch.load()不仅限于保存和加载模型,还可以用于保存和加载其他任意的Python对象。只需将要保存的对象传递给torch.save(),然后使用torch.load()来加载该对象。

        例如:

data = [1, 2, 3, 4, 5]
torch.save(data, 'data.pth')loaded_data = torch.load('data.pth')

        这样可以方便地保存和加载各种数据,如训练集、测试集、预处理数据等。 

 2.3 跨设备加载模型

        torch.load()函数允许在加载模型时指定目标设备。通过使用map_location参数,可以将模型加载到不同的设备上,例如从GPU加载到CPU或从一种GPU加载到另一种GPU。

        以下是一个示例:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 从GPU加载到CPU
model = torch.load('model.pth', map_location='cpu')# 从一种GPU加载到另一种GPU
model = torch.load('model.pth', map_location='cuda:1')

         这对于在不同设备上运行模型或在没有GPU的机器上加载训练好的GPU模型非常有用。

2.4 序列化兼容性

        torch.save()使用Python的pickle模块进行序列化,默认使用协议版本2。这个默认版本在PyTorch 1.6及更高版本中是兼容的。如果您需要与旧版本的PyTorch或其他Python库进行兼容,您可以通过设置pickle_protocol参数来选择不同的协议版本。 

torch.save(model.state_dict(), 'model.pth', pickle_protocol=4)

        在选择协议版本时,需要权衡序列化的性能和兼容性。 

3. 模型保存和加载

        当涉及到模型保存和加载时,还有一些其他的注意事项和用法:

3.1 保存和加载模型的状态字典

        通常情况下,我们只保存和加载模型的状态字典(state_dict()),而不是整个模型。状态字典包含了模型的参数和缓冲区(如权重和偏置),但不包括模型的结构。这种做法更加灵活,因为它允许在加载模型时自由选择模型的结构,并且可以与不同的模型架构进行兼容。

#保存模型的状态字典:
torch.save(model.state_dict(), 'model.pth')#加载模型的状态字典:
model = MyModel()
model.load_state_dict(torch.load('model.pth'))

        请确保在加载模型之前,模型的定义与保存时的模型结构相匹配。 

3.2 冻结某些层或参数

        在某些情况下,可能希望冻结模型的某些层或参数,即在加载模型后不更新它们的参数。可以通过设置参数的requires_grad属性来实现这一点。

        例如,假设模型有一个名为fc的全连接层,您可以冻结该层的参数:

model = MyModel()
model.load_state_dict(torch.load('model.pth'))# 冻结全连接层的参数
for param in model.fc.parameters():
param.requires_grad = False

3.3 多个模型的保存和加载

        如果您需要保存和加载多个模型,您可以将它们保存为一个字典,并使用一个文件来存储整个字典。

        保存多个模型:

state = {'model1': model1.state_dict(),'model2': model2.state_dict()
}
torch.save(state, 'models.pth')

        加载多个模型: 

state = torch.load('models.pth')
model1.load_state_dict(state['model1'])
model2.load_state_dict(state['model2'])

         这种方法可以方便地保存和加载多个相关模型。

 3.3 保存和加载检查点

        在训练过程中,可以定期保存模型的检查点,以便在训练过程中发生意外情况时能够恢复模型。通过定期保存检查点,可以避免从头开始训练,并从最新的检查点继续训练。

# 训练循环中的保存检查点
if epoch % checkpoint_interval == 0:torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,}, f'checkpoint_{epoch}.pth')

        在发生中断或需要恢复训练时,可以加载最新的检查点: 

# 加载最新的检查点
latest_checkpoint = get_latest_checkpoint(checkpoint_dir)
checkpoint = torch.load(latest_checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

        这样,可以从最新的检查点恢复训练。 

 

 

 

 

 

 

 

 

 

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/761567.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uni-app开发---4.首页

一、创建 home 分支 运行如下的命令&#xff0c;基于 master 分支在本地创建 home 子分支&#xff0c;用来开发和 home 首页相关的功能&#xff1a; git checkout -b home 二、配置网络请求 由于平台的限制&#xff0c;小程序项目中不支持 axios&#xff0c;而且原生…

ARM-UART实验

串口控制三盏灯亮灭 视频 串口实验验证.mp4 代码 uart4.c #include "uart4.h"void uart4_init() {//使能GPIOB GPIOG UART4外设时钟RCC->MP_AHB4ENSETR | (0x1<<1);//GPIOBRCC->MP_AHB4ENSETR | (0x1<<6);//GPIOGRCC->MP_APB1ENSETR | (0X…

鸿蒙一次开发,多端部署(三)应用UX设计原则

设计原则 当为多种不同的设备开发应用时&#xff0c;有如下设计原则&#xff1a; 差异性 充分了解所要支持的设备&#xff0c;包括屏幕尺寸、交互方式、使用场景、用户人群等&#xff0c;对设备的特性进行针对性的设计。 一致性 除了要考虑每个设备的特性外&#xff0c;还…

SOCKS5代理、代理IP、HTTP与网络安全的融合之旅

在数字化世界的无边网络海洋中&#xff0c;数据以难以想象的速度流动&#xff0c;连接着世界的每一个角落。作为一名软件工程师&#xff0c;深入理解网络通信的基石——SOCKS5代理、代理IP、HTTP协议&#xff0c;并掌握这些技术在网络安全中的应用&#xff0c;是航行于这片海洋…

C# 读取二维数组集合输出到Word预设表格

目录 应用场景 设计约定 范例运行环境 配置Office DCOM 实现代码 组件库引入 核心代码 DataSet转二维数组 导出写入WORD表格 调用举例 小结 应用场景 存储或导出个人WORD版简历是招聘应用系统中的常用功能&#xff0c;我们通常会通过应用系统采集用户的个人简历信息…

android recyclerview 总结

面试官问我熟不熟 recyclerview&#xff0c;我说不熟 他就没再继续问&#xff0c;整个过程还是比较丝滑的 呵呵&#xff1f;&#xff1f;这么一个基础控件&#xff0c;你居然敢说不熟&#xff0c;真没想到 1 recyclerview相比listview的区别 1.1 ViewHolder 的编写规范化了 …

云主机搭建与服务软件部署

文章目录 登录访问云电脑与云电脑传输文件配置ssh服务ssh连接云电脑使用scp传输文件云端服务软件部署与实现外部访问首先购买云主机,以阿里云服务器 ECS为例子,官网购买就行了,选择默认安装了windows server 2022服务器系统 登录访问云电脑 购买完成进入控制台,能看到创建…

蓝桥杯第十三届蓝桥杯大赛软件赛决赛CC++ 研究生组之选素数

蓝桥杯第十三届蓝桥杯大赛软件赛决赛C/C 研究生组之选素数 [题目传送门](0选素数 - 蓝桥云课 (lanqiao.cn)) 问题大意&#xff1a; 小蓝有一个数字&#xff0c;要进行如下操作&#xff1a; 首先选出一个小于x 的质数p&#xff0c;然后将x变成要比原本大的最小的为p的倍数的…

js判断字符串是否为JSON格式

使用场景&#xff1a;在有些项目中我们会将用户输入的JSON字符串转化为对象形式并展示出来&#xff0c;那么首先我们就要判断一个字符串是否为一个合法的JSON字符串。 代码如下&#xff1a; isJSON(str) { if (typeof str string) { try { let obj …

使用CUDA 为Tegra构建OpenCV

返回&#xff1a;OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;MultiArch与Ubuntu/Debian 的交叉编译 下一篇&#xff1a;在iOS中安装 警告&#xff1a; 本教程可能包含过时的信息。 使用CUDA for Tegra 的OpenCV 本文档是构建支持 CUD…

谷歌具身智能最新进展:RT-H 机器人通用灵巧抓取

随着 GPT-4 等大型语言模型与机器人研究的结合愈发紧密&#xff0c;人工智能正在越来越多地走向现实世界&#xff0c;因此具身智能相关的研究也正受到越来越多的关注。在众多研究项目中&#xff0c;谷歌的「RT」系列机器人始终走在前沿&#xff08;参见《大模型正在重构机器人&…

各位老板,你需要的工厂数字孪生可视化库在这

各位老板是不是很喜欢下面这种有逼格的大屏,下面介绍一下怎么实现的,保证有所收获。 Cesium是一个开源的WebGL JavaScript库&#xff0c;用于创建高性能的三维地球、地图和虚拟环境。它支持在浏览器中实现高质量的地球模拟&#xff0c;同时提供了丰富的功能特点&#xff0c;使得…

Superset二次开发之PostgreSQL 统计信息介绍

pg_stat_user_tables 视图提供了关于 PostgreSQL 数据库中用户定义表的统计信息。这些统计信息涵盖了从表的扫描操作到修改次数等多个方面。 以下是 pg_stat_user_tables 中所有字段的含义&#xff1a; relid: 表的 OID&#xff08;对象标识符&#xff09;。这是表在系统中的…

k8s系列之十五 Istio 部署Bookinfo 应用

Bookinfo 应用中的几个微服务是由不同的语言编写的。 这些服务对 Istio 并无依赖&#xff0c;但是构成了一个有代表性的服务网格的例子&#xff1a;它由多个服务、多个语言构成&#xff0c;并且 reviews 服务具有多个版本。 该应用由四个单独的微服务构成。 这个应用模仿在线书…

模板高级使用(非类型模板参数,特化,分离编译)

文章目录 模板没有实例化取内嵌类型报错问题非类型模板参数模板的特化函数模板的特化类模板的特化1.全特化2.偏特化 模板的分离编译 模板没有实例化取内嵌类型报错问题 首先在这里分享一个模板的常见报错问题。就是模板的在没有实例化的情况下去取模板类里面的内嵌类型这时候的…

代码随想录|Day25|回溯05|491.非递减子序列、46.全排列、47.全排列II

491. 非递减子序列 本题并不能像 90.子集II 那样&#xff0c;使用排序进行树层去重。虽然题目没有明确不能排序&#xff0c;但如果排序了&#xff0c;集合本身就是递增子序列&#xff0c;这是LeetCode示例2中没有出现的。 所以本题的关键在于&#xff0c;如何在不排序的情况下对…

引入spring security 403问题

禁用 csrf 即可 httpSecurity.csrf(csrf -> csrf.disable()); Configuration public class SecurityConfig {/*** 认证管理器** param authenticationConfiguration* return* throws Exception*/Beanpublic AuthenticationManager authenticationManager(AuthenticationConf…

请解释 VB.NET 中的多态性(Polymorphism)以及如何实现它。

请解释 VB.NET 中的多态性&#xff08;Polymorphism&#xff09;以及如何实现它。 多态性&#xff08;Polymorphism&#xff09;是面向对象编程中的一个重要概念&#xff0c;它允许不同的对象对同一个消息作出不同的响应。在VB.NET中&#xff0c;多态性通过继承和方法重写来实…

2024格行VS华为VS飞猫哪个是最值得购买随身WiFi?中兴随身WiFi好用吗?

经常出差旅行&#xff0c;或者户外工作的朋友因为长期在外&#xff0c;手机流量经常不够用&#xff0c;想必都是随身WiFi的忠实用户&#xff0c;但是也都被这款产品割韭菜割的头皮发麻。今天&#xff0c;我们统计了市面上最靠谱的、最热销、口碑最好的几款随身WiFi。排名依据来…

Java学习笔记(17)

集合进阶 单列集合 Collection List set Add clear remove contains isempty size Add方法可能也会添加失败 同理&#xff0c;可能删除失败 Contains细节 为什么要重写equals&#xff1f; 因为contains底层用的是object类中的equals方法&#xff0c;比较的是地址值&#xf…