Pytorch(4)-模型保存-载入-eval()

模型保存与提取

  • 1. 整个模型 保存-载入
  • 2. 仅模型参数 保存-载入
  • 3. GPU/CPU模型保存与导入
  • 4. net.eval()--固定模型随机项

神经网络模型在线训练完之后需要保存下来,以便下次使用时可以直接导入已经训练好的模型。pytorch 提供两种方式保存模型:

方式1:保存整个网络,载入时直接载入整个网络,优点:代码简单,缺点需要的存储空间大

方式2:只保存网络参数,载入时需要先建立与原来网络一样结构的网络,然后将网络参数导入到该网络中,方式2的优缺点与方式1相反。

1. 整个模型 保存-载入

模型的结构参数都保存下来了

# 保存模型:设置 保存目录 和 保存文件名.扩展名,常用扩展名: .pkl .pth (扩展名只要好辨识就即可)
PATH="./model/mynet1.pkl"
# 导入官方提供的预训练模型
net1=torchvision.models.alexnet(pretrainend=True)
# 用数据集训练网络
.....
# 保存训练好的网络
torch.save(net1, PATH)
-----------------------------------------------------------
# 载入模型:设置载入路径,即模型保存的路径
PATH="./model/mynet1.pkl"
net1_1=torch.load(PATH)

2. 仅模型参数 保存-载入

保存时–只保存网络中的参数 (速度快, 占内存少), 载入时–需要提前创建好结构和net2是一样的

# 保存模型:设置 保存目录 和 保存文件名.扩展名,常用扩展名: .pkl .pth (扩展名只要好辨识就即可)
PATH="./model/mynet2.pkl"
# 导入官方提供的预训练模型
net2=torchvision.models.alexnet(pretrainend=True)
# 用数据集训练网络
.....
# 保存训练好的网络
torch.save(net1.state_dict(), PATH)
-----------------------------------------------------------
# 载入模型:设置载入路径,即模型保存的路径
PATH="./model/net2.pkl"
# 新建一个网络
net2_2=torchvision.models.alexnet(pretrained=True)
# 载入模型参数
net2_2.load_state_dict(torch.load(PATH))

迷糊的现象

在使用莫烦的文档做实验时,保存的两个文件:net.pkl,net_params.pkl大小差异比较大。保证在导入模型是比较快。
在这里插入图片描述
但是使用torchvision.models.模块中的一系列网络时,因为网络的参数很大,所以实验过程中用两种方法保存模型的文件大小是一致的。(猜测是内置模型使用torch.save(net1, ‘net.pkl’)时默认保存的是模型参数)

提供一个神经网络模型占用空间大小的计算方法:
在这里插入图片描述
参考文档:经典CNN模型计算量与内存需求分析

3. GPU/CPU模型保存与导入

在训练是模型是GPU/CPU,决定了模型载入时的模型原型。可以分为下面三种情况
(只展示导入整个网络模型的情况,具体实验还没做过):

1.CPU(原型)->CPU, GPU(原型)->GPU

torch.load( ‘net.pkl’)

2.GPU(原型)->CPU

torch.load(‘model_dict.pkl’, map_location=lambda storage, loc: storage)

3.CPU(模型文件)->GPU

torch.load(‘model_dic.pkl’, map_location=lambda storage, loc: storage.cuda)

参考文档:https://blog.csdn.net/u012135425/article/details/85217542

4. net.eval()–固定模型随机项

两种模型载入方式、.eval() 作用实验demo

step1: 载入模型

# 20191204 pytorch 模型载入测试
import torchvision as tvt
import torch
net1=tvt.models.alexnet(pretrained=True)  # 1.自动从网上下载的预先训练模型
net2=torch.load("./model/mynet1.pkl")     # 2.导入事先训练好的保存的整个网络net3=tvt.models.alexnet(pretrained=True)  # 3.导入只保存模型参数的网络,需要新建一个网络
net3.load_state_dict(torch.load("../model/mynet2.pkl"))
net3.eval()                              #   固定dropout和归一化层,否则每次推理会生成不同的结果。

step2:输出三个网络同一层参数的和,net2 和net3 对应参数相等。可以看出来,两种模型保存和导入方式是等价的。

net1 tensor(-21257.7656, grad_fn=<SumBackward0>)
net2 tensor(-21253.9473, device='cuda:0', grad_fn=<SumBackward0>)
net3 tensor(-21253.9551, grad_fn=<SumBackward0>)

