PyTorch中保存模型的两种方式

文章目录

  • 一、状态字典(State Dictionary)
  • 二、序列化模型(Serialized Model)
  • 三、示例代码


一、状态字典(State Dictionary)

这种保存形式将模型的参数保存为一个字典,其中包含了所有模型的权重和偏置等参数。状态字典保存了模型在训练过程中学到的参数值,而不包含模型的结构。可以使用这个字典来加载模型的参数,并将其应用于相同结构的模型。
在 PyTorch 中,您可以使用 torch.save() 函数将模型的状态字典保存到文件中,例如:

torch.save(model.state_dict(), 'model.pth')

然后,可以使用 torch.load() 函数加载状态字典并将其应用于相同结构的模型:

model = MyModel()  # 创建模型对象
model.load_state_dict(torch.load('model.pth'))

这种保存形式非常适用于仅保存和加载模型的参数,而不需要保存和加载模型的结构。

二、序列化模型(Serialized Model)

这种保存形式将整个模型(包括模型的结构、参数等)保存为一个文件。序列化模型保存了模型的完整信息,可以完全恢复模型的状态,包括模型的结构、权重、偏置以及其他相关参数。
在 PyTorch 中,您可以使用 torch.save() 函数直接保存整个模型对象,例如:

torch.save(model, 'model.pth')

然后,您可以使用 torch.load() 函数加载整个序列化模型:

model = torch.load('model.pth')

这种保存形式适用于需要保存和加载完整模型信息的情况,包括模型的结构和参数。

三、示例代码

import torchclass LinearNet(torch.nn.Module):def __init__(self, input_size, output_size):super().__init__()self.net = torch.nn.Sequential(torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),torch.nn.Sigmoid(),torch.nn.Linear(in_features= 5, out_features=5, bias=True),torch.nn.Sigmoid(),torch.nn.Linear(in_features=5, out_features=output_size, bias=True))def forward(self,x):return self.net(x)square_net = LinearNet(1,1)# square_net.load_state_dict(torch.load('weight.pth'))  #直接加载已经训练好的权重if __name__ == '__main__':# print(square_net(torch.tensor([3.16],dtype=torch.float32)))# save 方式1torch.save(square_net.state_dict(), "./w1.pth")my_state_dict = torch.load("./w1.pth")print("纯state_dict:\n", my_state_dict)print("type:", type(my_state_dict))# save 方式2torch.save(square_net, "./w2.pth")my_state_dict = torch.load("./w2.pth")print("\n\n模型结构:\n", my_state_dict)print("type:", type(my_state_dict))# 执行结果'''纯state_dict:OrderedDict([('net.0.weight', tensor([[ 0.0820],[-0.6923],[ 0.5066],[-0.8931],[ 0.0460]])), ('net.0.bias', tensor([ 0.1455,  0.5106,  0.2347,  0.4903, -0.6838])), ('net.2.weight', tensor([[-0.4055, -0.2721,  0.3770, -0.2285,  0.3025],[-0.0416,  0.0133, -0.3834, -0.2151,  0.1454],[ 0.0749, -0.3664, -0.1901, -0.2829,  0.3957],[-0.3567,  0.2668,  0.3343, -0.3351, -0.3808],[ 0.4375,  0.1000,  0.1185,  0.2295, -0.3997]])), ('net.2.bias', tensor([-0.2405, -0.2751,  0.1928,  0.3970, -0.0005])), ('net.4.weight', tensor([[-0.4388, -0.2654,  0.3038,  0.2008,  0.0381]])), ('net.4.bias', tensor([0.1847]))])模型结构:LinearNet((net): Sequential((0): Linear(in_features=1, out_features=5, bias=True)(1): Sigmoid()(2): Linear(in_features=5, out_features=5, bias=True)(3): Sigmoid()(4): Linear(in_features=5, out_features=1, bias=True)))'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/698365.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flink中的双流Join

1. Flink中双流Join介绍 Flink版本Join支持类型Join API1.4innerTable/SQL1.5inner,left,right,fullTable/SQL1.6inner,left,right,fullTable/SQL/DataStream Join大体分为两种:Window Join 和 Interval Join 两种。 Window Join又可以根据Window的类型细分为3种…

alist修改密码(docker版)

rootarmbian:~# docker exec -it [docker名称] ./alist admin set abcd123456 INFO[2024-02-20 11:06:29] reading config file: data/config.json INFO[2024-02-20 11:06:29] load config from env with prefix: ALIST_ INFO[2024-02-20 11:06:29] init logrus..…

Redis篇之缓存雪崩、击穿、穿透详解

学习材料:https://xiaolincoding.com/redis/cluster/cache_problem.html 缓存雪崩 什么是缓存雪崩 在面对业务量较大的查询场景时,会把数据库中的数据缓存至redis中,避免大量的读写请求同时访问mysql客户端导致系统崩溃。这种情况下&#x…

链表之“无头单向非循环链表”

目录 ​编辑 1.顺序表的问题及思考 2.链表 2.1链表的概念及结构 2.2无头单向非循环链表的实现 1.创建结构体 2.单链表打印 3.动态申请一个节点 3.单链表尾插 4.单链表头插 5.单链表尾删 6.单链表头删 7.单链表查找 8.单链表在pos位置之前插入x 9.单链表删除pos位…

Android引入aar包的方法

