深度学习——批标准化Batch Normalization

什么是批标准化?

批标准化(Batch Normalization)是深度学习中常用的一种技术,旨在加速神经网络的训练过程并提高模型的收敛速度。
批标准化通过在神经网络的每一层中对输入数据进行标准化来实现。具体而言,对于每个输入样本,在每一层的前向传播过程中,都会计算其均值和方差,并使用批量内的均值和方差对输入进行标准化。标准化后的数据会经过缩放和平移操作,使得网络可以学习到适合当前任务的特定数据分布。这样做的好处包括:
1.收敛速度更快:批标准化有助于避免梯度消失和梯度爆炸问题,使得神经网络在训练过程中更快地收敛。
2.允许更高的学习率:标准化输入可以使学习率的选择更加宽松,使得学习过程更加稳定。
3.正则化作用:批标准化在一定程度上具有正则化的效果,有助于防止过拟合。
4.不那么依赖初始化:由于标准化的存在,对网络的初始权重设置并不像传统网络那样敏感,这简化了网络的初始化过程。

对比使用批标准化和不使用批标准化

import torch
from torch import nn
from torch.nn import init
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np# 用于可复现
# torch.manual_seed(1)    # reproducible
# np.random.seed(1)# Hyper parameters
# 样本点
N_SAMPLES = 2000
# 批大小
BATCH_SIZE = 64
# 轮次
EPOCH = 12
# 学习率
LR = 0.03
# 隐藏层层数
N_HIDDEN = 8
# 激活函数
ACTIVATION = torch.tanh
B_INIT = -0.2   # use a bad bias constant initializer# training data
# 生成-7到10之间的N_SAMPLES个值的等差数列,并将其转化为一个二维列向量
x = np.linspace(-7, 10, N_SAMPLES)[:, np.newaxis]
# 生成一个均值为0,标准差为2的和x相同形状的噪声数据
noise = np.random.normal(0, 2, x.shape)
# 生成x对应的y值
y = np.square(x) - 5 + noise# test data
test_x = np.linspace(-7, 10, 200)[:, np.newaxis]
noise = np.random.normal(0, 2, test_x.shape)
test_y = np.square(test_x) - 5 + noisetrain_x = torch.from_numpy(x).float()
train_y = torch.from_numpy(y).float()
test_x = torch.from_numpy(test_x).float()
test_y = torch.from_numpy(test_y).float()train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2,)# show data
# plt.scatter(train_x.numpy(), train_y.numpy(), c='#FF9359', s=50, alpha=0.2, label='train')
# plt.scatter(test_x.numpy(), test_y.numpy(), c='blue', s=50, alpha=0.2, label='test')
# plt.legend(loc='best')class Net(nn.Module):def __init__(self, batch_normalization=False):super(Net, self).__init__()# 是否进行批标准化self.do_bn = batch_normalization# 全连接层的列表self.fcs = []# 批标准化层的列表self.bns = []self.bn_input = nn.BatchNorm1d(1, momentum=0.5)   # for input datafor i in range(N_HIDDEN):               # build hidden layers and BN layers# 如果是第一层,输入神经元个数为1,其余为10个input_size = 1 if i == 0 else 10# 全连接层fc = nn.Linear(input_size, 10)# 将全连接层重新命名然后设置为类属性setattr(self, 'fc%i' % i, fc)       # IMPORTANT set layer to the Module# 对全连接层的参数进行初始化self._set_init(fc)                  # parameters initialization# 添加到列表中self.fcs.append(fc)if self.do_bn:bn = nn.BatchNorm1d(10, momentum=0.5)setattr(self, 'bn%i' % i, bn)   # IMPORTANT set layer to the Moduleself.bns.append(bn)self.predict = nn.Linear(10, 1)         # output layerself._set_init(self.predict)            # parameters initializationdef _set_init(self, layer):init.normal_(layer.weight, mean=0., std=.1)init.constant_(layer.bias, B_INIT)# 前向传播def forward(self, x):pre_activation = [x]if self.do_bn:x = self.bn_input(x)     # input batch normalizationlayer_input = [x]for i in range(N_HIDDEN):x = self.fcs[i](x)pre_activation.append(x)if self.do_bn: x = self.bns[i](x)   # batch normalizationx = ACTIVATION(x)layer_input.append(x)out = self.predict(x)# 返回预测值、每个隐藏层的输入、激活函数的输出return out, layer_input, pre_activationnets = [Net(batch_normalization=False), Net(batch_normalization=True)]# print(*nets)    # print net architecture# 优化器
opts = [torch.optim.Adam(net.parameters(), lr=LR) for net in nets]
# MSE作为损失函数
loss_func = torch.nn.MSELoss()def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):for i, (ax_pa, ax_pa_bn, ax, ax_bn) in enumerate(zip(axs[0, :], axs[1, :], axs[2, :], axs[3, :])):[a.clear() for a in [ax_pa, ax_pa_bn, ax, ax_bn]]if i == 0:p_range = (-7, 10);the_range = (-7, 10)else:p_range = (-4, 4);the_range = (-1, 1)ax_pa.set_title('L' + str(i))ax_pa.hist(pre_ac[i].data.numpy().ravel(), bins=10, range=p_range, color='#FF9359', alpha=0.5);ax_pa_bn.hist(pre_ac_bn[i].data.numpy().ravel(), bins=10, range=p_range, color='#74BCFF', alpha=0.5)ax.hist(l_in[i].data.numpy().ravel(), bins=10, range=the_range, color='#FF9359');ax_bn.hist(l_in_bn[i].data.numpy().ravel(), bins=10, range=the_range, color='#74BCFF')for a in [ax_pa, ax, ax_pa_bn, ax_bn]: a.set_yticks(());a.set_xticks(())ax_pa_bn.set_xticks(p_range);ax_bn.set_xticks(the_range)axs[0, 0].set_ylabel('PreAct');axs[1, 0].set_ylabel('BN PreAct');axs[2, 0].set_ylabel('Act');axs[3, 0].set_ylabel('BN Act')plt.pause(0.01)if __name__ == "__main__":f, axs = plt.subplots(4, N_HIDDEN + 1, figsize=(10, 5))# 开启动态绘制plt.ion()  # something about plottingplt.show()# traininglosses = [[], []]  # recode loss for two networksfor epoch in range(EPOCH):print('Epoch: ', epoch)layer_inputs, pre_acts = [], []# 训练两个网络for net, l in zip(nets, losses):net.eval()              # set eval mode to fix moving_mean and moving_varpred, layer_input, pre_act = net(test_x)l.append(loss_func(pred, test_y).data.item())layer_inputs.append(layer_input)pre_acts.append(pre_act)net.train()             # free moving_mean and moving_varplot_histogram(*layer_inputs, *pre_acts)     # plot histogramfor step, (b_x, b_y) in enumerate(train_loader):for net, opt in zip(nets, opts):     # train for each network# 获取到预测值pred, _, _ = net(b_x)# 计算lossloss = loss_func(pred, b_y)# 梯度清零opt.zero_grad()# 误差反向传播loss.backward()# 逐步优化网络参数opt.step()    # it will also learns the parameters in Batch Normalization# 关闭动态绘制plt.ioff()# plot training loss# 绘制loss图plt.figure(2)plt.plot(losses[0], c='#FF9359', lw=3, label='Original')plt.plot(losses[1], c='#74BCFF', lw=3, label='Batch Normalization')plt.xlabel('step')plt.ylabel('test loss')plt.ylim((0, 2000))plt.legend(loc='best')# evaluation# set net to eval mode to freeze the parameters in batch normalization layers[net.eval() for net in nets]    # set eval mode to fix moving_mean and moving_varpreds = [net(test_x)[0] for net in nets]plt.figure(3)# 测试拟合效果plt.plot(test_x.data.numpy(), preds[0].data.numpy(), c='#FF9359', lw=4, label='Original')plt.plot(test_x.data.numpy(), preds[1].data.numpy(), c='#74BCFF', lw=4, label='Batch Normalization')plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='r', s=50, alpha=0.2, label='train')plt.legend(loc='best')plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/6839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux基本指令操作

