Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

1 model.train() 和 model.eval()用法和区别

1.1 model.train()

model.train()的作用是启用 Batch Normalization 和 Dropout

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

1.2 model.eval()

model.eval()的作用是不启用Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

1.3 分析原因

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

# 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()

为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。
在这里插入图片描述
想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。

2.model.eval()和torch.no_grad()的区别

1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/117966.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaEE重点知识归纳】第11节:认识异常

目录 一:异常的概念和体系结构 1.概念 2.体系结构 3.异常分类 二:异常的处理 1.防御式编程 2.异常的抛出 3.异常的捕获 4.异常的处理流程 三:自定义异常 一:异常的概念和体系结构 1.概念 (1)在…

vue3使用Element ui plus中MessageBox消息框+radio框配合使用

想要达到的效果 首先安装element ui plus 省略~~ 官网地址: https://element-plus.gitee.io/zh-CN/component/message-box.htmlhttps://element-plus.gitee.io/zh-CN/component/message-box.html 需要用到的 引入 import { h } from "vue"; import {E…

为什么需要山洪灾害监测预警系统?

在山洪高发地区,安装山洪灾害监测预警系统能够通过实时监测,预警山洪信息,对于保障我们的生命财产安全具有重要意义。 监测山洪不仅需要对山体进行监测,还要监测降雨量以及水位上升情况。山洪灾害监测预警系统是由GNSS监测站和水…

天锐绿盾加密软件——企业数据透明加密、防泄露系统

天锐绿盾是一种企业级数据透明加密、防泄密系统,旨在保护企业的核心数据,防止数据泄露和恶意攻击。它采用内核级透明加密技术,可以在不影响员工正常工作的前提下,对需要保护的数据进行加密操作。 PC访问地址: https:/…

bootstrap_study

<meta http-equiv"X-UA-Compatible" content"IEedge"> <meta name"viewport" content"widthdevice-width, initial-scale1"> <!-- 新 Bootstrap 核心 CSS 文件 --> <link href"https://cdn.staticfile.org/…

选择工业交换机时,需要关注哪些方面的性能?

在工业自动化、能源、交通等领域的网络通信中&#xff0c;工业交换机是一种非常重要的网络设备。它的性能和可靠性直接影响到整个网络的稳定性和安全性。因此&#xff0c;在选择工业交换机时&#xff0c;我们需要关注以下几个方面的性能&#xff1a; 1. 抗干扰性能&#xff1a;…

树上形态改变统计贡献:1025T4

http://cplusoj.com/d/senior/p/SS231025D 答案为 ∑ w [ x ] − w [ s o n [ x ] ] \sum w[x]-w[son[x]] ∑w[x]−w[son[x]]&#xff0c; x x x 非儿子 要维护断边&#xff0c;LCT固然可以&#xff0c;但不一定需要 发现如果发生了变化&#xff0c;只会由重儿子变成次重儿子…

2023最新js数组常用方法大全

一、增删改方法 增删改查四大天王是数组中最常见也是最简单的方法&#xff0c;需要留意的是哪些方法会对原数组产生影响&#xff0c;哪些方法不会,查找方法较多&#xff0c;单独说明 下面前五种增删方法都对原数组产生影响 push()unshift()pop()shift()splice() push() pu…

HashMap 哈希碰撞、负载因子、插入方式、扩容倍数

HashMap 怎么解决的哈希碰撞问题&#xff1f; 主要采用了链地址法。具体来说&#xff1a; 每个哈希桶不仅存储一个键-值对&#xff0c;而是存储一个链表或树结构。这样&#xff0c;具有相同哈希值的键-值对可以被存储在同一个哈希桶中&#xff0c;并通过链表或树结构来解决碰…

如何将Linux上部署的5.7MySql数据库编码修改utf8(最新版)

如何将Linux&#xff08;服务器&#xff09;上部署的5.7MySql数据库编码修改utf8&#xff08;最新版&#xff09; 一、解决办法步骤1步骤2&#xff08;此处为问题描述吐槽&#xff0c;可以直接跳过该步骤到步骤三&#xff09;步骤3步骤4步骤5 二、结果 # 前言 提示&#xff1a…

