内部协变量偏移问题(有无BN的代码比较)

1.什么是内部协变量偏移问题:

比如1000条数据,batch_size=4,相当于要练250批次,当第一次批次的4条数据进行模型的训练时,此时网络学习动态已经养成,当第二批次进行训练时,极大可能导致差异较大,即参数变化很大,那么下一层的输入就会收到很大的影响,导致整个网络的学习动态发生改变。

这样结果主要归结于前向传播中的变化的累积,每一层的输出都是下一层的输入。如果上一层的参数在训练中发生较大的变化(特别是在训练初期,毕竟样本太少,很难得到一个方差较小、大家都认可的方案),这将直接影响到下一层接收的输入分布。如果每一层都在接收到与前一次迭代时【前一次batch_size】不同分布的输入,它们就需要不断调整自己来适应这种变化,这会使得网络的训练过程变得复杂且低效。

另外,输入分布不断变化这将导致每一层都需要学习不同的学习速率。这使得设置一个全局学习率变得非常困难。

举个例子:

你在驾车时,道路和交通规则每隔几分钟就会改变。即使你已经适应了当前的驾驶条件,突然的变化也会迫使你不断重新学习如何驾驶,这显然会降低你的驾驶效率和安全性。同样地,在神经网络中,如果每层的输入规则(即数据分布)持续变化网络层就需要不断调整反应,这降低了学习的效率。

2.解决方案:Batch Normalization

批归一化(Batch Normalization)通过在每一层后规范化输入【均值为0,方差为1】,使得输入分布保持相对稳定【总的来说就是规划输入分布,即调整前一层的输出,允许更高的学习效率】,从而缓解了内部协变量偏移的问题。这意味着网络的每一层都可以预期到它将会接收到具有相似分布的输入,从而使训练过程更稳定,加快收敛速度,使得网络可以使用更高的学习率,而不会那么容易发生训练发散的问题。

举个例子: 在原始的模型中增加BN层,并训练该模型;接着,我们将训练一个不使用BN的模型,并对比两者的训练误差和测试准确率。

class ModelBN(nn.Module):def __init__(self):super(ModelBN, self).__init__()# 定义卷积层,添加BN层self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),  # 添加批归一化层nn.ReLU(),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),  # 添加批归一化层nn.ReLU(),nn.MaxPool2d(stride=2, kernel_size=2))# 定义全连接层,添加BN层self.dense = nn.Sequential(nn.Linear(14 * 14 * 128, 1024),nn.BatchNorm1d(1024),  # 添加批归一化层nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(1024, 10))def forward(self, x):x = self.conv1(x)x = x.view(-1, 14 * 14 * 128)x = self.dense(x)return x# 实例化带BN的模型,并移至GPU
model_bn = ModelBN().to(device)

训练和测试:

def train_and_test(model, optimizer, epochs=3):cost = nn.CrossEntropyLoss()for epoch in range(epochs):model.train()  # 设置模型为训练模式running_loss = 0.0for data in data_loader_train:inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = cost(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()model.eval()  # 设置模型为评估模式correct = 0total = 0with torch.no_grad():for data in data_loader_test:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(data_loader_train)}, 'f'Accuracy: {100 * correct / total}%')# 定义优化器
optimizer_bn = optim.Adam(model_bn.parameters())# 训练和测试带BN的模型
print("Training with Batch Normalization:")
train_and_test(model_bn, optimizer_bn)# 训练和测试不带BN的模型
print("Training without Batch Normalization:")
optimizer = optim.Adam(model.parameters())  # 不带BN的模型使用同样的优化器设置
train_and_test(model, optimizer)

3.完整代码:

