分类神经网络3:DenseNet模型复现

目录

DenseNet网络架构

DenseNet部分实现代码


DenseNet网络架构

论文原址:https://arxiv.org/pdf/1608.06993.pdf

稠密连接神经网络(DenseNet)实质上是ResNet的进阶模型(了解ResNet模型请点击),二者均是通过建立前面层与后面层之间的“短路连接”,但不同的是,DenseNet建立的是前面所有层与后面层的密集连接,其一大特点是通过特征在通道上的连接来实现特征重用,这让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能。DenseNet 网络的模型结构如下:

DenseNet 的网络结构主要由DenseBlockTransition Layer组成。

DenseBlock:密集连接机制。互相连接所有的层,即每一层的输入都来自于它前面所有层的特征图,每一层的输出均会直接连接到它后面所有层的输入,这可以实现特征重用(即对不同“级别”的特征——不同表征进行总体性地再探索),提升效率。具体的连接方式如下图示:

在同一个DenseBlock当中,特征层的高宽不会发生改变,但是通道数会发生改变可以看出DenseBlock中采用了BN+ReLU+Conv的结构,然而一般网络是用Conv+BN+ReLU的结构。这是由于卷积层的输入包含了它前面所有层的输出特征,它们来自不同层的输出,因此数值分布差异比较大,所以它们在输入到下一个卷积层时,必须先经过BN层将其数值进行标准化,然后再进行卷积操作。通常为了减少参数,一般还会先加一个1x1 卷积来减少参数量。所以DenseBlock中的每一层采用BN+ReLU+1x1Conv 、Conv+BN+ReLU+3x3 Conv的结构。

Transition Layer:用于将不同DenseBlock之间进行连接,整合上一个DenseBlock获得的特征,并且缩小上一个DenseBlock的宽高,达到下采样的效果,实质上起到压缩模型的作用。Transition Layer中一般包含一个1x1卷积(用于调整通道数)和2x2平均池化(用于降低特征图大小),结构为BN+ReLU+1x1 Conv+2x2 AvgPooling

DenseNet网络的具体配置信息如下:

可以看出,一个DenseNet中一般有3个或4个DenseBlock,最后的DenseBlock后连接了一个最大池化层,然后是一个包含1000个类别的全连接层,通过softmax激活函数得到类别属性。

DenseNet部分实现代码

直接上干货

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from collections import OrderedDict__all__ = ["densenet121", "densenet161", "densenet169", "densenet201"]class DenseLayer(nn.Module):def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient = False):super(DenseLayer,self).__init__()self.norm1 = nn.BatchNorm2d(num_input_features)self.relu1 = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)self.relu2 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)self.drop_rate = float(drop_rate)self.memory_efficient = memory_efficientdef bn_function(self, inputs):concated_features = torch.cat(inputs, 1)bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features)))return bottleneck_outputdef any_requires_grad(self, input):for tensor in input:if tensor.requires_grad:return Truereturn False@torch.jit.unuseddef call_checkpoint_bottleneck(self, input):def closure(*inputs):return self.bn_function(inputs)return cp.checkpoint(closure, *input)def forward(self, input):if isinstance(input, torch.Tensor):prev_features = [input]else:prev_features = inputif self.memory_efficient and self.any_requires_grad(prev_features):if torch.jit.is_scripting():raise Exception("Memory Efficient not supported in JIT")bottleneck_output = self.call_checkpoint_bottleneck(prev_features)else:bottleneck_output = self.bn_function(prev_features)new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))if self.drop_rate > 0:new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)return new_featuresclass DenseBlock(nn.ModuleDict):def __init__(self,num_layers,num_input_features,bn_size,growth_rate,drop_rate,memory_efficient = False,):super(DenseBlock,self).__init__()for i in range(num_layers):layer = DenseLayer(num_input_features + i * growth_rate,growth_rate=growth_rate,bn_size=bn_size,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.add_module("denselayer%d" % (i + 1), layer)def forward(self, init_features):features = [init_features]for name, layer in self.items():new_features = layer(features)features.append(new_features)return torch.cat(features, 1)class Transition(nn.Sequential):"""Densenet Transition Layer:1 × 1 conv2 × 2 average pool, stride 2"""def __init__(self, num_input_features, num_output_features):super(Transition,self).__init__()self.norm = nn.BatchNorm2d(num_input_features)self.relu = nn.ReLU(inplace=True)self.conv = nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False)self.pool = nn.AvgPool2d(kernel_size=2, stride=2)class DenseNet(nn.Module):def __init__(self,growth_rate = 32,num_init_features = 64,block_config = None,num_classes = 1000,bn_size = 4,drop_rate = 0.,memory_efficient = False,):super(DenseNet,self).__init__()# First convolutionself.features = nn.Sequential(OrderedDict([("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),("norm0", nn.BatchNorm2d(num_init_features)),("relu0", nn.ReLU(inplace=True)),("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),]))# Each denseblocknum_features = num_init_featuresfor i, num_layers in enumerate(block_config):block = DenseBlock(num_layers=num_layers,num_input_features=num_features,bn_size=bn_size,growth_rate=growth_rate,drop_rate=drop_rate,memory_efficient=memory_efficient,)self.features.add_module("denseblock%d" % (i + 1), block)num_features = num_features + num_layers * growth_rateif i != len(block_config) - 1:trans = Transition(num_input_features=num_features, num_output_features=num_features // 2)self.features.add_module("transition%d" % (i + 1), trans)num_features = num_features // 2# Final batch normself.features.add_module("norm5", nn.BatchNorm2d(num_features))# Linear layerself.classifier = nn.Linear(num_features, num_classes)# Official init from torch repo.for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.constant_(m.bias, 0)def forward(self, x):features = self.features(x)out = F.relu(features, inplace=True)out = F.adaptive_avg_pool2d(out, (1, 1))out = torch.flatten(out, 1)out = self.classifier(out)return outdef densenet121(num_classes):"""Densenet-121 model"""return DenseNet(32, 64, (6, 12, 24, 16),num_classes=num_classes)def densenet161(num_classes):"""Densenet-161 model"""return DenseNet(48, 96, (6, 12, 36, 24),  num_classes=num_classes)def densenet169(num_classes):"""Densenet-169 model"""return DenseNet(32, 64, (6, 12, 32, 32),   num_classes=num_classes)def densenet201(num_classes):"""Densenet-201 model"""return DenseNet(32, 64, (6, 12, 48, 32), num_classes=num_classes)if __name__=="__main__":# from torchsummaryX import summarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = densenet121(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)# summary(net, torch.ones((1, 3, 224, 224)).to(device))

