笔记小结:现代卷积神经网络之批量归一化

本文为李沐老师《动手学深度学习》笔记小结,用于个人复习并记录学习历程,适用于初学者

训练深层神经网络是十分困难的,特别是在较短的时间内使他们收敛更加棘手。 本节将介绍批量规范化(batch normalization),这是一种流行且有效的技术,可持续加速深层网络的收敛速度。

从零开始实现

张量的批量规范化函数
import torch
from torch import nndef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data
批量规范化层
class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y
使用批量规范化层作用于LeNet

批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))
训练
准备工作

和之前多篇文章中提到的一样,不再赘述,只给出代码

from IPython import display
import torchvision
from torch.utils import data
from torchvision import transforms
import matplotlib.pyplot as pltdef load_data_fashion_mnist(batch_size, resize=None): """下载Fashion-MNIST数据集,然后将其加载到内存中"""trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize))trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=0)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=0)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))def get_dataloader_workers():  """使用4个进程来读取数据"""return 4batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = load_data_fashion_mnist(batch_size)def accuracy(y_hat, y):  #@save"""计算预测正确的数量"""if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1) #找出输入张量(tensor)中最大值的索引cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())
class Accumulator:  #@save"""在n个变量上累加"""def __init__(self, n):self.data = [0.0] * ndef add(self, *args):self.data = [a + float(b) for a, b in zip(self.data, args)]def reset(self):self.data = [0.0] * len(self.data)def __getitem__(self, idx):return self.data[idx]import matplotlib.pyplot as plt
from matplotlib_inline import backend_inlinedef use_svg_display(): """使⽤svg格式在Jupyter中显⽰绘图"""backend_inline.set_matplotlib_formats('svg')def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):"""设置matplotlib的轴"""axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()class Animator:  #@save"""在动画中绘制数据"""def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,figsize=(3.5, 2.5)):# 增量地绘制多条线if legend is None:legend = []use_svg_display()self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)if nrows * ncols == 1:self.axes = [self.axes, ]# 使用lambda函数捕获参数self.config_axes = lambda: set_axes(self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)self.X, self.Y, self.fmts = None, None, fmtsdef add(self, x, y):# 向图表中添加多个数据点if not hasattr(y, "__len__"):y = [y]n = len(y)if not hasattr(x, "__len__"):x = [x] * nif not self.X:self.X = [[] for _ in range(n)]if not self.Y:self.Y = [[] for _ in range(n)]for i, (a, b) in enumerate(zip(x, y)):if a is not None and b is not None:self.X[i].append(a)self.Y[i].append(b)self.axes[0].cla()for x, y, fmt in zip(self.X, self.Y, self.fmts):self.axes[0].plot(x, y, fmt)self.config_axes()display.display(self.fig)display.clear_output(wait=True)def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel())return metric[0] / metric[1]import time
class Timer:  #@save"""记录多次运行时间"""def __init__(self):self.times = []self.start()def start(self):"""启动计时器"""self.tik = time.time()def stop(self):"""停止计时器并将时间记录在列表中"""self.times.append(time.time() - self.tik)return self.times[-1]def avg(self):"""返回平均时间"""return sum(self.times) / len(self.times)def sum(self):"""返回时间总和"""return sum(self.times)def cumsum(self):"""返回累计时间"""return np.array(self.times).cumsum().tolist()def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')def try_gpu(i=0):  #@save"""如果存在,则返回gpu(i),否则返回cpu()"""if torch.cuda.device_count() >= i + 1:return torch.device(f'cuda:{i}')return torch.device('cpu')
训练

和以前一样,我们将在Fashion-MNIST数据集上训练网络。 这个代码与我们第一次训练LeNet时几乎完全相同,主要区别在于学习率大得多。

begin = time.time()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
end = time.time()
print(end - begin)

这个结果,对比当时不用批量归一化层的LeNet,训练的收敛速度快了许多,loss变小了,train acc提高了许多,但是test acc没有提高太多,出现了过拟合。 

简洁实现

除了使用我们刚刚定义的BatchNorm,我们也可以直接使用深度学习框架中定义的BatchNorm。 该代码看起来几乎与我们上面的代码相同。

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))

