使用 TensorFlow 创建 DenseNet 121

一、说明

本篇示意DenseNet如何在tensorflow上实现,DenseNet与ResNet有类似的地方,都有层与层的“短路”方式,但两者对层的短路后处理有所不同,本文遵照原始论文的技术路线,完整复原了DenseNet的全部网络。

图1:DenseNet中的各种块和层(来源:原始DenseNet论文)

      

二、DenseNet综述

        DenseNet(密集卷积网络)是一种架构,专注于使深度学习网络更深入,但同时通过在层之间使用更短的连接来提高它们的训练效率。DenseNet 是一个卷积神经网络,其中每一层都连接到网络中更深的所有其他层,即第一层连接到第 2、3、4 层等,第二层连接到第 3、4、5 层等。这样做是为了在网络各层之间实现最大的信息流。为了保持前馈特性,每一层从前面的所有层获取输入,并将自己的特征图传递给它将要到达的所有层。与 Resnets 不同,它不是通过求和来组合特征,而是通过连接它们来组合特征。因此,“ith”层具有“i”输入,并且由其所有先前卷积块的特征图组成。它自己的特征图被传递到所有下一个“I-i”层。这在网络中引入了“(I(I+1)))/2”连接,而不是像传统深度学习架构中那样只是“I”连接。因此,与传统的卷积神经网络相比,它需要的参数更少,因为不需要学习不重要的特征图。

        DenseNet由两个重要的块组成,而不是基本的卷积层和池化层。它们是密集块和过渡层。

        接下来,我们看看所有这些块和层的外观,以及如何在 python 中实现它们。

图2:DenseNet121框架(来源:DenseNet原始论文,由作者编辑)

        DenseNet从基本的卷积和池化层开始。然后有一个密集块,然后是一个过渡层,另一个密集块后跟一个过渡层,另一个密集块后跟一个过渡层,最后是一个密集块,然后是一个分类层。

        第一个卷积块有 64 个大小为 7x7 的过滤器,步幅为 2。接下来是最大池化为 3x3 且步幅为 2 的 MaxPooling 层。这两行可以在 python 中用以下代码表示。

input = Input (input_shape)
x = Conv2D(64, 7, strides = 2, padding = 'same')(input)
x = MaxPool2D(3, strides = 2, padding = 'same')(x)

2.1 定义卷积块 —

        输入后的每个卷积块具有以下顺序:批处理归一化,然后是 ReLU 激活,然后是实际的 Conv2D 层。为了实现这一点,我们可以编写以下函数。

#batch norm + relu + conv
def bn_rl_conv(x,filters,kernel=1,strides=1):x = BatchNormalization()(x)x = ReLU()(x)x = Conv2D(filters, kernel, strides=strides,padding = 'same')(x)return x

图3.密集块(来源:DenseNet论文-作者编辑)

2.2 定义密集块 —

        如图 3 所示,每个密集块都有两个卷积,具有 1x1 和 3x3 大小的内核。在密集块 1 中,重复 6 次,在密集块 2 中重复 12 次,在密集块 3 中重复 24 次,最后在密集块 4 中重复 16 次。

在密集块中,每个 1x1 卷积都有 4 倍的滤波器数量。所以我们使用 4*过滤器,但 3x3 过滤器只存在一次。此外,我们必须将输入与输出张量连接起来。

每个块分别运行 6、12、24、16 次重复,使用 'for 循环'。

def dense_block(x, repetition):for _ in range(repetition):y = bn_rl_conv(x, 4*filters)y = bn_rl_conv(y, filters, 3)x = concatenate([y,x])return x

图4:过渡层(来源:DenseNet论文,作者编辑)

2.3 定义过渡层 

        — 在过渡层中,我们将通道数减少到现有通道的一半。有一个 1x1 卷积层和一个 2x2 平均池化层,步幅为 2。bn_rl_conv,函数中已经设置了 1x1 的内核大小,因此我们不需要明确地再次定义它。

        在过渡层中,我们必须将通道删除到现有通道的一半。我们有输入张量x,我们想找到有多少个通道,我们需要得到其中的一半。因此,我们可以使用 Keras 后端 (K) 获取张量 x 并返回一个维度为 x 的元组。而且,我们只需要该形状的最后一个数字,即过滤器的数量。所以我们加上 [-1]。最后,我们可以将这个数量的过滤器除以 2 以获得所需的结果。

