PyTorch模型的保存加载

一、引言

我们今天来看一下模型的保存与加载~

我们平时在神经网络的训练时间可能会很长,为了在每次使用模型时避免高代价的重复训练,我们就需要将模型序列化到磁盘中,使用的时候反序列化到内存中。

PyTorch提供了两种主要的方法来保存和加载模型,分别是直接序列化模型对象和存储模型的网络参数。

二、直接序列化模型对象

直接序列化模型对象:方法使用torch.save()函数将整个模型对象保存为一个文件,然后使用torch.load()函数将其加载回内存。这种方法可以方便地保存和加载整个模型,包括其结构、参数以及优化器等信息。

import torch
import torch.nn as nn
import pickleclass Model(nn.Module):def __init__(self, input_size, output_size):super(Model, self).__init__()self.linear1 = nn.Linear(input_size, input_size * 2)self.linear2 = nn.Linear(input_size * 2, output_size)def forward(self, inputs):inputs = self.linear1(inputs)output = self.linear2(inputs)return outputdef test01():model = Model(12, 2)# 第一个参数: 这是要保存的模型# 第二个参数: 这是模型保存的路径# 第三个参数: 指定了用于序列化和反序列化的模块# 第四个参数: 这是使用的pickle协议的版本,协议引入了二进制格式,提高了序列化数据的效率torch.save(model, 'model/test_model_save.pth', pickle_module=pickle, pickle_protocol=2)def test02():# 第一个参数: 加载的路径# 第二个参数: 模型加载的设备# 第三个参数: 加载的模块model = torch.load('model/test_model_save.pth', map_location='cpu', pickle_module=pickle)

在使用 torch.save() 保存模型时,需要注意一些关于 CPU 和 GPU 的问题,特别是在加载模型时需要注意 :

  1. 保存和加载设备一致性:

    • 当你在 GPU 上训练了一个模型,并使用 torch.save() 保存了该模型的状态字典(state_dict),然后尝试在一个没有 GPU 的环境中加载该模型时,会引发错误,因为 PyTorch 期望在相同的设备上执行操作。
    • 为了解决这个问题,你可以在没有 GPU 的机器上保存整个模型(而不是仅保存 state_dict),这样 PyTorch 会将权重数据移动到 CPU 上,并且在加载时不会引发错误。
  2. 移动模型到 CPU:

    • 如果你在 GPU 上保存了模型的 state_dict,并且想在 CPU 上加载它,你需要确保在加载 state_dict 之前将模型移动到 CPU。这可以通过调用模型的 to('cpu') 方法来实现。
  3. 移动模型到 GPU:

    • 如果你在 CPU 上保存了模型的 state_dict,并且想在 GPU 上加载它,你需要确保在加载 state_dict 之前将模型移动到 GPU。这可以通过调用模型的 to(device) 方法来实现,其中 device 是一个包含 CUDA 信息的对象(如果 GPU 可用)。
三、存储模型的网络参数
import torch
import torch.nn as nn
import torch.optim as optimclass Model(nn.Module):def __init__(self, input_size, output_size):super(Model, self).__init__()self.linear1 = nn.Linear(input_size, input_size * 2)self.linear2 = nn.Linear(input_size * 2, output_size)def forward(self, inputs):inputs = self.linear1(inputs)output = self.linear2(inputs)return outputdef test01():model = Model(12, 2)optimizer = optim.Adam(model.parameters(), lr=0.01)# 定义存储参数save_params = {'init_params': {'input_size': 12,'output_size': 2},'acc_score': 0.96,'avg_loss': 0.85,'iter_numbers': 100,'optim_params': optimizer.state_dict(),'model_params': model.state_dict()}# 存储模型参数torch.save(save_params, 'model/model_params.pth')def test02():# 加载模型参数model_params = torch.load('model/model_params.pth')# 初始化模型model = Model(model_params['init_params']['input_size'], model_params['init_params']['output_size'])# 初始化优化器optimizer = optim.Adam(model.parameters())optimizer.load_state_dict(model_params['optim_params'])# 显示其他参数print('迭代次数:', model_params['iter_numbers'])print('准确率:', model_params['acc_score'])print('平均损失:', model_params['avg_loss'])

