昇思MindSpore学习笔记6-01LLM原理和实践--FCN图像语义分割

摘要:

        记录MindSpore AI框架使用FCN全卷积网络理解图像进行图像语议分割的过程、步骤和方法。包括环境准备、下载数据集、数据集加载和预处理、构建网络、训练准备、模型训练、模型评估、模型推理等。

一、

1.语义分割

图像语义分割

semantic segmentation

        图像处理

        机器视觉

                图像理解

        AI领域重要分支

        应用

                人脸识别

                物体检测

                医学影像

                卫星图像分析

                自动驾驶感知

        目的

                图像每个像素点分类

                输出与输入大小相同的图像

                输出图像的每个像素对应了输入图像每个像素的类别

        图像领域语义

                图像的内容

                对图片意思的理解

实例

2.FCN全卷积网络

Fully Convolutional Networks

图像语义分割框架

        2015年UC Berkeley提出

        端到端(end to end)像素级(pixel level)预测全卷积网络

全卷积神经网络主要使用三种技术:

1.卷积化Convolutional

VGG-16

        FCN的backbone

        输入224*224RGB图像

                固定大小的输入

                丢弃了空间坐标

                产生非空间输出

        输出1000个预测值

卷积层

        输出二维矩阵

        生成输入图片映射的heatmap

2.上采样Upsample

卷积过程

        卷积操作

        池化操作

特征图尺寸变小

上采样操作

        得到原图大小的稠密图像预测

双线性插值参数

初始化上采样逆卷积参数

反向传播学习非线性上采样

3.跳跃结构Skip Layer

将深层的全局信息与浅层的局部信息相结合

                             底层stride 32的预测FCN-32s    2倍上采样

融合(相加)  pool4层stride 16的预测FCN-16s    2倍上采样

融合(相加)  pool3层stride 8的预测FCN-8s

特点:

(1)不含全连接层(fc)的全卷积(fully conv)网络,可适应任意尺寸输入。

(2)增大数据尺寸的反卷积(deconv)层,能够输出精细的结果。

(3)结合不同深度层结果的跳级(skip)结构,同时确保鲁棒性和精确性。

二、环境准备

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 

三、数据处理

1.下载数据集

from download import download
​
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"
​
download(url, "./dataset", kind="tar", replace=True)

输出:

Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar (537.2 MB)file_sizes: 100%|█████████████████████████████| 563M/563M [00:03<00:00, 160MB/s]
Extracting tar file...
Successfully downloaded / unzipped to ./dataset
'./dataset'

2.数据预处理

PASCAL VOC 2012数据集图像分辨率不一致

        标准化处理

3.数据加载

混合PASCAL VOC 2012数据集SDB数据集

import numpy as np
import cv2
import mindspore.dataset as ds
​
class SegDataset:def __init__(self,image_mean,image_std,data_file='',batch_size=32,crop_size=512,max_scale=2.0,min_scale=0.5,ignore_label=255,num_classes=21,num_readers=2,num_parallel_calls=4):
​self.data_file = data_fileself.batch_size = batch_sizeself.crop_size = crop_sizeself.image_mean = np.array(image_mean, dtype=np.float32)self.image_std = np.array(image_std, dtype=np.float32)self.max_scale = max_scaleself.min_scale = min_scaleself.ignore_label = ignore_labelself.num_classes = num_classesself.num_readers = num_readersself.num_parallel_calls = num_parallel_callsmax_scale > min_scale
​def preprocess_dataset(self, image, label):image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)sc = np.random.uniform(self.min_scale, self.max_scale)new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
​image_out = (image_out - self.image_mean) / self.image_stdout_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)pad_h, pad_w = out_h - new_h, out_w - new_wif pad_h > 0 or pad_w > 0:image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)offset_h = np.random.randint(0, out_h - self.crop_size + 1)offset_w = np.random.randint(0, out_w - self.crop_size + 1)image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]if np.random.uniform(0.0, 1.0) > 0.5:image_out = image_out[:, ::-1, :]label_out = label_out[:, ::-1]image_out = image_out.transpose((2, 0, 1))image_out = image_out.copy()label_out = label_out.copy()label_out = label_out.astype("int32")return image_out, label_out
​def get_dataset(self):ds.config.set_numa_enable(True)dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],shuffle=True, num_parallel_workers=self.num_readers)transforms_list = self.preprocess_datasetdataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],output_columns=["data", "label"],num_parallel_workers=self.num_parallel_calls)dataset = dataset.shuffle(buffer_size=self.batch_size * 10)dataset = dataset.batch(self.batch_size, drop_remainder=True)return dataset
​
​
# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"
​
# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21
​
# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,image_std=IMAGE_STD,data_file=DATA_FILE,batch_size=train_batch_size,crop_size=crop_size,max_scale=max_scale,min_scale=min_scale,ignore_label=ignore_label,num_classes=num_classes,num_readers=2,num_parallel_calls=4)
​
dataset = dataset.get_dataset()

