基于Pytorch构建DenseNet网络对cifar-10进行分类

DenseNet是指Densely connected convolutional networks(密集卷积网络)。它的优点主要包括有效缓解梯度消失、特征传递更加有效、计算量更小、参数量更小、性能比ResNet更好。它的缺点主要是较大的内存占用。

DenseNet网络与Resnet、GoogleNet类似,都是为了解决深层网络梯度消失问题的网络。

Resnet从深度方向出发,通过建立前面层与后面层之间的“短路连接”或“捷径”,从而能训练出更深的CNN网络。

GoogleNet从宽度方向出发,通过Inception(利用不同大小的卷积核实现不同尺度的感知,最后进行融合来得到图像更好的表征)。

DenseNet从特征入手,通过对前面所有层与后面层的密集连接,来极致利用训练过程中的所有特征,进而达到更好的效果和减少参数。

DenseNet网络

Dense Block:像GoogLeNet网络由Inception模块组成、ResNet网络由残差块(Residual Building Block)组成一样,DenseNet网络由Dense Block组成,论文截图如下所示:每个层从前面的所有层获得额外的输入,并将自己的特征映射传递到后续的所有层,使用级联(Concatenation)方式,每一层都在接受来自前几层的”集体知识(collective knowledge)”。增长率(growth rate)k是每个层的额外通道数。

58c8038c8e5f0cf7dea34eb09bd15c88.png

其实说了那么多我也不大明白原理和数学推理,只需要按照相关代码做就行了

