图像语义分割是计算机视觉领域中的一个重要任务,它旨在对图像中的每个像素进行分类,从而实现对图像内容的详细理解。在众多图像语义分割算法中,全卷积网络(Fully Convolutional Networks, FCN)因其端到端的训练方式和高效的像素级预测能力而备受关注。本教程将带领你使用MindSpore实现FCN-8s模型,进行图像语义分割任务。通过该教程,你将学习数据预处理、网络构建、模型训练、评估和推理的完整流程。
全卷积网络(Fully Convolutional Networks,FCN)是UC Berkeley的Jonathan Long等人于2015年在论文《Fully Convolutional Networks for Semantic Segmentation》中提出的用于图像语义分割的一种框架。FCN是首个端到端(end to end)进行像素级(pixel level)预测的全卷积网络。
语义分割简介
语义分割(semantic segmentation)是图像处理和机器视觉技术中关于图像理解的重要一环,常被应用于人脸识别、物体检测、医学影像、卫星图像分析、自动驾驶感知等领域。语义分割的目的是对图像中每个像素点进行分类。与普通的分类任务只输出某个类别不同,语义分割任务输出与输入大小相同的图像,输出图像的每个像素对应了输入图像每个像素的类别。
模型简介
FCN主要用于图像分割领域,是一种端到端的分割方法。通过进行像素级的预测直接得出与原图大小相等的label map。因FCN丢弃全连接层替换为全卷积层,网络所有层均为卷积层,故称为全卷积网络。
全卷积神经网络主要使用以下三种技术:
-
卷积化(Convolutional):
使用VGG-16作为FCN的backbone。VGG-16的输入为224*224的RGB图像,输出为1000个预测值。VGG-16中共有三个全连接层,全连接层也可视为带有覆盖整个区域的卷积。将全连接层转换为卷积层能使网络输出由一维非空间输出变为二维矩阵,利用输出能生成输入图片映射的heatmap。 -
上采样(Upsample):
在卷积过程的卷积操作和池化操作会使得特征图的尺寸变小,为得到原图的大小的稠密图像预测,需要对得到的特征图进行上采样操作。使用双线性插值的参数来初始化上采样逆卷积的参数,后通过反向传播来学习非线性上采样。 -
跳跃结构(Skip Layer):
利用上采样技巧对最后一层的特征图进行上采样得到原图大小的分割是步长为32像素的预测,称之为FCN-32s。由于最后一层的特征图太小,损失过多细节,采用skips结构将更具有全局信息的最后一层预测和更浅层的预测结合,使预测结果获取更多的局部细节。
网络特点
- 不含全连接层(fc)的全卷积(fully conv)网络,可适应任意尺寸输入。
- 增大数据尺寸的反卷积(deconv)层,能够输出精细的结果。
- 结合不同深度层结果的跳级(skip)结构,同时确保鲁棒性和精确性。
数据处理
数据预处理和标准化:
- 为什么要标准化数据? 标准化数据有助于加速模型的收敛速度,并且可以防止数值过大或过小导致的数值不稳定问题。
- 为什么要进行数据增强? 数据增强(如随机裁剪、翻转等)可以增加数据的多样性,从而提高模型的泛化能力,减少过拟合。
下载并解压数据集:
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"
download(url, "./dataset", kind="tar", replace=True)
数据预处理
由于PASCAL VOC 2012数据集中图像的分辨率大多不一致,无法放在一个tensor中,故输入前需做标准化处理。
数据加载
将PASCAL VOC 2012数据集与SDB数据集进行混合。
import numpy as np
import cv2
import mindspore.dataset as dsclass SegDataset:def __init__(self,image_mean,image_std,data_file='',batch_size=32,crop_size=512,max_scale=2.0,min_scale=0.5,ignore_label=255,num_classes=21,num_readers=2,num_parallel_calls=4):self.data_file = data_fileself.batch_size = batch_sizeself.crop_size = crop_sizeself.image_mean = np.array(image_mean, dtype=np.float32)self.image_std = np.array(image_std, dtype=np.float32)self.max_scale = max_scaleself.min_scale = min_scaleself.ignore_label = ignore_labelself.num_classes = num_classesself.num_readers = num_readersself.num_parallel_calls = num_parallel_callsmax_scale > min_scaledef preprocess_dataset(self, image, label):image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)sc = np.random.uniform(self.min_scale, self.max_scale)new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)image_out = (image_out - self.image_mean) / self.image_stdout_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)pad_h, pad_w = out_h - new_h, out_w - new_wif pad_h > 0 or pad_w > 0:image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)offset_h = np.random.randint(0, out_h - self.crop_size + 1)offset_w = np.random.randint(0, out_w - self.crop_size + 1)image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]if np.random.uniform(0.0, 1.0) > 0.5:image_out = image_out[:, ::-1, :]label_out = label_out[:, ::-1]image_out = image_out.transpose((2, 0, 1))image_out = image_out.copy()label_out = label_out.copy()label_out = label_out.astype("int32")return image_out, label_outdef get_dataset(self):ds.config.set_numa_enable(True)dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],shuffle=True, num_parallel_workers=self.num_readers)transforms_list = self.preprocess_datasetdataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],output_columns=["data", "label"],num_parallel_workers=self.num_parallel_calls)dataset = dataset.shuffle(buffer_size=self.batch_size * 10)dataset = dataset.batch(self.batch_size, drop_remainder=True)return dataset# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,image_std=IMAGE_STD,data_file=DATA_FILE,batch_size=train_batch_size,crop_size=crop_size,max_scale=max_scale,min_scale=min_scale,ignore_label=ignore_label,num_classes=num_classes,num_readers=2,num_parallel_calls=4)dataset = dataset.get_dataset()
训练集可视化
运行以下代码观察载入的数据集图片(数据处理过程中已做归一化处理)。
import numpy as np
import matplotlib.pyplot as pltplt.figure(figsize=(16, 8))# 对训练集中的数据进行展示
for i in range(1, 9):plt.subplot(2, 4, i)show_data = next(dataset.create_dict_iterator())show_images = show_data["data"].asnumpy()show_images = np.clip(show_images, 0, 1)
# 将图片转换HWC格式后进行展示plt.imshow(show_images[0].transpose(1, 2, 0))plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()
网络构建
网络流程
FCN网络的流程如下图所示:
- 输入图像image,经过pool1池化后,尺寸变为原始尺寸的1/2。
- 经过pool2池化,尺寸变为原始尺寸的1/4。
- 接着经过pool3、pool4、pool5池化,大小分别变为原始尺寸的1/8、1/16、1/32。
- 经过conv6-7卷积,输出的尺寸依然是原图的1/32。
- FCN-32s是最后使用反卷积,使得输出图像大小与输入图像相同。
- FCN-16s是将conv7的输出进行反卷积,使其尺寸扩大两倍至原图的1/16,并将其与pool4输出的特征图进行融合,后通过反卷积扩大到原始尺寸。
- FCN-8s是将conv7的输出进行反卷积扩大4倍,将pool4输出的特征图反卷积扩大2倍,并将pool3输出特征图拿出,三者融合后通反卷积扩大到原始尺寸。
使用以下代码构建FCN-8s网络。
import mindspore.nn as nnclass FCN8s(nn.Cell):def __init__(self, n_class):super().__init__()self.n_class = n_classself.conv1 = nn.SequentialCell(nn.Conv2d(in_channels=3, out_channels=64,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=64,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(64),nn.ReLU())self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.SequentialCell(nn.Conv2d(in_channels=64, out_channels=128,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(in_channels=128, out_channels=128,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(128),nn.ReLU())self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv3 = nn.SequentialCell(nn.Conv2d(in_channels=128, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU())self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv4 = nn.SequentialCell(nn.Conv2d(in_channels=256, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU())self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv5 = nn.SequentialCell(nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU())self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv6 = nn.SequentialCell(nn.Conv2d(in_channels=512, out_channels=4096,kernel_size=7, weight_init='xavier_uniform'),nn.BatchNorm2d(4096),nn.ReLU(),)self.conv7 = nn.SequentialCell(nn.Conv2d(in_channels=4096, out_channels=4096,kernel_size=1, weight_init='xavier_uniform'),nn.BatchNorm2d(4096),nn.ReLU(),)self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=4, stride=2, weight_init='xavier_uniform')self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=4, stride=2, weight_init='xavier_uniform')self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=16, stride=8, weight_init='xavier_uniform')def construct(self, x):x1 = self.conv1(x)p1 = self.pool1(x1)x2 = self.conv2(p1)p2 = self.pool2(x2)x3 = self.conv3(p2)p3 = self.pool3(x3)x4 = self.conv4(p3)p4 = self.pool4(x4)x5 = self.conv5(p4)p5 = self.pool5(x5)x6 = self.conv6(p5)x7 = self.conv7(x6)sf = self.score_fr(x7)u2 = self.upscore2(sf)s4 = self.score_pool4(p4)f4 = s4 + u2u4 = self.upscore_pool4(f4)s3 = self.score_pool3(p3)f3 = s3 + u4out = self.upscore8(f3)return out
训练准备
导入VGG-16部分预训练权重
加载预训练的VGG-16权重:
- 为什么要使用预训练模型? 使用预训练模型可以利用在大规模数据集上预训练的权重,从而加速训练过程,并且在数据量较小时,预训练模型可以提供更好的初始权重,提升模型性能。
FCN使用VGG-16作为骨干网络,用于实现图像编码。使用下面代码导入VGG-16预训练模型的部分预训练权重。
from download import download
from mindspore import load_checkpoint, load_param_into_neturl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16.ckpt"
download(url, "./fcn8s_vgg16.ckpt", kind="file", replace=True)# 加载预训练模型的参数
param_dict = load_checkpoint("fcn8s_vgg16.ckpt")# 创建FCN8s网络实例
net = FCN8s(n_class=21)# 加载参数到网络中
load_param_into_net(net, param_dict)
定义损失函数和优化器
- 为什么选择交叉熵损失函数? 交叉熵损失函数适用于多分类任务,它可以衡量模型输出的概率分布与真实分布之间的差异,是图像语义分割任务中常用的损失函数。
- 为什么选择Adam优化器? Adam优化器结合了动量和RMSProp的优点,具有较快的收敛速度和较好的鲁棒性,是深度学习中常用的优化器。
使用交叉熵损失函数(Cross Entropy Loss)和Adam优化器。
import mindspore.nn as nn
import mindspore.ops as opsclass CrossEntropyLoss(nn.Cell):def __init__(self, num_classes, ignore_label=255):super(CrossEntropyLoss, self).__init__()self.num_classes = num_classesself.ignore_label = ignore_labelself.one_hot = ops.OneHot()self.on_value = ops.OnesLike()(1.0)self.off_value = ops.ZerosLike()(0.0)self.ce = ops.SoftmaxCrossEntropyWithLogits()def construct(self, logits, labels):labels = ops.Reshape()(labels, (-1,))valid_mask = ops.NotEqual()(labels, self.ignore_label)labels = ops.MaskedSelect()(labels, valid_mask)logits = ops.Reshape()(logits, (-1, self.num_classes))logits = ops.MaskedSelect()(logits, valid_mask)logits = ops.Reshape()(logits, (-1, self.num_classes))one_hot_labels = self.one_hot(labels, self.num_classes, self.on_value, self.off_value)loss = self.ce(logits, one_hot_labels)return loss.mean()loss_fn = CrossEntropyLoss(num_classes=21)
optimizer = nn.Adam(net.trainable_params(), learning_rate=0.001)
定义训练函数
from mindspore import Model, context
from mindspore.train.callback import LossMonitor, TimeMonitorcontext.set_context(mode=context.GRAPH_MODE, device_target="CPU")# 创建模型
model = Model(net, loss_fn=loss_fn, optimizer=optimizer, metrics={"accuracy"})# 训练模型
epoch_size = 10
model.train(epoch_size, dataset, callbacks=[LossMonitor(), TimeMonitor()], dataset_sink_mode=False)
模型评估
为什么使用mIoU作为评估指标? mIoU(Mean Intersection over Union)是语义分割任务中的常用评估指标,它可以衡量预测结果与真实标签的重叠程度,反映模型的分割性能。
定义评估函数
from mindspore import Tensor
import numpy as npdef evaluate_model(model, dataset):model.set_train(False)total_inter, total_union = 0, 0for data in dataset.create_dict_iterator():image, label = data["data"], data["label"]pred = model.predict(image)pred = ops.ArgMaxWithValue(axis=1)(pred)[0].asnumpy()label = label.asnumpy()inter = np.logical_and(pred == label, label != 255)union = np.logical_or(pred == label, label != 255)total_inter += np.sum(inter)total_union += np.sum(union)return total_inter / total_union# 评估模型
miou = evaluate_model(model, dataset)
print("Mean Intersection over Union (mIoU):", miou)
模型推理
定义推理函数
def infer(model, image):model.set_train(False)pred = model.predict(image)pred = ops.ArgMaxWithValue(axis=1)(pred)[0].asnumpy()return pred# 推理示例
image = next(dataset.create_dict_iterator())["data"]
pred = infer(model, image)# 可视化结果
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Input Image")
plt.imshow(image[0].asnumpy().transpose(1, 2, 0))
plt.axis("off")plt.subplot(1, 2, 2)
plt.title("Predicted Segmentation")
plt.imshow(pred[0])
plt.axis("off")plt.show()
[外链图片转存中…(img-IzGdBwQG-1719545128292)]