机器学习周记(第三十七周:语义分割)2024.4.29~2024.5.5

目录

摘要

ABSTRACT

1 DeepLabV3

1.1 空间金字塔池化(ASPP)

1.2 解码器(Decoder)

1.3 Xception

2 相关代码


摘要

  DeepLabV3+ 是由Google Brain团队开发的深度学习模型,专注于语义分割任务。它采用深度卷积神经网络和空间金字塔池化等先进技术,能够准确地捕获图像的语义信息。通过空洞卷积和多尺度预测等策略,DeepLabV3+能够在不同尺度下有效地分割图像,同时通过解码器模块和融合级联特征进一步提高分割精度。这使得DeepLabV3+成为处理复杂场景和提取细节信息的强大工具,在图像分割、医学图像分析等领域具有广泛的应用前景。

ABSTRACT

DeepLabV3+ is a deep learning model developed by the Google Brain team, focusing on semantic segmentation tasks. It employs advanced techniques such as deep convolutional neural networks and spatial pyramid pooling to accurately capture the semantic information of images. Through strategies like dilated convolutions and multi-scale prediction, DeepLabV3+ can effectively segment images at different scales, while further enhancing segmentation accuracy through decoder modules and feature cascading fusion. This makes DeepLabV3+ a powerful tool for handling complex scenes and extracting detailed information, with wide applications in areas such as image segmentation and medical image analysis.

1 DeepLabV3

  DeepLabv3+模型的整体架构如下图所示,它的Encoder的主体是带有空洞卷积的DCNN,可以采用常用的分类网络如ResNet,然后是带有空洞卷积的空间金字塔池化模块(Atrous Spatial Pyramid Pooling,ASPP),主要是为了引入多尺度信息;相比DeepLabv3v3+引入了Decoder模块,其将底层特征与高层特征进一步融合,提升分割边界准确度。从某种意义上看,DeepLabv3+DilatedFCN基础上引入了EcoderDecoder的思路。

  对于DilatedFCN,主要是修改分类网络的后面block,用空洞卷积来替换stride=2的下采样层,如下图所示:其中a是原始FCN,由于下采样的存在,特征图不断降低;而bDilatedFCN,在第block3后引入空洞卷积,在维持特征图大小的同时保证了感受野和原始网络一致。

  在DeepLab中,将输入图片与输出特征图的尺度之比记为output_stride,如上图的output_stride为16,如果加上ASPP结构,就变成如下图所示。其实这就是DeepLabv3结构,v3+只不过是增加了Decoder模块。这里的DCNN可以是任意的分类网络,一般又称为backbone,如采用ResNet网络。

1.1 空间金字塔池化(ASPP)

  在DeepLab中,采用空间金字塔池化模块来进一步提取多尺度信息,这里是采用不同rate的空洞卷积来实现这一点。ASPP模块主要包含以下几个部分: (1) 一个1×1卷积层,以及三个3x3的空洞卷积,对于output_stride=16,其rate为(6, 12, 18) ,若output_stride=8,rate加倍(这些卷积层的输出channel数均为256,并且含有BN层); (2)一个全局平均池化层得到image-level特征,然后送入1x1卷积层(输出256个channel),并双线性插值到原始大小; (3)将(1)和(2)得到的4个不同尺度的特征在channel维度concat在一起,然后送入1x1的卷积进行融合并得到256-channel的新特征。

  ASPP主要是为了抓取多尺度信息,这对于分割准确度至关重要,一个与ASPP结构比较像的是[PSPNet](Pyramid Scene Parsing Network)中的金字塔池化模块,如下图所示,主要区别在于这里采用池化层来获取多尺度特征。

1.2 解码器(Decoder)

  对于DeepLabv3,经过ASPP模块得到的特征图的output_stride为8或者16,其经过1x1的分类层后直接双线性插值到原始图片大小,这是一种非常暴力的decoder方法,特别是output_stride=16。然而这并不利于得到较精细的分割结果,故v3+模型中借鉴了EncoderDecoder结构,引入了新的Decoder模块,如下图所示。首先将encoder得到的特征双线性插值得到4x的特征,然后与encoder中对应大小的低级特征concat,如ResNet中的Conv2层,由于encoder得到的特征数只有256,而低级特征维度可能会很高,为了防止encoder得到的高级特征被弱化,先采用1x1卷积对低级特征进行降维(paper中输出维度为48)。两个特征concat后,再采用3x3卷积进一步融合特征,最后再双线性插值得到与原始图片相同大小的分割预测。 

