第一步:准备数据
头发分割数据,总共有1050张图片,里面的像素值为0和1,所以看起来全部是黑的,不影响使用
第二步:搭建模型
DeepLabV3+的网络结构如下图所示,主要为Encoder-Decoder结构。其中,Encoder为改进的DeepLabV3,Decoder为3+版本新提出的。
1.1、Encoder
在Encoder部分,主要包括了backbone(即:图1中的DCNN)、ASPP两大部分。
其中backbone有两种网络结构:将layer4改为空洞卷积的Resnet系列、改进的Xception。从backbone出来的feature map分两部分:一部分是最后一层卷积输出的feature maps,另一部分是中间的低级特征的feature maps;backbone输出的第一部分送入ASPP模块,第二部分则送入Decoder模块。
ASPP模块接受backbone的第一部分输出作为输入,使用了四种不同膨胀率的空洞卷积块(包括卷积、BN、激活层)和一个全局平均池化块(包括池化、卷积、BN、激活层)得到一共五组feature maps,将其concat起来之后,经过一个1*1卷积块(包括卷积、BN、激活、dropout层),最后送入Decoder模块。
1.2、Decoder
在Decoder部分,接收来自backbone中间层的低级feature maps和来自ASPP模块的输出作为输入。
首先,对低级feature maps使用1*1卷积进行通道降维,从256降到48(之所以需要降采样到48,是因为太多的通道会掩盖ASPP输出的feature maps的重要性,且实验验证48最佳);
然后,对来自ASPP的feature maps进行插值上采样,得到与低级featuremaps尺寸相同的feature maps;
接着,将通道降维的低级feature maps和线性插值上采样得到的feature maps使用concat拼接起来,并送入一组3*3卷积块进行处理;
最后,再次进行线性插值上采样,得到与原图分辨率大小一样的预测图。
第三步:代码
1)损失函数为:交叉熵损失函数
2)网络代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets.xception import xception
from nets.mobilenetv2 import mobilenetv2class MobileNetV2(nn.Module):def __init__(self, downsample_factor=8, pretrained=True):super(MobileNetV2, self).__init__()from functools import partialmodel = mobilenetv2(pretrained)self.features = model.features[:-1]self.total_idx = len(self.features)self.down_idx = [2, 4, 7, 14]if downsample_factor == 8:for i in range(self.down_idx[-2], self.down_idx[-1]):self.features[i].apply(partial(self._nostride_dilate, dilate=2))for i in range(self.down_idx[-1], self.total_idx):self.features[i].apply(partial(self._nostride_dilate, dilate=4))elif downsample_factor == 16:for i in range(self.down_idx[-1], self.total_idx):self.features[i].apply(partial(self._nostride_dilate, dilate=2))def _nostride_dilate(self, m, dilate):classname = m.__class__.__name__if classname.find('Conv') != -1:if m.stride == (2, 2):m.stride = (1, 1)if m.kernel_size == (3, 3):m.dilation = (dilate//2, dilate//2)m.padding = (dilate//2, dilate//2)else:if m.kernel_size == (3, 3):m.dilation = (dilate, dilate)m.padding = (dilate, dilate)def forward(self, x):low_level_features = self.features[:4](x)x = self.features[4:](low_level_features)return low_level_features, x #-----------------------------------------#
# ASPP特征提取模块
# 利用不同膨胀率的膨胀卷积进行特征提取
#-----------------------------------------#
class ASPP(nn.Module):def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):super(ASPP, self).__init__()self.branch1 = nn.Sequential(nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate,bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True),)self.branch2 = nn.Sequential(nn.Conv2d(dim_in, dim_out, 3, 1, padding=6*rate, dilation=6*rate, bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True), )self.branch3 = nn.Sequential(nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True), )self.branch4 = nn.Sequential(nn.Conv2d(dim_in, dim_out, 3, 1, padding=18*rate, dilation=18*rate, bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True), )self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=True)self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)self.branch5_relu = nn.ReLU(inplace=True)self.conv_cat = nn.Sequential(nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=True),nn.BatchNorm2d(dim_out, momentum=bn_mom),nn.ReLU(inplace=True), )def forward(self, x):[b, c, row, col] = x.size()#-----------------------------------------## 一共五个分支#-----------------------------------------#conv1x1 = self.branch1(x)conv3x3_1 = self.branch2(x)conv3x3_2 = self.branch3(x)conv3x3_3 = self.branch4(x)#-----------------------------------------## 第五个分支,全局平均池化+卷积#-----------------------------------------#global_feature = torch.mean(x,2,True)global_feature = torch.mean(global_feature,3,True)global_feature = self.branch5_conv(global_feature)global_feature = self.branch5_bn(global_feature)global_feature = self.branch5_relu(global_feature)global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)#-----------------------------------------## 将五个分支的内容堆叠起来# 然后1x1卷积整合特征。#-----------------------------------------#feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)result = self.conv_cat(feature_cat)return resultclass DeepLab(nn.Module):def __init__(self, num_classes, backbone="mobilenet", pretrained=True, downsample_factor=16):super(DeepLab, self).__init__()if backbone=="xception":#----------------------------------## 获得两个特征层# 浅层特征 [128,128,256]# 主干部分 [30,30,2048]#----------------------------------#self.backbone = xception(downsample_factor=downsample_factor, pretrained=pretrained)in_channels = 2048low_level_channels = 256elif backbone=="mobilenet":#----------------------------------## 获得两个特征层# 浅层特征 [128,128,24]# 主干部分 [30,30,320]#----------------------------------#self.backbone = MobileNetV2(downsample_factor=downsample_factor, pretrained=pretrained)in_channels = 320low_level_channels = 24else:raise ValueError('Unsupported backbone - `{}`, Use mobilenet, xception.'.format(backbone))#-----------------------------------------## ASPP特征提取模块# 利用不同膨胀率的膨胀卷积进行特征提取#-----------------------------------------#self.aspp = ASPP(dim_in=in_channels, dim_out=256, rate=16//downsample_factor)#----------------------------------## 浅层特征边#----------------------------------#self.shortcut_conv = nn.Sequential(nn.Conv2d(low_level_channels, 48, 1),nn.BatchNorm2d(48),nn.ReLU(inplace=True)) self.cat_conv = nn.Sequential(nn.Conv2d(48+256, 256, 3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Conv2d(256, 256, 3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.Dropout(0.1),)self.cls_conv = nn.Conv2d(256, num_classes, 1, stride=1)def forward(self, x):H, W = x.size(2), x.size(3)#-----------------------------------------## 获得两个特征层# low_level_features: 浅层特征-进行卷积处理# x : 主干部分-利用ASPP结构进行加强特征提取#-----------------------------------------#low_level_features, x = self.backbone(x)x = self.aspp(x)low_level_features = self.shortcut_conv(low_level_features)#-----------------------------------------## 将加强特征边上采样# 与浅层特征堆叠后利用卷积进行特征提取#-----------------------------------------#x = F.interpolate(x, size=(low_level_features.size(2), low_level_features.size(3)), mode='bilinear', align_corners=True)x = self.cat_conv(torch.cat((x, low_level_features), dim=1))x = self.cls_conv(x)x = F.interpolate(x, size=(H, W), mode='bilinear', align_corners=True)return x
第四步:统计一些指标(训练过程中的loss和miou)
第五步:搭建GUI界面
第六步:整个工程的内容
项目源码下载:
整套算法系列:语义分割实战演练_AI洲抿嘴的薯片的博客-CSDN博客
项目源码下载地址:关注文末【AI街潜水的八角】,回复【头发分割】即可下载
整套项目源码内容包含
有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码