下面,我们使用相同超参数来训练模型。 请注意,通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。

begin = time.time()
train_ch6(net, train_iter, test_iter, num_epochs, lr, try_gpu())
end = time.time()

从结果可以看到,运行速度快了,并且过拟合也小了许多。 

小结

  • 在模型训练过程中,批量规范化利用小批量的均值和标准差,不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。
  • 批量规范化在全连接层和卷积层的使用略有不同。
  • 批量规范化层和暂退层一样,在训练模式和预测模式下计算不同。
  • 批量规范化有许多有益的副作用,主要是正则化。另一方面,”减少内部协变量偏移“的原始动机似乎不是一个有效的解释。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/49363.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Redis-10大数据类型理解与测试

Redis10大数据类型 我要打10个1.redis字符串(String)2.redis列表(List)3.redis哈希表(Hash)4.redis集合(Set)5.redis有序集合(ZSet)6.redis地理空间(GEO)7.redis基数统计(HyperLogLog)8.redis位图(bitmap)9.redis位域(bitfield)10.redis流(Stream) 官网地址Redis 键(key)常用案…

鸿蒙界面开发

界面开发 //构建 → 界面 build() {//行Row(){//列Column(){//文本 函数名(参数) 对象.方法名(参数) 枚举名.变量名Text(this.message).fontSize(40)//设置文本大小.fontWeight(FontWeight.Bold)//设置文本粗细.fontColor(#ff2152)//设置文本颜色}.widt…

MMROTATE的混淆矩阵confusion matrix生成

mmdetection中加入了混淆矩阵生成并可视化的功能,具体的代码在tools/analysis_tools/confusion_matrix.py。 mmrotate由于主流遥感数据集中的DOTA数据集标注格式问题,做了一些修改,所以我们如果是做遥感图像检测的Dota数据集的混淆矩阵&…

安装CUDA Cudnn Pytorch(GPU版本)步骤