4.训练集可视化

import numpy as np
import matplotlib.pyplot as plt
​
plt.figure(figsize=(16, 8))
​
# 对训练集中的数据进行展示
for i in range(1, 9):plt.subplot(2, 4, i)show_data = next(dataset.create_dict_iterator())show_images = show_data["data"].asnumpy()show_images = np.clip(show_images, 0, 1)
# 将图片转换HWC格式后进行展示plt.imshow(show_images[0].transpose(1, 2, 0))plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()

输出:

四、网络构建

FCN网络流程

        输入图像image

        pool1池化

                尺寸变为原始尺寸的1/2

        pool2池化

                尺寸变为原始尺寸的1/4

        pool3池化

                尺寸变为原始尺寸的1/8

        pool4池化

                尺寸变为原始尺寸的1/16

        pool5池化

                尺寸变为原始尺寸的1/32

        conv6-7卷积

                输出尺寸原图的1/32

        FCN-32s

                反卷积扩大到原始尺寸

        FCN-16s

                融合

                        conv7反卷积尺寸扩大两倍至原图的1/16

                        pool4特征图

                反卷积扩大到原始尺寸

        FCN-8s

                融合

                        conv7反卷积尺寸扩大4倍

                        pool4特征图反卷积扩大2倍

                        pool3特征图

                反卷积扩大到原始尺寸

构建FCN-8s网络代码:

