Pytorch(4)-模型保存-载入-eval()

模型保存与提取

  • 1. 整个模型 保存-载入
  • 2. 仅模型参数 保存-载入
  • 3. GPU/CPU模型保存与导入
  • 4. net.eval()--固定模型随机项

神经网络模型在线训练完之后需要保存下来,以便下次使用时可以直接导入已经训练好的模型。pytorch 提供两种方式保存模型:

方式1:保存整个网络,载入时直接载入整个网络,优点:代码简单,缺点需要的存储空间大

方式2:只保存网络参数,载入时需要先建立与原来网络一样结构的网络,然后将网络参数导入到该网络中,方式2的优缺点与方式1相反。

1. 整个模型 保存-载入

模型的结构参数都保存下来了

# 保存模型:设置 保存目录 和 保存文件名.扩展名,常用扩展名: .pkl .pth (扩展名只要好辨识就即可)
PATH="./model/mynet1.pkl"
# 导入官方提供的预训练模型
net1=torchvision.models.alexnet(pretrainend=True)
# 用数据集训练网络
.....
# 保存训练好的网络
torch.save(net1, PATH)
-----------------------------------------------------------
# 载入模型:设置载入路径,即模型保存的路径
PATH="./model/mynet1.pkl"
net1_1=torch.load(PATH)

2. 仅模型参数 保存-载入

保存时–只保存网络中的参数 (速度快, 占内存少), 载入时–需要提前创建好结构和net2是一样的

# 保存模型:设置 保存目录 和 保存文件名.扩展名,常用扩展名: .pkl .pth (扩展名只要好辨识就即可)
PATH="./model/mynet2.pkl"
# 导入官方提供的预训练模型
net2=torchvision.models.alexnet(pretrainend=True)
# 用数据集训练网络
.....
# 保存训练好的网络
torch.save(net1.state_dict(), PATH)
-----------------------------------------------------------
# 载入模型:设置载入路径,即模型保存的路径
PATH="./model/net2.pkl"
# 新建一个网络
net2_2=torchvision.models.alexnet(pretrained=True)
# 载入模型参数
net2_2.load_state_dict(torch.load(PATH))

迷糊的现象

在使用莫烦的文档做实验时,保存的两个文件:net.pkl,net_params.pkl大小差异比较大。保证在导入模型是比较快。
在这里插入图片描述
但是使用torchvision.models.模块中的一系列网络时,因为网络的参数很大,所以实验过程中用两种方法保存模型的文件大小是一致的。(猜测是内置模型使用torch.save(net1, ‘net.pkl’)时默认保存的是模型参数)

提供一个神经网络模型占用空间大小的计算方法:
在这里插入图片描述
参考文档:经典CNN模型计算量与内存需求分析

3. GPU/CPU模型保存与导入

在训练是模型是GPU/CPU,决定了模型载入时的模型原型。可以分为下面三种情况
(只展示导入整个网络模型的情况,具体实验还没做过):

1.CPU(原型)->CPU, GPU(原型)->GPU

torch.load( ‘net.pkl’)

2.GPU(原型)->CPU

torch.load(‘model_dict.pkl’, map_location=lambda storage, loc: storage)

3.CPU(模型文件)->GPU

torch.load(‘model_dic.pkl’, map_location=lambda storage, loc: storage.cuda)

参考文档:https://blog.csdn.net/u012135425/article/details/85217542

4. net.eval()–固定模型随机项

两种模型载入方式、.eval() 作用实验demo

step1: 载入模型

# 20191204 pytorch 模型载入测试
import torchvision as tvt
import torch
net1=tvt.models.alexnet(pretrained=True)  # 1.自动从网上下载的预先训练模型
net2=torch.load("./model/mynet1.pkl")     # 2.导入事先训练好的保存的整个网络net3=tvt.models.alexnet(pretrained=True)  # 3.导入只保存模型参数的网络,需要新建一个网络
net3.load_state_dict(torch.load("../model/mynet2.pkl"))
net3.eval()                              #   固定dropout和归一化层,否则每次推理会生成不同的结果。

step2:输出三个网络同一层参数的和,net2 和net3 对应参数相等。可以看出来,两种模型保存和导入方式是等价的。

net1 tensor(-21257.7656, grad_fn=<SumBackward0>)
net2 tensor(-21253.9473, device='cuda:0', grad_fn=<SumBackward0>)
net3 tensor(-21253.9551, grad_fn=<SumBackward0>)

step3: 产生一个随机输入a,输入到网络1,2,3,打印输出结果。

