PyTorch|保存与加载自己的模型

训练好一个模型之后,我们往往要对其进行保存,除非下次用时想再次训练一遍。

下面以一个简单的回归任务来详细讲解模型的保存和加载。

图片

来看这样一组数据:

x=torch.linspace(-1,1,50)x=x.view(50,1)y=x.pow(2)+0.3*torch.rand(50).view(50,1)

画图:

plt.scatter(x.numpy(),y.numpy())

图片

很显然,x与y基本呈二次函数关系,那么接下来我们就来拟合整个函数

import torchimport matplotlib.pyplot as pltimport torch.nn as nnimport torch.optim as optimx=torch.linspace(-1,1,50)x=x.view(50,1)y=x.pow(2)+0.3*torch.rand(50).view(50,1)net1=nn.Sequential(nn.Linear(1,10),                  nn.ReLU(),                  nn.Linear(10,1))criterion=nn.MSELoss()optimizer=optim.SGD(net1.parameters(),lr=0.2)#训练模型for i in range(1000):    pred=net1(x)    loss=criterion(pred,y)    optimizer.zero_grad()    loss.backward()    optimizer.step()
#测试模型net1.eval()with torch.no_grad():    y1=net1(x)    plt.plot(x.numpy(),y1.numpy(),'r-')    plt.scatter(x.numpy(),y.numpy())

图片

结果似乎不错!

这里我们得到了一个网络net1,它可以被当作一个二次函数,用于描述之前的x,y数据的关系

得到这个网络后,我们想保存它,主要有两种方式

1,保存整个网络,包括训练后的各个层的参数

​​​​​​​

#保存整个网络,包括训练后的各个层的参数torch.save(net1,'net1weight.pkl')

2,只保存训练好的网络的参数,速度更快

​​​​​​​

#只保存训练好的网络的参数,速度更快torch.save(net1.state_dict(),'net1_params.pkl')

假设我们按第一种方式保存,那么下次想要使用次网络时需要这样做:

network=torch.load('net1weight.pkl')
#测试模型network.eval()with torch.no_grad():    y1=network(x)    plt.plot(x.numpy(),y1.numpy(),'b-')    plt.scatter(x.numpy(),y.numpy())

图片

假设我们按第二种方式保存,那么下次想要使用次网络时需要这样做:

network=nn.Sequential(nn.Linear(1,10),                  nn.ReLU(),                  nn.Linear(10,1))network.load_state_dict(torch.load('net1_params.pkl'))​​​​​​​
#测试模型network.eval()with torch.no_grad():    y1=network(x)    plt.plot(x.numpy(),y1.numpy(),'g-')    plt.scatter(x.numpy(),y.numpy())

图片

可以看出,第二次首先需要构造出一个一模一样的模型,接着再导入参数即可。当然,这只是个简单的回归模型,其它模型保存与加载同样如此。

总结一下:

模型保存与导入有两种方式:

方式一:​​​​​​​

#模型保存torch.save(net1,'net1weight.pkl')#模型导入network=torch.load('net1weight.pkl')

方式二:​​​​​​​

#模型保存torch.save(net1.state_dict(),'net1_params.pkl')#模型导入network.load_state_dict(torch.load('net1_params.pkl'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/606176.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【HarmonyOS】深入了解 ArkUI 的动画交互以提高用户体验

从今天开始,博主将开设一门新的专栏用来讲解市面上比较热门的技术 “鸿蒙开发”,对于刚接触这项技术的小伙伴在学习鸿蒙开发之前,有必要先了解一下鸿蒙,从你的角度来讲,你认为什么是鸿蒙呢?它出现的意义又是…

解压方法之一 zip

文章目录 解压方法之一 zip语法参数参考实例仅保存文件名更多信息 解压方法之一 zip … _linux-beginner-zip: Linux zip命令的功能是用于压缩文件,解压命令为unzip。 通过zip命令可以将很多文件打包成.zip格式的压缩包,里面会包含文件的名称、路径、…

性能分析与调优: Linux 实现 CPU剖析与火焰图

目录 一、实验 1.环境 2.CPU 剖析 3.CPU火焰图 一、实验 1.环境 (1)主机 表1-1 主机 主机架构组件IP备注prometheus 监测 系统 prometheus、node_exporter 192.168.204.18grafana监测GUIgrafana192.168.204.19agent 监测 主机 node_exporter192…

【AI视野·今日CV 计算机视觉论文速览 第284期】Fri, 5 Jan 2024

AI视野今日CS.CV 计算机视觉论文速览 Fri, 5 Jan 2024 Totally 62 papers 👉上期速览✈更多精彩请移步主页 Daily Computer Vision Papers Learning to Prompt with Text Only Supervision for Vision-Language Models Authors Muhammad Uzair Khattak, Muhammad F…

jenkins忘记admin密码

jenkins忘记admin密码,重置密码: 1.找打jenkins目录下面的config.xml [rootVM-0-15-centos .jenkins]# find ./* -name config.xml ./config.xml [rootVM-0-15-centos .jenkins]# pwd /root/.jenkins删除下面的这部分内容: [rootVM-0-15-c…

网站被篡改怎么办,如何进行有效的防护

