《动手学深度学习 Pytorch版》 7.3 网络中的网络(NiN)

LeNet、AlexNet和VGG的设计模式都是先用卷积层与汇聚层提取特征,然后用全连接层对特征进行处理。

AlexNet和VGG对LeNet的改进主要在于扩大和加深这两个模块。网络中的网络(NiN)则是在每个像素的通道上分别使用多层感知机。

import torch
from torch import nn
from d2l import torch as d2l

7.3.1 NiN

NiN的想法是在每个像素位置应用一个全连接层。 如果我们将权重连接到每个空间位置,我们可以将其视为 1 × 1 1\times 1 1×1 卷积层,即是作为在每个像素位置上独立作用的全连接层。 从另一个角度看,是将空间维度中的每个像素视为单个样本,将通道维度视为不同特征(feature)。

NiN块以一个普通卷积层开始,后面是两个 1 × 1 1\times 1 1×1 的卷积层。这两个卷积层充当带有ReLU激活函数的逐像素全连接层。

def nin_block(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

7.3.2 NiN 模型

最初的 NiN 网络是在 AlexNet 后不久提出的,显然 NiN 网络是从 AlexNet 中得到了一些启示的。 NiN 使用窗口形状为 11 × 11 11\times 11 11×11 5 × 5 5\times 5 5×5 3 × 3 3\times 3 3×3 的卷积层,输出通道数量与 AlexNet 中的相同。每个NiN块后有一个最大汇聚层,汇聚窗口形状为 3 × 3 3\times 3 3×3 ,步幅为 2。

NiN 和 AlexNet 之间的显著区别是 NiN 使用一个 NiN 块取代了全连接层。其输出通道数等于标签类别的数量。最后放一个全局平均汇聚层,生成一个对数几率。

NiN 设计的一个优点是显著减少了模型所需参数的数量。然而,在实践中,这种设计有时会增加训练模型的时间。

在这里插入图片描述

net = nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, strides=1, padding=1),nn.AdaptiveAvgPool2d((1, 1)),# 将四维的输出转成二维的输出,其形状为(批量大小,10)nn.Flatten())
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

7.3.3 训练模型

lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要二十五分钟,慎跑
loss 0.600, train acc 0.769, test acc 0.775
447.9 examples/sec on cuda:0

在这里插入图片描述

练习

(1)调整 NiN 的超参数,以提高分类准确性。

net2 = nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),nin_block(384, 10, kernel_size=3, strides=1, padding=1),nn.AdaptiveAvgPool2d((1, 1)),nn.Flatten())lr, num_epochs, batch_size = 0.15, 12, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net2, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要三十分钟,慎跑
loss 0.353, train acc 0.871, test acc 0.884
449.5 examples/sec on cuda:0

在这里插入图片描述

学习率调大一点点之后精度更高了,但是波动变的分外严重。


(2)为什么 NiN 块中有两个 1 × 1 1\times 1 1×1 的卷积层?删除其中一个,然后观察和分析实验现象。