希望对大家能够有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/829714.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java面试八股文-2024

面试指南 TMD,一个后端为什么要了解那么多的知识,真是服了。啥啥都得了解 MySQL MySQL索引可能在以下几种情况下失效: 不遵循最左匹配原则:在联合索引中,如果没有使用索引的最左前缀,即查询条件中没有包含…

linux demo

1.1)if case test the results #!bin/bash read -p “请输入你的成绩:” num if [ $num -ge 0 ] && [ $num -le 100 ];then if [ $num -ge 80 ] && [ $num -le 100 ];then echo “成绩优秀” elif [ $num -ge 60 ] && [ $num …

Altera FPGA 配置flash读写

目录 一、读写控制器的配置 二、生成flash的配置文件 三、关于三种配置文件的大小 四、其他 一、读写控制器的配置 Altera ASMI Parallel(下文简称ASMI)这个IP就仅仅是个Flash读写控制器,可以自由的设计数据来源。 关于这个IP的使用,可以…

【ARMv9 DSU-120 系列 2. -- DSU-120 Cluster 中组件详细介绍】

请阅读【Arm DynamIQ™ Shared Unit-120 专栏 】 文章目录 DynamIQ cluster componentsCoresComplexescluster shared logic主要特点小结Shared Logic ComponentsSnoop Control Unit缓存直接传输窥探过滤器自动大小调整Clock manag

MAC有没有免费NTFS tuxera激活码 tuxera破解 tuxera for mac2023序列号直装版 ntfs formac教程

Tuxera NTFS 2023破解版是一款非常好用的在线磁盘读写工具,该软件允许mac用户在Windows NTFS格式的硬盘上进行读写操作,Mac的文件系统是HFS,而Windows则使用NTFS格式,这导致在Mac系统上不能直接读写Windows格式的硬盘。然而&#…

【C++】指针与引用

文章目录 指针什么是指针?使用指针C++ 传递指针给函数C++ 从函数返回指针C++ Null 指针C++ 指针的算术运算指针递增与递减C++ 指针 vs 数组C++ 指向指针的指针(多级间接寻址)引用C++ 引用 vs 指针创建引用把引用作为参数把引用作为返回值参考链接指针 每一个变量都有一个内…

程序员:写好代码就行了,为什么要学写作

🍁 展望:关注我, AI 学习之旅上,我与您一同成长! 一、引言 在当今这个信息爆炸的时代,程序员们往往沉浸在代码的世界里,用代码来解决问题。然而,你是否曾想过,除了代码,…

智慧校园-教务管理系统建设要素

