【深度学习实战(33)】训练之model.train()和model.eval()

一、model.train(),model.eval()作用?

model.train() 和 model.eval() 是 PyTorch 中的两个方法,用于设置模型的训练模式和评估模式。

model.train() 方法将模型设置为训练模式。在训练模式下,模型会启用 dropout 和 batch normalization 等正则化方法,并且可以计算梯度以进行参数更新,同时还可以追踪梯度计算的图。训练时,均值、方差分别是该批次内数据相应维度的均值与方差

model.eval() 方法将模型设置为评估模式。在评估模式下,模型会禁用 dropout 和 batch normalization 等正则化方法,这样可以保证每次评估的结果是确定的。评估模式下的模型通常用于模型的测试、验证或推理阶段。推理时,均值、方差是基于所有批次的期望计算所得

区分训练模式和评估模式的目的在于保证模型在不同阶段的行为一致性。例如,在训练模式下,模型需要计算并追踪梯度以进行反向传播和参数更新;而在评估模式下,模型不需要计算梯度,只需要给出确定的预测结果。

二、model.train(),model.eval()对dropout产生的影响

(1)使用model.train():有神经元被置零,且比例符合nn.Dropout(0.5)中的0.5设定

import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
model.train()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

(2)使用model.eval():没有神经元置零,nn.Dropout(0.5)被关闭

import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
#model.train()
model.eval()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

(3)不使用model.train()和model.eval():有神经元被置零,但是比例非常随机,不符合nn.Dropout(0.5)中的0.5设定

import torch
import torch.nn as nnmodel = nn.Dropout(0.5)
#model.train()
#model.eval()
input = torch.rand([3, 4])print("before dropout:",input)
output = model(input)
print("after dropout in train mode:",output)

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

三、model.train(),model.eval()对batch normalization产生的影响

(1)使用model.eval():bn中的均值,方差,不发生改变

# 1.导入所需的库:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# 2.定义数据集的转换方法。MNIST数据集是由28x28像素的手写数字组成的图像,将其转换为torch张量并进行标准化处理:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 3.下载MNIST数据集并进行转换:
trainset = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False,download=True, transform=transform)# 4.创建数据加载器:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=0)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,shuffle=False, num_workers=0)# 5.现在你可以使用trainloader和testloader来获取训练集和测试集的批次数据了。例如,可以使用迭代器遍历数据集中的批次:
#dataiter = iter(trainloader)
#images, labels = dataiter.next()# 上述代码将返回一个批次的图像和对应的标签。可以使用images和labels来进行模型的训练和评估。
# 这就是使用torch库自带的MNIST数据集的基本流程。根据需要,你还可以添加其他的数据处理和增强步骤。# 定义模型
class Model(nn.Module):def __init__(self, hidden_num=32, out_num=10):super().__init__()self.fc1 = nn.Linear(28*28, hidden_num)self.bn  = nn.BatchNorm1d(hidden_num)self.fc2 = nn.Linear(hidden_num, out_num)self.softmax = nn.Softmax()def forward(self, inputs, **kwargs):x = inputs.flatten(1)x = self.fc1(x)print("========= bn之前存的数据: =========")print(self.bn.running_mean, self.bn.running_var)print()print("========= 当前 Batch 的数据: =========")x_mean = torch.mean(x,0)x_variance = torch.mean((x - x_mean)*(x - x_mean),0)print(x_mean, x_variance)print()print("========= torch官方计算之后的bn新数据: =========")x = self.bn(x)print(self.bn.running_mean, self.bn.running_var)print()# x = self.dropout(x)x = self.fc2(x)x = self.softmax(x)return xtorch.manual_seed(1)
model = Model()
#model.train()
model.eval()
for img, label in trainloader:label = nn.functional.one_hot(label.flatten(), 10)out = model(img)break

在这里插入图片描述

(2)使用model.train():bn中的均值,方差,通过滑动平均地方式发生改变,

torch.manual_seed(1)
model = Model()
model.train()
#model.eval()
for img, label in trainloader:label = nn.functional.one_hot(label.flatten(), 10)out = model(img)break

在这里插入图片描述

(3)不使用model.train()和model.eval():默认bn中的均值,方差,通过滑动平均地方式发生改变,

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/11558.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytorch常用内置loss函数与正则化技术(补充小细节)

