PyTorch|保存与加载自己的模型

训练好一个模型之后,我们往往要对其进行保存,除非下次用时想再次训练一遍。

下面以一个简单的回归任务来详细讲解模型的保存和加载。

图片

来看这样一组数据:

x=torch.linspace(-1,1,50)x=x.view(50,1)y=x.pow(2)+0.3*torch.rand(50).view(50,1)

画图:

plt.scatter(x.numpy(),y.numpy())

图片

很显然,x与y基本呈二次函数关系,那么接下来我们就来拟合整个函数

import torchimport matplotlib.pyplot as pltimport torch.nn as nnimport torch.optim as optimx=torch.linspace(-1,1,50)x=x.view(50,1)y=x.pow(2)+0.3*torch.rand(50).view(50,1)net1=nn.Sequential(nn.Linear(1,10),                  nn.ReLU(),                  nn.Linear(10,1))criterion=nn.MSELoss()optimizer=optim.SGD(net1.parameters(),lr=0.2)#训练模型for i in range(1000):    pred=net1(x)    loss=criterion(pred,y)    optimizer.zero_grad()    loss.backward()    optimizer.step()
#测试模型net1.eval()with torch.no_grad():    y1=net1(x)    plt.plot(x.numpy(),y1.numpy(),'r-')    plt.scatter(x.numpy(),y.numpy())

图片

结果似乎不错!

这里我们得到了一个网络net1,它可以被当作一个二次函数,用于描述之前的x,y数据的关系

得到这个网络后,我们想保存它,主要有两种方式

1,保存整个网络,包括训练后的各个层的参数

​​​​​​​

#保存整个网络,包括训练后的各个层的参数torch.save(net1,'net1weight.pkl')

2,只保存训练好的网络的参数,速度更快

​​​​​​​

#只保存训练好的网络的参数,速度更快torch.save(net1.state_dict(),'net1_params.pkl')

假设我们按第一种方式保存,那么下次想要使用次网络时需要这样做:

network=torch.load('net1weight.pkl')
#测试模型network.eval()with torch.no_grad():    y1=network(x)    plt.plot(x.numpy(),y1.numpy(),'b-')    plt.scatter(x.numpy(),y.numpy())

图片

假设我们按第二种方式保存,那么下次想要使用次网络时需要这样做:

network=nn.Sequential(nn.Linear(1,10),                  nn.ReLU(),                  nn.Linear(10,1))network.load_state_dict(torch.load('net1_params.pkl'))​​​​​​​
#测试模型network.eval()with torch.no_grad():    y1=network(x)    plt.plot(x.numpy(),y1.numpy(),'g-')    plt.scatter(x.numpy(),y.numpy())

图片

可以看出,第二次首先需要构造出一个一模一样的模型,接着再导入参数即可。当然,这只是个简单的回归模型,其它模型保存与加载同样如此。

总结一下:

模型保存与导入有两种方式:

方式一:​​​​​​​

#模型保存torch.save(net1,'net1weight.pkl')#模型导入network=torch.load('net1weight.pkl')

方式二:​​​​​​​

#模型保存torch.save(net1.state_dict(),'net1_params.pkl')#模型导入network.load_state_dict(torch.load('net1_params.pkl'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/606176.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【HarmonyOS】深入了解 ArkUI 的动画交互以提高用户体验

从今天开始,博主将开设一门新的专栏用来讲解市面上比较热门的技术 “鸿蒙开发”,对于刚接触这项技术的小伙伴在学习鸿蒙开发之前,有必要先了解一下鸿蒙,从你的角度来讲,你认为什么是鸿蒙呢?它出现的意义又是…

解压方法之一 zip

文章目录 解压方法之一 zip语法参数参考实例仅保存文件名更多信息 解压方法之一 zip … _linux-beginner-zip: Linux zip命令的功能是用于压缩文件,解压命令为unzip。 通过zip命令可以将很多文件打包成.zip格式的压缩包,里面会包含文件的名称、路径、…

uView Avatar 头像

本组件一般用于展示头像的地方,如个人中心,或者评论列表页的用户头像展示等场所。 #平台差异说明 App(vue)App(nvue)H5小程序√√√√ #基本使用 通过src指定头像的路径即可简单使用,如果传…

性能分析与调优: Linux 实现 CPU剖析与火焰图

目录 一、实验 1.环境 2.CPU 剖析 3.CPU火焰图 一、实验 1.环境 (1)主机 表1-1 主机 主机架构组件IP备注prometheus 监测 系统 prometheus、node_exporter 192.168.204.18grafana监测GUIgrafana192.168.204.19agent 监测 主机 node_exporter192…

【AI视野·今日CV 计算机视觉论文速览 第284期】Fri, 5 Jan 2024

AI视野今日CS.CV 计算机视觉论文速览 Fri, 5 Jan 2024 Totally 62 papers 👉上期速览✈更多精彩请移步主页 Daily Computer Vision Papers Learning to Prompt with Text Only Supervision for Vision-Language Models Authors Muhammad Uzair Khattak, Muhammad F…