1.3 Xception

  DeepLabv3所采用的backboneResNet网络,在v3+模型作者尝试了改进的XceptionXception网络主要采用depthwise separable convolution,这使得Xception计算量更小。改进的Xception主要体现在以下几点: (1)参考MSRA的修改(Deformable Convolutional Networks),增加了更多的层; (2)所有的最大池化层使用stride=2的depthwise separable convolutions替换,这样可以改成空洞卷积 ; (3)与MobileNet类似,在3x3 depthwise convolution后增加BN和ReLU。

  采用改进的Xception网络作为backboneDeepLab网络分割效果上有一定的提升。作者还尝试了在ASPP中加入depthwise separable convolution,发现在基本不影响模型效果的前提下减少计算量。

2 相关代码

class DeepLab:输入x经过backbone得到16倍下采样的feature map1和低级feature map2;feature map1送入ASPP模块,得到结果,然后和feature map2一起送入Decoder模块;最后经过插值得到与原图大小相等的预测图。代码如下:

class DeepLab(BaseModel):def __init__(self, num_classes, in_channels=3, backbone='xception', pretrained=True, output_stride=16, freeze_bn=False, **_):super(DeepLab, self).__init__()assert ('xception' or 'resnet' in backbone)if 'resnet' in backbone:self.backbone = ResNet(in_channels=in_channels, output_stride=output_stride, pretrained=pretrained)low_level_channels = 256else:self.backbone = Xception(output_stride=output_stride, pretrained=pretrained)low_level_channels = 128self.ASSP = ASSP(in_channels=2048, output_stride=output_stride)self.decoder = Decoder(low_level_channels, num_classes)if freeze_bn: self.freeze_bn()def forward(self, x):H, W = x.size(2), x.size(3)x, low_level_features = self.backbone(x)x = self.ASSP(x)x = self.decoder(x, low_level_features)x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)return x# Two functions to yield the parameters of the backbone# & Decoder / ASSP to use differentiable learning rates# FIXME: in xception, we use the parameters from xception and not aligned xception# better to have higher lr for this backbonedef get_backbone_params(self):return self.backbone.parameters()def get_decoder_params(self):return chain(self.ASSP.parameters(), self.decoder.parameters())def freeze_bn(self):for module in self.modules():if isinstance(module, nn.BatchNorm2d): module.eval()

需要注意的是:如果使用ResNet系列作为backbone,中间的低级feature map输出维度为256,如果使用Xception作为backbone,中间的低级feature map维度为128。不过,不管是256还是128,最终都要在送入Decoder后降采样到48通道。

backbone-ResNet:对于ResNet系列,一共有layer0~4,共五个layer。其中,前三个layers,也即layer0~layer2不变,仅针对layer3、layer4进行了改进,将普通卷积改为了空洞卷积。如果输出步幅(输入尺寸与输出feature map尺寸之比)为8,需要改动layer3和layer4;如果输出步幅为16,则仅改动layer4:

if output_stride == 16: s3, s4, d3, d4 = (2, 1, 1, 2)
elif output_stride == 8: s3, s4, d3, d4 = (1, 1, 2, 4)if output_stride == 8: for n, m in self.layer3.named_modules():if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'):m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3)elif 'conv2' in n:m.dilation, m.padding, m.stride = (d3,d3), (d3,d3), (s3,s3)elif 'downsample.0' in n:m.stride = (s3, s3)for n, m in self.layer4.named_modules():if 'conv1' in n and (backbone == 'resnet34' or backbone == 'resnet18'):m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4)elif 'conv2' in n:m.dilation, m.padding, m.stride = (d4,d4), (d4,d4), (s4,s4)elif 'downsample.0' in n:m.stride = (s4, s4)

此外,中间的低级feature maps在ResNet系列中,是layer1的输出。

