笔记小结:现代卷积神经网络之批量归一化

本文为李沐老师《动手学深度学习》笔记小结,用于个人复习并记录学习历程,适用于初学者

训练深层神经网络是十分困难的,特别是在较短的时间内使他们收敛更加棘手。 本节将介绍批量规范化(batch normalization),这是一种流行且有效的技术,可持续加速深层网络的收敛速度。

从零开始实现

张量的批量规范化函数
import torch
from torch import nndef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data
批量规范化层
class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y
使用批量规范化层作用于LeNet

批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))
训练
准备工作

和之前多篇文章中提到的一样,不再赘述,只给出代码

from IPython import display
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as pltdef load_data_fashion_mnist(batch_size, resize=None): """下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=0)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=0)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))def get_dataloader_workers():  """使用4个进程来读取数据"""return 4batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1) #找出输入张量(tensor)中最大值的索引cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())
class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]import matplotlib.pyplot as plt
from matplotlib_inline import backend_inlinedef use_svg_display(): """使⽤svg格式在Jupyter中显⽰绘图"""backend_inline.set_matplotlib_formats('svg')def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""设置matplotlib的轴"""axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()class Animator:  #@save"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []use_svg_display()self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]import time
class Timer:  #@save"""记录多次运行时间"""def __init__(self):self.times = []self.start()def start(self):"""启动计时器"""self.tik = time.time()def stop(self):"""停止计时器并将时间记录在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均时间"""return sum(self.times) / len(self.times)def sum(self):"""返回时间总和"""return sum(self.times)def cumsum(self):"""返回累计时间"""return np.array(self.times).cumsum().tolist()def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')def try_gpu(i=0):  #@save"""如果存在,则返回gpu(i),否则返回cpu()"""if torch.cuda.device_count() >= i + 1:return torch.device(f'cuda:{i}')return torch.device('cpu')
训练

和以前一样,我们将在Fashion-MNIST数据集上训练网络。 这个代码与我们第一次训练LeNet时几乎完全相同,主要区别在于学习率大得多。

begin = time.time()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
end = time.time()
print(end - begin)

这个结果,对比当时不用批量归一化层的LeNet,训练的收敛速度快了许多,loss变小了,train acc提高了许多,但是test acc没有提高太多,出现了过拟合。 

简洁实现

除了使用我们刚刚定义的BatchNorm,我们也可以直接使用深度学习框架中定义的BatchNorm。 该代码看起来几乎与我们上面的代码相同。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))

下面,我们使用相同超参数来训练模型。 请注意,通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。

begin = time.time()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
end = time.time()

从结果可以看到,运行速度快了,并且过拟合也小了许多。 

小结

  • 在模型训练过程中,批量规范化利用小批量的均值和标准差,不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。
  • 批量规范化在全连接层和卷积层的使用略有不同。
  • 批量规范化层和暂退层一样,在训练模式和预测模式下计算不同。
  • 批量规范化有许多有益的副作用,主要是正则化。另一方面,”减少内部协变量偏移“的原始动机似乎不是一个有效的解释。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/49363.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

React中引入使用本地图片