jenkins忘记admin密码

jenkins忘记admin密码,重置密码: 1.找打jenkins目录下面的config.xml [rootVM-0-15-centos .jenkins]# find ./* -name config.xml ./config.xml [rootVM-0-15-centos .jenkins]# pwd /root/.jenkins删除下面的这部分内容: [rootVM-0-15-c…

网站被篡改怎么办,如何进行有效的防护

随着互联网的飞速发展,信息传播的速度和范围得到了极大的提升。然而,这也为网页篡改行为提供了可乘之机。网页被篡改不仅会损害网站的形象,还可能对用户造成误导,甚至导致安全漏洞。因此,网页防篡改技术成为了网络安全…

Linux部署前后端项目

部署SpringBoot项目 创建SpringBoot项目 先确保有一个可以运行的springboot项目,这里就记录创建项目的流程了,可以自行百度。 命令行启动 2.1、在linux中,我是在data目录下新创建的一个project目录(此目录创建位置不限制&…

智慧园区运维:1500路摄像头故障监控及多机房一体化运维

一、引言 随着智慧园区的快速发展,对园区内IT设施的运维管理提出了更高的要求。本解决方案旨在满足智慧园区对1500路摄像头故障监控及视频画面质量分析的需求,同时具备可扩充性,适应未来园区规模的不断扩大。通过监控易的解决方案&#xff0c…

【C++11】可调用对象

C中存在可调用对象(callable objects)的一个概念。其具体定义为: 1)函数指针 2)具有operator()的类对象(仿函数) 3)可以被转换为函数指针的对象 4&#xff09…

如何翻译整本书并制作为双语对照?

随着人工智能技术的快速发展,机器翻译已经不再是遥不可及的梦想。众多大互联网公司如谷歌、百度等都相继推出了免费的翻译工具,使得跨语言沟通变得触手可及。今年,数百家公司更是开发出大型AI语言模型,其中以ChatGPT 4引人瞩目&am…

外延炉及其相关的小知识

外延炉是一种用于生产半导体材料的设备,其工作原理是在高温高压环境下将半导体材料沉积在衬底上。 硅外延生长,是在具有一定晶向的硅单晶衬底上,生长一层具有和衬底相同晶向的电阻率且厚度不同的晶格结构完整性好的晶体。 外延生长的特点&am…

Java 8升级Java 11,升级必知要点!竟然有这些坑…

随着技术的不断进步,Java作为一种广泛使用的编程语言,其版本更新带来了许多新特性和性能提升。从Java 8升级到Java 11,是一个重要的转变,它不仅带来了新的编程范式,还引入了对现代软件开发的多项优化。然而&#xff0c…

Redis分布式锁(二)基于Redis的分布式锁

一、redis锁 1、思路 利用set nx ex获取锁,并设置过期时间,保存线程标识释放锁时先判断线程标识是否与自己一致,一致则删除 2、特性 利用set nx满足互斥性利用set ex保证故障时锁依然能释放,避免死锁,提高安全性利…

.net6解除文件上传限制。Multipart body length limit 16384 exceeded

在C#中上传文件时如果不修改默认文件的上传大小会提示Multipart body length limit 16384 exceeded这个错误提示表明你的请求中的Multipart body长度超过了16384字节的限制。这通常意味着你正在尝试发送一个太大的请求体,可能是因为包含了太多数据或者太大的文件。要…

29道memcached面试题含答案(很全)

点击下载《29道memcached面试题含答案(很全)》 1. Memcached是什么,有什么作用? Memcached是一个开源的,高性能的内存缓存软件,从名称上看Mem就是内存的意思,而Cache就是缓存的意思。Memcache…

如何进行深入的竞品分析:掌握这些技巧让你更加了解市场

随着互联网行业的快速发展,产品经理需要对竞品进行深入分析,才能更好地把握市场需求和趋势,为公司带来更好的商业价值。那么,如何做好竞品分析呢?以下是我对于这个问题的思考和建议。 一、确定分析的目的和范围 在开…

vs 修改系统环境变量putenv、_putenv

事情起因是某一天需要在vs2010的工程中去动态配置adb环境变量,win10环境 一开始,使用了putenv,很快进入代码调试,死活无法达成目的(奇怪的是另外一个工程就能修改成功) 一番面向运气编程,最后…

3D Gaussian Splatting 应用场景及最新进展【附10篇前沿论文和代码】

CV玩家们,知道3D高斯吗?对,就是计算机视觉最近的新宠,在几个月内席卷三维视觉和SLAM领域的3D高斯。不太了解也没关系,我今天就来和同学们一起聊聊这个话题。 3D Gaussian Splatting(3DGS)是用于…

vue项目中px单位转rem插件

一、安装插件: "postcss-px2rem": "^0.3.0", "postcss-px2rem-exclude": "0.0.6",二、新建postcss.config.js module.exports {plugins: {autoprefixer: {},"postcss-px2rem-exclude": {"remUnit":…