pytorch model.train() 和model.eval() 对 BN 层的影响

model.train()

  • BN做归一化时,使用的均值和方差是当前这个Batch的
  • 如果这时 track_running_stats=True, 则会更新running_meanrunning_var
  • 但是,running_meanrunning_var不用在训练阶段

model.eval()

  • BN 做归一化时,使用的均值和方差是BN存储的running_meanrunning_var
  • 不管这时track_running_stats 是 True 还是 False, 都不会更新 running_meanrunning_var

感兴趣可以在以下测试代码下调整测试

'''
Author: Chae Luv
Date: 2022-08-17 22:40:13
LastEditors: Chae Luv
LastEditTime: 2022-08-17 23:15:22
FilePath: /re-record-audio-watermark/10-base_model/test_bn.py
Description: Copyright (c) 2022 by Chae Luv/USTC, All Rights Reserved. 
'''
import torch
import torch.nn as nndef create_inputs():return torch.randn(8, 3, 20, 20)def simulated_bn_forward(x, bn_weight, bn_bias, eps, mean_val=None, var_val=None):if mean_val is None:mean_val = x.mean([0, 2, 3])if var_val is None:var_val = x.var([0, 2, 3], unbiased=False)x = x - mean_val[None, ..., None, None]x = x / torch.sqrt(var_val[None, ..., None, None] + eps)x = x * bn_weight[..., None, None] + bn_bias[..., None, None]return mean_val, var_val, xpytorch_bn = nn.BatchNorm2d(num_features=3, momentum=None)
running_mean = torch.zeros(3)
running_var = torch.ones_like(running_mean)# 切换到eval模式
pytorch_bn.train(mode=False)
test_input = create_inputs()
print(f'pytorch_bn running_mean is {pytorch_bn.running_mean}')
print(f'pytorch_bn running_var is {pytorch_bn.running_var}')
bn_outputs = pytorch_bn(test_input)
print(f'Now pytorch_bn running_mean is {pytorch_bn.running_mean}')
print(f'Now pytorch_bn running_var is {pytorch_bn.running_var}')
# 用之前统计的running_mean和running_var替代输入的running_mean和running_var
_, _, simulated_outputs = simulated_bn_forward(test_input, pytorch_bn.weight,pytorch_bn.bias, pytorch_bn.eps,running_mean, running_var)
assert torch.allclose(simulated_outputs, bn_outputs)# 关闭track_running_stats后,即使在eval模式下,也会去计算输入的mean和var
pytorch_bn.train(mode=True)
pytorch_bn.track_running_stats = False
bn_outputs_notrack = pytorch_bn(test_input)
_, _, simulated_outputs_notrack = simulated_bn_forward(test_input, pytorch_bn.weight,pytorch_bn.bias, pytorch_bn.eps)print(torch.sum(simulated_outputs_notrack - bn_outputs_notrack))
assert torch.allclose(simulated_outputs_notrack, bn_outputs_notrack)
assert not torch.allclose(bn_outputs, bn_outputs_notrack)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/507899.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

联想用u盘重装系统步骤_详解联想如何使用u盘重装win10系统

联想是国内知名的品牌之一,很多朋友都购买了联想品牌的电脑,但是在使用的过程中难免会出现些磕磕碰碰的问题。所以今天小编就大家详细的介绍一下联想电脑使用u盘重装win10系统的方法。联系怎么使用u盘重装win10系统呢?最近有不少朋友在询问这…

笔记本电脑键盘切换_真想本小新13pro搭档,笔记本电脑周边好物清单推荐

原标题:真想本小新13pro搭档,笔记本电脑周边好物清单推荐真想本小新13pro搭档,笔记本电脑周边好物清单推荐 2020-10-24 15:21:493点赞4收藏2评论9月28日 - 11月12日,参与#双11购物攻略#征稿活动,赢取苹果全家桶8888元超…

pytorch 训练模型很慢,卡在数据读取,卡I/O的有效解决方案