文章目录 前言一、常用损失函数(后面用到了新的会一一补充)1.1 回归中的损失函数1.1.1 nn.MSELoss()示例1:向量-向量示例2:矩阵--矩阵(维度必须一致才行)1.2 分类中的损失函数1.2.1 二分类(1)nn.BCELoss --- 二分类交叉熵损失函数示例1:向量-向量示例2:矩阵--矩阵(维…

基于SSM的“基于协同过滤的在线通用旅游平台网站”的设计与实现(源码+数据库+文档)

基于SSM的“基于协同过滤的在线通用旅游平台网站”的设计与实现(源码数据库文档) 开发语言:Java 数据库:MySQL 技术:SSM 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 系统主界面 景点信息界面 后台界面 部分源码…

【每日刷题】Day39

【每日刷题】Day39 🥕个人主页:开敲🍉 🔥所属专栏:每日刷题🍍 🌼文章目录🌼 1. 622. 设计循环队列 - 力扣(LeetCode) 2. 387. 字符串中的第一个唯一字符 - …

百度云防护502 Bad Gateway原因总结和处理方法

最近,随着原百度云加速用户新接入百度云防护后,很多站长反馈网站打不开,出现了502 Bad Gateway的情况。 为此,百度云这里给大家总结下,出现502的大概几个原因: 1.服务器防火墙拦截了百度云防护的IP节点请求…

vivado Kintex UltraScale+ 配置存储器器件

Kintex UltraScale 配置存储器器件 下表所示闪存器件支持通过 Vivado 软件对 Kintex UltraScale 器件执行擦除、空白检查、编程和验证等配置操作。 本附录中的表格所列赛灵思系列非易失性存储器将不断保持更新 , 并支持通过 Vivado 软件对其中所列非易失性存…

【VUE.js】前端框架——未完成

基于脚手架创建前端工程 环境 当安装node.js时,它本身就携带有npm命令。(-v 查版本号)安装VUE CLI npm i vue/cli -g(全局) 创建 vue create 【project name】 镜像源解决方案 输入创建命令后,提示检查更…

【JAVA】JAVA的垃圾回收机制详解

对于Java的垃圾回收机制,它是Java虚拟机(JVM)提供的一种自动内存管理机制,主要负责回收不再使用的对象以释放内存空间。垃圾回收机制主要包括以下几个方面的内容: 垃圾对象的识别:Java虚拟机通过一些算法&…

C++学习笔记3

A. 求出那个数 题目描述 喵喵是一个爱睡懒觉的姑娘,所以每天早上喵喵的妈妈都花费很大的力气才能把喵喵叫起来去上学。 在放学的路上,喵喵看到有一家店在打折卖闹钟,她就准备买个闹钟回家叫自己早晨起床,以便不让妈妈这么的辛苦…

Caddy2使用阿里云DNS申请https证书,利用阿里云DNS境内外不同解析给Gone文档做一个同域名的国内镜像站点

我从头到尾实现了一个Golang的依赖注入框架,并且集成了gin、xorm、redis、cron、消息中间件等功能;自己觉得还挺好用的,并且打算长期维护! github地址:https://github.com/gone-io/gone 文档原地址:https:/…

2024CCPC郑州站超详细题解(含题面)ABFHJLM(河南全国邀请赛)

文章目录 前言A Once In My LifeB 扫雷 1F 优秀字符串H 随机栈J 排列与合数L Toxel 与 PCPC IIM 有效算法 前言 这是大一博主第一次参加xcpc比赛,虽然只取得了铜牌,但是收获满满,在了解了和别人的差距后会更加激励自己去学习,下面…

Python从0到POC编写--函数

数学函数: 1. len len() 函数返回对象(字符、列表、元组等)长度或项目个数, 例如: str "python" len(str)2. range range() 函数返回的是一个可迭代对象(类型是对象),…

并行执行的4种类别——《OceanBase 并行执行》系列 4

OceanBase 支持多种类型语句的并行执行。在本篇博客中,我们将根据并行执行的不同类别,分别详细阐述:并行查询、并行数据操作语言(DML)、并行数据定义语言(DDL)以及并行 LOAD DATA 。 《并行执行…

vm虚拟机扩容centos磁盘内存

1.查看虚拟机扩展前磁盘内存 df -h 2.关机情况下扩展磁盘内存 3.对扩容的磁盘分区 fdisk /dev/sda 输入n新增分区,回车,选择p,回车 为分区设置分区格式,在Fdisk命令处输入:t 分区号用默认 3(或回车&…

OSS证书自动续签,一分钟轻松搞定,解决阿里云SSL免费证书每3个月失效问题

文章目录 一、🔥httpsok-v1.11.0支持OSS证书自动部署介绍支持特点 二、废话不多说上教程:1、场景2、实战Stage 1:ssh登录阿里云 ECSStage 2:进入nginx (docker)容器Stage 3:执行如下指令Stage 3…

测试环境搭建整套大数据系统(十六:超级大文件处理遇到的问题)

一:yarn出现损坏的nodemanger 报错现象 日志:1/1 local-dirs usable space is below configured utilization percentage/no more usable space [ /opt/hadoop-3.2.4/data/nm-local-dir : used space above threshold of 90.0% ] ; 1/1 log-dirs usabl…

01-02-2

1、typedef的使用 a.语法 typedef 原名 别名;。 ​ typedef struct student {int num;char name[20];char sex; }stu,*pstu;//stu相当于struct student这个类型,*pstu相当于struct student * 别名的理解方法:若是字母前面有符号&#xff0…

SOUI4里使用字体回退

在新版本的SOUI里render-skia导出了一个新的函数用于字体回退功能。Render_Skia_SetFontFallback 函数原型如下: EXTERN_C void SOUI_COM_API Render_Skia_SetFontFallback(FontFallback fontFallback);因为我的工程是使用动态库,这里可以直接获取到这…

如何用微信小程序实现远程控制4路控制器/断路器

如何用微信小程序实现远程控制4路控制器/断路器呢? 本文描述了使用微信小程序调用HTTP接口,实现控制4路控制器/断路器,支持4路输出,均可独立控制,可接入各种电器。 可选用产品:可根据实际场景需求&#xf…

内容与图像一对多问题解决

场景复现 分析: 其实这是两给表,一个内容表,一个图片表,一对多的关系。 解决思路: 1. 先上传图片拿到图片的List集合ids,返回值是集合的ids,给到前端 2. 再添加内容表的数据生成了id,遍历查…

佳博打印机如何设置打印模式为热敏模式

1、打开电脑搜索框,如下图输入打印机: 2、点击打印机设置,如下图: 3、点击打印机首选项,如下图: 4、点击下图“卷”进行设置 也可对打印机间距高度进行调整