【深度学习笔记】稠密连接网络(DenseNet)

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

5.12 稠密连接网络(DenseNet)

ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个:稠密连接网络(DenseNet) [1]。 它与ResNet的主要区别如图5.10所示。

在这里插入图片描述

图5.10 ResNet(左)与DenseNet(右)在跨层连接上的主要区别:使用相加和使用连结

图5.10中将部分前后相邻的运算抽象为模块 A A A和模块 B B B。与ResNet的主要区别在于,DenseNet里模块 B B B的输出不是像ResNet那样和模块 A A A的输出相加,而是在通道维上连结。这样模块 A A A的输出可以直接传入模块 B B B后面的层。在这个设计里,模块 A A A直接跟模块 B B B后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。

DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。

5.12.1 稠密块

DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构,我们首先在conv_block函数里实现这个结构。

import time
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def conv_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))return blk

稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。

class DenseBlock(nn.Module):def __init__(self, num_convs, in_channels, out_channels):super(DenseBlock, self).__init__()net = []for i in range(num_convs):in_c = in_channels + i * out_channelsnet.append(conv_block(in_c, out_channels))self.net = nn.ModuleList(net)self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数def forward(self, X):for blk in self.net:Y = blk(X)X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结return X

在下面的例子中,我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时,我们会得到通道数为 3 + 2 × 10 = 23 3+2\times 10=23 3+2×10=23的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。

blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])

5.12.2 过渡层

由于每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过 1 × 1 1\times1 1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。

def transition_block(in_channels, out_channels):blk = nn.Sequential(nn.BatchNorm2d(in_channels), nn.ReLU(),nn.Conv2d(in_channels, out_channels, kernel_size=1),nn.AvgPool2d(kernel_size=2, stride=2))return blk

对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。

blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])

5.12.3 DenseNet模型

我们来构造DenseNet模型。DenseNet首先使用同ResNet一样的单卷积层和最大池化层。

net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,我们可以设置每个稠密块使用多少个卷积层。这里我们设成4,从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽,并减半通道数。

num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]for i, num_convs in enumerate(num_convs_in_dense_blocks):DB = DenseBlock(num_convs, num_channels, growth_rate)net.add_module("DenseBlosk_%d" % i, DB)# 上一个稠密块的输出通道数num_channels = DB.out_channels# 在稠密块之间加入通道数减半的过渡层if i != len(num_convs_in_dense_blocks) - 1:net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))num_channels = num_channels // 2

同ResNet一样,最后接上全局池化层和全连接层来输出。

net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) 

我们尝试打印每个子模块的输出维度确保网络无误:

X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():X = layer(X)print(name, ' output shape:\t', X.shape)

输出:

0  output shape:	 torch.Size([1, 64, 48, 48])
1  output shape:	 torch.Size([1, 64, 48, 48])
2  output shape:	 torch.Size([1, 64, 48, 48])
3  output shape:	 torch.Size([1, 64, 24, 24])
DenseBlosk_0  output shape:	 torch.Size([1, 192, 24, 24])
transition_block_0  output shape:	 torch.Size([1, 96, 12, 12])
DenseBlosk_1  output shape:	 torch.Size([1, 224, 12, 12])
transition_block_1  output shape:	 torch.Size([1, 112, 6, 6])
DenseBlosk_2  output shape:	 torch.Size([1, 240, 6, 6])
transition_block_2  output shape:	 torch.Size([1, 120, 3, 3])
DenseBlosk_3  output shape:	 torch.Size([1, 248, 3, 3])
BN  output shape:	 torch.Size([1, 248, 3, 3])
relu  output shape:	 torch.Size([1, 248, 3, 3])
global_avg_pool  output shape:	 torch.Size([1, 248, 1, 1])
fc  output shape:	 torch.Size([1, 10])

5.12.4 获取数据并训练模型

由于这里使用了比较深的网络,本节里我们将输入高和宽从224降到96来简化计算。

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0020, train acc 0.834, test acc 0.749, time 27.7 sec
epoch 2, loss 0.0011, train acc 0.900, test acc 0.824, time 25.5 sec
epoch 3, loss 0.0009, train acc 0.913, test acc 0.839, time 23.8 sec
epoch 4, loss 0.0008, train acc 0.921, test acc 0.889, time 24.9 sec
epoch 5, loss 0.0008, train acc 0.929, test acc 0.884, time 24.3 sec

小结

  • 在跨层连接上,不同于ResNet中将输入与输出相加,DenseNet在通道维上连结输入与输出。
  • DenseNet的主要构建模块是稠密块和过渡层。

参考文献

[1] Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (Vol. 1, No. 2).


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/729862.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

5个实用的PyCharm插件

大家好,本文向大家推荐五个顶级插件,帮助开发人员提升PyCharm工作流程,将生产力飞升到新高度。 1.CodiumAI 安装链接:https://plugins.jetbrains.com/plugin/21206-codiumate--code-test-and-review-with-confidence--by-codium…

Windows上基于名称快速定位文件和文件夹的免费工具Everything

在Windows上搜索文件时,使用windows上内置搜索会很慢,这里推荐使用Everything工具进行搜索。 "Everything"是Windows上一款搜索引擎,它能够基于文件名快速定位文件和文件夹位置。不像Windows内置搜索,"Everything&…

容器:Docker部署

docker 是容器,可以将项目的环境(比如 java、nginx)和项目的代码一起打包成镜像,所有同学都能下载镜像,更容易分发和移植。 再启动项目时,不需要敲一大堆命令,而是直接下载镜像、启动镜像就可以…

echarts x轴名称过长tip显示全称