多线程加载 在 datalaoder中指定num_works > 0,多线程加载数据集,最大可设置为 cpu 核数设置 pin_memory True, 固定内存访问单元,节约内存调度时间示例如下: loader DataLoader(dataset,batch_sizebatch_size * group_size,shuffleTr…

python达梦数据库_python 操作达 梦数据库

python 达梦数据库操作流程连接数据库 dm.connect( ... )获取游标 dm_conn.cursor()编写SQL语句 sql_str执行SQL语句 dm_cursor.execute()获取结果列表 dt_breakpoint dm_cursor.fetchall()关闭游标 dm_cursor.close()关闭数据库连接 dm_conn.close()代码示例import pandas as…

C++求复数的角度_11.初中数学:方程5x2m=4x的解,在2与10之间,怎么求m的取值范围?...

欢迎您来到方老师数学课堂,请点击上方蓝色字体,关注方老师数学课堂。所有的视频内容,全部免费,请大家放心关注,放心订阅。初中数学:方程5x-2m-4-x的解,在2与10之间,怎么求m的取值范围…

python3 beautifulsoup 模块详解_关于beautifulsoup模块的详细介绍

这篇文章主要给大家介绍了python中 Beautiful Soup 模块的搜索方法函数。 方法不同类型的过滤参数能够进行不同的过滤,得到想要的结果。文中介绍的非常详细,对大家具有一定的参考价值,需要的朋友们下面来一起看看吧。前言我们将利用 Beautifu…

python解zuobiaoxi方程_欧式期权定价的python实现

0. pre 在《给你的二叉树期权定价》中就挖了坑要写期权定价的代码,这会有时间来填坑啦。本文将会用python实现欧式期权定价。具体的定价算法分别是基于BS公式的、蒙特卡洛的以及二叉树的。对于二叉树和BS公式还不熟悉的小伙伴可以移步至往期关于二叉树期权定价和BS公…

去除标签_有效去除“狗皮膏药”标签,快学起来吧

去除商品标签向来是比较头疼一件事,有时候在去掉标签后会留下粘性残留物,它会粘上灰尘和其他脏东西,把表面变成脏兮兮的颜色,让人看着太不舒服了。其实去除标签残留粘胶并不难,可能家里就有去除它的工具哦~那今天小编就…

win10很多软件显示模糊_还在使用第三方软件?Win10可以直接显示显卡温度啦

微软刚刚开始向参与快速通道测试的用户推送Windows 10 20H1 Build 18963 版带来部分新功能和优化等。这个版本也是常规优化版本因此带来的新功能很少,但这次更新为任务管理器带来原生的显示显卡温度功能。用户打开任务管理器点击性能选项卡然后找到「独立显卡」即可…

分数怎么化成带分数_小升初数学总复习第三个基础模块:分数的认识

今天我们开始小升初数学总复习第三个基础模块的复习:分数的认识分数的认识一共分为8个知识考点。第一,分数的意义把单位“1”.平均分成若干份,表示这样的一份或者几份的数叫做分数。表示其中一份的数叫做分数单位。第二&#xff0…

active mq topic消费后删除_《我想进大厂》之MQ夺命连环11问

继之前的mysql夺命连环之后,我发现我这个标题被好多套用的,什么夺命zookeeper,夺命多线程一大堆,这一次,开始面试题系列MQ专题,消息队列作为日常常见的使用中间件,面试也是必问的点之一&#xf…

嘀嗒还是滴答_2021年顺风车车主口碑榜!滴滴、滴答、一喂顺风车成TOP3

出行平台烧钱抢用户抢司机,大家都见怪不怪了,只是近期平台为自身利益而牺牲司机的例子层出不穷,在司机刚进入平台补贴多流水多,没多久司机收入都不够交车租的,司机踩坑,全家受罪,很多司机表示自…

wampserver橙色如何变成绿色_实验室如何自建数据库和网站主页

本文首发于微信公众号:火行(ID:firegotech)实验室如何自建数据库和网站主页作者:沐倾(火行科研Club创始成员)编辑:火花(声明:本文适用于非计算机专业领域人士&#xff09…

mysql ( )连接_MySQL中concat函数(连接字符串)

MySQL中concat函数使用方法:CONCAT(str1,str2,…)返回结果为连接参数产生的字符串。如有任何一个参数为NULL ,则返回值为 NULL。注意:如果所有参数均为非二进制字符串,则结果为非二进制字符串。如果自变量中含有任一二进制字符串&…

远程连接电脑_Python黑科技:在家远程遥控公司电脑,python+微信一键连接!

有时候需要远程家里的台式机使用,因为我平时都是用 MAC 多,但是远程唤醒只能针对局域网,比较麻烦,于是我想用微信实现远程唤醒机器。准备工作本程序主要是实现远程管理 Windows10操作系统的开机和关机:在 Windows机器的…

mysql count 优化索引_如何通过使用索引在InnoDB上优化COUNT(*)性能

我有一个小而狭窄的InnoDB表,大约有900万条记录。在桌子上count(*)或count(id)桌子上做的速度非常慢(超过6秒):DROP TABLE IF EXISTS perf2;CREATE TABLE perf2 (id int(11) NOT NULL AUTO_INCREMENT,channel_id int(11) DEFAULT NULL,timestamp bigint(…

ppt生成器_9款魔性#傻瓜生成器#,上班可以划水一天

有些 #傻瓜生成器#,表面上叫傻瓜,实际上一玩就停不下来。不分享出来,未免不近人情。毕竟,大伙上班累了,也需要一些奇怪的东西划划水不是?1 沙雕DIY跳舞生成器这款傻瓜生成器,真不能叫傻瓜&#…

vue组件prop变量和内部变量数据格式不一样时,变量同步prop值,变量改变通知父组件.

vue组件含有v-model的props,当对其进行封装,想对该属性进行双向绑定时,可以采用computed的方式包一层get(){return props.xxx},set(v)>{emit(update:xxx,v)},或者使用vueuse的useModel来深层代理,但是只适合要封装的组件prop的内部的变量数据类型一致,不一致就只能拆开写,通…