一.先看自己的电脑NVIDIA 支持CUDA版本是多少? 1.打开NVIDIA控制面板 2.点击帮助---系统信息--组件 我的支持CUDA11.6 二.再看支持Pytorch的CUDA版本 三.打开CUDA官网 下载CUDA 11.6 下载好后,安装 选择 自定义 然后安装位置 (先去F盘…

【ffmpeg命令入门】ffplay常用命令

文章目录 前言ffplay的简介FFplay 的基本用法常用参数及其作用示例 效果演示图播放普通视频播放网络媒体流RTSP 总结 前言 FFplay 是 FFmpeg 套件中的一个强大的媒体播放器,它基于命令行接口,允许用户以灵活且高效的方式播放音频和视频文件。作为一个简…

uniapp原生插件开发实战——iOS打开文件到自己的app

用原生开发获取文件的名称、路径等能力封装为一个插件包供前端使用 首先根据ios插件开发教程,创建一个插件工程,template 选framework 开始编写代码: iOS 9 及以下版本会调用以下方法: - (BOOL)application:(UIApplication *_N…

【数据分析详细教学】全球气温变迁:一个多世纪的数据分析

全球气温变迁:一个多世纪的数据分析 1. 数据集选择与获取 数据可以从NASA的GISTEMP数据集获取,通常提供的格式有TXT和CSV。我们假设数据是以CSV格式提供。 2. 数据预处理 使用Python的pandas库读取数据并进行预处理。 import pandas as pd# 加载数…

C#知识|账号管理系统:修改登录密码界面的UI设计

哈喽,你好啊!我是雷工! 本节记录添加修改登录密码界面的过程,以下为练习笔记。 01 效果演示 演示跳转打开修改登录密码子窗体效果: 02 添加窗体 在UI层添加一个Windows窗体,命名为:FrmModifyPwd.cs; 03 设置窗体属性 按照下表的内容设置窗体的相关属性: 设置属性 …

物联网架构之Hadoop

一:系统环境设置(所有节点都设置) 1:关闭selinux和防火墙 setenforce 0 sed -i /^SELINUX/s/enforcing/disabled/ /etc/selinux/config systemctl stop firewalld systemctl disable firewalld 2:为各个节点设置主机名…

mysql面试(二)

前言 这是mysql面试基础的第二节,主要是了解一下mysql数据更新的基本流程,还有三大日志的作用。但是具体的比如undolog是如何应用在mvcc机制中的,由于篇幅问题就放在下一在章节 数据更新流程 上面是说了更新真正数据之前的大致流程&#x…

requets库传data和传json的区别

传data和传json的qubie 被测对象,白月黑羽系统 系统下载地址: https://www.byhy.net/prac/pub/info/bysms/ 测试用例下载地址: https://cdn2.byhy.net/files/selenium/testcases.xlsx 一、传data import json import requests import pytes…

7、Qt5开发及实列(笔记3-系统操作)

说明&#xff1a;此示例包含了基本的常使用的系统操作 效果如下: mainwindos.cpp #pragma execution_character_set("utf-8") #include "mainwindow.h"#include <QDesktopWidget> #include <QApplication> #include <QHostInfo> #in…

docker基础镜像

一、配置 docker 本地源 [docker-ce-stable] nameDocker CE Stable baseurlhttp://10.35.186.181/docker-ce-stable/ enabled1 gpgcheck0 配置阿里云Docker Yum源 yum install -y yum-utils device-mapper-persistent-data lvm2 git yum-config-manager --add-repo http://mirr…

Windows安装Visual Studio2022 + QT5.15开发环境

最近&#xff0c;把系统换成了Windows11&#xff0c;想重新安装QT5.12&#xff0c;结果发现下载不了离线安装包。 最后索性安装QT5.15了&#xff0c;特此记录下。 预祝大家&#xff1a;不论是何时安装&#xff0c;都可以安装到指定版本的QT。 一、VS2022安装 VS2022官网下…

ubuntu 22.04 安装部署gitlab详细过程

目录 gitlab介绍 gitlab安装 步骤1&#xff1a;更新系统 步骤2&#xff1a;添加 GitLab 的 GPG 密钥 gitlab企业版 gitlab社区版 步骤3&#xff1a;安装 GitLab 社区版 社区版 步骤4&#xff1a;初始化 GitLab 步骤5&#xff1a;访问 GitLab 步骤6&#xff1a;查看r…

C++ - 基于多设计模式下的同步异步⽇志系统

1.项目介绍 项⽬介绍 本项⽬主要实现⼀个⽇志系统&#xff0c; 其主要⽀持以下功能: • ⽀持多级别⽇志消息 • ⽀持同步⽇志和异步⽇志 • ⽀持可靠写⼊⽇志到控制台、⽂件以及滚动⽂件中 • ⽀持多线程程序并发写⽇志 • ⽀持扩展不同的⽇志落地⽬标地 2.开发环境 • Cent…

AI学习记录 - 导数在神经网络训练中的作用(自己画的图,很丑不要介意!)

导数的作用 我们去调整神经网络的权重&#xff0c;一般不会手动去调整&#xff0c;如果只有很少的神经元&#xff0c;人工调整确实可以实现&#xff0c;当我们有几十层&#xff0c;一层几百上千个神经元的时候&#xff0c;人工调整就不可能了。 一个权重的调整涉及到两个问题&…

TCP的FIN报文可否携带数据

问题发现&#xff1a; 发现FTP-DATA数据传输完&#xff0c;TCP的挥手似乎只有两次 实际发现FTP-DATA报文中&#xff0c;TCP层flags中携带了FIN标志 piggyback FIN 问题转化为 TCP packet中如果有FIN flag&#xff0c;该报文还能携带data数据么&#xff1f; 答案是肯定的 RFC7…

【LeetCode:3098. 求出所有子序列的能量和 + 记忆化缓存】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第三篇 嵌入式Linux驱动开发篇-第四十七章 字符设备和杂项设备总结回顾

i.MX8MM处理器采用了先进的14LPCFinFET工艺&#xff0c;提供更快的速度和更高的电源效率;四核Cortex-A53&#xff0c;单核Cortex-M4&#xff0c;多达五个内核 &#xff0c;主频高达1.8GHz&#xff0c;2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…