【实战】Kubernetes安装持久化工具NFS-StorageClass

文章目录 前言技术积累存储类&#xff08;storage class&#xff09;什么是NFS什么是PV\PVC为什么要用NFS-StorageClass 安装NFS-StorageClass保证N8S集群正常投用安装NFS工具与客户端NFS安装常见错误安装NFS-StorageClass存储器 前言 前面的博文我们介绍了如何用kuberadmin的…

交流会|合同交付类业务的项目管理方法和实践分享

10月19日&#xff0c;由深圳市软件行业协会、易趋&#xff08;深圳蓝云软件&#xff09;、上海清晖、宁波银行深圳分行联合主办的第八期“项目管理技术与实践交流会议”在深圳成功举办。 本期沙龙邀请了易趋&#xff08;蓝云软件&#xff09;资深咨询顾问刘苗老师、协会特聘专…

iOS .a类型静态库使用终端进行拆解和合并生成

项目中会用到许多第三方的.a类型的静态库&#xff0c;有时候会有一些静态库回包含相同文件而产生冲突&#xff0c;我们就需要对这个库进行去重的一个操作。一般有哪些文件冲突了&#xff0c;xcode报错都会有详细的提示。我们可以将这两个库合并&#xff0c;也可以其中一方中的文…

推荐一款可以识别m3u8格式ts流批量下载并且合成mp4视频的chrome插件——猫抓

https://chrome.google.com/webstore/detail/%E7%8C%AB%E6%8A%93/jfedfbgedapdagkghmgibemcoggfppbb?utm_sourceext_app_menuhttps://chrome.google.com/webstore/detail/%E7%8C%AB%E6%8A%93/jfedfbgedapdagkghmgibemcoggfppbb?utm_sourceext_app_menu 网页媒体嗅探工具 一…

JS DataTable中导出PDF中文乱码

JS DataTable中导出PDF中文乱码 文章目录 JS DataTable中导出PDF中文乱码一. 问题二. 原因三. vfs_fonts.js四. pdfmake.js五. 解决六.参考资料 一. 问题 二. 原因 DataTable使用pdfmake&#xff0c;pdfmake默认字体为Roboto&#xff0c;不支持中文字体。添加自己的字体&#…

Linux 挂载磁盘到指定目录

问题&#xff1a;公司分配了数据磁盘&#xff0c;但是分区也没有挂载到目录 首先 df -h 查看一下挂载点的情况 查看服务器上未挂载的磁盘 fdisk -l 注&#xff1a;图中sda、sdb &#xff08;a、b指的是硬盘的序号&#xff09; 分区操作 我们可以看到b硬盘有536G未分区&…

LinkedList概念+MyLinkedList的实现

文章目录 LinkedList笔记一、 LinkedList1.概念2.LinkedList的构造方法3.LinkedList的遍历 二、MyLinkedList的实现1.定义内部类2.打印链表、求链表长度、判断是否包含关键字3. 头插法和尾插法4.在任意位置插入5.删除结点6.清空链表 LinkedList笔记 一、 LinkedList 1.概念 L…

VSCode 自动修改闭合标签

1.打开应用商店&#xff0c;搜索 auto rename tag &#xff0c;选择第一个&#xff0c;点击安装。 2.安装完毕后随便打开一个 HTML 文件&#xff0c;当我们修改起始标签时&#xff0c;闭合标签也会自动更改。 原创作者&#xff1a;吴小糖 创建时间&#xff1a;2023.10.25

ChineseChess2

中国象棋&#xff1a;黑将&#xff0c;红帅双炮&#xff0c;只要红帅中间露头怎么走怎么赢 卡主黑将的走位&#xff0c;控制住就好了 ChineseChess-CSDN博客

hdlbits系列verilog解答(向量级联)-18

文章目录 一、问题描述二、verilog源码三、仿真结果一、问题描述 级联运算符允许将向量连接在一起以形成更大的向量。但是有时您希望将同一个数据级联在一起很多次,而做类似 assign a = {b,b,b,b,b,b}; .复制运算符允许重复一个向量并将它们连接在一起:{num{vector}}。这将按…