PyTorch模型的保存加载

一、引言

我们今天来看一下模型的保存与加载~

我们平时在神经网络的训练时间可能会很长,为了在每次使用模型时避免高代价的重复训练,我们就需要将模型序列化到磁盘中,使用的时候反序列化到内存中。

PyTorch提供了两种主要的方法来保存和加载模型,分别是直接序列化模型对象和存储模型的网络参数。

二、直接序列化模型对象

直接序列化模型对象:方法使用torch.save()函数将整个模型对象保存为一个文件,然后使用torch.load()函数将其加载回内存。这种方法可以方便地保存和加载整个模型,包括其结构、参数以及优化器等信息。

import torch
import torch.nn as nn
import pickleclass Model(nn.Module):def __init__(self, input_size, output_size):super(Model, self).__init__()self.linear1 = nn.Linear(input_size, input_size * 2)self.linear2 = nn.Linear(input_size * 2, output_size)def forward(self, inputs):inputs = self.linear1(inputs)output = self.linear2(inputs)return outputdef test01():model = Model(12, 2)# 第一个参数: 这是要保存的模型# 第二个参数: 这是模型保存的路径# 第三个参数: 指定了用于序列化和反序列化的模块# 第四个参数: 这是使用的pickle协议的版本,协议引入了二进制格式,提高了序列化数据的效率torch.save(model, 'model/test_model_save.pth', pickle_module=pickle, pickle_protocol=2)def test02():# 第一个参数: 加载的路径# 第二个参数: 模型加载的设备# 第三个参数: 加载的模块model = torch.load('model/test_model_save.pth', map_location='cpu', pickle_module=pickle)

在使用 torch.save() 保存模型时,需要注意一些关于 CPU 和 GPU 的问题,特别是在加载模型时需要注意 :

  1. 保存和加载设备一致性:

    • 当你在 GPU 上训练了一个模型,并使用 torch.save() 保存了该模型的状态字典(state_dict),然后尝试在一个没有 GPU 的环境中加载该模型时,会引发错误,因为 PyTorch 期望在相同的设备上执行操作。
    • 为了解决这个问题,你可以在没有 GPU 的机器上保存整个模型(而不是仅保存 state_dict),这样 PyTorch 会将权重数据移动到 CPU 上,并且在加载时不会引发错误。
  2. 移动模型到 CPU:

    • 如果你在 GPU 上保存了模型的 state_dict,并且想在 CPU 上加载它,你需要确保在加载 state_dict 之前将模型移动到 CPU。这可以通过调用模型的 to('cpu') 方法来实现。
  3. 移动模型到 GPU:

    • 如果你在 CPU 上保存了模型的 state_dict,并且想在 GPU 上加载它,你需要确保在加载 state_dict 之前将模型移动到 GPU。这可以通过调用模型的 to(device) 方法来实现,其中 device 是一个包含 CUDA 信息的对象(如果 GPU 可用)。
三、存储模型的网络参数
import torch
import torch.nn as nn
import torch.optim as optimclass Model(nn.Module):def __init__(self, input_size, output_size):super(Model, self).__init__()self.linear1 = nn.Linear(input_size, input_size * 2)self.linear2 = nn.Linear(input_size * 2, output_size)def forward(self, inputs):inputs = self.linear1(inputs)output = self.linear2(inputs)return outputdef test01():model = Model(12, 2)optimizer = optim.Adam(model.parameters(), lr=0.01)# 定义存储参数save_params = {'init_params': {'input_size': 12,'output_size': 2},'acc_score': 0.96,'avg_loss': 0.85,'iter_numbers': 100,'optim_params': optimizer.state_dict(),'model_params': model.state_dict()}# 存储模型参数torch.save(save_params, 'model/model_params.pth')def test02():# 加载模型参数model_params = torch.load('model/model_params.pth')# 初始化模型model = Model(model_params['init_params']['input_size'], model_params['init_params']['output_size'])# 初始化优化器optimizer = optim.Adam(model.parameters())optimizer.load_state_dict(model_params['optim_params'])# 显示其他参数print('迭代次数:', model_params['iter_numbers'])print('准确率:', model_params['acc_score'])print('平均损失:', model_params['avg_loss'])

