动手学深度学习—网络中的网络NiN(代码详解)

目录

      • 1. NiN块
      • 2. NiN模型
      • 3. 训练模型

LeNet、AlexNet和VGG都有一个共同的设计模式:
通过一系列的卷积层与汇聚层来提取空间结构特征;然后通过全连接层对特征的表征进行处理。

如果在过程的早期使用全连接层,可能会完全放弃表征的空间结构。

而NiN(网络中的网络)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机。

1. NiN块

卷积层的输入和输出由四维张量组成(样本,通道,高度,宽度)
全连接层的输入和输出通常是二维张量(样本,特征)

NiN在每个像素位置(针对每个高度和宽度)应用一个全连接层,可以将其视为1x1卷积层。将间维度中的每个像素视为单个样本,将通道维度视为不同特征。在这里插入图片描述
第一层为普通卷积层,之后的两个卷积层充当带有ReLU函数的逐像素全连接层。
NiN块

import torch
from torch import nn
from d2l import torch as d2l# 定义NiN块,(输入通道,输出通道,核大小,步幅,填充)
def nin_block(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding), nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

2. NiN模型

1、NiN使用窗口形状为11×11、5×5和3×3的卷积层,输出通道数量与AlexNet中的相同
2、每个NiN块后有一个最大汇聚层,汇聚窗口形状为3×3,步幅为2
3、取消了全连接层,使用一个NiN块,其输出通道数等于标签类别的数量
4、最后放一个全局平均汇聚层,生成一个对数几率

net = nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),# 标签类别nin_block(384, 10, kernel_size=3, strides=1, padding=1),nn.AdaptiveAvgPool2d((1, 1)),# 将四维的输出转成二维的输出,其形状为(批量大小,10)nn.Flatten())

观察每个块的输出形状

# 创建一个数据样本来查看每个块的输出形状
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)

在这里插入图片描述

3. 训练模型

定义精度评估函数

"""定义精度评估函数:1、将数据集复制到显存中2、通过调用accuracy计算数据集的精度
"""
def evaluate_accuracy_gpu(net, data_iter, device=None): #@save# 判断net是否属于torch.nn.Module类if isinstance(net, nn.Module):net.eval()# 如果不在参数选定的设备,将其传输到设备中if not device:device = next(iter(net.parameters())).device# Accumulator是累加器,定义两个变量:正确预测的数量,总预测的数量。metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:# 将X, y复制到设备中if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)# 计算正确预测的数量,总预测的数量,并存储到metric中metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

定义GPU训练函数

"""定义GPU训练函数:1、为了使用gpu,首先需要将每一小批量数据移动到指定的设备(例如GPU)上;2、使用Xavier随机初始化模型参数;3、使用交叉熵损失函数和小批量随机梯度下降。
"""
#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""# 定义初始化参数,对线性层和卷积层生效def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)# 在设备device上进行训练print('training on', device)net.to(device)# 优化器:随机梯度下降optimizer = torch.optim.SGD(net.parameters(), lr=lr)# 损失函数:交叉熵损失函数loss = nn.CrossEntropyLoss()# Animator为绘图函数animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])# 调用Timer函数统计时间timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# Accumulator(3)定义3个变量:损失值,正确预测的数量,总预测的数量metric = d2l.Accumulator(3)net.train()# enumerate() 函数用于将一个可遍历的数据对象for i, (X, y) in enumerate(train_iter):timer.start() # 进行计时optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将特征和标签转移到devicey_hat = net(X)l = loss(y_hat, y) # 交叉熵损失l.backward() # 进行梯度传递返回optimizer.step()with torch.no_grad():# 统计损失、预测正确数和样本数metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop() # 计时结束train_l = metric[0] / metric[2] # 计算损失train_acc = metric[1] / metric[2] # 计算精度# 进行绘图if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))# 测试精度test_acc = evaluate_accuracy_gpu(net, test_iter) animator.add(epoch + 1, (None, None, test_acc))# 输出损失值、训练精度、测试精度print(f'loss {train_l:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')# 设备的计算能力print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')