backbone-Xception:如果以Xception作为backbone,则需要对Xception的中间流(Middle Flow)和出口流(Exit flow)进行改动:去掉原有的池化层,并将原有的卷积层替换为带有步长的可分离卷积,但是入口流(Entry Flow)不变:

# Stride for block 3 (entry flow), and the dilation rates for middle flow and exit flow
if output_stride == 16: b3_s, mf_d, ef_d = 2, 1, (1, 2)
if output_stride == 8: b3_s, mf_d, ef_d = 1, 2, (2, 4)# Entry Flow
self.conv1 = nn.Conv2d(in_channels, 32, 3, 2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(32)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, 3, 1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64)self.block1 = Block(64, 128, stride=2, dilation=1, use_1st_relu=False)
self.block2 = Block(128, 256, stride=2, dilation=1)
self.block3 = Block(256, 728, stride=b3_s, dilation=1)# Middle Flow
for i in range(16):exec(f'self.block{i+4} = Block(728, 728, stride=1, dilation=mf_d)')# Exit flow
self.block20 = Block(728, 1024, stride=1, dilation=ef_d[0], exit_flow=True)self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=ef_d[1])
self.bn3 = nn.BatchNorm2d(1536)
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=ef_d[1])
self.bn4 = nn.BatchNorm2d(1536)
self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=ef_d[1])
self.bn5 = nn.BatchNorm2d(2048)

而中间的低级feature maps在Xception系列中,是Entry Flow中block1的输出。

class ASPP:从backbone出来的输出步幅为16的feature maps被送入了ASPP模块,在该模块中经过不同膨胀率的卷积块和一个全局信息提取块后,concat起来,最后经过一个1*1卷积块之后,即为ASPP模块的输出。注意,这里之所以说是“块”,是因为其不单单包含一个操作,也包含了多个其他的操作,如BN、RELU、Dropout等,上文的1.1节等地方均有类似描述。如ASPP的不同膨胀率的分支定义如下:

def assp_branch(in_channels, out_channles, kernel_size, dilation):padding = 0 if kernel_size == 1 else dilationreturn nn.Sequential(nn.Conv2d(in_channels, out_channles, kernel_size, padding=padding, dilation=dilation, bias=False),nn.BatchNorm2d(out_channles),nn.ReLU(inplace=True))

全局信息提取块定义如下:

self.avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),nn.Conv2d(in_channels, 256, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True))

ASPP类定义的完整代码如下:

class ASSP(nn.Module):def __init__(self, in_channels, output_stride):super(ASSP, self).__init__()assert output_stride in [8, 16], 'Only output strides of 8 or 16 are suported'if output_stride == 16: dilations = [1, 6, 12, 18]elif output_stride == 8: dilations = [1, 12, 24, 36]self.aspp1 = assp_branch(in_channels, 256, 1, dilation=dilations[0])self.aspp2 = assp_branch(in_channels, 256, 3, dilation=dilations[1])self.aspp3 = assp_branch(in_channels, 256, 3, dilation=dilations[2])self.aspp4 = assp_branch(in_channels, 256, 3, dilation=dilations[3])self.avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),nn.Conv2d(in_channels, 256, 1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True))self.conv1 = nn.Conv2d(256*5, 256, 1, bias=False)self.bn1 = nn.BatchNorm2d(256)self.relu = nn.ReLU(inplace=True)self.dropout = nn.Dropout(0.5)initialize_weights(self)def forward(self, x):x1 = self.aspp1(x)x2 = self.aspp2(x)x3 = self.aspp3(x)x4 = self.aspp4(x)x5 = F.interpolate(self.avg_pool(x), size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True)x = self.conv1(torch.cat((x1, x2, x3, x4, x5), dim=1))x = self.bn1(x)x = self.dropout(self.relu(x))return x

class Decoder:Decoder部分属于最后一部分了,其接受backbone的低级feature maps和ASPP输出的feature maps,并对其分别进行了降维、上采样,然后concat,最后经过一组3*3卷积块后输出。其类定义代码如下:

