[pytorch、学习] - 5.9 含并行连结的网络(GoogLeNet)

参考

5.9 含并行连结的网络(GoogLeNet)

在2014年的ImageNet图像识别挑战赛中,一个名叫GoogLeNet的网络结构大放异彩。它虽然在名字上向LeNet致敬,但在网络结构上已经很难看到LeNet的影子。GoogLeNet吸收了NiN中网络串联网络的思想,并在此基础上做了很大改进。在随后的几年里,研究人员对GoogLeNet进行了数次改进,本节将介绍这个模型系列的第一个版本。

5.9.1 Inception块

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fn8Tzn2Z-1594258539068)(attachment:image.png)]

Inception块中可以自定义的超参数是每个层的输出通道,我们以此来控制模型的复杂度

import time
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device("cuda" if torch.cuda.is_available() else 'cpu')class Inception(nn.Module):def __init__(self, in_c, c1, c2, c3, c4):super(Inception, self).__init__()# 路线1: 单 1 x 1 卷积层self.p1_1 = nn.Conv2d(in_c, c1, kernel_size =1)# 路线2: 1 x 1 卷积层后接3 x  3卷积层self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size = 1)self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size = 3, padding = 1)# 路线3: 1 x 1卷积层后接5 x 5卷积层self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size = 1)self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size = 5, padding = 2)# 路线4: 3 x 3 最大池化层后接1 x 1卷积层self.p4_1 = nn.MaxPool2d(kernel_size = 3, stride = 1, padding = 1)self.p4_2 = nn.Conv2d(in_c, c4, kernel_size = 1)def forward(self, x):p1 = F.relu(self.p1_1(x))p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))p4 = F.relu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim = 1)  # 在通道维上连结输出

5.9.2 GoogLeNet模型

GoogLeNet跟VGG一样,在主体卷积部分中使用5个模块(block),每个模块之间使用步幅为2的3x3最大池化层来减少输出宽高。第一模块使用一个64通道的7x7卷积层。

b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size =7, stride =2, padding =3),nn.ReLU(),nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
)

第二模块使用2个卷积层: 首先通过64通道的 1x1 卷积层,然后是将通道增大3倍的 3x3 卷积层。它对应Inception块中的第二条线路。

b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size = 1),nn.Conv2d(64, 192, kernel_size = 3, padding =1),nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
)

第三模块串联2个完整的Inception块。第一个Inception块的输出通道数为 64 + 128 + 32 + 32 = 256,其中4条路线的输出通道数比例为 64: 128 : 32: 32 = 2: 4: 1: 1。其中第二、第三条路线先分别将输入通道数减少至96/192 = 1/2 和 16/192 = 1/ 12后,再接上第二层卷积层。第二个Inception块输出通道数增至 128 +192 +96 + 64 = 480,每条线路的输出通道数之比为 128: 192: 96: 64 = 4: 6 : 3 : 2。其中第二、第三条路线先分别将输入通道数减少至 128/ 256 = 1/2 和 32/ 256 = 1/8。

b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),Inception(256, 128, (128, 192),(32, 96), 64),nn.MaxPool2d(kernel_size = 3, stride =2, padding =1)
)

第四模块更加复杂。它串联了5个Inception块,其输出通道数分别是192 + 208 + 48 + 64 = 512、160 + 224 + 64 + 64 = 512、128 + 256 + 64 + 64 = 512、 112 + 288 + 64 + 64 = 528 和 256 + 320 + 128 + 128 = 832。这些线路的通道数分配和第三模块中类似,首先含3x3卷积层的第二条线路输出最多通道,其次是仅含1x1卷积层的第一条线路,之后是含5x5卷积层的第3条线路和含3x3最大池化层的第4条线路。其中第二、第三条线路都会按比例减小通道数。这些比例在各个Inception块中都略有不同。