登陆指令(云服务器版) 当我们获取公网IP地址后,我们就可以打开xshell。 此时会有这样的界面,我们若是想的登陆,则需要输入以下的指令 ssh 用户名公网IP地址 然后会跳出以下的窗口 接着输入密码——密码便是先前定好…

微服务安全简介

​由于其可扩展性、灵活性和敏捷性,微服务架构已经变得越来越受欢迎。然而,随着这种架构的分布和复杂性增加,确保强大的安全措施变得至关重要。微服务的安全性超越了传统的方法,需要采用全面的策略来保护免受不断演变的威胁和漏洞…

Promise 讲解,js知识,es6

文章目录 一、Promise的三种状态1. 初始态pending2. 成功态fulfilled,调用resolve方法3. 失败态rejected,调用reject方法 二、Promise的方法then方法catch方法 三、async和awaitasync 函数await 表达式 四、代码举例帮助理解1、Promise的值通过then方法获…

在vsCode 中执行Electron 项目时,出现中文乱码问题

问题:vscode 中执行Electron 项目时,控制台出现乱码 解决方法: 在 terminal 修改编码格式:65001代表UTF-8,936代表GBK

IC设计从业者必备的宝藏网站!

对于IC设计从业者而言,获取准确的学习资源,行业资讯直观重要,今日我们推荐ic行业专业的宝藏网站,希望对从业者有所帮助。 01-找开源项目的网站 GitHub除了Git代码仓库托管及基本的 Web管理界面以外,还提供了订阅、讨论…

用 Generative AI 构建企业专属的用户助手机器人

原文来源: https://tidb.net/blog/a9cdb8ec 关于作者:李粒,PingCAP PM TL;DR 本文介绍了如何用 Generative AI 构建一个使用企业专属知识库的用户助手机器人。除了使用业界常用的基于知识库的回答方法外,还尝试使用模型在 fe…

JAVA面试总结-Redis篇章(三)——缓存雪崩