def transition_layer(x):x = bn_rl_conv(x, K.int_shape(x)[-1] //2 )x = AvgPool2D(2, strides = 2, padding = 'same')(x)return x

        因此,我们完成了定义密集块和过渡层的工作。现在我们需要将密集块和过渡层堆叠在一起。所以我们写了一个 for 循环来运行 6,12,24,16 次重复。因此,循环运行 4 次,每次使用 6、12、24 或 16 中的值之一。这样就完成了 4 个密集块和过渡层。

for repetition in [6,12,24,16]:d = dense_block(x, repetition)x = transition_layer(d)

        最后,是GlobalAveragePooling,然后是最终的输出层。正如我们在上面的代码块中看到的,密集块由“d”定义,而在最后一层,在密集块 4 之后,没有过渡层 4,而是直接进入分类层。因此,“d”是应用GlobalAveragePooling的连接,而不是“x”。另一种选择是从上面的代码中删除“for”循环,并将层一个接一个地堆叠,而不使用最终的过渡层。

x = GlobalAveragePooling2D()(d)
output = Dense(n_classes, activation = 'softmax')(x)

现在我们已经将所有块放在一起,让我们将它们合并以查看整个DenseNet架构。

三、完整的 DenseNet 121 架构 

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Dense
from tensorflow.keras.layers import AvgPool2D, GlobalAveragePooling2D, MaxPool2D
from tensorflow.keras.models import Model
from tensorflow.keras.layers import ReLU, concatenate
import tensorflow.keras.backend as K
# Creating Densenet121
def densenet(input_shape, n_classes, filters = 32):#batch norm + relu + convdef bn_rl_conv(x,filters,kernel=1,strides=1):x = BatchNormalization()(x)x = ReLU()(x)x = Conv2D(filters, kernel, strides=strides,padding = 'same')(x)return xdef dense_block(x, repetition):for _ in range(repetition):y = bn_rl_conv(x, 4*filters)y = bn_rl_conv(y, filters, 3)x = concatenate([y,x])return xdef transition_layer(x):x = bn_rl_conv(x, K.int_shape(x)[-1] //2 )x = AvgPool2D(2, strides = 2, padding = 'same')(x)return xinput = Input (input_shape)x = Conv2D(64, 7, strides = 2, padding = 'same')(input)x = MaxPool2D(3, strides = 2, padding = 'same')(x)for repetition in [6,12,24,16]:d = dense_block(x, repetition)x = transition_layer(d)x = GlobalAveragePooling2D()(d)output = Dense(n_classes, activation = 'softmax')(x)model = Model(input, output)return model
input_shape = 224, 224, 3
n_classes = 3
model = densenet(input_shape,n_classes)
model.summary()

输出:(假设 3 个最终类 — 模型摘要的最后几行)

四、 查看体系结构关系图 

        可以使用以下代码。

from tensorflow.python.keras.utils.vis_utils import model_to_dot
from IPython.display import SVG
import pydot
import graphvizSVG(model_to_dot(model, show_shapes=True, show_layer_names=True, rankdir='TB',expand_nested=False, dpi=60, subgraph=False
).create(prog='dot',format='svg'))

        输出 — 图表的前几个块

        这就是我们如何实现DenseNet 121架构。

五、引用 

  1. 黄高、刘壮、劳伦斯·范德马滕和基利安·温伯格,密集连接的卷积网络,arXiv 1608.06993 (2016)

    阿琼·萨卡尔

       2 密网论文链接:https://arxiv.org/pdf/1608.06993.pdf 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/99484.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java @Override 注解

在代码中,你可能会看到大量的 Override 注解。 这个注解简单来说就是让编译器去读的,能够避免你在写代码的时候犯一些低级的拼写错误。 Java Override 注解用来指定方法重写(Override),只能修饰方法并且只能用于方法…

怎么将Linux上的文件上传到github上

文章目录 1. 先在window浏览器中创建一个存储项目的仓库2. 复制你的ssh下的地址1) 生成ssh密钥 : 在Linux虚拟机的终端中,运行以下命令生成ssh密钥2)将ssh密钥添加到github账号 : 运行以下命令来获取公钥内容: 3. 克隆GitHub存储库:在Linux虚拟机的终端中&#xff0…

leetcode42 接雨水

题目 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图,计算按此排列的柱子,下雨之后能接多少雨水。 示例 输入:height [0,1,0,2,1,0,1,3,2,1,2,1] 输出:6 解释:上面是由数组 [0,1,0,2,1,0,1,3,2,1,2,1] 表示的高…

Golang网络编程:即时通讯系统Instance Messaging System

系统基本架构 版本迭代 项目改造 无人机是client,我们是server,提供注册登入,场景选择等。信道模拟器是server,我们是client,我们向信道模拟器发送数据,等待信道模拟器计算结果,返回给无人机。…

使用ChatGPT和MindShow一分钟生成PPT模板

对于最近学校组织的实习答辩,由于时间太短了,而且小编也特别的忙,于是就用ChatGPT结合MindShow一分钟快速生成PPT,确实很实用。只要你跟着小编后面,你也可以快速制作出这个PPT,下面小编就来详细介绍一下&am…

联想M7216NWA一体机连接WiFi及手机添加打印机方法

联想M7216NWA一体机连接WiFi方法: 1、首先按打印机操作面板上的“功能键”;【用“”(上翻页)“-”(下翻页)来选择菜单的内容】 2、下翻页键找到并选择“网络”,然后“确认键”; 3…

自动驾驶技术的基础知识

自动驾驶技术是现代汽车工业中的一项革命性发展,它正在改变着我们对交通和出行的理解。本文将介绍自动驾驶技术的基础知识,包括其概念、历史发展、分类以及关键技术要素。 1. 自动驾驶概念 自动驾驶是一种先进的交通技术,它允许汽车在没有人…

Docker安装——Ubuntu (Jammy 22.04)