在这里插入图片描述

训练模型

# 训练模型
lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/116949.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Day 1 Vue 页面框架

现在前端框架越来越像后端了,特别是TypeScript这样的语言出现后,开发前端的体验跟后端渐渐接近了。当然,作为一个后端,直接上手前端,还是有很多坑要填的。 本次开发,前端页面框架直接选择Vue。原因很简单&…

深入浅出排序算法之归并排序

目录 1. 归并排序的原理 1.1 二路归并排序执行流程 2. 代码分析 2.1 代码设计 3. 性能分析 4. 非递归版本 1. 归并排序的原理 “归并”一词的中文含义就是合并、并入的意思,而在数据结构中的定义是将两个或者两个以上的有序表组合成一个新的有序表。 归并排序…

Python-pptx教程之一从零开始生成PPT文件

简介 python-pptx是一个用于创建、读取和更新PowerPoint(.pptx)文件的python库。 典型的用途是根据动态内容(如数据库查询、分析数据等),将这些内容自动化生成PowerPoint演示文稿,将数据可视化&#xff0c…

【IDEA配置】IDEA配置

参考视频:【idea必知必会】优化设置 告别卡顿 1. 显示内存 右击底下空白区域,出现memory indicator内存指示器,点击勾选即可显示。有的是在Settings->Appearance->Window Options里,如图所示: 2. 内存设置 …

虚拟世界游戏定制开发:创造独一无二的虚拟体验

在游戏开发领域,虚拟世界游戏定制开发是一项引人注目的任务,旨在满足客户独特的需求和愿景,创造一个完全个性化的虚拟世界游戏。这种类型的游戏开发需要专业的技能、深刻的游戏开发知识和密切的与客户合作,以确保游戏满足客户的期…

CI2454 2.4g无线MCU芯片应用

Ci2454集成MCU芯片 | Ci2454是一款集成无线收发器和 8 位 RISC(精简指令集)MCU 的SOC芯片。 #Ci2454芯片 集成MCU芯片# 中国芯片# 无线收发器特性: 工作在 2.4GHz ISM 频段 调制方式:GFSK/FSK 数据速率:2Mbps/1Mbps…

数据库基础(一)【MySQL】

文章目录 安装 MySQL修改密码连接和退出数据库服务器使用 systemctl 管理服务器进程配置数据库从文件角度看待数据库查看连接情况 安装 MySQL 这是在 Linux 中安装 MySQL 的教程:Linux 下 MySQL 安装。本系列测试用的 MySQL 版本是 5.7,机器是 centOS7.…

LabVIEW中将枚举与条件结构一起使用

LabVIEW中将枚举与条件结构一起使用 枚举是一个具有相应数值的字符串标签型列表。在LabVIEW(U8 , U16-默认值和U32)中以无符号整数形式应用。 例如,可以有一个枚举保存四个季节,在这种情况下,每个字符串都…

Go之流程控制大全: 细节、示例与最佳实践

引言 在计算机编程中,流程控制是核心的组成部分,它决定了程序应该如何根据给定的情况执行或决策。以下是Go语言所支持的流程控制结构的简要概览: 流程控制类型代码if-else条件分支if condition { } else { }for循环for initialization; con…

GoLong的学习之路(一)语法之变量与常量

目录 GoLang变量批量声明变量的初始化类型推导短变量声明匿名变量 常量iota(特殊)(需要重点记忆) GoLang go的诞生为了解决在21世纪多核和网络化环境越来越复杂的变成问题而发明的Go语言。 go语言是从Ken Thomepson发明的B语言和…

分享个包含各省、市、区的编码数据的在线静态资源脚本