JAVA面试总结-Redis篇章(三)——缓存雪崩

go性能分析工具之trace

参考文章: https://eddycjy.gitbook.io/golang/di-9-ke-gong-ju/go-tool-trace https://mp.weixin.qq.com/s__bizMzUxMDQxMDMyNg&mid2247484297&idx1&sn7a01fa4f454189fc3ccdb32a6e0d6897&scene21#wechat_redirect 你有没有考虑过,你的g…

0基础学习VR全景平台篇 第70篇:VR直播-如何设置付费观看、试看

对于拥有优质内容的VR直播,可以通过付费观看的方式进行内容变现,是当下非常流行的商业模式。 付费价格>0时便会自动弹出“试看时间”的设置项。试看时间=0秒时,用户进入直播间需要先付费才可观看;试看时间&…

Python中字符串拼接有哪些方法

目录 什么是字符串拼接 为什么要进行字符串拼接 Python中字符串拼接有哪些方法? 什么是字符串拼接 字符串拼接是将多个字符串连接在一起形成一个新的字符串的操作。在编程中,字符串拼接经常用于将不同的字符串组合在一起,以创建更长或更有…

勘探开发人工智能技术:地震层位解释

1 地震层位解释 层位解释是地震构造解释的重要内容,是根据目标层位的地震反射特征如振幅、相位、形态、连续性、特征组合等信息在地震数据体上进行追踪解释获得地震层位数据的方法。 1.1 地震信号、层位与断层 图1.1 所示为地震信号采集的过程,地面炮…

图像处理之canny边缘检测(非极大值抑制和高低阈值)

Canny 边缘检测方法 Canny算子是John F.Canny 大佬在1986年在其发表的论文 《Canny J. A computational approach to edge detection [J]. IEEE Transactions on Pattern Analysis and Machine Intelligence, 1986 (6): 679-698.》提出来的。 检测目标: 低错误率…

学好Elasticsearch系列-Mapping

本文已收录至Github,推荐阅读 👉 Java随想录 文章目录 Mapping 的基本概念查看索引 Mapping 字段数据类型数字类型基本数据类型Keywords 类型Dates(时间类型)对象类型空间数据类型文档排名类型文本搜索类型 两种映射类型自动映射&…

动手学DL——深度学习预备知识随笔【深度学习】【PyTorch】

文章目录 2、预备知识2.1、数据操作2.2、线性代数&矩阵计算2.3、导数2.4、基础优化方法 2、预备知识 2.1、数据操作 batch:以图片数据为例,一次读入的图片数量。 小批量样本可以充分利用GPU进行并行计算提高计算效率。 数据访问 数组:np…

Android 实现阅读用户协议的文字控件效果

开发中&#xff0c;经常要用到一些阅读隐私协议的场景&#xff0c;原生的textview控件很难做到在一个控件里有两个点击事件&#xff0c;那现在就来安利一个强大的组件——SpannableStringBuilder。 先看看效果&#xff1a; 直接上代码&#xff0c;布局文件&#xff1a; <Li…

【图像处理】使用自动编码器进行图像降噪(改进版)

阿里雷扎凯沙瓦尔兹 一、说明 自动编码器是一种学习压缩和重建输入数据的神经网络。它由一个将数据压缩为低维表示的编码器和一个从压缩表示中重建原始数据的解码器组成。该模型使用无监督学习进行训练&#xff0c;旨在最小化输入和重建输出之间的差异。自动编码器可用于降维、…

【iOS】动态链接器dyld

参考&#xff1a;认识 dyld &#xff1a;动态链接器 dyld简介 dyld&#xff08;Dynamic Linker&#xff09;是 macOS 和 iOS 系统中的动态链接器&#xff0c;它是负责在运行时加载和链接动态共享库&#xff08;dylib&#xff09;或可执行文件的组件。在 macOS 系统中&#xf…

STM32MP157驱动开发——按键驱动(定时器)

“定时器 ”机制&#xff1a; 内核函数 定时器涉及函数参考内核源码&#xff1a;include\linux\timer.h 给定时器的各个参数赋值&#xff1a; setup_timer(struct timer_list * timer, void (*function)(unsigned long),unsigned long data)&#xff1a;设置定时器&#xf…

多元函数的概念

目录 多元函数的极限&#xff1a; 例题1&#xff1a; 例题2&#xff1a; 多元函数的连续性 连续函数的性质 偏导数 高阶偏导数 定理1&#xff1a; 全微分 可微的必要条件 用定义来判断是否可微 可微的充分条件 连续偏导可微的关系 多元函数的极限&#xff1a; 对于一个二元…

macOS Ventura 13.5 (22G74) 正式版发布,ISO、IPSW、PKG 下载

macOS Ventura 13.5 (22G74) 正式版发布&#xff0c;ISO、IPSW、PKG 下载 本站下载的 macOS Ventura 软件包&#xff0c;既可以拖拽到 Applications&#xff08;应用程序&#xff09;下直接安装&#xff0c;也可以制作启动 U 盘安装&#xff0c;或者在虚拟机中启动安装。另外也…