xAxis的axisLabel的内容如下: axisLabel: { rotate: -45, color: document.body.className.indexOf(custom-f4c46d) > -1 ? #fff : #343434, // 显示省略号操作(第一步) formatter: function (value) { var val if (value.length >…

NTP协议介绍

知识改变命运,技术就是要分享,有问题随时联系,免费答疑,欢迎联系! 网络时间协议NTP(Network Time Protocol)是TCP/IP协议族里面的一个应用层协议,用来使客户端和服务器之间进行时…

C while 循环

只要给定的条件为真,C 语言中的 while 循环语句会重复执行一个目标语句。 语法 C 语言中 while 循环的语法: while(condition) {statement(s); }在这里,statement(s) 可以是一个单独的语句,也可以是几个语句组成的代码块。 co…

IOS开发0基础入门UIkit-1cocoapod安装、更新和使用 , 安装中出现的错误及解决方案 M1或者M2安装cocoapods

cocoapod是ios开发时常用的包管理工具 1.M1或者是M2系统安装cocoapods先操作一下两个设置 1、打开访达->应用->实用工具->终端->右键点击终端->显示简介->勾选使用 Rosetta 打开,关闭终端,重新打开。 2、打开访达->应用->Xcod…

ApiPost设置预执行脚本获取token,并设置给请求头

ApiPost设置预执行脚本获取token,并设置给请求头 预执行脚本 这个地方获取字段为 {"msg": "操作成功","code": 200,"token": "eyJhbGciOixMiJ9.123-NQQPPKGr4Yxa1_H_JIrUXJQ" }修改head 里面参数

OpenAI劲敌吹新风! Claude 3正式发布,Claude3使用指南

Claude 3是什么? 是Anthropic 实验室近期推出的 Claude 3 大规模语言模型(Large Language Model,LLM)系列,代表了人工智能技术的一个显著飞跃。 该系列包括三个不同定位的子模型:Claude 3 Haiku、Claude 3…

BUUCTF-Misc3

LSB1 1.打开附件 得到一张图片,像是某个大学的校徽 2.Stegsolve工具 根据标题LSB,可能是LSB隐写 放到Stegsolve中,点Analyse在点Data Extract 数据提取 因为是LSB隐写,发现含以.png结尾的图片 3.保存图片 4.得到flag 扫描二维…

一招教你优化TCP提高大文件传输效率

在当今企业的数据传输实践中,传统的传输控制协议(TCP)在处理大型文件传输时,其固有的可靠性和复杂性有时会导致效率不足。为了提升大文件传输的效率,对TCP进行优化成为了一个关键任务。 TCP传输的可靠性是其核心优势&a…

UnityShader常用算法笔记(颜色叠加混合、RGB-HSV-HSL的转换、重映射、UV序列帧动画采样等,持续更新中)

一.颜色叠加混合 1.Blend混合 // 正常,透明度混合 Normal Blend SrcAlpha OneMinusSrcAlpha //柔和叠加 Soft Additive Blend OneMinusDstColor One //正片叠底 相乘 Multiply Blend DstColor Zero //两倍叠加 相加 2x Multiply Blend DstColor SrcColor //变暗…

聊聊 HTTP 性能优化

哈喽大家好,我是咸鱼。 作为用户的我们在 “上网冲浪” 的时候总是希望快一点,尤其是抢演唱会门票的时候,但是现实并非如此,有时候我们会遇到页面加载缓慢、响应延迟的情况。 而 HTTP 协议作为互联网世界的基础,从网…

穷人想赚钱该怎么选打工VS创业?2024年如何把握新机遇?

在贫穷的困境中,打工与创业似乎成为了两条截然不同的道路,摆在每一个渴望改变命运的人面前。然而,这并非简单的选择题,而是一场关于勇气、智慧与机遇的较量。打工,对于许多人来说,是稳定且相对安全的收入来…

Aigtek前置微小信号放大器有什么作用

前置微小信号放大器是一种被广泛应用于无线通信、雷达、射频等领域中的低噪声放大器。相较于传统的放大器,前置微小信号放大器具有更高的灵敏度和更低的噪声系数。下面安泰Aigtek将介绍前置微小信号放大器的作用和意义。 一、前置微小信号放大器的作用 放大弱信号 前…

C语言实现回调函数

C语言实现回调函数 一、回调函数概念1.1 什么叫函数指针 二、回调函数案例 一、回调函数概念 回调函数就是一个被作为参数传递的函数。在C语言中,回调函数只能使用函数指针实现,在C、Python、ECMAScript等更现代的编程语言中还可以使用仿函数或匿名函数…

IDEA启动项目读取nacos乱码导致启动失败

新安装的2023社区版IDEA,启动项目报错。 forest: interceptors: - com.gdsz.b2b.frontend.api.Interceptors.ApiInterceptor org.yaml.snakeyaml.error.YAMLException: java.nio.charset.MalformedInputException: Input length 1 at org.yaml.snakeyaml.reader.S…

Java面试题总结10之MySQL索引和锁

索引的基本原理 把无需的数据变成有序的查询 1,把创建了索引的列的内容进行排序 2,对排序结果生成倒排表 3,到倒排表内容上拼上数据地址链 4,在查询的时候,先拿到倒排表内容,再取出数据地址链&#xf…

7-3 前世档案(Python)

网络世界中时常会遇到这类滑稽的算命小程序,实现原理很简单,随便设计几个问题,根据玩家对每个问题的回答选择一条判断树中的路径(如下图所示),结论就是路径终点对应的那个结点。 现在我们把结论从左到右顺序…

基于Leatlet标注Geojson下载器实现

在上一篇文章中,我们学习了Leaflet的基础知识,包括如何创建地图、添加图层等。在本文中,我们将深入学习Leaflet中标注的创建和管理,包括如何添加标注、自定义标注图标、创建图层组、批量添加和删除标注、为标注添加属性和弹出框等…