class Decoder(nn.Module):def __init__(self, low_level_channels, num_classes):super(Decoder, self).__init__()self.conv1 = nn.Conv2d(low_level_channels, 48, 1, bias=False)self.bn1 = nn.BatchNorm2d(48)self.relu = nn.ReLU(inplace=True)# Table 2, best performance with two 3x3 convsself.output = nn.Sequential(nn.Conv2d(48+256, 256, 3, stride=1, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Conv2d(256, 256, 3, stride=1, padding=1, bias=False),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Dropout(0.1),nn.Conv2d(256, num_classes, 1, stride=1),)initialize_weights(self)def forward(self, x, low_level_features):low_level_features = self.conv1(low_level_features)low_level_features = self.relu(self.bn1(low_level_features))H, W = low_level_features.size(2), low_level_features.size(3)x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)x = self.output(torch.cat((low_level_features, x), dim=1))return x

需要注意的是,该代码将最后的4倍上采样插值的操作放到Decoder外面了,这一点与论文稍有差别,但只是归属不同,效果是一样的,不影响使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/6806.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QtWindows任务栏

目录 引言任务栏进度右键菜单缩略图工具栏完整代码 引言 针对Windows系统的任务栏,Qt基于系统的原生接口封装有一些非常见类,如QWinTaskbarButton、QWinTaskbarButton、QWinThumbnailToolBar等,用于利用工具栏提供更多的信息,诸如…

【CTF Web】XCTF GFSJ0482 weak_auth Writeup(弱口令+密码爆破)