b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),Inception(512, 160, (112, 224), (24, 64), 64),Inception(512, 128, (128, 256), (24, 64), 64),Inception(512, 112, (144, 288), (32, 64), 64),Inception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size = 3, stride=2, padding =1)
)

第五模块有输出通道数为 256 + 320 + 128 + 128 = 832 和 834 + 384 + 128 + 128 = 1024的两个Inception块。其中每条线路的通道数的分配思路和第三、第四模块中的一致,只是在具体数值上有所不同。需要注意的是,第五模块的后面紧跟输出层,该模块同NiN一样使用全局平均池化层来将每个通道的宽和高变成1。最后我们将输出变成二维数组后接上一个输出个数为标签类别数的全连接层。

b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48, 128), 128),d2l.GlobalAvgPool2d()
)net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(), nn.Linear(1024, 10))print(net)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
将输入的高和宽从224降到96来简化计算,下面演示各个模块之间的输出的形状变化

net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(), nn.Linear(1024, 10))
X = torch.rand(1, 1, 96, 96)
for blk in net.children():X = blk(X)print("output shape: ", X.shape)

在这里插入图片描述

5.9.3 获取数据和训练模型

batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250117.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mybits注解详解

一、mybatis 简单注解 关键注解词 : Insert : 插入sql , 和xml insert sql语法完全一样 Select : 查询sql, 和xml select sql语法完全一样 Update : 更新sql, 和xml update sql语法完全一样 Delete : 删除sql, 和xml d…

使用python装饰器计算函数运行时间的实例

使用python装饰器计算函数运行时间的实例 装饰器在python里面有很重要的作用, 如果能够熟练使用,将会大大的提高工作效率 今天就来见识一下 python 装饰器,到底是怎么工作的。 本文主要是利用python装饰器计算函数运行时间 一些需要精确的计算…

SQLServer用存储过程实现插入更新数据

实现 1)有同样的数据,直接返回(返回值:0)。 2)有主键同样。可是数据不同的数据。进行更新处理(返回值:2); 3)没有数据,进行插入数据处…

[pytorch、学习] - 9.1 图像增广

参考 9.1 图像增广 在5.6节(深度卷积神经网络)里我们提过,大规模数据集是成功应用神经网络的前提。图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不相同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以…

mysql绿色版安装

导读:MySQL是一款关系型数据库产品,官网给出了两种安装包格式:MSI和ZIP。MSI格式是图形界面安装方式,基本只需下一步即可,这篇文章主要介绍ZIP格式的安装过程。ZIP Archive版是免安装的。只要解压就行了。 一、首先下…

在微信浏览器字体被调大导致页面错乱的解决办法

iOS的解决方案是覆盖掉微信的样式: body { /* IOS禁止微信调整字体大小 */-webkit-text-size-adjust: 100% !important; } 安卓的解决方案是通过 WeixinJSBridge 对象将网页的字体大小设置为默认大小,并且重写设置字体大小的方法,让用户不能在…

[pytorch、学习] - 9.2 微调

参考 9.2 微调 在前面得一些章节中,我们介绍了如何在只有6万张图像的Fashion-MNIST训练数据集上训练模型。我们还描述了学术界当下使用最广泛规模图像数据集ImageNet,它有超过1000万的图像和1000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。 假设我们想从图…

Springboot默认加载application.yml原理

Springboot默认加载application.yml原理以及扩展 SpringApplication.run(…)默认会加载classpath下的application.yml或application.properties配置文件。公司要求搭建的框架默认加载一套默认的配置文件demo.properties,让开发人员实现“零”配置开发,但…

java 集合(Set接口)

Set接口:无序集合,不允许有重复值,允许有null值 存入与取出的顺序有可能不一致 HashSet:具有set集合的基本特性,不允许重复值,允许null值 底层实现是哈希表结构 初始容量为16 保存自定义对象时,保证数据的唯…

关于mac机抓包的几点基础知识