训练完成后,通常需要保存模型的参数值,以便用于后续的测试过程。使用torch.save()函数来保存模型的状态字典(state_dict),这个状态字典包含了模型的可学习参数(权重和偏置值)

optimizer = optim.Adam(model.parameters(), lr=0.01)

  • 创建一个Adam优化器对象,在PyTorch中,优化器用于更新模型的参数以最小化损失函数。Adam是一种常用的优化算法,它结合了Momentum和RMSProp的优点,具有自适应学习率调整的特性。
  • model.parameters()表示要优化的模型参数,即模型中所有可学习的权重和偏置值。lr=0.01表示学习率(learning rate)为0.01,这是控制参数更新步长的重要超参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/8478.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

缓存雪崩、击穿、击穿

缓存雪崩: 就是大量数据在同一时间过期或者redis宕机时,这时候有大量的用户请求无法在redis中进行处理,而去直接访问数据库,从而导致数据库压力剧增,甚至有可能导致数据库宕机,从而引发的一些列连锁反应&a…

MATLAB 基于规则格网的点云抽稀方法(自定义实现)(65)

MATLAB 基于规则格网的点云抽稀方法(自定义实现)(65) 一、算法介绍二、算法实现1.代码2.结果一、算法介绍 海量点云的处理,需要提前进行抽稀预处理,相比MATLAB预先给出的抽稀方法,这里提供一种基于规则格网的自定义抽稀方法,步骤清晰,便于理解抽稀内涵, 主要涉及到使…

springboot整合rabbitmq的不同工作模式详解

前提是已经安装并启动了rabbitmq,并且项目已经引入rabbitmq,完成了配置。 不同模式所需参数不同,生产者可以根据参数不同使用重载的convertAndSend方法。而消费者均是直接监听某个队列。 不同的交换机是实现不同工作模式的关键组件.每种交换…

三层交换机与防火墙连通上网实验

防火墙是一种网络安全设备,用于监控和控制网络流量。它可以帮助防止未经授权的访问,保护网络免受攻击和恶意软件感染。防火墙可以根据预定义的规则过滤流量,例如允许或阻止特定IP地址或端口的流量。它也可以检测和阻止恶意软件、病毒和其他威…

SlowFast报错:ValueError: too many values to unpack (expected 4)

SlowFast报错:ValueError: too many values to unpack (expected 4) 报错细节 File "/home/user/yuanjinmin/SlowFast/tools/visualization.py", line 81, in run_visualizationfor inputs, labels, _, meta in tqdm.tqdm(vis_loader): ValueError: too …

牛客题-链表内区间反转