def nin_block2(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())net3 = nn.Sequential(nin_block2(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block2(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block2(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),# 标签类别数是10nin_block2(384, 10, kernel_size=3, strides=1, padding=1),nn.AdaptiveAvgPool2d((1, 1)),# 将四维的输出转成二维的输出,其形状为(批量大小,10)nn.Flatten())lr, num_epochs, batch_size = 0.15, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
d2l.train_ch6(net3, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())  # 大约需要二十分钟,慎跑
loss 0.309, train acc 0.884, test acc 0.890
607.5 examples/sec on cuda:0

在这里插入图片描述

有时候会更好,有时候会不收敛。


(3)计算 NiN 的资源使用情况。

a. 参数的数量是多少?b. 计算量是多少?c. 训练期间需要多少显存?d. 预测期间需要多少显存?

a. 参数数量:

[ 11 × 11 + 2 ] + [ 5 × 5 + 2 ] + [ 3 × 3 + 2 ] + [ 3 × 3 + 2 ] = 123 + 27 + 11 + 11 = 172 \begin{align} &[11\times 11 + 2] + [5\times 5 + 2] + [3\times 3 + 2] + [3\times 3 + 2]\\ =& 123+27+11+11\\ =& 172 \end{align} ==[11×11+2]+[5×5+2]+[3×3+2]+[3×3+2]123+27+11+11172

b. 计算量:

{ [ ( 224 − 11 + 4 ) / 4 ] 2 × 1 1 2 × 96 + 22 4 2 × 2 } + [ ( 26 − 5 + 2 + 1 ) 2 × 5 2 × 96 × 256 + 2 6 2 × 2 ] + [ ( 12 − 3 + 1 + 1 ) 2 × 3 2 × 256 × 384 + 1 2 2 × 2 ] + [ ( 5 − 3 + 1 + 1 ) 2 × 3 2 × 384 × 10 + 5 2 × 2 ] = 34286966 + 353895752 + 107053344 + 553010 = 495789072 \begin{align} &\{[(224-11+4)/4]^2\times 11^2\times 96 + 224^2\times 2\} + [(26-5+2+1)^2\times 5^2\times 96\times 256 + 26^2\times 2] + \\ &[(12-3+1+1)^2\times 3^2\times 256\times 384 + 12^2\times 2]+[(5-3+1+1)^2\times 3^2\times 384\times 10 + 5^2\times 2]\\ =&34286966+353895752+107053344+553010\\ =&495789072 \end{align} =={[(22411+4)/4]2×112×96+2242×2}+[(265+2+1)2×52×96×256+262×2]+[(123+1+1)2×32×256×384+122×2]+[(53+1+1)2×32×384×10+52×2]34286966+353895752+107053344+553010495789072


(4)一次性直接将 384 × 5 × 5 384\times 5\times 5 384×5×5 的表示压缩为 10 × 5 × 5 10\times 5\times 5 10×5×5 的表示,会存在哪些问题?

压缩太快可能导致特征损失过多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/85578.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

科技云报道:云安全的新战场上,如何打破“云威胁”的阴霾?

科技云报道原创。 近年来,在云计算和网络安全产业的蓬勃发展下,我国云安全行业市场规模呈现高速增长态势,在网络安全市场总体规模中占比不断上升。 据统计,近5年我国云安全市场保持高速增长,2021年我国云安全市场规模…

代码随想录算法训练营第58天|739. 每日温度,496.下一个更大元素 I (单调栈开始)

链接: 739. 每日温度 链接: 496.下一个更大元素 I 739. 每日温度 单调栈入门题 这题的关键时保证了栈内所有的元素都是单调递增的&#xff08;单调栈&#xff09; class Solution {public int[] dailyTemperatures(int[] temperatures) {Deque<Integer> stack new L…

(25)(25.1) 光学流量传感器的测试和设置

文章目录 25.1.1 测试传感器 25.1.2 校准传感器 25.1.3 测距传感器检查 25.1.4 预解锁检查 25.1.5 首次飞行 25.1.6 第二次飞行 25.1.7 正常操作设置 25.1.8 视频示例&#xff08;Copter-3.4&#xff09; 25.1.9 空中校准 25.1.1 测试传感器 将传感器连接至自动驾驶仪…

uniapp-引入模块失效的问题

在前段时间的开发中&#xff0c;我遇到了一个bug 就是引入了一个模块但是却没有生效&#xff0c;由于静态页面不是我完成的&#xff0c;所以我花了一些时间才发现问题所在 1. 查看是否忘记添加components components: {addCart,myCart, }, 2. 查看是否引入路径有误 Compone…

【C语言】指针的进阶(四)—— 企业笔试题解析

笔试题1&#xff1a; int main() {int a[5] { 1, 2, 3, 4, 5 };int* ptr (int*)(&a 1);printf("%d,%d", *(a 1), *(ptr - 1));return 0; } 【答案】在x86环境下运行 【解析】 &a是取出整个数组的地址&#xff0c;&a就表示整个数组&#xff0c;因此…

Biome-BGC生态系统模型与Python融合技术

Biome-BGC是利用站点描述数据、气象数据和植被生理生态参数&#xff0c;模拟日尺度碳、水和氮通量的有效模型&#xff0c;其研究的空间尺度可以从点尺度扩展到陆地生态系统。 在Biome-BGC模型中&#xff0c;对于碳的生物量积累&#xff0c;采用光合酶促反应机理模型计算出每天…

使用Chatgpt编写的PHP数据库pdo操作类(增删改查)

摘要 将PDO封装成PHP类进行调用有很多好处&#xff0c;包括&#xff1a; 1、封装性和抽象性&#xff1a; 通过将PDO封装到一个类中&#xff0c;您可以将数据库操作逻辑与应用程序的其他部分分离开来&#xff0c;提高了代码的组织性和可维护性。这样&#xff0c;您只需在一个地…

soildwork2022怎么恢复软件界面的默认设置?

1.点击菜单中的” 视图” 2.在弹出的子菜单中选择”工作区” 3.选择工作区中的”默认” 4.点击默认后软件界面就恢复了默认设置。

FPGA 图像缩放 千兆网 UDP 网络视频传输,基于B50610 PHY实现,提供工程和QT上位机源码加技术支持

目录 1、前言版本更新说明免责声明 2、相关方案推荐UDP视频传输--无缩放FPGA图像缩放方案我这里已有的以太网方案 3、设计思路框架视频源选择IT6802解码芯片配置及采集动态彩条跨时钟FIFO图像缩放模块详解设计框图代码框图2种插值算法的整合与选择 UDP协议栈UDP视频数据组包UDP…

二叉树层序遍历及判断完全二叉树

个人主页:Lei宝啊 愿所有美好如期而遇 目录 二叉树层序遍历&#xff1a; 判断完全二叉树&#xff1a; 二叉树层序遍历&#xff1a; 层序遍历就是一层一层&#xff0c;从上到下遍历&#xff0c;上图遍历结果为&#xff1a;4 2 7 1 3 6 9 思路&#xff1a; 通过队列来实现层序…

发送实时音频数据到udp服务

由于浏览器不能直接连接udp服务&#xff0c;所以需要搭建一个websocket服务做中转&#xff0c;让websocket服务连接udp服务 1、vue开发获取实时音频数据并按4096分包后添加rtp协议头发送到websocket服务&#xff08;连接websocket自行编写连接到127.0.0.1:8889&#xff09; da…

Flowable 工作流 删除任务

文章目录 任务删减重点表说明ACT_RU_EXECUTION 查询SQL删除方法运行中数据历史数据 包容网关任务删减多实例节点任务删多实例节点生产数据说明 任务删减 重点表说明 ACT_RU_EXECUTION ACT_RU_EXECUTION 表(执行实例表&#xff09;&#xff1a;这个表和act_run_task表&#x…

[k8s] kubectl port-forward 和kubectl expose的区别

kubectl port-forward 和 kubectl expose 是 Kubernetes 命令行工具 kubectl 提供的两种不同方式来公开服务。 kubectl port-forward kubectl port-forward 命令用于在本地主机和集群内部的 Pod 之间建立一个临时的端口转发通道。 该命令将本地机器上的一个端口绑定到集群内部…

scrapy框架--

Scrapy是一个用于爬取数据的Python框架。下面是Scrapy框架的基本操作步骤&#xff1a; 安装Scrapy&#xff1a;首先&#xff0c;确保你已经安装好了Python和pip。然后&#xff0c;在命令行中运行以下命令安装Scrapy&#xff1a;pip install scrapy 创建Scrapy项目&#xff1a;…

购物H5商城架构运维之路

一、引言 公司属于旅游行业&#xff0c;需要将旅游&#xff0c;酒店&#xff0c;购物&#xff0c;聚合到线上商城。通过对会员数据进行聚合&#xff0c;形成大会员系统&#xff0c;从而提供统一的对客窗口。 二、业务场景 围绕更加有效地获取用户&#xff0c;提升用户的LTV&a…

Linphone3.5.2 ARM RV1109音视频对讲开发记录

Linphone3.5.2 ARM RV1109音视频对讲开发记录 说明 这是一份事后记录&#xff0c;主要记录的几个核心关键点&#xff0c;有可能很多细节没有记上&#xff0c;主要是方便后面自己再找回来! 版本 3.5.2 一些原因选的是这样一个旧的版本&#xff01; 新的开发最好选新一些的版…

Python线程和进程

1、深度解析Python线程和进程 一篇文章带你深度解析Python线程和进程 - 知乎使用Python中的线程模块&#xff0c;能够同时运行程序的不同部分&#xff0c;并简化设计。如果你已经入门Python&#xff0c;并且想用线程来提升程序运行速度的话&#xff0c;希望这篇教程会对你有所帮…

stm32之看门狗

STM32 有两个看门狗&#xff0c;独立看门狗和窗口看门狗&#xff0c;独立看门狗又称宠物狗&#xff0c;窗 口看门狗又称警犬。可用来检测和解决由软件错误引起的故障。两个看门狗的原理都是当计数器达到给定的超时值时&#xff0c;产生系统复位&#xff0c;对于窗口型看门狗同…

FL Studio21.2中文版数字音乐制作软件

现在的FL也可以像splice一样啦&#xff0c;需要什么样的声音只需在fl里搜索&#xff0c;就会自动展示给你! FL Studio 简称FL&#xff0c;全称&#xff1a;Fruity Loops Studio&#xff0c;国人习惯叫它"水果"。软件现有版本是 FL Studio 21&#xff0c;已全面升级支…

如何利用播放器节省20%点播成本

点播成本节省的点其实涉及诸多部分&#xff0c;例如&#xff1a;CDN、转码、存储等&#xff0c;而利用播放器降本却是很多客户比较陌生的部分。火山引擎基于内部支撑抖音集团相关业务的实践&#xff0c;播放器恰恰是成本优化中最重要和最为依赖的部分。 火山引擎的视频团队做了…