训练完成后,通常需要保存模型的参数值,以便用于后续的测试过程。使用torch.save()函数来保存模型的状态字典(state_dict),这个状态字典包含了模型的可学习参数(权重和偏置值)

optimizer = optim.Adam(model.parameters(), lr=0.01)

  • 创建一个Adam优化器对象,在PyTorch中,优化器用于更新模型的参数以最小化损失函数。Adam是一种常用的优化算法,它结合了Momentum和RMSProp的优点,具有自适应学习率调整的特性。
  • model.parameters()表示要优化的模型参数,即模型中所有可学习的权重和偏置值。lr=0.01表示学习率(learning rate)为0.01,这是控制参数更新步长的重要超参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/8478.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

值模板参数Value Template Parameters

模板通常使用类型作为参数&#xff0c;但它们也可以使用值。使用类型和可选名称声明一个值模板参数&#xff0c;方式与声明函数参数类似。值模板参数仅限于可以指定编译时常量的类型是bool、char、int等&#xff0c;但不允许使用浮点类型、字符串字面值和类。 #include <io…

控制反转(IOC)和依赖注入(DI)

什么是IOC&#xff08;控制反转&#xff09;&#xff1f; IoC 的思想就是将原本在程序中手动创建对象的控制权&#xff0c;交由 Spring 框架来管理。 控制&#xff1a;指的是对象创建&#xff08;实例化、管理&#xff09;的权力 反转&#xff1a;控制权交给外部环境&#xf…

缓存雪崩、击穿、击穿

缓存雪崩&#xff1a; 就是大量数据在同一时间过期或者redis宕机时&#xff0c;这时候有大量的用户请求无法在redis中进行处理&#xff0c;而去直接访问数据库&#xff0c;从而导致数据库压力剧增&#xff0c;甚至有可能导致数据库宕机&#xff0c;从而引发的一些列连锁反应&a…

MATLAB 基于规则格网的点云抽稀方法(自定义实现)(65)

MATLAB 基于规则格网的点云抽稀方法(自定义实现)(65) 一、算法介绍二、算法实现1.代码2.结果一、算法介绍 海量点云的处理,需要提前进行抽稀预处理,相比MATLAB预先给出的抽稀方法,这里提供一种基于规则格网的自定义抽稀方法,步骤清晰,便于理解抽稀内涵, 主要涉及到使…

springboot整合rabbitmq的不同工作模式详解

前提是已经安装并启动了rabbitmq&#xff0c;并且项目已经引入rabbitmq&#xff0c;完成了配置。 不同模式所需参数不同&#xff0c;生产者可以根据参数不同使用重载的convertAndSend方法。而消费者均是直接监听某个队列。 不同的交换机是实现不同工作模式的关键组件.每种交换…

DCL 的学习

-- 创建用户 itcast , 只能够在当前主机localhost访问, 密码123456; create user itcastlocalhost identified by 123456; -- 创建用户 heima , 可以在任意主机访问该数据库, 密码123456 ; create user heima% identified by 123456; -- 修改用户 heima 的访问密码为 1234 ; a…

赶紧收藏!2024 年最常见 100道 Java 基础面试题(三十五)

上一篇地址&#xff1a;赶紧收藏&#xff01;2024 年最常见 100道 Java 基础面试题&#xff08;三十四&#xff09;-CSDN博客 六十九、spring mvc和struts的区别是什么&#xff1f; Spring MVC和Struts都是Java EE&#xff08;Java Enterprise Edition&#xff09;中流行的MV…

三层交换机与防火墙连通上网实验

防火墙是一种网络安全设备&#xff0c;用于监控和控制网络流量。它可以帮助防止未经授权的访问&#xff0c;保护网络免受攻击和恶意软件感染。防火墙可以根据预定义的规则过滤流量&#xff0c;例如允许或阻止特定IP地址或端口的流量。它也可以检测和阻止恶意软件、病毒和其他威…

20240508日记

今天工作内容&#xff1a; 1.二号机S3点位焊接测试&#xff0c;调整位置精度。 2.一号机送针位置调整 3.自定义焊接功能测试 4.EAP服务启动测试 明日计划&#xff1a; 1.EAP流程修改功能开发 1.1 Read Barcode Complete 事件&#xff0c;上传料盘码和设备ID&#xff0c;等EA…

SlowFast报错:ValueError: too many values to unpack (expected 4)