class Bottleneck(nn.Module):def __init__(self, input_channel, growth_rate):super(Bottleneck, self).__init__()self.bn1 = nn.BatchNorm2d(input_channel)self.relu1 = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(input_channel, 4 * growth_rate, kernel_size=1)self.bn2 = nn.BatchNorm2d(4 * growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1)def forward(self, x):out = self.conv1(self.relu1(self.bn1(x)))out = self.conv2(self.relu2(self.bn2(out)))out = torch.cat([out, x], 1)return out
class Transition(nn.Module):def __init__(self, input_channels, out_channels):super(Transition, self).__init__()self.bn = nn.BatchNorm2d(input_channels)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(input_channels, out_channels, kernel_size=1)def forward(self, x):out = self.conv(self.relu(self.bn(x)))out = F.avg_pool2d(out, 2)return out
class DenseNet(nn.Module):def __init__(self, nblocks, growth_rate, reduction, num_classes):super(DenseNet, self).__init__()self.growth_rate = growth_ratenum_planes = 2 * growth_rateself.basic_conv = nn.Sequential(nn.Conv2d(3, 2 * growth_rate, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(2 * growth_rate),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.dense1 = self._make_dense_layers(num_planes, nblocks[0])num_planes += nblocks[0] * growth_rateout_planes = int(math.floor(num_planes * reduction))self.trans1 = Transition(num_planes, out_planes)num_planes = out_planesself.dense2 = self._make_dense_layers(num_planes, nblocks[1])num_planes += nblocks[1] * growth_rateout_planes = int(math.floor(num_planes * reduction))self.trans2 = Transition(num_planes, out_planes)num_planes = out_planesself.dense3 = self._make_dense_layers(num_planes, nblocks[2])num_planes += nblocks[2] * growth_rateout_planes = int(math.floor(num_planes * reduction))self.trans3 = Transition(num_planes, out_planes)num_planes = out_planesself.dense4 = self._make_dense_layers(num_planes, nblocks[3])num_planes += nblocks[3] * growth_rateself.AdaptiveAvgPool2d = nn.AdaptiveAvgPool2d(1)# 全连接层self.fc = nn.Sequential(nn.Linear(num_planes, 256),nn.ReLU(inplace=True),# 使一半的神经元不起作用,防止参数量过大导致过拟合nn.Dropout(0.5),nn.Linear(256, 128),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(128, 10))def _make_dense_layers(self, in_planes, nblock):layers = []for i in range(nblock):layers.append(Bottleneck(in_planes, self.growth_rate))in_planes += self.growth_ratereturn nn.Sequential(*layers)def forward(self, x):out = self.basic_conv(x)out = self.trans1(self.dense1(out))out = self.trans2(self.dense2(out))out = self.trans3(self.dense3(out))out = self.dense4(out)out = self.AdaptiveAvgPool2d(out)out = out.view(out.size(0), -1)out = self.fc(out)return out
def DenseNet121():return DenseNet([6, 12, 24, 16], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet169():return DenseNet([6, 12, 32, 32], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet201():return DenseNet([6, 12, 48, 32], growth_rate=32, reduction=0.5, num_classes=10)
def DenseNet265():return DenseNet([6, 12, 64, 48], growth_rate=32, reduction=0.5, num_classes=10)
# 初始化模型
from torchstat import stat
# 定义模型输出模式,GPU和CPU均可
model = DenseNet121().to(DEVICE)

在NVIDIA GeForce GTX 1660 SUPER显卡上训练了100轮,大致上一轮1分钟,这是DenseNet网络训练的损失率和准确率,在验证集也是保持80%的准确率。

fef4e1c3ccec7ba873ae14960d444595.png

DenseNet也是一个系列,包括DenseNet-121、DenseNet-169等等,论文中给出了4种层数的DenseNet,论文截图如下所示:所有网络的增长率k是32,表示每个Dense Block中每层输出的feature map个数。

410bfbcfc3a6141a28efec184547aa49.png

关于图像分类的模型算法,热情也没了,到此也就告一段落了,后续再讨论一些新的话题。

最后欢迎关注公众号:python与大数据分析

47d362ba65d9cc25fac0dea80aa05dc0.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/45596.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QChart:数据可视化(用图像形式显示数据内容)

1、数据可视化的图形有:柱状/线状/条形/面积/饼/点图、仪表盘、走势图,弦图、金字塔、预测曲线图、关系图、数学公式图、行政地图、GIS地图等。 2、在QT Creator的主页面,点击 欢迎》示例》右侧输入框 输入Chart,即可查看到QChar…

go es实例

go es实例 1、下载第三方库 go get github.com/olivere/elastic下载过程中出现如下报错: 解决方案: 2、示例 import package mainimport ("context""encoding/json""fmt""reflect""time""…

LabVIEW模拟化学反应器的工作

LabVIEW模拟化学反应器的工作 近年来,化学反应器在化学和工业过程领域有许多应用。高价值产品是通过混合产品,化学反应,蒸馏和结晶等多种工业过程转换原材料制成的。化学反应器通常用于大型加工行业,例如酿酒厂公司饮料产品的发酵…

微信小程序中如何动态添加 class 属性

在微信小程序中,你可以使用setData方法来动态添加class。首先,在你的页面的js文件中,定义一个变量来存储需要动态添加的class,例如: data: {dynamicClass: }然后,在需要动态添加class的地方,使…

list元素

列表元素 列表元素分为有序列表和无序列表 有序列表 ol – order list – 有序列表 li – list item – 列表元素 <ol type"1"><li>有序列表1</li><li>有序列表2</li><li>有序列表3</li> </ol>属性 type type属…

提示词4大经典框架;将AI融入动画工作流的案例和实践经验;构建基于LLM的系统和产品的模式;提示工程的艺术 | ShowMeAI日报

&#x1f440;日报&周刊合集 | &#x1f3a1;生产力工具与行业应用大全 | &#x1f9e1; 点赞关注评论拜托啦&#xff01; &#x1f916; 高效提示词的4大经典框架&#xff1a;ICIO、CRISPE、BROKE、RASCEF ICIO 框架 Intruction (任务) &#xff1a;你希望AI去做的任务&am…

2023年目标检测研究进展

综述 首先关于写这个笔记&#xff0c;我个人思考了很久关于以下几点。1&#xff1a;19年开始从做OCR用到图像和文本这种多模态联合处理的后&#xff0c;也就有意识的开始关注自然语言处理&#xff0c;这样的结果导致可能停留在前期图像上的学习和实践&#xff0c;停滞的研究如…

微服务中间件--Ribbon负载均衡

Ribbon负载均衡 a.Ribbon负载均衡原理b.Ribbon负载均衡策略 (IRule)c.Ribbon的饥饿加载 a.Ribbon负载均衡原理 1.发起请求http://userservice/user/1&#xff0c;Ribbon拦截该请求 2.Ribbon通过EurekaServer拉取userservice 3.EurekaServer返回服务列表给Ribbon做负载均衡 …

bug记录:微信小程序 给button使用all: initial重置样式

场景&#xff1a;通过uniapp开发微信小程序 &#xff0c;使用uview的u-popup弹窗&#xff0c;里面内嵌了一个原生button标签&#xff0c;因为微信小程序的button是有默认样式的&#xff0c;所以通过all: initial重置样式 。但是整个弹窗的点击事件都会被button上面的点击事件覆…

数据库结构差异对比工具

简介 前几年写了一个数据库对比工具&#xff0c;但是由于实现方式的原因&#xff0c;数据库支持有限&#xff0c;所以重新设计了一下&#xff0c;便于支持多种数据库&#xff0c;并且更新了UI。 新版地址&#xff1a;https://gitee.com/xgpxg/db-diff 旧版地址&#xff1a;h…

[K8s]问题描述:k8s拉起来的容器少了cuda的so文件

问题解决&#xff1a;需要设置Runtimes&#xff1a;nvidia的同时设置Default Runtimenvidia

NVIDIA Jetson 项目:机器人足球比赛

推荐&#xff1a;使用 NSDT场景编辑器 助你快速搭建可二次编辑器的3D应用场景 事实上&#xff0c;整个比赛都致力于这个想法。RoboCup小型联盟&#xff08;SSL&#xff09;视觉停电技术挑战赛鼓励团队“探索本地传感和处理&#xff0c;而不是非车载计算机和全球摄像机感知环境的…

go语言中channel类型

目录 一、什么是channel 二、为什么要有channel 三、channel操作使用 初始化 操作 单向channel 双向channel&#xff0c;可读可写 四、close下什么场景会出现panic 五、总结 一、什么是channel Channels are a typed conduit through which you can send and receive …

C# --- Struct and Record

C# --- Struct and Record StructRecord Struct struct是一种数据类型, 和class非常类似, 主要有以下的不同 struct是value type, class是reference type 因为是value type所以strcut不是必须储存在heap上struct不能等于null, The default value for a struct is an empty inst…

第6步---MySQL的控制流语句和窗口函数

第6步---MySQL的控制流语句和窗口函数 1.IF关键字 -- 控制流语句 SELECT IF(5>3,大于,小于);-- 会单独生成一列的 SELECT *,IF(score >90 , 优秀, 一般) 等级 FROM stu_score;-- IFNULL(expr1,expr2) SELECT id,name ,IFNULL(salary,0),dept_id FROM emp4;-- ISNULL() …

Java-类与对象(上)

什么是面向对象 Java是一门纯面向对象的语言(Object Oriented Program&#xff0c;简称OOP)&#xff0c;在面向对象的世界里&#xff0c;一切皆为对象。 面向对象是解决问题的一种思想&#xff0c;主要依靠对象之间的交互完成一件事情。 以面向对象方式来进行处理&#xff0c;就…

第6章 分布式文件存储

mini商城第6章 分布式文件存储 一、课题 分布式文件存储 二、回顾 1、理解Oauth2.0的功能作模式 2、实现mini商城项目的权限登录 三、目标 1、了解文件存储系统的概念 2、了解常用文件服务器的区别 3、掌握Minio的应用 四、内容 第1章 MinIO简介 官

CentOS 7重置root密码

CentOS 7 如何找回被您 遗忘得 root密码呢&#xff1f; 步骤如下&#xff1a; 步骤一&#xff1a;在开机出现如下界面的时候就按“e”键 步骤二&#xff1a;在步骤一按下”e”键之后&#xff0c;出现如下界面&#xff0c;按 ↓键一直到底部找到“LANGzh_CN.UTF-8”这句&…

【物联网无线通信技术】NFC从理论到实践(FM17XX)

NFC&#xff0c;全称是Near Field Communication&#xff0c;即“近场通信”&#xff0c;也叫“近距离无线通信”。NFC诞生于2004年&#xff0c;是基于RFID非接触式射频识别技术演变而来&#xff0c;由当时的龙头企业NXP(原飞利浦半导体)、诺基亚以及索尼联合发起。NFC采用13.5…

Excel VBA 复制除指定工作表外所有的工作表的内容到一张工作表中

当我们有一张表里面有很多sheet 具有相同的表结构&#xff0c;如果需要汇总到一张表中&#xff0c;那么我们可以借助VBA 去实现汇总自动化 Sub 复制所有工作表内容()Dim ws As WorksheetDim targetSheet As WorksheetDim lastRow As Long 设置目标表格&#xff0c;即要将所有…