Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别

1 model.train() 和 model.eval()用法和区别

1.1 model.train()

model.train()的作用是启用 Batch Normalization 和 Dropout

如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train()。model.train()是保证BN层能够用到每一批数据的均值和方差。对于Dropout,model.train()是随机取一部分网络连接来训练更新参数。

1.2 model.eval()

model.eval()的作用是不启用Batch Normalization 和 Dropout
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。

训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是model中含有BN层和Dropout所带来的的性质。

在做one classification的时候,训练集和测试集的样本分布是不一样的,尤其需要注意这一点。

1.3 分析原因

使用PyTorch进行训练和测试时一定注意要把实例化的model指定train/eval。model.eval()时,框架会自动把BN和Dropout固定住,不会取平均,而是用训练好的值,不然的话,一旦test的batch_size过小,很容易就会被BN层导致生成图片颜色失真极大!!!!!!

# 定义一个网络
class Net(nn.Module):def __init__(self, l1=120, l2=84):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, l1)self.fc2 = nn.Linear(l1, l2)self.fc3 = nn.Linear(l2, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 实例化这个网络Model = Net()# 训练模式使用.train()Model.train(mode=True)# 测试模型使用.eval()Model.eval()

为什么PyTorch会关注我们是训练还是评估模型?最大的原因是dropout和BN层(以dropout为例)。这项技术在训练中随机去除神经元。
在这里插入图片描述
想象一下,如果右边被删除的神经元(叉号)是唯一促成正确结果的神经元。一旦我们移除了被删除的神经元,它就迫使其他神经元训练和学习如何在没有被删除神经元的情况下保持准确。这种dropout提高了最终测试的性能,但它对训练期间的性能产生了负面影响,因为网络是不全的。

2.model.eval()和torch.no_grad()的区别

1.在PyTorch中进行validation/test时,会使用model.eval()切换到测试模式,在该模式下:

主要用于通知dropout层和BN层在train和validation/test模式间切换:
在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); BN层会继续计算数据的mean和var等参数并更新。
在eval模式下,dropout层会让所有的激活单元都通过,而BN层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2. 该模式不会影响各层的gradient计算行为,即gradient计算和存储与training模式一样,只是不进行反向传播(back probagation)。

而with torch.no_grad()则主要是用于停止autograd模块的工作,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。

如果不在意显存大小和计算时间的话,仅仅使用model.eval()已足够得到正确的validation/test的结果;而with torch.no_grad()则是更进一步加速和节省gpu空间(因为不用计算和存储梯度),从而可以更快计算,也可以跑更大的batch来测试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/117966.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaEE重点知识归纳】第11节:认识异常

目录 一:异常的概念和体系结构 1.概念 2.体系结构 3.异常分类 二:异常的处理 1.防御式编程 2.异常的抛出 3.异常的捕获 4.异常的处理流程 三:自定义异常 一:异常的概念和体系结构 1.概念 (1)在…

vue3使用Element ui plus中MessageBox消息框+radio框配合使用

想要达到的效果 首先安装element ui plus 省略~~ 官网地址: https://element-plus.gitee.io/zh-CN/component/message-box.htmlhttps://element-plus.gitee.io/zh-CN/component/message-box.html 需要用到的 引入 import { h } from "vue"; import {E…

为什么需要山洪灾害监测预警系统?

在山洪高发地区,安装山洪灾害监测预警系统能够通过实时监测,预警山洪信息,对于保障我们的生命财产安全具有重要意义。 监测山洪不仅需要对山体进行监测,还要监测降雨量以及水位上升情况。山洪灾害监测预警系统是由GNSS监测站和水…

天锐绿盾加密软件——企业数据透明加密、防泄露系统

天锐绿盾是一种企业级数据透明加密、防泄密系统,旨在保护企业的核心数据,防止数据泄露和恶意攻击。它采用内核级透明加密技术,可以在不影响员工正常工作的前提下,对需要保护的数据进行加密操作。 PC访问地址: https:/…

选择工业交换机时,需要关注哪些方面的性能?

在工业自动化、能源、交通等领域的网络通信中,工业交换机是一种非常重要的网络设备。它的性能和可靠性直接影响到整个网络的稳定性和安全性。因此,在选择工业交换机时,我们需要关注以下几个方面的性能: 1. 抗干扰性能:…

如何将Linux上部署的5.7MySql数据库编码修改utf8(最新版)

如何将Linux(服务器)上部署的5.7MySql数据库编码修改utf8(最新版) 一、解决办法步骤1步骤2(此处为问题描述吐槽,可以直接跳过该步骤到步骤三)步骤3步骤4步骤5 二、结果 # 前言 提示&#xff1a…

【实战】Kubernetes安装持久化工具NFS-StorageClass