1. 我使用的抓包工具为WireShark,以下操作按我当前的版本(Version 2.6.1)做的,以前的版本或者以后的版本可能有稍微的区别。 2. 将mac设置为热点:打开系统偏好设置,点击共享: 然后点击WIFI选项,设置WIFI名…

SpringBoot启动如何加载application.yml配置文件

一、前言 在spring时代配置文件的加载都是通过web.xml配置加载的(Servlet3.0之前)&#xff0c;可能配置方式有所不同&#xff0c;但是大多数都是通过指定路径的文件名的形式去告诉spring该加载哪个文件&#xff1b; <context-param><param-name>contextConfigLocat…

[github] - git使用小结(分支拉取、版本回退)

1. 首次(fork项目之后) $ git clone [master] $ git branch -a $ git checkout -b [自己的分支名] [远程仓库的分支名]克隆的是主干网络 2. 再次拉取代码 $ git pull [master下选择分支名] [分支名] $ git push origin HEAD:[分支名]拉取首先得进入主仓(不是自己的远程仓)然后…

MYSQL 查看最大连接数和修改最大连接数

MySQL查看最大连接数和修改最大连接数 1、查看最大连接数show variables like %max_connections%;2、修改最大连接数set GLOBAL max_connections 200; 以下的文章主要是向大家介绍的是MySQL最大连接数的修改&#xff0c;我们大家都知道MySQL最大连接数的默认值是100, 这个数值…

阿里云服务器端口开放对外访问权限

登陆阿里云管理控制台 点击自己的实例 点击安全组配置 点击配置规则 点击添加安全组规则 配置出入放心&#xff0c;和开放的端口号&#xff0c;以及那些网段可以访问&#xff0c;这里设置所有网段都可以访问 转自&#xff1a;https://jingyan.baidu.com/article/95c9d20d624d1e…

PageHelper工作原理

数据分页功能是我们软件系统中必备的功能&#xff0c;在持久层使用mybatis的情况下&#xff0c;pageHelper来实现后台分页则是我们常用的一个选择&#xff0c;所以本文专门类介绍下。 PageHelper原理 相关依赖 <dependency><groupId>org.mybatis</groupId>&…

10-多写一个@Autowired导致程序崩了

再是javaweb实验六中&#xff0c;是让我们改代码&#xff0c;让它跑起来&#xff0c;结果我少注释了一个&#xff0c;导致一直报错&#xff0c;检查许久没有找到&#xff0c;最后通过代码替换逐步查找&#xff0c;才发现问题。 转载于:https://www.cnblogs.com/zhumengdexiaoba…

Java class不分32位和64位

1、32位JDK编译的java class在32位系统和64位系统下都可以运行&#xff0c;64位系统兼容32位程序&#xff0c;可以理解。2、无论是Linux还是Windows平台下的JDK编译的java class在Linux、Windows平台下通用&#xff0c;Java跨平台特性。3、64位JDK编译的java class在32位的系统…

包装对象

原文地址&#xff1a;https://wangdoc.com/javascript/ 定义 对象是JavaScript语言最主要的数据类型&#xff0c;三种原始类型的值--数值、字符串、布尔值--在一定条件下&#xff0c;也会自动转为对象&#xff0c;也就是原始类型的包装对象。所谓包装对象&#xff0c;就是分别与…

[C++] 转义序列

参考 C Primer(第5版)P36 名称转义序列换行符\n横向制表符\t报警(响铃)符\a纵向制表符\v退格符\b双引号"反斜杠\问号?单引号’回车符\r进纸符\f

vue使用(二)

本节目标&#xff1a; 1.数据路径的三种方式 2.{{}}和v-html的区别 1.绑定图片的路径 方法一&#xff1a;直接写路径 <img src"http://pic.baike.soso.com/p/20140109/20140109142534-188809525.jpg"> 方法二&#xff1a;在data中写路径&#xff0c;在…