随着互联网的飞速发展,信息传播的速度和范围得到了极大的提升。然而,这也为网页篡改行为提供了可乘之机。网页被篡改不仅会损害网站的形象,还可能对用户造成误导,甚至导致安全漏洞。因此,网页防篡改技术成为了网络安全…

如何翻译整本书并制作为双语对照?

随着人工智能技术的快速发展,机器翻译已经不再是遥不可及的梦想。众多大互联网公司如谷歌、百度等都相继推出了免费的翻译工具,使得跨语言沟通变得触手可及。今年,数百家公司更是开发出大型AI语言模型,其中以ChatGPT 4引人瞩目&am…

外延炉及其相关的小知识

外延炉是一种用于生产半导体材料的设备,其工作原理是在高温高压环境下将半导体材料沉积在衬底上。 硅外延生长,是在具有一定晶向的硅单晶衬底上,生长一层具有和衬底相同晶向的电阻率且厚度不同的晶格结构完整性好的晶体。 外延生长的特点&am…

Redis分布式锁(二)基于Redis的分布式锁

一、redis锁 1、思路 利用set nx ex获取锁,并设置过期时间,保存线程标识释放锁时先判断线程标识是否与自己一致,一致则删除 2、特性 利用set nx满足互斥性利用set ex保证故障时锁依然能释放,避免死锁,提高安全性利…

如何进行深入的竞品分析:掌握这些技巧让你更加了解市场

随着互联网行业的快速发展,产品经理需要对竞品进行深入分析,才能更好地把握市场需求和趋势,为公司带来更好的商业价值。那么,如何做好竞品分析呢?以下是我对于这个问题的思考和建议。 一、确定分析的目的和范围 在开…

vs 修改系统环境变量putenv、_putenv

事情起因是某一天需要在vs2010的工程中去动态配置adb环境变量,win10环境 一开始,使用了putenv,很快进入代码调试,死活无法达成目的(奇怪的是另外一个工程就能修改成功) 一番面向运气编程,最后…

3D Gaussian Splatting 应用场景及最新进展【附10篇前沿论文和代码】

CV玩家们,知道3D高斯吗?对,就是计算机视觉最近的新宠,在几个月内席卷三维视觉和SLAM领域的3D高斯。不太了解也没关系,我今天就来和同学们一起聊聊这个话题。 3D Gaussian Splatting(3DGS)是用于…

ShardingSphere-JDBC学习笔记

引言 开源产品的小故事 Sharding-JDBC是2015年开源的,早期的定位就是一个分布式数据库的中间件,而在它之前有一个MyCat的产品。MyCat也是从阿里开源出来的,作为分库分表的代名词火了很长一段时间,而MyCat早年的目标就是想进入ap…

一致性 Hash

一致性 Hash 一致性哈希算法(Consistent Hashing Algorithm)是一种分布式算法,常用于负载均衡。Memcached client 也选择这种算法,解决将 key-value 均匀分配到众多 Memcached server 上的问题。它可以取代传统的取模操作,解决了取模操作无法…

关于网盘下载速度提升的一些技巧!!

这里写自定义目录标题 前言:步骤:一、下载IDM二、安装油猴三、添加到Google拓展程序上PS:四、添加脚本五、IDM配置六、打开网页版网盘 前言: 18G的网盘资源下载时间仅仅3-5分钟 步骤: 一、下载IDM 这里我以IDM举例…

kettle分页抽取数据

背景 kettle抽取数据大家还是比较熟悉的,kettle在抽取数据的时候会开启很多通道,同时抽取,但是我现在遇到一个场景: 从一个mysql数据库里获取“已办”状态的数据id,然后拿这些id去一个oracle数据库里查询&#xff0c…

【MATLAB】ICEEMDAN_LSTM神经网络时序预测算法

有意向获取代码,请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 ICEEMDAN-LSTM神经网络时序预测算法是一种结合了改进的完全扩展经验模态分解(ICEEMDAN)和长短期记忆神经网络(LSTM)的时间序列预测方法。 …

【UE Niagara学习笔记】02 - 制作燃烧的火焰

目录 效果 步骤 一、添加资产 二、制作材质 三、制作粒子 3.1 循环播放 3.2 粒子生成的数量 3.3 粒子的生命周期和初始大小 3.4 火焰高度 3.5 火焰范围 3.6 火焰颜色 效果 步骤 一、添加资产 1. 在虚幻商城中搜索“M5 VFX Vol2. Fire and Flames(Niagara)”…

遇见狂神说 Spring MVC 学习笔记(完整笔记+代码)

MVC架构介绍 MVC是模型(Model)、视图(View)、控制器(Controller)的简写,是一种软件设计规范MVC是将业务逻辑、数据、显示分离的方式来组织代码MVC主要作用是降低了视图与业务逻辑间的双向偶合MVC不是一种设计模式,是一种架构模式。当然不同的MVC存在差异…

python 文件

open """ def open(file: FileDescriptorOrPath, //路径mode: OpenTextMode "r", //设置打开文件的模式 r 以只读方式打开文件。文件的指针将会放在文件的开头。这是默认模式。 w 打开一个文件只用写入。如果该文件已存在则打开文件&#…