step3: 产生一个随机输入a,输入到网络1,2,3,打印输出结果。

a=torch.randn([1,3,224,224])
y1=net1(a)
y2=net2(a)
y3=net3(a)
# 第二次输入
y11=net1(a)
y22=net2(a)
y33=net3(a)
# 打印y1,y2,y3,y11,y22,y33(1000维的和)
y1: tensor(-5.2689, grad_fn=<SumBackward0>)
y2: tensor(-1.6695, device='cuda:0', grad_fn=<SumBackward0>)
y3: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)y11: tensor(-4.4205, grad_fn=<SumBackward0>)
y22: tensor(-5.9475, device='cuda:0', grad_fn=<SumBackward0>)
y33: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)

只有net3的输出是固定的,因为在模型导入的时候执行了net3.eval().
结论:无论采用 方式1 还是 方式2 导入的模型, 在模型测试时,都需要用.eval()方法固定一下网络在训练过程中的随机项目,如dropout 等,避免网络在同一个输入下产生不一样的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/445237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据学习(08)--Hadoop中的数据仓库Hive

文章目录目录1.什么是数据仓库&#xff1f;1.1数据仓库概念1.2传统数据仓库面临的挑战1.3 Hive介绍1.4 Hive与传统数据库的对比1.5 Hive在企业中的部署与应用2.Hive系统架构3.Hive工作原理3.1 SQL转换为MapReduce作业的基本原理3.2 Hive中SQL查询转换MapReduce作业的过程4.Hive…

dubbo知识点总结 持续更新

Dubbo 支持哪些协议&#xff0c;每种协议的应用场景&#xff0c;优缺点&#xff1f;  dubbo&#xff1a; 单一长连接和 NIO 异步通讯&#xff0c;适合大并发小数据量的服务调用&#xff0c; 以及消费者远大于提供者。传输协议 TCP&#xff0c;异步&#xff0c;Hessian 序列化…

使用Linux auto Makefile自动生成的运行步骤

首先创建一个 Linux Makefile.am.这一步是创建Linux Makefile很重要的一步&#xff0c;automake要用的脚本配置文件是Linux Makefile.am&#xff0c;用户需要自己创建相应的文件。之后&#xff0c;automake工具转换成Linux Makefile.in。AD&#xff1a; 在向大家详细介绍Linux …

无限踩坑系列(6)-mySQL数据库链接错误

mySQL数据库链接错误错误1错误2长链接短连接应用场景需要一直访问mySQL数据库&#xff0c;遇到如下错误&#xff1a;错误1 释放已经释放的数据库链接conn.&#xff0c;或者&#xff0c;操作已经释放的数据库链接conn.或者失去链接后再操作数据库都可能会报这个错误 aise err.I…

初探函数式编程和面对对象式编程

文章目录目录1.函数式编程和面向对象编程概念1.1 函数式编程1.2 面向对象编程2.函数式编程和面向对象编程的优缺点2.1 函数式编程优点缺点2.2 面对对象编程优点缺点3.为什么在并行计算中函数式编程比较好3.1 什么是并行计算3.2 函数式编程兴起原因目录 1.函数式编程和面向对象…

linux常用解压和压缩文件的命令

linux常用解压和压缩文件的命令 .tar 解包&#xff1a;tar xvf FileName.tar打包&#xff1a;tar cvf FileName.tar DirName&#xff08;注&#xff1a;tar是打包&#xff0c;不是压缩&#xff01;&#xff09;———————————————.gz解压1&#xff1a;gunzip FileN…

Python外(4)-读写mat文件

读写mat文件1.读取2.写入.mat 是matlab中数据存储的标准格式&#xff0c;Python中能够通过库scipy读取和保存。导入scipy库 from scipy import io 1.读取 io.loadmat(file_name, mdictNone, appendmatTrue, **kwargs) 简便方式&#xff1a; io.loadmat(file_name) append mat–…

Linux下的xml文件的创建

创建一个xml文档流程如下&#xff1a; l 用xmlNewDoc函数创建一个文档指针doc&#xff1b; l 用xmlNewNode函数创建一个节点指针root_node&#xff1b; l 用xmlDocSetRootElement将root_node设置为doc的根结点&#xff1b; l 给root_node添加一系列的子节点&#x…

压力测试http_load 通过修改配置测试https协议成功了。

到http://www.acme.com/software/http_load/ 下载http_load &#xff0c;安装也很简单直接make;make instlall 就行。 如果你需要测试https&#xff0c;你必须将 Makefile中 # CONFIGURE: If you want to compile in support for https, uncomment these # definitions. You w…