import mindspore.nn as nn
​
class FCN8s(nn.Cell):def __init__(self, n_class):super().__init__()self.n_class = n_classself.conv1 = nn.SequentialCell(nn.Conv2d(in_channels=3, out_channels=64,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=64,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(64),nn.ReLU())self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.SequentialCell(nn.Conv2d(in_channels=64, out_channels=128,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(in_channels=128, out_channels=128,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(128),nn.ReLU())self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv3 = nn.SequentialCell(nn.Conv2d(in_channels=128, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU())self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv4 = nn.SequentialCell(nn.Conv2d(in_channels=256, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU())self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv5 = nn.SequentialCell(nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU())self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv6 = nn.SequentialCell(nn.Conv2d(in_channels=512, out_channels=4096,kernel_size=7, weight_init='xavier_uniform'),nn.BatchNorm2d(4096),nn.ReLU(),)self.conv7 = nn.SequentialCell(nn.Conv2d(in_channels=4096, out_channels=4096,kernel_size=1, weight_init='xavier_uniform'),nn.BatchNorm2d(4096),nn.ReLU(),)self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=4, stride=2, weight_init='xavier_uniform')self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=4, stride=2, weight_init='xavier_uniform')self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=16, stride=8, weight_init='xavier_uniform')
​def construct(self, x):x1 = self.conv1(x)p1 = self.pool1(x1)x2 = self.conv2(p1)p2 = self.pool2(x2)x3 = self.conv3(p2)p3 = self.pool3(x3)x4 = self.conv4(p3)p4 = self.pool4(x4)x5 = self.conv5(p4)p5 = self.pool5(x5)x6 = self.conv6(p5)x7 = self.conv7(x6)sf = self.score_fr(x7)u2 = self.upscore2(sf)s4 = self.score_pool4(p4)f4 = s4 + u2u4 = self.upscore_pool4(f4)s3 = self.score_pool3(p3)f3 = s3 + u4out = self.upscore8(f3)return out

五、训练准备

1.导入VGG-16部分预训练权重

from download import download
from mindspore import load_checkpoint, load_param_into_net
​
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
def load_vgg16():ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"param_vgg = load_checkpoint(ckpt_vgg16)load_param_into_net(net, param_vgg)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt (513.2 MB)file_sizes: 100%|█████████████████████████████| 538M/538M [00:03<00:00, 179MB/s]
Successfully downloaded file to fcn8s_vgg16_pretrain.ckpt

2.损失函数

交叉熵损失函数

mindspore.nn.CrossEntropyLoss()

计算FCN网络输出与mask之间的交叉熵损失

3.自定义评价指标 Metrics

用于评估模型效果

设共有 K+1个类

        从L_0 到L_{ki}

        其中包含一个空类或背景

P_{ij}表示本属于i类但被预测为j类的像素数量

P_{ii}表示真正的数量

P_{ij}P_{ji}则分别被解释为假正和假负

Pixel Accuracy

PA像素精度

        标记正确的像素占总像素的比例。

PA=\frac{\sum_{i=0}^{k}P_{ii}}{\sum_{i=0}^{k}\sum_{j=0}^{k}P_{ij}}

Mean Pixel Accuracy

MPA均像素精度

计算每个类内正确分类像素数的比例

求所有类的平均

MPA=\frac{1}{K+1}\sum \sum_{i=0}^{k}\frac{P_{ii}}{\sum_{j=0}^{k}P_{ij}}

Mean Intersection over Union

MloU均交并比

        语义分割的标准度量

                计算两个集合的交集和并集之比

                        交集为真实值(ground truth)

                        并集为预测值(predicted segmentation)

                两者之比:正真数 (intersection) /(真正+假负+假正(并集))

                在每个类上计算loU

                平均

MIoU=\frac{1}{K+1} \sum_{i=0}^{k}\frac{p_{ii}}{\sum_{j=0}^{k}p_{ij}+{\sum_{j=0}^{k}p_{ji}}-p_{ii}}

Frequency Weighted Intersection over Union

FWIoU频权交井比

根据每个类出现的频率设置权重

FWIoU=\frac{1}{\sum_{i=0}^{k}\sum_{j=0}^{k}p_{ij}} \sum_{i=0}^{k}\frac{p_{ii}}{\sum_{j=0}^{k}p_{ij}+{\sum_{j=0}^{k}p_{ji}}-p_{ii}}

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as train
​
class PixelAccuracy(train.Metric):def __init__(self, num_class=21):super(PixelAccuracy, self).__init__()self.num_class = num_class
​def _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrix
​def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)
​def update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)
​def eval(self):pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()return pixel_accuracy
​
​
class PixelAccuracyClass(train.Metric):def __init__(self, num_class=21):super(PixelAccuracyClass, self).__init__()self.num_class = num_class
​def _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrix
​def update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)
​def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)
​def eval(self):mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)return mean_pixel_accuracy
​
​
class MeanIntersectionOverUnion(train.Metric):def __init__(self, num_class=21):super(MeanIntersectionOverUnion, self).__init__()self.num_class = num_class
​def _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrix
​def update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)
​def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)
​def eval(self):mean_iou = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -np.diag(self.confusion_matrix))mean_iou = np.nanmean(mean_iou)return mean_iou
​
​
class FrequencyWeightedIntersectionOverUnion(train.Metric):def __init__(self, num_class=21):super(FrequencyWeightedIntersectionOverUnion, self).__init__()self.num_class = num_class
​def _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrix
​def update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)
​def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)
​def eval(self):freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)iu = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -np.diag(self.confusion_matrix))
​frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()return frequency_weighted_iou

六、模型训练

导入VGG-16预训练参数

实例化损失函数、优化器

Model接口编译网络

训练FCN-8s网络

import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Model
​
device_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)
​
train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochs
​
lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,base_lr,total_step,iters_per_epoch,decay_epoch=2)
lr = Tensor(lr_scheduler[-1])
​
# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
​
# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10,keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s",directory="./ckpt",config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)

输出:

epoch: 1 step: 1, loss is 3.0504844
epoch: 1 step: 2, loss is 3.017057
epoch: 1 step: 3, loss is 2.9523003
epoch: 1 step: 4, loss is 2.9488814
epoch: 1 step: 5, loss is 2.666231
epoch: 1 step: 6, loss is 2.7145326
epoch: 1 step: 7, loss is 1.796408
epoch: 1 step: 8, loss is 1.5167583
epoch: 1 step: 9, loss is 1.6862022
epoch: 1 step: 10, loss is 2.4622822
......
epoch: 1 step: 1141, loss is 1.70966
epoch: 1 step: 1142, loss is 1.434751
epoch: 1 step: 1143, loss is 2.406475
Train epoch time: 762889.258 ms, per step time: 667.445 ms

七、模型评估

IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"
​
# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)
​
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
​
if device_target == "Ascend":model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
​
# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,image_std=IMAGE_STD,data_file=DATA_FILE,batch_size=train_batch_size,crop_size=crop_size,max_scale=max_scale,min_scale=min_scale,ignore_label=ignore_label,num_classes=num_classes,num_readers=2,num_parallel_calls=4)
dataset_eval = dataset.get_dataset()
model.eval(dataset_eval)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt (1.00 GB)file_sizes: 100%|██████████████████████████| 1.08G/1.08G [00:10<00:00, 99.7MB/s]
Successfully downloaded file to FCN8s.ckpt
/
{'pixel accuracy': 0.9734831394168291,'mean pixel accuracy': 0.9423324801371116,'mean IoU': 0.8961453779807752,'frequency weighted IoU': 0.9488883312345654}

八、模型推理

使用训练的网络对模型推理结果进行展示。

import cv2
import matplotlib.pyplot as plt
​
net = FCN8s(n_class=num_classes)
# 设置超参
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []
# 推理效果展示(上方为输入图片,下方为推理效果图片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):img_lst.append(show_images[i])mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):plt.subplot(2, 4, i + 1)plt.imshow(img_lst[i].transpose(1, 2, 0))plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 4, i + 5)plt.imshow(res[i])plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

输出:

九、总结

FCN

        使用全卷积层

        通过学习让图片实现端到端分割。

        优点:

                输入接受任意大小的图像

                高效,避免了由于使用像素块而带来的重复存储和计算卷积的问题。

        待改进之处:

                结果不够精细。比较模糊和平滑,边界处细节不敏感。

                像素分类,没有考虑像素之间的关系(如不连续性和相似性)

                忽略空间规整(spatial regularization)步骤,缺乏空间一致性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/43151.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【计算机毕业设计】018基于weixin小程序实习记录

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

Ubuntu下LXC安装、配置和调优

Yo&#xff0c;各位码农朋友们&#xff01;今天我给大家带来一个火爆的技术博客&#xff0c;我们要在三丰云免费服务器上进行Ubuntu下LXC的安装、配置和调优。别小瞧这个免费云服务器&#xff0c;它可是有点意思的&#xff0c;配置还不错&#xff0c;1核CPU、1G内存、10G硬盘、…

亚马逊云(AWS)云服务器在国内仍有很多人在用?

在中国这个竞争激烈的云服务市场&#xff0c;亚马逊云&#xff08;AWS&#xff09;云服务器不仅站稳了脚跟&#xff0c;而且持续受到众多企业的青睐。究其原因&#xff0c;AWS凭借其无可比拟的全球影响力、技术创新力、以及对中国市场的深度理解&#xff0c;构建了一套难以抗拒…

react重新渲染以及避免不必要的渲染

触发react重新渲染 类组件触发重新渲染 1、setState 2、context 3、props 4、forceUpdate 函数式组件重新渲染 1、useState 2、props 避免不必要的渲染 类组件避免不必要的渲染 1、React.PureComponent/shouldComponentUpdate 函数式组件避免不必要的渲染 1、Reac…

【Java系列】深入解析 Lambda表达式