文章目录 前言技术积累存储类(storage class)什么是NFS什么是PV\PVC为什么要用NFS-StorageClass 安装NFS-StorageClass保证N8S集群正常投用安装NFS工具与客户端NFS安装常见错误安装NFS-StorageClass存储器 前言 前面的博文我们介绍了如何用kuberadmin的…

交流会|合同交付类业务的项目管理方法和实践分享

10月19日,由深圳市软件行业协会、易趋(深圳蓝云软件)、上海清晖、宁波银行深圳分行联合主办的第八期“项目管理技术与实践交流会议”在深圳成功举办。 本期沙龙邀请了易趋(蓝云软件)资深咨询顾问刘苗老师、协会特聘专…

推荐一款可以识别m3u8格式ts流批量下载并且合成mp4视频的chrome插件——猫抓

https://chrome.google.com/webstore/detail/%E7%8C%AB%E6%8A%93/jfedfbgedapdagkghmgibemcoggfppbb?utm_sourceext_app_menuhttps://chrome.google.com/webstore/detail/%E7%8C%AB%E6%8A%93/jfedfbgedapdagkghmgibemcoggfppbb?utm_sourceext_app_menu 网页媒体嗅探工具 一…

JS DataTable中导出PDF中文乱码

JS DataTable中导出PDF中文乱码 文章目录 JS DataTable中导出PDF中文乱码一. 问题二. 原因三. vfs_fonts.js四. pdfmake.js五. 解决六.参考资料 一. 问题 二. 原因 DataTable使用pdfmake,pdfmake默认字体为Roboto,不支持中文字体。添加自己的字体&#…

Linux 挂载磁盘到指定目录

问题:公司分配了数据磁盘,但是分区也没有挂载到目录 首先 df -h 查看一下挂载点的情况 查看服务器上未挂载的磁盘 fdisk -l 注:图中sda、sdb (a、b指的是硬盘的序号) 分区操作 我们可以看到b硬盘有536G未分区&…

LinkedList概念+MyLinkedList的实现

文章目录 LinkedList笔记一、 LinkedList1.概念2.LinkedList的构造方法3.LinkedList的遍历 二、MyLinkedList的实现1.定义内部类2.打印链表、求链表长度、判断是否包含关键字3. 头插法和尾插法4.在任意位置插入5.删除结点6.清空链表 LinkedList笔记 一、 LinkedList 1.概念 L…

VSCode 自动修改闭合标签

1.打开应用商店,搜索 auto rename tag ,选择第一个,点击安装。 2.安装完毕后随便打开一个 HTML 文件,当我们修改起始标签时,闭合标签也会自动更改。 原创作者:吴小糖 创建时间:2023.10.25

ChineseChess2

中国象棋:黑将,红帅双炮,只要红帅中间露头怎么走怎么赢 卡主黑将的走位,控制住就好了 ChineseChess-CSDN博客

酒石酸盐晶体是优质葡萄酒的一个特征?

来自云仓酒庄品牌雷盛红酒分享澄清一下,葡萄酒中的酒石酸盐晶体,在德国被称为“温斯坦”,既无害也不是质量差的标志,相反,它们是富含矿物质的葡萄酒的特征。虽然酒石酸盐可以在年轻的葡萄酒中结晶,但它们最…

unity脚本_Mathf和Math c#

首先创建一个脚本 当我们要做一个值趋近于一个值变化时 可以用Mathf.Lerp(start,end,time);方法实现 比如物体跟随

通过IP地址可以做什么

通过IP地址可以做很多事情,因为它是互联网通信的基础之一。本文将探讨IP地址的定义、用途以及一些可能的应用。 IP地址的用途 1. 设备标识:IP地址用于标识互联网上的每个设备,这包括计算机、服务器、路由器、智能手机等。它类似于我们日常生…

折纸问题

折纸的次数 —— 从上到下的折痕 本质上是中序遍历的问题,因为每一次在已有的折痕后折的时候,当前折痕上的折痕一定为凹,当前折痕下的折痕一定为凸 。实际模拟了一个不存在的二叉树结构的中序遍历。 注:折纸折几次整颗二叉树就有…

超好用的数据可视化工具推荐,小白也适用!

Excel、Tableau……可以做数据可视化的工具不少,但简单、好用又高效,甚至连无SQL基础的小白也能轻松使用的就真没几个。奥威BI数据可视化工具是少有的操作难度低、成本支出低、灵活自助分析能力强的BI工具。 1、操作难度低 奥威BI数据可视化工具的操作…

C++(Qt)软件调试---线程死锁调试(15)

C(Qt)软件调试—线程死锁调试(15) 文章目录 C(Qt)软件调试---线程死锁调试(15)1、前言2、常见死锁3、linux下gdb调试C死锁1.1 使用代码1.2 gdb调试 3、linux下gdb调试Qt死锁1.1 使用代码1.2 gdb调试 4、Windows下gdb调试C死锁5、W…