weak_auth 小宁写了一个登陆验证页面,随手就设了一个密码。 解法 随便输入一些字符,提示以 admin 登录。 使用 Burp 抓包。 导入密码字典。 进行爆破。 得到密码。 账号:admin 密码:123456取得 flag。 Flag cyberpeace{42c9664…

Servlet框架

简介 Servlet是运行在web服务器或应用服务器上的程序,他是作为来自web浏览器或其他http客户端的请求和HTTP服务器上的数据库或应用程序之间的中间层。 使用Servlet可以手机来自网页表单的用户输入,呈现来自数据库或者其他源记录,还可以动态创…

解决网络ping不通问题

网络ping不通可能有多种原因,以下是一些常见的解决方法: 1. 检查IP地址和域名:确保你使用的是正确的IP地址或者域名来ping目标设备。如果IP地址或者域名错误,ping请求将无法到达目标设备。 2. 检查网络连接:首先确保…

【LeetCode刷题】153. 寻找旋转排序数组中的最小值

1. 题目链接2. 题目描述3. 解题方法4. 代码 1. 题目链接 153. 寻找旋转排序数组中的最小值 2. 题目描述 3. 解题方法 根据题目分析,可以明确一点,无论该数组如何旋转,都会有这样的一个性质,就是nums[0] > nums[n-1]&#xf…

RK3568 学习笔记 : u-boot 千兆网络无法 ping 通PC问题的解决方法二

参考 RK3568 学习笔记 : u-boot 千兆网络无法 ping 通PC问题的解决 前言 rk3568 rockchip 提供的 u-boot,默认的设备树需要读取 单独分区 resouce.img 镜像中的 设备树文件,也就是 Linux 内核的设备树 dtb 文件,gmac 网络才能正常的 ping 通…

STM32F1之FLASH闪存

目录 1. 简介 2. 闪存模块组织 3. FLASH基本结构 4. FLASH解锁 5. 使用指针访问存储器 6. 程序存储器全擦除 7. 程序存储器页擦除 8. 程序存储器编程 9. 选项字节 1. 简介 STM32F1系列的FLASH包含程序存储器、系统存储器和选项字节三个部分,通过…

【Android】Android应用性能优化总结

AndroidApp应用性能优化总结 最近大半年的时间里,大部分投在了某国内新能源汽车的某款AndroidApp开发上。 由于该App是该款车上,常用重点应用。所以车厂对应用性能的要求比较高。 主要包括: 应用冷启动达到***ms。应用热(温)启动达到***ms应…

RK3568笔记二十四:基于Flask的网页监控系统

若该文为原创文章,转载请注明原文出处。 此实验参考 《鲁班猫监控检测》,原代码有点BUG,已经下载不了。2. 鲁班猫监控检测 — [野火]嵌入式AI应用开发实战指南—基于LubanCat-RK系列板卡 文档 (embedfire.com) 一、简介 记录简单的摄像头监…

易语言IDE界面美化支持库

易语言IDE界面美化支持库 下载下来可以看到,是一个压缩包。 那么,怎么安装到易语言中呢? 解压之后,得到这两个文件。 直接将clr和lib丢到易语言安装目录中,这样子就安装完成了。 打开易语言,在支持库配置…

在营销的世界,你一定要记住:营满,则销

营销的世界中,有一个非常重要的一件事,这几个字一定要记住: 营满,则销;营未满,则不销。 你有没有把握,这是一个没办法可以复杂的东西,真得看营销人的直觉,营跟销是独立的两件事,营在营势,销是自然而然的。这里, 什么样的客户,看到什么样的产品。会有什么样的抗…

HCIP的学习(11)

OSPF的LSA详解 LSA头部信息 ​ [r2]display ospf lsdb router 1.1.1.1----查看OSPF某一条LSA的详细信息,类型以及LS ID参数。 链路状态老化时间 指一条LSA的老化时间,即存在了多长时间。当一条LSA被始发路由器产生时,该参数值被设定为0之后…

32 OpenCV Harris角点检测

文章目录 cornerHarris 算子示例 角点检测 cornerHarris 算子 void cv::cornerHarris ( InputArray src,OutputArray dst,int blockSize,int ksize,double K,int borderType BORDER_DEFAULT) src:待检测Harris角点的输入图像,图像必须是CV 8U或者CV 32F的单通道…

Maven 在项目的 pom.xml 文件中 指定 阿里云的景象仓库

配置 在 项目的 pom.xml 文件中添加如下配置即可 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation&…

【Unity】位图字体制作工具:蒲公英

一般来讲&#xff0c;如果需要制作位图字体&#xff0c;一般是使用 BMFont 这种第三方工具&#xff1a;BMFont - AngelCode.comhttp://www.angelcode.com/products/bmfont/ 然而这个工具对于非程序员来说&#xff0c;操作起来较为繁琐困难。每次美术修改了字体之后&…

【区块链】比特币架构

比特币架构 2009年1月&#xff0c;在比特币系统论文发表两个月之后&#xff0c;比特币系统正式运行并开放了源码&#xff0c;标志着比特币网络的正式诞生。通过其构建的一个公开透明、去中心化、防篡改的账本系统&#xff0c;比特币开展了一场规模空前的加密数字货币体验。在区…

C++手写协程项目(协程实现线程结构体、线程调度器定义,线程挂起、切换、恢复函数,模块测试)

协程结构体定义 之前我们使用linux下协程函数实现了线程切换&#xff0c;使用的是ucontext_t结构体&#xff0c;和基于这个结构体的四个函数。现在我们要用这些工具来实现我们自己的一个线程结构体&#xff0c;并实现线程调度和线程切换、挂起。 首先我们来实现以下线程结构体…

Linux常用软件安装(JDK、MySQL、Tomcat、Redis)

目录 一、上传与下载工具Filezilla1. filezilla官网 二、JDK安装1. 在opt中创建JDK目录2.上传JDK压缩文件到新建目录中3.卸载系统自代jdk4.安装JDK5.JDK环境变量配置6. 验证是否安装成功 三、安装MySQL1.创建mysql文件夹2.下载mysql安装压缩包3.上传到文件夹里面4. 卸载系统自带…

ThreeJS:光线投射与3D场景交互

光线投射Raycaster 光线投射详细介绍可参考&#xff1a;https://en.wikipedia.org/wiki/Ray_casting&#xff0c; ThreeJS中&#xff0c;提供了Raycaster类&#xff0c;用于进行鼠标拾取&#xff0c;即&#xff1a;当三维场景中鼠标移动时&#xff0c;利用光线投射&#xff0c;…

SpringCloudAlibaba:4.1云原生网关higress的搭建

概述 简介 Higress是基于阿里内部的Envoy Gateway实践沉淀、以开源Istio Envoy为核心构建的下一代云原生网关&#xff0c; 实现了流量网关 微服务网关 安全网关三合一的高集成能力&#xff0c;深度集成Dubbo、Nacos、Sentinel等微服务技术栈 定位 在虚拟化时期的微服务架构…