[pytorch、学习] - 5.8 网络中的网络(NiN)

参考

5.8 网络中的网络(NiN)

前几节介绍的LeNet、AlexNet和VGG在设计上的共同之处是:先以由卷积层构成的模块充分抽取空间特征,再以由全连接层构成的模块来输出分类结果。其中,AlexNet和VGG对LeNet的改进主要在于如何对这两个模块加宽(增加通道数)和加深。本节我们介绍网络中的网络(NiN)。它提出了另外一个思路,即串联多个由卷积层和“全连接”层构成的小网络来构建一个深层网络。

5.8.1 NiN块

我们知道,卷积层的输入和输出通常是四维数组(样本, 通道, 高, 宽),而全连接层的输入和输出则通常是二维数组(样本、特征)。如果想在全连接层后再接上卷积层,则需要将全连接层的输出变成四维。
在这里插入图片描述

NiN块是NiN中的基础块。它由一个卷积层加两个充当全连接层的 1 * 1 卷积层串联而成。其中第一个卷积层的超参数可以自行设置,而第二和第三个卷积层的超参数一般是固定的。

import time
import torch
from torch import nn, optimimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def nin_block(in_channels, out_channels, kernel_size, stride, padding):blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())return blk

5.8.2 NiN模型

NiN是在AlexNet问世不久后提出的。它们的卷积层设定有类似之处。NiN使用卷积窗口形状分别为 11×11、 5×5和3×3的卷积层,相应的输出通道也与AlexNet中的一致。每个NiN块后接一个步幅为2、窗口形状为3×3的最大池化层。

除使用NiN块以外,NiN还有一个设计与AlexNet显著不同:NiN去掉了AlexNet最后的3个全连接层,取而代之地,NiN使用了输出通道数等于标签类别数的NiN块,然后使用全局平均池化层对每个通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。NiN的这个设计的好处是可以显著减小模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。

import torch.nn.functional as Fclass GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])net = nn.Sequential(nin_block(1, 96, kernel_size = 11, stride = 4, padding = 0),nn.MaxPool2d(kernel_size = 3, stride = 2),nin_block(96, 256, kernel_size = 5, stride = 1, padding = 2),nn.MaxPool2d(kernel_size = 3, stride = 2),nin_block(256, 384, kernel_size = 3, stride = 1, padding = 1),nn.MaxPool2d(kernel_size=3, stride =2),nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size = 3, stride=1, padding = 1),GlobalAvgPool2d(),# 将四维的输出转成二维的输出,其形状为(批量, 10)d2l.FlattenLayer()
)print(net)

在这里插入图片描述
构建数据观察每一层的结构

X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children():X = blk(X)print(name, 'output shape: ', X.shape)

在这里插入图片描述

5.8.3 获取数据和训练模型

batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250121.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springboot---集成mybits方法

SpringBoot集成mybatis配置 一个有趣的现象:传统企业大都喜欢使用hibernate,互联网行业通常使用mybatis;之所以出现这个问题感觉与对应的业务有关,比方说,互联网的业务更加的复杂,更加需要进行灵活性的处理&#xff0c…

jQuery源码解读

参考 : https://www.cnblogs.com/yuqingfamily/p/5785593.html 转载于:https://www.cnblogs.com/wfblog/p/9172622.html

info.plist文件里面添加描述 - 配置定位,相册等

<key>NSAppleMusicUsageDescription</key> <string>App需要您的同意,才能访问媒体资料库</string> <key>NSBluetoothPeripheralUsageDescription</key> <string>App需要您的同意,才能访问蓝牙</string> <key>NSCalendar…

[pytorch、学习] - 5.9 含并行连结的网络(GoogLeNet)

参考 5.9 含并行连结的网络&#xff08;GoogLeNet&#xff09; 在2014年的ImageNet图像识别挑战赛中&#xff0c;一个名叫GoogLeNet的网络结构大放异彩。它虽然在名字上向LeNet致敬&#xff0c;但在网络结构上已经很难看到LeNet的影子。GoogLeNet吸收了NiN中网络串联网络的思…

mybits注解详解

一、mybatis 简单注解 关键注解词 &#xff1a; Insert &#xff1a; 插入sql , 和xml insert sql语法完全一样 Select &#xff1a; 查询sql, 和xml select sql语法完全一样 Update &#xff1a; 更新sql, 和xml update sql语法完全一样 Delete &#xff1a; 删除sql, 和xml d…

使用python装饰器计算函数运行时间的实例

使用python装饰器计算函数运行时间的实例 装饰器在python里面有很重要的作用&#xff0c; 如果能够熟练使用&#xff0c;将会大大的提高工作效率 今天就来见识一下 python 装饰器&#xff0c;到底是怎么工作的。 本文主要是利用python装饰器计算函数运行时间 一些需要精确的计算…

SQLServer用存储过程实现插入更新数据