a=torch.randn([1,3,224,224])
y1=net1(a)
y2=net2(a)
y3=net3(a)
# 第二次输入
y11=net1(a)
y22=net2(a)
y33=net3(a)
# 打印y1,y2,y3,y11,y22,y33(1000维的和)
y1: tensor(-5.2689, grad_fn=<SumBackward0>)
y2: tensor(-1.6695, device='cuda:0', grad_fn=<SumBackward0>)
y3: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)y11: tensor(-4.4205, grad_fn=<SumBackward0>)
y22: tensor(-5.9475, device='cuda:0', grad_fn=<SumBackward0>)
y33: tensor(-4.4349, device='cuda:0', grad_fn=<SumBackward0>)

只有net3的输出是固定的,因为在模型导入的时候执行了net3.eval().
结论:无论采用 方式1 还是 方式2 导入的模型, 在模型测试时,都需要用.eval()方法固定一下网络在训练过程中的随机项目,如dropout 等,避免网络在同一个输入下产生不一样的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/445237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据学习(08)--Hadoop中的数据仓库Hive

文章目录目录1.什么是数据仓库&#xff1f;1.1数据仓库概念1.2传统数据仓库面临的挑战1.3 Hive介绍1.4 Hive与传统数据库的对比1.5 Hive在企业中的部署与应用2.Hive系统架构3.Hive工作原理3.1 SQL转换为MapReduce作业的基本原理3.2 Hive中SQL查询转换MapReduce作业的过程4.Hive…

dubbo知识点总结 持续更新

Dubbo 支持哪些协议&#xff0c;每种协议的应用场景&#xff0c;优缺点&#xff1f;  dubbo&#xff1a; 单一长连接和 NIO 异步通讯&#xff0c;适合大并发小数据量的服务调用&#xff0c; 以及消费者远大于提供者。传输协议 TCP&#xff0c;异步&#xff0c;Hessian 序列化…

无限踩坑系列(6)-mySQL数据库链接错误

mySQL数据库链接错误错误1错误2长链接短连接应用场景需要一直访问mySQL数据库&#xff0c;遇到如下错误&#xff1a;错误1 释放已经释放的数据库链接conn.&#xff0c;或者&#xff0c;操作已经释放的数据库链接conn.或者失去链接后再操作数据库都可能会报这个错误 aise err.I…

初探函数式编程和面对对象式编程

文章目录目录1.函数式编程和面向对象编程概念1.1 函数式编程1.2 面向对象编程2.函数式编程和面向对象编程的优缺点2.1 函数式编程优点缺点2.2 面对对象编程优点缺点3.为什么在并行计算中函数式编程比较好3.1 什么是并行计算3.2 函数式编程兴起原因目录 1.函数式编程和面向对象…

搜索详解

搜索 一.dfs和bfs简介 深度优先遍历(dfs) 本质&#xff1a; 遍历每一个点。 遍历流程&#xff1a; 从起点开始&#xff0c;在其一条分支上一条路走到黑&#xff0c;走不通了就往回走&#xff0c;只要当前有分支就继续往下走&#xff0c;直到将所有的点遍历一遍。 剪枝&a…

特征工程总结

目录1 特征工程是什么&#xff1f; 2 数据预处理   2.1 无量纲化     2.1.1 标准化     2.1.2 区间缩放法     2.1.3 标准化与归一化的区别   2.2 对定量特征二值化   2.3 对定性特征哑编码   2.4 缺失值计算   2.5 数据变换 3 特征选择   3.1 Filter …

Jmeter测试并发https请求成功了

Jmeter2.4 如何测试多个并发https请求&#xff0c;终于成功了借此机会分享给大家 首先要安装jmeter2.4版本的&#xff0c;而且不建议大家使用badboy&#xff0c;因为这存在兼容性问题。对于安装&#xff0c;我就不讲了&#xff0c;我就说说如何测试https&#xff0c;想必大家都…

关系数据库——sql基础1定义

关系数据库标准语言SQL 基本概念 SQL语言是一个功能极强的关系数据库语言。同时也是一种介于关系代数与关系演算之间的结构化查询语言&#xff08;Structured Query Language&#xff09;&#xff0c;其功能包括数据定义、数据查询、数据操纵和数据控制。 SQL的特点&#xff…

大数据学习(09)--Hadoop2.0介绍