链表内区间反转 这是代码 typedef struct ListNode listnode; struct ListNode* reverseBetween(struct ListNode* head, int m, int n ) {if (head NULL) {return NULL;}listnode* findhead head;listnode* findtail head;listnode* prev NULL;int count1 m;int count2…

OpenCV 入门(二)—— 车牌定位

OpenCV 入门系列: OpenCV 入门(一)—— OpenCV 基础 OpenCV 入门(二)—— 车牌定位 OpenCV 入门(三)—— 车牌筛选 OpenCV 入门(四)—— 车牌号识别 OpenCV 入门&#xf…

十分钟掌握Java集合之List接口

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一…

C++反射之检测struct或class是否实现指定函数

目录 1.引言 2.检测结构体或类的静态函数 3.检测结构体或类的成员函数 3.1.方法1 3.2.方法2 1.引言 诸如Java, C#这些语言是设计的时候就有反射支持的。c没有原生的反射支持。并且,c提供给我们的运行时类型信息非常少,只是通过typeinfo提供了有限的…

leetcode刷题(5): STL的使用

文章目录 56. 合并区间解题思路c实现 55. 跳跃游戏解题思路c 实现 75. 颜色分类解题思路c 实现 36 下一个排列解题思路c 实现 56. 合并区间 题目: 以数组 intervals 表示若干个区间的集合,其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间&a…

Linux(openEuler、CentOS8)企业内网samba服务器搭建(Windows与Linux文件共享方案)

本实验环境为openEuler系统<以server方式安装>&#xff08;CentOS8基本一致&#xff0c;可参考本文) 目录 知识点实验1. 安装samba2. 启动smb服务并设置开机启动3. 查看服务器监听状态4. 配置共享访问用户5. 创建共享文件夹6. 修改配置文件7. 配置防火墙8. 使用windows…

Hypack 2024 简体中文资源完整翻译汉化已经全部完成

Hypack 2024 简体中文资源完整翻译汉化已经全部完成 Hypack 2024&#xff0c;资源汉化共翻译11065条。毕竟涉及测绘、水文、疏浚等专业术语太多&#xff0c;翻译有很多理解不正确的地方&#xff0c;望各位专业人员指正。 压缩包内包含Hypack 2024、Hypack 2022、Hypack 2021、…

Autoxjs 实践-Spring Boot 集成 WebSocket

概述 最近弄了福袋工具&#xff0c;由于工具运行中&#xff0c;不好查看福袋结果&#xff0c;所以我想将福袋工具运行数据返回到后台&#xff0c;做数据统计、之后工具会越来越多&#xff0c;就弄了个后台&#xff0c;方便管理。 实现效果 WebSocket&#xff1f; websocket是…

Qt应用开发(拓展篇)——图表 QChart

一、前言 QChart是一个图形库模块&#xff0c;它可以实现不同类型的序列和其他图表相关对象(如图例和轴)的图形表示。要在布局中简单地显示图表&#xff0c;可以使用QChartView来代替QChart。此外&#xff0c;线条、样条、面积和散点序列可以通过使用QPolarChart类表示为极坐标…

Python的奇妙之旅——回顾其历史

我们这个神奇的宇宙里&#xff0c;有一个名叫Python的小家伙&#xff0c;它不仅聪明&#xff0c;而且充满活力。它一路走来&#xff0c;从一个小小的编程语言成长为如今全球最受欢迎的编程语言之一。今天&#xff0c;我们就来回顾一下Python的历史&#xff0c;看看它如何从一个…

字体设计_西文字体设计(英文字体设计)

一 西文字体设计基础知识 设计目标和历史成因 设计目标&#xff1a;让眼睛看着舒服的字体 那什么样的字体让眼睛看着舒服呢&#xff1f; 让眼睛看着舒服的字体造型其实是我们记忆里的手写体、自然造型。 所以就能理解西文字体为什么同一笔画&#xff0c;有的地方粗有的地方…

DDPM与扩散模型

很早之前就新建了一个专栏从0开始弃坑扩散模型 ,但发了一篇文章就没有继续这一系列&#xff0c;在这个AIGC的时代&#xff0c;于是我准备重启这个专栏。 整个专栏的学习顺序可以见这篇汇总文章 这是本专栏的第一章 目录 引言生成模型的发展历程 引言 扩散模型( Diffusion Mode…

52页 | 2024大型语言模型行业图谱研究报告(免费下载)

【1】关注本公众号&#xff0c;转发当前文章到微信朋友圈 【2】私信发送 【2024大型语言模型行业图谱研究报告】 【3】获取本方案PDF下载链接&#xff0c;直接下载即可。 如需下载本方案PPT原格式&#xff0c;请加入微信扫描以下方案驿站知识星球&#xff0c;获取上万份PPT解…

STC8增强型单片机开发——库函数

一、使用库函数点灯 导入库函数。 下载STC8H的库函数&#xff1a;&#x1f4ce;STC8G-STC8H-LIB-DEMO-CODE_2023.07.17_优化版.zip 来到库函数的目录下&#xff0c;拷贝以下文件&#xff1a; Config.hType_def.hGPIO.hGPIO.c 新建项目&#xff0c;将拷贝的4个文件放到项目目录…

WEB基础--JDBC操作数据库

使用JDBC操作数据库 使用JDBC查询数据 五部曲&#xff1a;建立驱动&#xff0c;建立连接&#xff0c;获取SQL语句&#xff0c;执行SQL语句&#xff0c;释放资源 建立驱动 //1.加载驱动Class.forName("com.mysql.cj.jdbc.Driver"); 建立连接 //2.连接数据库 Stri…