实现 1&#xff09;有同样的数据&#xff0c;直接返回&#xff08;返回值&#xff1a;0&#xff09;。 2&#xff09;有主键同样。可是数据不同的数据。进行更新处理&#xff08;返回值&#xff1a;2&#xff09;&#xff1b; 3&#xff09;没有数据&#xff0c;进行插入数据处…

[pytorch、学习] - 9.1 图像增广

参考 9.1 图像增广 在5.6节(深度卷积神经网络)里我们提过,大规模数据集是成功应用神经网络的前提。图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不相同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以…

mysql绿色版安装

导读&#xff1a;MySQL是一款关系型数据库产品&#xff0c;官网给出了两种安装包格式&#xff1a;MSI和ZIP。MSI格式是图形界面安装方式&#xff0c;基本只需下一步即可&#xff0c;这篇文章主要介绍ZIP格式的安装过程。ZIP Archive版是免安装的。只要解压就行了。 一、首先下…

在微信浏览器字体被调大导致页面错乱的解决办法

iOS的解决方案是覆盖掉微信的样式&#xff1a; body { /* IOS禁止微信调整字体大小 */-webkit-text-size-adjust: 100% !important; } 安卓的解决方案是通过 WeixinJSBridge 对象将网页的字体大小设置为默认大小&#xff0c;并且重写设置字体大小的方法&#xff0c;让用户不能在…

[pytorch、学习] - 9.2 微调

参考 9.2 微调 在前面得一些章节中,我们介绍了如何在只有6万张图像的Fashion-MNIST训练数据集上训练模型。我们还描述了学术界当下使用最广泛规模图像数据集ImageNet,它有超过1000万的图像和1000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。 假设我们想从图…

Springboot默认加载application.yml原理

Springboot默认加载application.yml原理以及扩展 SpringApplication.run(…)默认会加载classpath下的application.yml或application.properties配置文件。公司要求搭建的框架默认加载一套默认的配置文件demo.properties&#xff0c;让开发人员实现“零”配置开发&#xff0c;但…

java 集合(Set接口)

Set接口&#xff1a;无序集合&#xff0c;不允许有重复值&#xff0c;允许有null值 存入与取出的顺序有可能不一致 HashSet:具有set集合的基本特性&#xff0c;不允许重复值&#xff0c;允许null值 底层实现是哈希表结构 初始容量为16 保存自定义对象时&#xff0c;保证数据的唯…

关于mac机抓包的几点基础知识

1. 我使用的抓包工具为WireShark&#xff0c;以下操作按我当前的版本(Version 2.6.1)做的&#xff0c;以前的版本或者以后的版本可能有稍微的区别。 2. 将mac设置为热点&#xff1a;打开系统偏好设置&#xff0c;点击共享&#xff1a; 然后点击WIFI选项&#xff0c;设置WIFI名…

SpringBoot启动如何加载application.yml配置文件

一、前言 在spring时代配置文件的加载都是通过web.xml配置加载的(Servlet3.0之前)&#xff0c;可能配置方式有所不同&#xff0c;但是大多数都是通过指定路径的文件名的形式去告诉spring该加载哪个文件&#xff1b; <context-param><param-name>contextConfigLocat…

[github] - git使用小结(分支拉取、版本回退)

1. 首次(fork项目之后) $ git clone [master] $ git branch -a $ git checkout -b [自己的分支名] [远程仓库的分支名]克隆的是主干网络 2. 再次拉取代码 $ git pull [master下选择分支名] [分支名] $ git push origin HEAD:[分支名]拉取首先得进入主仓(不是自己的远程仓)然后…

MYSQL 查看最大连接数和修改最大连接数

MySQL查看最大连接数和修改最大连接数 1、查看最大连接数show variables like %max_connections%;2、修改最大连接数set GLOBAL max_connections 200; 以下的文章主要是向大家介绍的是MySQL最大连接数的修改&#xff0c;我们大家都知道MySQL最大连接数的默认值是100, 这个数值…

阿里云服务器端口开放对外访问权限

登陆阿里云管理控制台 点击自己的实例 点击安全组配置 点击配置规则 点击添加安全组规则 配置出入放心&#xff0c;和开放的端口号&#xff0c;以及那些网段可以访问&#xff0c;这里设置所有网段都可以访问 转自&#xff1a;https://jingyan.baidu.com/article/95c9d20d624d1e…

PageHelper工作原理

数据分页功能是我们软件系统中必备的功能&#xff0c;在持久层使用mybatis的情况下&#xff0c;pageHelper来实现后台分页则是我们常用的一个选择&#xff0c;所以本文专门类介绍下。 PageHelper原理 相关依赖 <dependency><groupId>org.mybatis</groupId>&…

10-多写一个@Autowired导致程序崩了

再是javaweb实验六中&#xff0c;是让我们改代码&#xff0c;让它跑起来&#xff0c;结果我少注释了一个&#xff0c;导致一直报错&#xff0c;检查许久没有找到&#xff0c;最后通过代码替换逐步查找&#xff0c;才发现问题。 转载于:https://www.cnblogs.com/zhumengdexiaoba…