简化这个代码 这个就是Lambda表达式,可以简化匿名内部类的写法 package lambda;public class demo2 {public static void main(String[] args) {//第二个参数是一个接口,所以我们在调用方法的时候,需要传递这个接口的实现类对象--接口多态// 但是这个实现类,我只要用一次,所以我…

Oracle dblink

在oracle数据库中&#xff0c;有的时候需要跨库访问一些数据&#xff0c;比如我在A表想要访问B表的数据&#xff0c;此时就可以在A表建立B表的dblink实现。 那么&#xff0c;如何搭建dblink呢 CREATE PUBLIC DATABASE LINK db_link_name CONNECT TO B IDENTIFIED BY B USING …

OJ-0708

示例1 1 App1 1 09:00 10:00 09:30 App1示例2 2 App1 1 09:00 11:00 App2 2 09:10 09:30 09:20 App2示例3 2 App1 1 09:00 11:00 App2 2 09:10 09:30 09:50 NA示例4 4 App1 1 09:00 10:00 App2 2 10:10 11:00 App3 4 11:10 12:30 App4 5 10:30 11:30 11:20 App4示例5 4 Ap…

@Builder注解详解:巧妙避开常见的陷阱

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 &#x1f38f;&#xff1a;你只管努力&#xff0c;剩下的交给时间 &#x1f3e0; &#xff1a;小破站 Builder注解详解&#xff1a;巧妙避开常见的陷阱 前言1. Builder的基本使用使用示例示例类创建对…

极客时间:使用Autogen Builder和本地LLM(Microsoft Phi3模型)在Mac上创建本地AI代理

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

pointpillar 代码报错

错误&#xff1a; Traceback (most recent call last): File "/home/gaoithe/project/python/code/CenterPoint/./tools/dist_test.py", line 415, in <module> main() File "/home/gaoithe/project/python/code/CenterPoint/./tools/dist_test.p…

运维系列.Nginx:自定义错误页面

运维系列 Nginx&#xff1a;自定义错误页面 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/…

本地部署秘塔开源搜索引擎

秘塔AI搜索是由秘塔科技于2024年初推出的一款新型搜索引擎&#xff0c;被业界誉为“中国版的Perplexity”。秘塔科技成立于2018年4月&#xff0c;其核心团队包括CEO闵可锐、技术专家唐悦和首席运营官王益为等。秘塔AI搜索以其高效简洁的特点受到关注&#xff0c;其搜索结果直接…

LeetCode——第 405 场周赛

题目 找出加密后的字符串 给你一个字符串 s 和一个整数 k。请你使用以下算法加密字符串&#xff1a; 对于字符串 s 中的每个字符 c&#xff0c;用字符串中 c 后面的第 k 个字符替换 c&#xff08;以循环方式&#xff09;。 返回加密后的字符串。 示例 1&#xff1a; 输入&…

数据结构(其二)--线性表

1. 基本概念 线性表&#xff1a; &#xff08;1&#xff09;.其中的各个元素&#xff0c;数据类型相同。 &#xff08;2&#xff09;.元素之间&#xff0c;有次序。 &#xff08;3&#xff09;.都有表头元素和表尾元素。 &#xff08;4&#xff09;.除了表头表尾&#xff…

谷粒商城学习笔记-16-人人开源搭建后台管理系统

文章目录 一&#xff0c;克隆前/后端代码1&#xff0c;克隆前端工程renren-fast-value2&#xff0c;克隆后端工程renren-fast 二&#xff0c;集成后台管理系统的后端代码三&#xff0c;启动后台管理系统四&#xff0c;前端系统的安装和运行1&#xff0c;下载安装VSCode2&#x…

为什么KV Cache只需缓存K矩阵和V矩阵,无需缓存Q矩阵?

大家都知道大模型是通过语言序列预测下一个词的概率。假定{ x 1 x_1 x1​&#xff0c; x 2 x_2 x2​&#xff0c; x 3 x_3 x3​&#xff0c;…&#xff0c; x n − 1 x_{n-1} xn−1​}为已知序列&#xff0c;其中 x 1 x_1 x1​&#xff0c; x 2 x_2 x2​&#xff0c; x 3 x_3 x…

拓展中国剩余定理

题目链接 代码&#xff1a; /*扩展中国剩余定理的使用范围更广泛&#xff0c;不要求模数全部互质扩展中国剩余定理&#xff1a;两两合并同余方程&#xff0c;合并 n - 1 次之后&#xff0c;就能求解合并两个同余方程&#xff1a;x ≡ r1 (mod p1) --- x a*p1 r1x ≡ r2 (mo…

from transformers.modeling_utils import PreTrainedModel

from transformers.modeling_utils import PreTrainedModel 是用于导入 Hugging Face Transformers 库中的 PreTrainedModel 类。这个类是所有预训练模型的基类&#xff0c;提供了许多通用功能和方法&#xff0c;适用于不同类型的模型&#xff08;如BERT、GPT、Transformer-XL等…

STM32对数码管显示的控制

1、在项目开发过程中会遇到STM32控制的数码管显示应用&#xff0c;这里以四位共阴极数码管显示控制为例讲解&#xff1b;这里采用的控制芯片为STM32F103RCT6。 2、首先要确定数码管的段选的8个引脚连接的单片机的引脚是哪8个&#xff0c;然后确认位选的4个引脚连接的单片机的4…

ChatGPT:fetch/xhr是什么意思

ChatGPT&#xff1a;fetch/xhr是什么意思 fetch 和 XHR&#xff08;XMLHttpRequest&#xff09;是两种用于在客户端与服务器之间进行异步通信的方法。在网页开发中&#xff0c;它们用于从服务器获取数据或将数据发送到服务器&#xff0c;而不需要刷新整个页面。 fetch fetch …