SlowFast报错&#xff1a;ValueError: too many values to unpack (expected 4) 报错细节 File "/home/user/yuanjinmin/SlowFast/tools/visualization.py", line 81, in run_visualizationfor inputs, labels, _, meta in tqdm.tqdm(vis_loader): ValueError: too …

牛客题-链表内区间反转

链表内区间反转 这是代码 typedef struct ListNode listnode; struct ListNode* reverseBetween(struct ListNode* head, int m, int n ) {if (head NULL) {return NULL;}listnode* findhead head;listnode* findtail head;listnode* prev NULL;int count1 m;int count2…

nginx的rewrite重定向

rewrite地址重定向&#xff0c;实现URL重定向的重要指令&#xff0c;它根据regex&#xff08;正则表达式&#xff09;来匹配内容跳转 语法&#xff1a;rewrite regex replacement[flag] rewrite ^/(.*) https://www.baidu.com/$1 permanent; # 这是一个正则表达式&#xff0c;匹…

pdf 文件版面分析--pdfplumber (python 文档解析提取)

pdfplumber 的特点 1、它是一个纯 python 第三方库&#xff0c;适合 python 3.x 版本 2、它用来查看pdf各类信息&#xff0c;能有效提取文本、表格 3、它不支持修改或生成pdf&#xff0c;也不支持对pdf扫描件的处理 import glob import pdfplumber import re from collection…

[前后端基础]图片传输与异步

前后端之间传递照片 在前后端之间传递照片&#xff0c;通常可以采用以下几种方式&#xff1a; Base64 编码传输&#xff1a;将图片转换为 Base64 编码的字符串&#xff0c;然后通过接口传递到后端&#xff0c;后端再将 Base64 字符串转换回图片格式。这种方式简单易行&#xff…

OpenCV 入门(二)—— 车牌定位

OpenCV 入门系列&#xff1a; OpenCV 入门&#xff08;一&#xff09;—— OpenCV 基础 OpenCV 入门&#xff08;二&#xff09;—— 车牌定位 OpenCV 入门&#xff08;三&#xff09;—— 车牌筛选 OpenCV 入门&#xff08;四&#xff09;—— 车牌号识别 OpenCV 入门&#xf…

C#面:简要谈对微软.NET 构架下 remoting 和 webservice 两项技术的理解以及实际中的应用

在微软 .NET 框架下&#xff0c;Remoting 和 WebService 是两种常用的技术&#xff0c;用于实现分布式应用程序的通信和交互。 Remoting&#xff08;远程调用&#xff09;&#xff1a; Remoting是一种用于在不同应用程序域之间进行通信的技术。它允许对象在不同的进程或计算机…

十分钟掌握Java集合之List接口

哈喽&#xff0c;各位小伙伴们&#xff0c;你们好呀&#xff0c;我是喵手。运营社区&#xff1a;C站/掘金/腾讯云&#xff1b;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点&#xff0c;并以文字的形式跟大家一起交流&#xff0c;互相学习&#xff0c;一…

培养逻辑思考力的7大基本方法笔记

《逻辑思考力》一书中介绍的培养逻辑思考力的七大基本方法如下&#xff0c;每种方法都旨在帮助人们更有效地分析问题、制定策略和做出决策&#xff1a; 1. 使解决问题的过程透明化 解释&#xff1a;这种方法强调的是清晰地界定问题解决的步骤&#xff0c;确保每一步都可追踪和…

C++反射之检测struct或class是否实现指定函数

目录 1.引言 2.检测结构体或类的静态函数 3.检测结构体或类的成员函数 3.1.方法1 3.2.方法2 1.引言 诸如Java, C#这些语言是设计的时候就有反射支持的。c没有原生的反射支持。并且&#xff0c;c提供给我们的运行时类型信息非常少&#xff0c;只是通过typeinfo提供了有限的…

微信小程序开发秘籍:揭秘基础库版本与客户端版本的不解之缘【代码示例】

微信小程序开发秘籍&#xff1a;揭秘基础库版本与客户端版本的不解之缘【代码示例】 基础概念&#xff1a;何为基础库&#xff1f;何为客户端&#xff1f;基础库&#xff08;Weixin Mini Program Base Library&#xff09;客户端&#xff08;WeChat Client&#xff09; 版本关系…