一、为什么要用 Ubuntu?(centos和ubuntu有什么区别) 使用lsb_release命令:lsb_release -a ,即可查看ubantu的版本,但是为什么要使用ubantu 呢? 区别:1、centos基于EHEL开发,而ubunt…

三十三、【进阶】索引的分类

1、索引的分类 (1)总分类 主键索引、唯一索引、常规索引、全文索引 (2)InnoDB存储引擎中的索引分类 2、 索引的选取规则(InnoDB存储引擎) 如果存在主键,主键索引就是聚集索引; 如果不存在主键&#xff…

最新 SpringCloud微服务技术栈实战教程 微服务保护 分布式事务 课后练习等

SpringCloud微服务技术栈实战教程,涵盖springcloud微服务架构Nacos配置中心分布式服务等 SpringCloud及SpringCloudAlibaba是目前最流行的微服务技术栈。但大家学习起来的感受就是组件很多,不知道该如何应用。这套《微服务实战课》从一个单体项目入手&am…

C++项目:仿mudou库one thread one loop式并发服务器实现

目录 1.实现目标 2.HTTP服务器 3.Reactor模型 3.1分类 4.功能模块划分: 4.1SERVER模块: 4.2HTTP协议模块: 5.简单的秒级定时任务实现 5.1Linux提供给我们的定时器 5.2时间轮思想: 6.正则库的简单使用 7.通用类型any类型的实现 8.日志宏的实现 9.缓冲区…

深度学习 图像分割 PSPNet 论文复现(训练 测试 可视化)

Table of Contents 一、PSPNet 介绍1、原理阐述2、论文解释3、网络模型 二、部署实现1、PASCAL VOC 20122、模型训练3、度量指标4、结果分析5、图像测试 一、PSPNet 介绍 PSPNet(Pyramid Scene Parsing Network)来自于CVPR2017的一篇文章,中文翻译为金字塔场景解析…

YOLOv7暴力涨点:Gold-YOLO,遥遥领先,超越所有YOLO | 华为诺亚NeurIPS23

💡💡💡本文独家改进:提出了全新的信息聚集-分发(Gather-and-Distribute Mechanism)GD机制,Gold-YOLO,替换yolov7 head部分 实现暴力涨点 Gold-YOLO | 亲测在多个数据集能够实现大幅涨点,适用各个场景的涨点 收录: YOLOv7高阶自研专栏介绍: http://t.csdnim…

【产品经理】国内企业服务SAAS平台的生存与发展

SaaS在国外发展的比较成熟,甚至已经成为了主流,但在国内这几年才掀起热潮;企业服务SaaS平台在少部分行业发展较快,大部分行业在国内还处于起步、探索阶段;SaaS将如何再国内生存和发展? 在企业服务行业做了五…

钡铼BL124EC实现EtherCAT转Ethernet/IP的优势

钡铼技术的BL124EC是一款用于将EtherCAT从站转换为Ethernet/IP从站的网关设备。它是钡铼技术开发的高性能、可靠的工业自动化通信解决方案之一。 添加图片注释,不超过 140 字(可选) BL124EC网关可以应用于多种工业自动化场景,以下…

OSPF的7大状态和5大报文详讲

- Down OSPF的初始状态 - Init 初始化——我刚刚给别人发Hello报文 我们可以将OSPF邻居建立的过程理解为:我和你打招呼,你和我打招呼,然后咱俩成了邻居 比如: R1和R2要建立OSPF邻居 R1给R2发送了Hello报文,但是R1此时…

很烦的Node报错积累

目录 1. 卡在sill idealTree buildDeps2、Node Sass老是安装不上的问题3、unable to resolve dependency tree4、nvm相关命令5、设置淘宝镜像等基操5.1 镜像 5.2 npm清理缓存6、Browserslist: caniuse-lite is outdated loader 1. 卡在sill idealTree buildDeps 参考&#xf…

想要精通算法和SQL的成长之路 - 恢复二叉搜索树和有序链表转换二叉搜索树

想要精通算法和SQL的成长之路 - 恢复二叉搜索树和有序链表转换二叉搜索树 前言一. 恢复二叉搜索树二. 有序链表转换二叉搜索树 前言 想要精通算法和SQL的成长之路 - 系列导航 一. 恢复二叉搜索树 原题链接 首先,一个正常地二叉搜索树在中序遍历下,遍历…

antd的upload上传组件,上传成功后清除表单校验——基础积累

今天在写后台管理系统时&#xff0c;发现之前的一个bug&#xff0c;就是antd的upload上传组件&#xff0c;需要进行表单校验。 直接上代码&#xff1a; 1.html部分 <a-form-modelref"ruleForm":model"form":label-col"labelCol":wrapper-col…

通道剪枝channel pruning

1、相关定义 过参数化&#xff1a;主要是指在训练阶段&#xff0c;在数学上需要进行大量的微分求解&#xff0c;去捕捉数据中微小的变化信息&#xff0c;一旦完成迭代式的训练之后&#xff0c;网络模型在推理的时候就不需要这么多参数。剪枝算法&#xff1a;核心思想就是减少网…