在翻《SpringBootVue3》——十三尼克陈作者的大型前后端分离项目实战里面&#xff0c;在看到地址管理的部分时&#xff0c;发现了该作者记录有一个静态的地址资源脚本 这里做个记录&#xff0c;打点 一、引入js <script src"https://s.yezgea02.com/1641120061385/td…

2024王道考研计算机组成原理——指令系统

零、本章概要 指令寻址&#xff1a;解决的是PC"1"的问题 数据寻址&#xff1a;使用寄存器/内存/结合 基址寻址&#xff1a;用于多道程序的并发执行 直接寻址&#xff1a;call 0x12345678 变址寻址&#xff1a;esi edi用于循环&#xff0c;因为使用直接寻址需要一堆…

超市商品管理系统 JAVA语言设计实现

目录 一、系统介绍 二、系统下载 三、系统截图 一、系统介绍 基于VueSpringBootMySQL的超市商品管理系统&#xff0c;超市区域模块、超市货架模块、商品类型模块、商品档案模块&#xff0c;分为用户网页端和管理后台&#xff0c;基于角色的访问控制&#xff0c;可将权限精确…

【Java】<泛型>,在编译阶段约束操作的数据结构,并进行检查。

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ JAVA泛型 泛型介绍&#xff1a; ①泛型&#…

HTML+CSS+JS+Django 实现前后端分离的科学计算器、利率计算器(附全部代码在gitcode链接)

&#x1f9ee;前后端分离计算器 &#x1f4da;git仓库链接和代码规范链接&#x1f4bc;PSP表格&#x1f387;成品展示&#x1f3c6;&#x1f3c6;科学计算器&#xff1a;1. 默认界面与页面切换2. 四则运算、取余、括号3. 清零Clear 回退Back4. 错误提示 Error5. 读取历史记录Hi…

2023年中职组“网络安全”赛项云南省竞赛任务书

2023年中职组“网络安全”赛项 云南省竞赛任务书 一、竞赛时间 总计&#xff1a;360分钟 竞赛阶段 竞赛阶段 任务阶段 竞赛任务 竞赛时间 分值 A模块 A-1 登录安全加固 180分钟 200分 A-2 本地安全策略配置 A-3 流量完整性保护 A-4 事件监控 A-5 服务加固…

VSCode 自动格式化

1.打开应用商店&#xff0c;搜索 prettier code formatter &#xff0c;选择第一个&#xff0c;点击安装。 2.安装完成后&#xff0c;点击文件&#xff0c;选择首选项&#xff0c;选择设置。 3.在搜索框内输入 save &#xff0c;勾选在保存时格式化文件。 4.随便打开一个文件&a…

Access denied for user ‘root‘@‘localhost‘ (using password:YES) 解决方案(禅道相关)

如果是忘记Mysql密码或更改权限后访问不了的问题请直接跳转以下链接&#xff1a; MySQL登录时出现Access denied for user ‘root‘‘localhost‘ (using password: YES)无法打开的解决方法 关于这个问题&#xff0c;网上查到的解决方法基本都是因为忘记Mysql密码或者用户权限问…

【Java 进阶篇】Java XML解析:从入门到精通

XML&#xff08;可扩展标记语言&#xff09;是一种常用的数据格式&#xff0c;用于存储和交换数据。在Java中&#xff0c;XML解析是一项重要的任务&#xff0c;它允许您从XML文档中提取和操作数据。本篇博客将从基础开始&#xff0c;详细介绍如何在Java中解析XML文档&#xff0…

前端AJAX入门到实战,学习前端框架前必会的(ajax+node.js+webpack+git)(二)

阳光总在风雨后&#xff0c;请相信有彩虹。 案例 - 图书管理 bootstrap弹框 需求&#xff0c;点击添加按钮&#xff0c;没有离开当前页面&#xff0c;在当前页面弹出弹框&#xff08;弹窗&#xff09; 先学着实现一个简单的弹框&#xff0c;如下图右下角 bootstrap有两种方式…