问题描述: 网上参考了许多引入aar包的方法是: 第一步:在build.gradle中的android{}外层添加如下代码 repositories {flatDir {dirs libs} }第二步:在dependencies中添加 implementation(name:名称, ext:aar)但是上述方法中的fl…

85、字符串操作的优化

上一节介绍了在模型的推理优化过程中,动态内存申请会带来额外的性能损失。 Python 语言在性能上之所以没有c++高效,有一部分原因就在于Python语言将内存的动态管理过程给封装起来了,我们作为 Python 语言的使用者是看不到这个过程的。 这一点有点类似于 c++ 标准库中的一些…

探索Promise异步模式抽象的变体——Promise.race篇

如果阅读有疑问的话,欢迎评论或私信!! 本人会很热心的阐述自己的想法!谢谢!!! 文章目录 前言初识Promise.race探索Promise.raceAPI实例 前言 在本栏前一篇Promise.all中,我们可以实…

谷歌最新黑科技:Gemini 1.5携100万Token挑战AI多模态极限

最近科技圈再次迎来震撼弹!除了火爆全球的openAI Sora文生视频模型外,谷歌发布了其大模型矩阵的最新成员——Gemini 1.5,一举将上下文窗口长度扩展至惊人的100万个tokens。这不仅仅是一个简单的数字增加,而是一次划时代的飞跃&…

万界星空科技电子机电行业MES系统,2000元/年起

电子行业在生产管理上具有典型的离散制造特点,采用多品种、多批量或单件的生产组织方式。产品升级换代迅速,生命周期短,变更频繁,版本控制复杂。 同时产品的种类较多,非标准产品多,加工工序复杂&#xff0…

三种标注格式VOC、COCO、YOLO及其转换

最近在做基于深度学习的目标检测,数据标注软件选择的LabelImg。 常用的几种标注格式及目录安排 一、VOC(标注文件xml结尾) 首先看一下VOC格式的分布: 在VOC这些文件夹中,我们主要用到: ① JPEGImages文件夹:图片 ②…

Dapp的优势与前景,具唯一性公开可追溯

​小编介绍:10年专注商业模式设计及软件开发,擅长企业生态商业模式,商业零售会员增长裂变模式策划、商业闭环模式设计及方案落地;扶持10余个电商平台做到营收过千万,数百个平台达到百万会员,欢迎咨询。 在…

B3768 [语言月赛202305] 独行

传送门: https://www.luogu.com.cn/problem/B3768 直接手推模拟,找规律,照着遍就行。 以下是找规律部分: int sq, sh, sw, nt;1.sq v0 * T1;if (sq > s) {nt s / v0;break;}sh v1 * t1;sw sq - sh;sw max (sw, 0);nt …

216699-36-4,6-Rhodamine X NHS ester,具有良好的脂溶性

117491-83-5,1890922-83-4,216699-36-4,6-Rhodamine X NHS ester,ROX SE, 6-isomer,6-ROX NHS 活化酯 您好,欢迎来到新研之家 文章关键词:117491-83-5,1890922-83-4,21…

【知识整理】Git Commit Message 规范

一. 概述 前面咱们整理过 Code Review 一文,提到了 Review 的重要性,已经同过gitlab进行CodeReview 的方式,那么本文详细说明一下对CodeReivew非常重要的Git Commit Message 规范。 我们在每次提交代码时,都需要编写 Commit Mes…

【C语言】指针变量未初始化

我们知道:全局变量未赋初值,编译器会直接赋值为0;局部变量如果未赋初值,则会维持上一状态保存在该地址上的值,这个值是随机的。把这个值赋值给局部变量是没有意义的。 但是指针变量是如何解决不赋初值? 指…

备战蓝桥杯—— 双指针技巧巧答链表2

对于单链表相关的问题,双指针技巧是一种非常广泛且有效的解决方法。以下是一些常见问题以及使用双指针技巧解决: 合并两个有序链表: 使用两个指针分别指向两个链表的头部,逐一比较节点的值,将较小的节点链接到结果链表…

基于STM32的宠物箱温度湿度监控系统

基于STM32的宠物箱温度湿度监控系统 一、引言 随着人们生活水平的提高,养宠物已经成为越来越多人的选择。宠物作为家庭的一员,其生活环境和健康状况受到了广泛关注。温度和湿度是影响宠物舒适度和健康的重要因素之一。因此,开发一款能够实时监控宠物箱温度和湿度的系统具有…

编程学习线上提问现场解答流程,零基础学编程从入门到精通

编程学习线上提问现场解答流程 一、前言 之前给大家分享的一款中文编程工具,越来越多的学员使用这个工具学习编程。 在学习中有疑难问题寻求解答流程 1、可以在本平台留言或发私信联系老师 2、可以在群提问及时解答问题 3、通过线上会议的方式,电脑…

Hudi程序导致集群RPC偏高问题分析

1、背景 Hudi程序中upsert操作频繁,过多的删除和回滚操作,导致集群RPC持续偏高 2、描述 hudi采用的是mvcc设计,提供了清理工具cleaner来把旧版本的文件分片删除,默认开启了清理功能,可以防止文件系统的存储空间和文件数量的无限…

企业计算机服务器中了crypt勒索病毒怎么办,crypt勒索病毒解密数据恢复

计算机服务器设备为企业的生产运营提供了极大便利,企业的重要核心数据大多都存储在计算机服务器中,保护企业计算机服务器免遭勒索病毒攻击,是一项艰巨的工作任务。但即便很多企业都做好的了安全运维工作,依旧免不了被勒索病毒攻击…