文章目录目录1.Hadoop的发展与优化1.1 Hadoop1.0 的不足与局限1.2 Hadoop2.0 的改进与提升2.HDFS2.0 的新特性2.1 HDFS HA2.2 HDFS Federation3. 新一代的资源管理器YARN3.1 MapReduce1.0 缺陷3.2 YARN的设计思路3.3 YARN 体系结构3.4 YARN工作流程3.5 YARN框架与MapReduce1.0框…

Java多线程常用方法

start()与run() start() 启动线程并执行相应的run()方法 run() 子线程要执行的代码放入run()方法 getName()和setName() getName() 获取此线程的名字 setName() 设置此线程的名字 isAlive() 是判断当前线程是否处于活动状态。活动状态就是已经启动尚未终止。 curren…

MachineLearning(2)-图像分类常用数据集

图像分类常用数据集1 CIFAR-102.MNIST3.STL_104.Imagenet5.L-Sun6.caltech-101在训练神经网络进行图像识别分类时&#xff0c;常会用到一些通用的数据集合。利用这些数据集合可以对比不同模型的性能差异。下文整理常用的图片数据集合&#xff08;持续更新中)。基本信息对比表格…

大数据学习(09)--spark学习

文章目录目录1.spark介绍1.1 spark介绍1.2 scale介绍1.3 spark和Hadoop比较2.spark生态系统3.spark运行框架3.1 基本概念3.2 架构的设计3.3 spark运行基本流程3.4 spark运行原理3.5 RDD运行原理3.5.1 设计背景3.5.2 RDD概念和特性3.5.3 RDD之间的依赖关系3.5.4 stage的划分3.5.…

机器学习中的聚类方法总结

聚类定义 定义 聚类就是对大量未知标注 的数据集&#xff0c;按数据 的内在相似性将数据集划分为多个类别&#xff0c;使 类别内的数据相似度较大而类别间的数据相 似度较小。是无监督的分类方式。 聚类思想 给定一个有N个对象的数据集&#xff0c;构造数据的k 个簇&#x…

关系数据库——关系数据语言

关系 域&#xff1a;一组具有相同数据类型的值的集合&#xff08;即取值范围&#xff09; 笛卡尔积&#xff1a;域上的一种集合运算。结果为一个集合&#xff0c;集合的每一个元素是一个元组&#xff0c;元组的每一个分量来自不同的域。 基数&#xff1a;一个域允许的不同取值…

机器学习问题总结(01)

文章目录1.请描述推荐系统中协同过滤算法CF的原理2.请描述决策树的原理、过程、终止条件&#xff0c;以及如何防止过拟合2.1决策树生成算法2.2 剪枝处理&#xff08;防止过拟合&#xff09;2.3 停止条件2.4 棵决策树的生成过程2.5 决策树的损失函数3.请描述K-means的原理&#…

Python实例讲解 -- 解析xml

Xml代码 <?xml version"1.0" encoding"utf-8"?> <info> <intro>信息</intro> <list id001> <head>auto_userone</head> <name>Jordy</name> <number&g…

python(22)--面向对象1-封装

python面向对象1面向过程/面向对象2面向对象核心概念-类3类的设计3.1类三要素-类名、属性、方法3.2面向对象基础语法3.2.1查看对象的常用方法3.2.2类定义3.2.3创建类对象3.2.4__init__()方法3.2.5 self参数3.2.6类内置方法和属性_del_()方法--销毁对象_str_()方法--定制化输出对…

机器学习问题总结(02)

文章目录1.stacking模型以及做模型融合的知识1.1 从提交结果中融合1.2 stacking1.3 blending2. 怎样去优化SVM算法模型的&#xff1f;2.1 SMO优化算法2.2 libsvm 和 Liblinear3.现有底层是tensorflow的keras框架&#xff0c;如果现在有一个tensorflow训练好的模型&#xff0c;k…

C/C++常见面试题(四)

C/C面试题集合四 目录 1、什么是C中的类&#xff1f;如何定义和实例化一个类&#xff1f; 2、请解释C中的继承和多态性。 3、什么是虚函数&#xff1f;为什么在基类中使用虚函数&#xff1f; 4、解释封装、继承和多态的概念&#xff0c;并提供相应的代码示例 5、如何处理内…

机器学习问题总结(03)

文章目录1.struct和class区别&#xff0c;你更倾向用哪个2.kNN&#xff0c;朴素贝叶斯&#xff0c;SVM的优缺点&#xff0c;各种算法优缺点2.1 KNN算法2.2 朴素贝叶斯2.3SVM算法2.4 ANN算法2.5 DT算法3. 10亿个整数&#xff0c;1G内存&#xff0c;O(n)算法&#xff0c;统计只出…