面向对象设计与分析40讲(16)静态工厂方法模式

前面我们介绍了简单工厂模式&#xff0c;在创建对象前&#xff0c;我们需要先创建工厂&#xff0c;然后再通过工厂去创建产品。 如果将工厂的创建方法static化&#xff0c;那么无需创建工厂即可通过静态方法直接调用的方式创建产品&#xff1a; // 工厂类&#xff0c;定义了静…

搜索详解

搜索 一.dfs和bfs简介 深度优先遍历(dfs) 本质&#xff1a; 遍历每一个点。 遍历流程&#xff1a; 从起点开始&#xff0c;在其一条分支上一条路走到黑&#xff0c;走不通了就往回走&#xff0c;只要当前有分支就继续往下走&#xff0c;直到将所有的点遍历一遍。 剪枝&a…

Python外(5)-for-enumerate()-zip()

for循环小技巧技巧1&#xff1a;enumerate()技巧2&#xff1a;打包两个可遍历数据&#xff0c;一起循环-zip()技巧1&#xff1a;enumerate() 在使用pytorch训练网络的过程中&#xff0c;官方教程给出了 for i, data in enumerate(trainloader, 0): 这涉及到enumerate函数的使用…

特征工程总结

目录1 特征工程是什么&#xff1f; 2 数据预处理   2.1 无量纲化     2.1.1 标准化     2.1.2 区间缩放法     2.1.3 标准化与归一化的区别   2.2 对定量特征二值化   2.3 对定性特征哑编码   2.4 缺失值计算   2.5 数据变换 3 特征选择   3.1 Filter …

Jmeter测试并发https请求成功了

Jmeter2.4 如何测试多个并发https请求&#xff0c;终于成功了借此机会分享给大家 首先要安装jmeter2.4版本的&#xff0c;而且不建议大家使用badboy&#xff0c;因为这存在兼容性问题。对于安装&#xff0c;我就不讲了&#xff0c;我就说说如何测试https&#xff0c;想必大家都…

关系数据库——sql基础1定义

关系数据库标准语言SQL 基本概念 SQL语言是一个功能极强的关系数据库语言。同时也是一种介于关系代数与关系演算之间的结构化查询语言&#xff08;Structured Query Language&#xff09;&#xff0c;其功能包括数据定义、数据查询、数据操纵和数据控制。 SQL的特点&#xff…

libcurl编程

一、curl简介 curl是一个利用URL语法在命令行方式下工作的文件传输工具。它支持的协议有&#xff1a;FTP, FTPS, HTTP, HTTPS, GOPHER, TELNET, DICT, FILE 以及 LDAP。curl同样支持HTTPS认证&#xff0c;HTTP POST方法, HTTP PUT方法, FTP上传, kerberos认证, HTTP上传, 代理服…

大数据学习(09)--Hadoop2.0介绍

文章目录目录1.Hadoop的发展与优化1.1 Hadoop1.0 的不足与局限1.2 Hadoop2.0 的改进与提升2.HDFS2.0 的新特性2.1 HDFS HA2.2 HDFS Federation3. 新一代的资源管理器YARN3.1 MapReduce1.0 缺陷3.2 YARN的设计思路3.3 YARN 体系结构3.4 YARN工作流程3.5 YARN框架与MapReduce1.0框…

Java多线程常用方法

start()与run() start() 启动线程并执行相应的run()方法 run() 子线程要执行的代码放入run()方法 getName()和setName() getName() 获取此线程的名字 setName() 设置此线程的名字 isAlive() 是判断当前线程是否处于活动状态。活动状态就是已经启动尚未终止。 curren…

MachineLearning(2)-图像分类常用数据集

图像分类常用数据集1 CIFAR-102.MNIST3.STL_104.Imagenet5.L-Sun6.caltech-101在训练神经网络进行图像识别分类时&#xff0c;常会用到一些通用的数据集合。利用这些数据集合可以对比不同模型的性能差异。下文整理常用的图片数据集合&#xff08;持续更新中)。基本信息对比表格…

Linux网络编程实例详解

本文介绍了在Linux环境下的socket编程常用函数用法及socket编程的一般规则和客户/服务器模型的编程应注意的事项和常遇问题的解决方法&#xff0c;并举了具体代 码实例。要理解本文所谈的技术问题需要读者具有一定C语言的编程经验和TCP/IP方面的基本知识。要实习本文的示例&…