# 导入必要的库
import torch
import torch.nn as nn
from torchvision import datasets, transforms, utils
import torch.optim as optim
import torchvision
import numpy as np
import matplotlib.pyplot as plt# 检查并设置设备,优先使用GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('设备状态:', device)# 定义数据转换步骤
transform = transforms.Compose([transforms.ToTensor(),  # 将图像转换为Tensortransforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # 将单通道图像复制到三通道transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))  # 标准化图像数据
])# 下载并加载MNIST数据集
data_train = datasets.MNIST(root='./data/', transform=transform, train=True, download=True)
data_test = datasets.MNIST(root='./data/', transform=transform, train=False)# 定义数据加载器
data_loader_train = torch.utils.data.DataLoader(dataset=data_train, batch_size=64, shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test, batch_size=64, shuffle=True)# 预览数据
images, labels = next(iter(data_loader_train))
img = utils.make_grid(images)  # 组合图像以便可视化
img = img.numpy().transpose(1, 2, 0)  # 调整图像维度以适配matplotlib
img = img * np.array([0.5, 0.5, 0.5]) + np.array([0.5, 0.5, 0.5])  # 反标准化显示图像
print([labels[i] for i in range(64)])  # 打印标签检查
plt.imshow(img)  # 显示图像
plt.show()  # 确保图像显示# 定义卷积神经网络模型
class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.ReLU(),nn.MaxPool2d(stride=2, kernel_size=2))self.dense = nn.Sequential(nn.Linear(14 * 14 * 128, 1024),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(1024, 10))def forward(self, x):x = self.conv1(x)x = x.view(-1, 14 * 14 * 128)x = self.dense(x)return x# 定义含批归一化的卷积神经网络模型
class ModelBN(nn.Module):def __init__(self):super(ModelBN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(),nn.MaxPool2d(stride=2, kernel_size=2))self.dense = nn.Sequential(nn.Linear(14 * 14 * 128, 1024),nn.BatchNorm1d(1024),nn.ReLU(),nn.Dropout(p=0.5),nn.Linear(1024, 10))def forward(self, x):x = self.conv1(x)x = x.view(-1, 14 * 14 * 128)x = self.dense(x)return x# 实例化并设置模型至GPU
model_bn = ModelBN().to(device)
model = Model().to(device)# 设置优化器
optimizer_bn = optim.Adam(model_bn.parameters())
optimizer = optim.Adam(model.parameters())# 训练和测试函数,记录损失和
print(model)  # 打印模型结构'''
7.训练模型
'''def train_and_test(model, optimizer, n_epochs=3):cost = nn.CrossEntropyLoss()losses = []accuracies = []for epoch in range(n_epochs):model.train()total_loss = 0correct = 0total = 0for data in data_loader_train:inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = cost(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()avg_loss = total_loss / len(data_loader_train)accuracy = 100 * correct / totallosses.append(avg_loss)accuracies.append(accuracy)print(f'Epoch {epoch + 1}/{n_epochs}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%')return losses, accuracies# 使用带BN和不带BN的模型进行训练
optimizer_bn = optim.Adam(model_bn.parameters())
losses_bn, acc_bn = train_and_test(model_bn, optimizer_bn)model_without_bn = Model().to(device)
optimizer_nobn = optim.Adam(model_without_bn.parameters())
losses_nobn, acc_nobn = train_and_test(model_without_bn, optimizer_nobn)plt.figure(figsize=(10, 5))
plt.plot(losses_bn, label='With BatchNorm')
plt.plot(losses_nobn, label='Without BatchNorm')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

BN有无的对比

可能原因:

短期增加,长期收益: BN有一定的正则化效果,这可能在短期内增加训练损失,但长期看有助于提高模型的泛化能力

批大小效应: BN通过对每个批量的数据进行归一化,依赖于批内数据的统计特性。如果批量大小不足以提供稳定的统计估计,或者批数据本身的变异性较大,初期的BN表现可能不够稳定。【小的干大的】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/21681.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

多模态融合目标检测新SOTA!推理速度提升2.7倍,实现最先进性能

为解决传统目标检测在复杂环境下效果不佳等问题,研究者们提出了多模态融合目标检测。 通过整合来自多个传感器的数据,充分利用不同传感器的优点,多模态融合目标检测能够更全面地捕捉目标信息,显著提高检测的准确性和鲁棒性&#…

弘君资本策略:短期博弈情绪边际降温 关注这四条线索

弘君资本指出,随着商场进入地产政策调查期,短期博弈情绪边沿降温,注重景气边沿改善和工业政策指向的结构性头绪。一是受供应侧节能降碳影响且可继续的提价链;二是获益于全球制造业向上的出口制造链;三是具有全球竞争力…

隐藏饼图的legend,重写legend列表。

因为要实现的饼图效果较复杂,所以,需要重新写列表。 点击右侧列表的圆点,实现隐藏左侧饼图相应环状。 // 饼图,点击自定义列表,显示和隐藏饼图对应的环状数据<template> <div class="index_div"> <a-spin :spinning="aLoading">&l…

Unity开发——编辑器打包、3种方式加载AssetBundle资源

一、创建ab资源 &#xff08;一&#xff09;Unity资源设置ab格式 1、选中要打包成assetbundle的资源&#xff1b; 可以是图片&#xff0c;材质球&#xff0c;预制体等&#xff0c;这里方便展示用预制体打包设置展示&#xff1b; 2、AssetBundle面板说明 &#xff08;1&…

【YOLOv5进阶】——模型结构与模型原理YOLOv5源码解析

一、基础知识 1、backbone backbone是核心组成部分&#xff0c;主要负责提取图像特征。具体来说&#xff0c;backbone通过一系列的卷积层和池化层对输入图像进行处理&#xff0c;逐渐降低特征图的尺寸同时增加通道数&#xff0c;从而保留和提取图像中重要的特征。这些提取出的…

Unity3D获得服务器时间/网络时间/后端时间/ServerTime,适合单机游戏使用

说明 一些游戏开发者在做单机游戏功能时&#xff08;例如&#xff1a;每日奖励、签到等&#xff09;&#xff0c;可能会需要获得服务端标准时间&#xff0c;用于游戏功能的逻辑处理。 问题分析 1、自己如果有服务器&#xff1a;自定义一个后端API&#xff0c;客户端按需请求…

使用Obfuscar 混淆WPF(Net6)程序

Obfuscar 是.Net 程序集的基本混淆器&#xff0c;它使用大量的重载将.Net程序集中的元数据&#xff08;方法&#xff0c;属性、事件、字段、类型和命名空间的名称&#xff09;重命名为最小集。详细使用方式参见&#xff1a;Obfuscar 在NetFramework框架进行的WPF程序的混淆比较…

Spring @Transactional 事务注解

一、spring 事务注解 1、实现层(方法上加) import org.springframework.transaction.annotation.Transactional;Transactional(rollbackFor Exception.class)public JsonResult getRtransactional() {// 手动标记事务回滚TransactionAspectSupport.currentTransactionStatus…

抖店入驻门槛,一降再降,2024年商家入驻抖店最佳的时机来了!

大家好&#xff0c;我是电商糖果 抖店已经发展有四年多的时间了&#xff0c;现在也算是比较成熟的电商平台. 这几年因为直播带货的火爆&#xff0c;再加上抖音的流量支撑&#xff0c;还有抖音在背后的扶持和推广。 让抖店成了电商行业的黑马项目&#xff0c;吸引了不少商家入…

ACWC:Worst-Case to Average-Case Decryption Error

参考文献&#xff1a; [LS19] Lyubashevsky V, Seiler G. NTTRU: Truly Fast NTRU Using NTT[J]. IACR Transactions on Cryptographic Hardware and Embedded Systems, 2019: 180-201.[DHK23] Duman J, Hvelmanns K, Kiltz E, et al. A thorough treatment of highly-efficie…

[element-ui]el-form自定义校验-图片上传验证(手动触发部分验证方法)

背景&#xff1a; 在做导入文件功能的时候&#xff0c;需要校验表单&#xff0c;如图所示 店铺字段绑定在表单数据对象上&#xff0c;在点击确定的时候正常按照表单验证规则去校验&#xff0c;就不再赘述。 文件上传是个异步过程&#xff0c;属性值改变后不会去触发验证规则…

智能管理,无忧报修——高校校园报事报修系统小程序全解析

随着数字化、智能化的发展&#xff0c;高校生活也迎来了前所未有的变革。你是否还在为宿舍的水龙头漏水、图书馆的灯光闪烁而烦恼&#xff1f;你是否还在为报修流程繁琐、等待时间长而焦虑&#xff1f;今天&#xff0c;这一切都将成为过去式&#xff01;因为一款震撼高校圈的新…

【QT5】<总览一> QT环境搭建、快捷键及编程规范

文章目录 前言 一、简单介绍QT 二、安装QT Creator 三、第一个QT项目 四、常用快捷键 五、QT中的编程规范 前言 在嵌入式Linux应用层开发时&#xff0c;经常使用QT作为图形化界面显示工具。为学习Linux下的QT编程&#xff0c;在Ubuntu和开发板中搭建QT开发环境&#xff…

TMS320F280049 ECAP模块--应用(2)

例1-上升沿触发 如下图所示&#xff0c;evt1-4设置为上升沿触发&#xff0c;在每个上升沿ctr值依次加载到cap1-4. 例2-上升下降沿触发 每个边沿都可选为事件&#xff0c;每次事件到来&#xff0c;依次把ctr加载到cap1-4。 例3-差异模式下上升沿触发 差异模式下每次事件到来时…

Qt_C++ RFID网络读卡器Socket Udp通讯示例源码

本示例使用的设备&#xff1a; WIFI/TCP/UDP/HTTP协议RFID液显网络读卡器可二次开发语音播报POE-淘宝网 (taobao.com) #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow> #include <QHostInfo> #include <QNetworkInterface> #include <…

PyQt5串口测试工具

笔者经常会遇到使用上位机进行相关测试的场景&#xff0c;但现成的上位机并不能完全满足自己的需求&#xff0c;或是上位机缺乏使用说明。所以&#xff0c;自己写&#xff1f; 环境说明 pycharm 2023.2.25 python 3.10 anaconda 环境配置 conda create -n envsram ##…

学生信息管理系统C++

设计目的 使学生进一步理解和掌握课堂上所学的面向对象C编程知识&#xff0c;巩固和加深学生对C面向对象课程的基本知识的理解和掌握。掌握C面向对象编程和程序调试的基本技能&#xff0c;学会利用C语言进行基本的软件设计&#xff0c;着重提高运用C面向对象语言解决实际问题的…

Go Modules 使用

文章参考https://blog.csdn.net/wohu1104/article/details/110505489 不使用Go Modules&#xff0c;所有的依赖包都是存放在 GOPATH /pkg下&#xff0c;没有版本控制。如果 package 没有做到完全的向前兼容&#xff0c;会导致多个项目无法运行(包版本需求不同)。 于是推出了g…

秋招突击——第四弹——Java的SSN框架快速入门——Spring(2)

文章目录 前言其他Spring加载properties 容器创建容器获取beanBeanFactory容器总结 注解注解开发对定义bean纯注解开发Bean管理Bean作用范围Bean生命周期 注解开发依赖注入第三方bean管理第三方bean管理第三方bean注入 注解开发总结 Spring整合整合mybatis整合Junit AOPAOP核心…

【C、C++编译工具】CLion工具介绍与安装

一、问题 最近突发奇想想学学最开始接触的语言C&#xff0c;之前大学的时候用的更多的工具还是VC&#xff0c;工作后慢慢接触了CLion&#xff0c;跟pycharm其实差不多&#xff0c;都是集成开发环境&#xff08;IDE&#xff09; 解释&#xff1a;什么是 IDE&#xff1f; 根据计…