自友科技教务管理系统是一种在现代化的教务管理理念和信息化管理技术之上的一种能够将学籍管理、教研备课、教学计划、教师工作、考务管理、试卷管理、成绩、选修课等紧密地联系起来,实现教务信息管理的一体化的系统,这样能够大大减少教务管理的人工操作…

windows ubuntu sed,awk,grep篇:2:sed 替换命令

目录 6.sed 替换命令语法 7.全局标志 g 8.数字标志(1,2,3 ….) 9.打印标志 p(print) 10.写标志 w 11.忽略大小写标志 i (ignore) 12.执行命令标志 e (excuate) 13.使用替换标志组合 14.sed 替换命令分界符 15.单行内容上执行多个命令 16.&的作用——获取匹配到的模式 17.分…

INSTEAD OF 触发器的创建

Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 INSTEAD OF 触发器,也称替换触发器,是一种特殊的触发器,和其他建立在数据表上的触发器不同,INSTEAD OF 触发器建立在视图上。…

什么是vue,vue怎样使用?

Vue (读音 /vjuː/,类似于 view) 是一套用于构建用户界面的渐进式框架。与其它大型框架不同的是,Vue 被设计为可以自底向上逐层应用。Vue 的核心库只关注视图层,不仅易于上手,还便于与第三方库或既有项目整合。另一方面&#xff0…

C++ 重载 [] 运算符

刚开始我是震惊的! 我从未想过[]下居然有逻辑! 从学步开始 曾因会使用a[0]访问数组元素而沾沾自喜 曾固步自封的认为[] ,理应是访问数组的一种方式 天真快乐的同时,认为[]只是一个无情的标识! 所以 当我们写下a[0]时,究竟是为了什么? 是为了找到a[0]对应的值 那么如何…

Podman入门全指南:安装、配置与运行容器

欢迎来到我的博客,代码的世界里,每一行都是一个故事 Podman入门全指南:安装、配置与运行容器 前言Podman简介什么是 Podman?Podman 与 Docker 的主要区别 安装Podman支持的操作系统和环境安装步骤详解LinuxUbuntuCentOS/RHEL MacO…

双系统下删除ubuntu

絮絮叨叨 由于我在安装Ubuntu的时候没有自定义安装位置,而是使用与window共存的方式让Ubuntu自己选择安装位置,导致卸载时我不知道去格式化哪个分区,查阅多方资料后无果,后在大佬帮助下找到解决方案 解决步骤 1、 插上Ubuntu安…

Axure如何调起浏览器的打印功能

Axure如何调起浏览器的打印功能 答:javascript:window.print(); 不明白的继续往下看 应用场景: 原型设计中,页面上的打印按钮,需要模拟操作演示,需要点击指定的按钮时,唤起浏览器的打印功能&#xff08…

使用Pandas从Excel文件中提取满足条件的数据并生成新的文件

目录 一、引言 二、环境准备 三、读取Excel文件 四、数据筛选 五、保存为新的Excel文件 六、案例与代码总结 七、进阶用法与注意事项 八、结语 在数据处理的日常工作中,我们经常需要从大量数据中筛选出满足特定条件的数据集。Pandas是一个强大的Python数据分…

比 PSD.js 更强的下一代 PSD 解析器,支持 WebAssembly

比 PSD.js 更强的下一代 PSD 解析器,支持 WebAssembly 1.什么是 webtoon/ps webtoon/ps 是 Typescript 中轻量级 Adobe Photoshop .psd/.psb 文件解析器,对 Web 浏览器和 NodeJS 环境提供支持,且做到零依赖。 Fast zero-dependency PSD par…

2024 年最好的免费数据恢复软件,您可以尝试的几个数据恢复软件

由于系统崩溃而丢失数据可能会给用户带来麻烦。我们将重要的宝贵数据和个人数据保存在我们的 PC、笔记本电脑和其他数字设备上。您可能会因分区丢失、意外删除文件和文件夹、格式化硬盘驱动器而丢失数据。数据丢失是不幸的,如果您不小心从系统中删除了文件或数据&am…

初识ChatGPT

初识ChatGPT AIGC这么火热,了解一下?本文主要通过ChatGPT整理了人工智能和GPT相关的很多概念,看完之后,应该能瞥见人工智能的冰山一角。 参考 GPT-4预示着前端开发的终结?你准备好面对无法预测的技术挑战了吗&#…

MATLAB初学者入门(23)—— 旅行商问题(TSP)优化

旅行商问题(TSP, Traveling Salesman Problem)是一个经典的优化问题,要求找到一个最短的路线,使得旅行商从一个城市出发,经过所有城市一次后,回到原出发点。这是一个NP难问题,在数学优化和计算机…