1、css样式中,可以写成如下: .login {// 映射路径background: url(images/img-login.png);background-size: 100% 100%; } .box{// 相对路径background: url(../assets/images/box.png);background-size: 100% 100%; }2、在jsx文…

Redis-10大数据类型理解与测试

Redis10大数据类型 我要打10个1.redis字符串(String)2.redis列表(List)3.redis哈希表(Hash)4.redis集合(Set)5.redis有序集合(ZSet)6.redis地理空间(GEO)7.redis基数统计(HyperLogLog)8.redis位图(bitmap)9.redis位域(bitfield)10.redis流(Stream) 官网地址Redis 键(key)常用案…

Odoo Registry 源码解读:前端世界的魔法师

亲爱的Odoo探险家们,准备好踏上一段奇妙的代码冒险了吗?今天,我们要深入探索Odoo前端世界的一位神秘大师——Registry。它可能不像那些花哨的UI组件那样引人注目,但要知道,真正的高手都是低调的。Registry就像是Odoo世界里的甘道夫,默默地用魔法维持着整个中土世界的平衡…

卷与nfs实现多台主机容器之间的数据共享

容器存放数据到宿主机里 数据持久化:永久保存 数据共享:容器和容器 容器和宿主机 就是宿主机里的文件夹 /var/lib/docker/volumes docker 容器:容器运行起来是一个进程,进程的数据默认都在内存里,内存里的数据停电…

鸿蒙界面开发

界面开发 //构建 → 界面 build() {//行Row(){//列Column(){//文本 函数名(参数) 对象.方法名(参数) 枚举名.变量名Text(this.message).fontSize(40)//设置文本大小.fontWeight(FontWeight.Bold)//设置文本粗细.fontColor(#ff2152)//设置文本颜色}.widt…

邮件流量分析

邮件流量分析是指对邮件系统中进出的邮件数量、类型、流向以及相关性能指标进行统计和评估的过程。这种分析有助于优化邮件系统的性能、检测异常活动、防止垃圾邮件和网络攻击,以及改善邮件营销策略。邮件流量分析可以分为几个主要方面: 邮件统计&#x…

MMROTATE的混淆矩阵confusion matrix生成

mmdetection中加入了混淆矩阵生成并可视化的功能,具体的代码在tools/analysis_tools/confusion_matrix.py。 mmrotate由于主流遥感数据集中的DOTA数据集标注格式问题,做了一些修改,所以我们如果是做遥感图像检测的Dota数据集的混淆矩阵&…

安装CUDA Cudnn Pytorch(GPU版本)步骤

一.先看自己的电脑NVIDIA 支持CUDA版本是多少? 1.打开NVIDIA控制面板 2.点击帮助---系统信息--组件 我的支持CUDA11.6 二.再看支持Pytorch的CUDA版本 三.打开CUDA官网 下载CUDA 11.6 下载好后,安装 选择 自定义 然后安装位置 (先去F盘…

Python - conda使用大全

如何使用Conda? 环境 创建环境 conda create -n spider_env python3.10.11查看环境 conda env listconda info -e激活环境 conda activate spider_env退出环境 conda deactivate删除环境 conda env remove -n spider_env包 导出包 说明:导出当前虚拟…

孙健提到的实验室的研究方向之一是什么?()

孙健提到的实验室的研究方向之一是什么?() 点击查看答案 A.虚拟现实B.环境感知和理解 C.智能体博弈D.所有选项都正确 图灵奖是在哪一年设立的?() A.1962B.1966 C.1976D.1986 孙健代表的实验室的前身主要研究什么?&…

【ffmpeg命令入门】ffplay常用命令

文章目录 前言ffplay的简介FFplay 的基本用法常用参数及其作用示例 效果演示图播放普通视频播放网络媒体流RTSP 总结 前言 FFplay 是 FFmpeg 套件中的一个强大的媒体播放器,它基于命令行接口,允许用户以灵活且高效的方式播放音频和视频文件。作为一个简…

英语中英镑与美元表达方式的异同点(英镑£(GBP)和美元$(USD))便士p(pence),美分¢(cents)

文章目录 英语中英镑与美元表达方式的异同点分析货币符号与位置货币符号位置规则 数字格式与分隔数字表示小数点使用 小数单位英镑的便士美元的分 示例分析(😵)示例1:高额交易示例2:零售价格 总结 英语中英镑与美元表达…

自定义QDialog使用详解

自定义QDialog使用详解 一、创建 QDialog 对象二、QDialog设置布局三、QDialog控制模态行为3.1 模态和非模态区别3.2 QDialog的模态使用四、使用 QDialogButtonBox五、处理对话框的结果六、使用 QDialog 的信号和槽QDialog是Qt框架中用于创建对话框窗口的基本类。对话框窗口通常…

uniapp原生插件开发实战——iOS打开文件到自己的app

用原生开发获取文件的名称、路径等能力封装为一个插件包供前端使用 首先根据ios插件开发教程,创建一个插件工程,template 选framework 开始编写代码: iOS 9 及以下版本会调用以下方法: - (BOOL)application:(UIApplication *_N…

【数据分析详细教学】全球气温变迁:一个多世纪的数据分析

全球气温变迁:一个多世纪的数据分析 1. 数据集选择与获取 数据可以从NASA的GISTEMP数据集获取,通常提供的格式有TXT和CSV。我们假设数据是以CSV格式提供。 2. 数据预处理 使用Python的pandas库读取数据并进行预处理。 import pandas as pd# 加载数…

C#知识|账号管理系统:修改登录密码界面的UI设计

哈喽,你好啊!我是雷工! 本节记录添加修改登录密码界面的过程,以下为练习笔记。 01 效果演示 演示跳转打开修改登录密码子窗体效果: 02 添加窗体 在UI层添加一个Windows窗体,命名为:FrmModifyPwd.cs; 03 设置窗体属性 按照下表的内容设置窗体的相关属性: 设置属性 …

物联网架构之Hadoop

一:系统环境设置(所有节点都设置) 1:关闭selinux和防火墙 setenforce 0 sed -i /^SELINUX/s/enforcing/disabled/ /etc/selinux/config systemctl stop firewalld systemctl disable firewalld 2:为各个节点设置主机名…

编译期计算

关于编译期计算&#xff0c;直接能够想到的应用是决定是否启用某个模板&#xff0c;或者多个模板之间做选择。但如果有足够多的信息&#xff0c;编译器甚至可以计算控制流的结果。 模板元编程 模板元编程的简单例子&#xff0c;如下&#xff1a; #include <iostream>te…

vue練習--prop

<!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"> <title>Prop練習</title> <!-- 开发环境版本…

mysql面试(二)

前言 这是mysql面试基础的第二节&#xff0c;主要是了解一下mysql数据更新的基本流程&#xff0c;还有三大日志的作用。但是具体的比如undolog是如何应用在mvcc机制中的&#xff0c;由于篇幅问题就放在下一在章节 数据更新流程 上面是说了更新真正数据之前的大致流程&#x…