动手学深度学习(Pytorch版)代码实践 -卷积神经网络-28批量规范化

28批量规范化

"""可持续加速深层网络的收敛速度"""
import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as pltdef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):"""实现一个具有张量的批量规范化层。"""# 如果启用了梯度计算,torch.is_grad_enabled() 返回 True;否则返回 False。if not torch.is_grad_enabled():# torch.no_grad() 是一个上下文管理器,用于临时禁用梯度计算# torch.enable_grad() 是一个上下文管理器,用于在禁用梯度计算的上下文中重新启用梯度计算。X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0) # 计算张量 X 沿着第 0 维的平均值# 维度 0 代表样本数量,即沿着每个特征计算所有样本的平均值。var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * var# gamma 和 beta 的更新是通过反向传播和优化器自动完成的Y = gamma * X_hat + beta # 缩放和移位return Y, moving_mean.data, moving_var.dataclass BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1# 经过归一化处理后的数据均值接近于零。因此,将滑动均值初始化为0,是对数据初始均值的一种合理假设。self.moving_mean = torch.zeros(shape)# 方差表示数据的离散程度。将滑动方差初始化为1,意味着假设数据的初始方差为1,# 即数据分布接近标准正态分布。这样初始化可以避免初始阶段的数值不稳定。self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在GPU上                              if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y#使用批量规范化层的 LeNet
net = nn.Sequential(nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2BatchNorm(num_features=6, num_dims=4),nn.ReLU(), # 激活函数nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5BatchNorm(num_features=16, num_dims=4),nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2nn.Flatten(), # 展平层:将多维输入展平为1维nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120BatchNorm(num_features=120, num_dims=2),nn.ReLU(),nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84BatchNorm(num_features=84, num_dims=2),nn.ReLU(), nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size)
# lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
# plt.show()# loss 0.200, train acc 0.925, test acc 0.812
# 34957.3 examples/sec on cuda:0# loss 0.189, train acc 0.928, test acc 0.894
# 33471.2 examples/sec on cuda:0#简明实现
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.ReLU(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.ReLU(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.ReLU(),nn.Linear(84, 10)
)
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
plt.show()# nn.Sigmoid()
# loss 0.263, train acc 0.902, test acc 0.833
# 46935.0 examples/sec on cuda:0# nn.ReLU()
# loss 0.224, train acc 0.914, test acc 0.874
# 44479.2 examples/sec on cuda:0
"""
通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。
"""

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/32571.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Swift 中的动态数组

Swift 的 Array 类型是一种强大而灵活的集合类型,可以根据需要自动扩展或缩减其容量。 动态数组的基本概念 Swift 中的数组是基于动态数组(dynamic array)的概念实现的。动态数组能够根据需要自动调整其容量,以容纳新增的元素&a…

Benchmarking Panoptic Scene Graph Generation (PSG), ECCV‘22 场景图生成,利用PSG数据集

2080-ti显卡复现 源代码地址 Jingkang50/OpenPSG: Benchmarking Panoptic Scene Graph Generation (PSG), ECCV22 (github.com) 安装 pytorch 1.7版本 cuda10.1 按照readme的做法安装 我安装的过程如下图所示,这个截图是到了pip install openmim这一步 下一步 下一步 这一步…

C语言 | Leetcode C语言题解之第167题两数之和II-输入有序数组

题目&#xff1a; 题解&#xff1a; int* twoSum(int* numbers, int numbersSize, int target, int* returnSize) {int* ret (int*)malloc(sizeof(int) * 2);*returnSize 2;int low 0, high numbersSize - 1;while (low < high) {int sum numbers[low] numbers[high]…

如何设置MySQL远程访问权限?

MySQL是一种流行的关系型数据库管理系统&#xff0c;它广泛应用于各种Web应用程序和数据驱动的应用中。在默认情况下&#xff0c;MySQL只允许本地访问&#xff0c;为了能够从远程服务器或客户端访问MySQL数据库&#xff0c;我们需要进行一些额外的设置和配置。 安装和配置MySQ…

在阿里云使用Docker部署MySQL服务,并且通过IDEA进行连接

阿里云使用Docker部署MySQL服务&#xff0c;并且通过IDEA进行连接 这里演示如何使用阿里云来进行MySQL的部署&#xff0c;系统使用的是Linux系统 (Ubuntu)。 为什么使用Docker? 首先是因为它的可移植性可以在任何有Docker环境的系统上运行应用&#xff0c;避免了在不通操作系…

【html】用html+css实现银行的账户信息表格

我们先来看一看某银行的账户信息表格 我们自己也可以实现类似的效果 效果图: 大家可以看到&#xff0c;其实效果差不多 接下来看看我们实现的代码 源码&#xff1a; <!DOCTYPE html> <html lang"zh"><head><meta charset"UTF-8"&…

机械师硬盘数据清空怎么办?机械师硬盘数据清空怎么恢复

机械师硬盘数据清空怎么恢复&#xff1f;随着数字化时代的到来&#xff0c;数据已成为我们生活和工作中不可或缺的一部分。然而&#xff0c;硬盘数据的意外清空往往会给我们带来极大的困扰。本文将探讨在机械师硬盘数据清空后&#xff0c;我们应该如何快速有效地恢复数据。 图片…

二本毕业,我是如何成为BAT-安卓开发工程师?

1.对基础原理不断挖掘 进入公司&#xff0c;我的职位是Linux应用开发工程师&#xff0c;做App网络传输模块&#xff0c;本质上就是把本地的数据通过socket传输到服务端。用到的技术是C语言&#xff0c;网络编程&#xff0c;多线程编程。 那时是最痛苦的几个月&#xff0c;因为…

[FreeRTOS 功能应用] 互斥访问与回环队列 功能应用

文章目录 一、基础知识点二、代码讲解三、结果演示四、代码下载 一、基础知识点 [FreeRTOS 基础知识] 互斥访问与回环队列 概念 [FreeRTOS 内部实现] 互斥访问与回环队列 [FreeRTOS 内部实现] 创建任务 xTaskCreate函数解析 本实验是基于STM32F103开发移植FreeRTOS实时操作系…

解决WebStorm中不显示npm任务面板

鼠标右键项目的package.json文件&#xff0c;然后点击show npm scripts选项。 然后npm工具窗口就显示了&#xff1a;

02--MySQL数据库概述

目录 第10章 子查询 10.1 SELECT的SELECT中嵌套子查询 10.2 SELECT的WHERE或HAVING中嵌套子查询 10.3 SELECT中的EXISTS型子查询 10.4 SELECT的FROM中嵌套子查询 第11章 MySQL支持的数据类型 11.1 数值类型:包括整数和小数 1、整数类型 2、bit类型 3、小数类型 11.2…

【Python系列】探索 NumPy 中的 mean 函数:计算平均值的利器

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

软件串口接收子程序

代码; stduart.c /*《AVR专题精选》随书例程3.通信接口使用技巧项目&#xff1a;使用延时法实现半双工软件串口文件&#xff1a;sfuart.c说明&#xff1a;软件串口驱动文件作者&#xff1a;邵子扬时间&#xff1a;2012年12月13日*/ #include "sfduart.h"// 循环中延…

WinMerge v2 (开源的文件比较/合并工具)

前言 WinMerge 是一款运行于Windows系统下的免费开源的文件比较/合并工具&#xff0c;使用它可以非常方便地比较多个文档内容甚至是文件夹与文件夹之间的文件差异。适合程序员或者经常需要撰写文稿的朋友使用。 一、下载地址 下载链接&#xff1a;http://dygod/source 点击搜…

【Linux】进程间通信3——线程安全

1.Linux线程互斥 1.1.进程线程间的互斥相关背景概念 临界资源&#xff1a; 多线程执行流共享的资源叫做临界资源。临界区&#xff1a; 每个线程内部&#xff0c;访问临界资源的代码&#xff0c;就叫做临界区。互斥&#xff1a; 任何时刻&#xff0c;互斥保证有且只有一个执行…

动态创建接口地址

和SpringBoot版本有关系 这里用的boot 2.2.2

LabVIEW电控旋翼测控系统

开发基于LabVIEW开发的电控旋翼测控系统&#xff0c;通过高效监控和控制提升旋翼系统的性能和安全性。系统集成了多种硬件设备&#xff0c;采用模块化设计&#xff0c;实现复杂的控制和数据处理功能&#xff0c;适用于现代航空航天领域。 项目背景 传统旋翼系统依赖机械和液压…

IO模型详解

阻塞IO模型 假设应用程序的进程发起IO调用&#xff0c;但是如果内核的数据还没准备好的话&#xff0c;那应用程序进程就一直在阻塞等待&#xff0c;一直等到内核数据准备好了&#xff0c;从内核拷贝到用户空间&#xff0c;才返回成功提示&#xff0c;此次IO操作&#xff0c;称…

C# 中的静态关键字

C# 语言中的 static 关键字用于声明静态类和静态类成员。静态类和静态类成员&#xff08;如构造函数、字段、属性、方法和事件&#xff09;在只需要一个对象&#xff08;类或类成员&#xff09;副本并在类型&#xff08;和成员&#xff09;的所有实例&#xff08;对象&#xff…

React+TS前台项目实战(十五)-- 全局常用组件Table封装

文章目录 前言Table组件1. 功能分析2. 代码详细注释3. 使用方式4. 效果展示 总结 前言 在这篇文章中&#xff0c;我们将对本系列项目中常用的表格组件Table进行自定义封装&#xff0c;以提高性能并适应项目需求。后期也可进行修改和扩展&#xff0c;以满足项目的需求。 Table组…