昇思25天学习打卡营第10天|FCN图像语义分割

一、简介:

本篇博客是昇思大模型打卡营应用实践部分的第一次分享,主题是计算机视觉(CV)领域的FCN图像语义分割,接下来几天还会陆续分享其他CV领域的知识(doge)。

全卷积网络(Fully Convolutional Networks,FCN)是UC Berkeley的Jonathan Long等人于2015年在Fully Convolutional Networks for Semantic Segmentation[1]一文中提出的用于图像语义分割的一种框架。FCN是首个端到端(end to end)进行像素级(pixel level)预测的全卷积网络。

二、语义分割:

在具体介绍FCN之前,首先介绍何为语义分割:

图像语义分割(semantic segmentation)是图像处理和机器视觉技术中关于图像理解的重要一环,AI领域中一个重要分支,常被应用于人脸识别、物体检测、医学影像、卫星图像分析、自动驾驶感知等领域。

语义分割的目的是对图像中每个像素点进行分类。与普通的分类任务只输出某个类别不同,语义分割任务输出与输入大小相同的图像,输出图像的每个像素对应了输入图像每个像素的类别。语义在图像领域指的是图像的内容,对图片意思的理解,下图是一些语义分割的实例:

可以看到最右边的原始图像经过语义分割之后,实现了像素级别的目标物体识别。

三、模型简介:

FCN主要用于图像分割领域,是一种端到端的分割方法,是深度学习应用在图像语义分割的开山之作。通过进行像素级的预测直接得出与原图大小相等的label map。因FCN丢弃全连接层替换为全卷积层,网络所有层均为卷积层,故称为全卷积网络。

全卷积神经网络主要使用以下三种技术:

1、卷积化:

使用VGG-16作为FCN的backbone。VGG-16的输入为224*224的RGB图像,输出为1000个预测值。VGG-16只能接受固定大小的输入,丢弃了空间坐标,产生非空间输出。VGG-16中共有三个全连接层,全连接层也可视为带有覆盖整个区域的卷积。将全连接层转换为卷积层能使网络输出由一维非空间输出变为二维矩阵,利用输出能生成输入图片映射的heatmap。

2、上采样:

在卷积过程的卷积操作和池化操作会使得特征图的尺寸变小,为得到原图的大小的稠密图像预测,需要对得到的特征图进行上采样操作。使用双线性插值的参数来初始化上采样逆卷积的参数,后通过反向传播来学习非线性上采样。在网络中执行上采样,以通过像素损失的反向传播进行端到端的学习。

3、跳跃结构:

利用上采样技巧对最后一层的特征图进行上采样得到原图大小的分割是步长为32像素的预测,称之为FCN-32s。由于最后一层的特征图太小,损失过多细节,采用skips结构将更具有全局信息的最后一层预测和更浅层的预测结合,使预测结果获取更多的局部细节。将底层(stride 32)的预测(FCN-32s)进行2倍的上采样得到原尺寸的图像,并与从pool4层(stride 16)进行的预测融合起来(相加),这一部分的网络被称为FCN-16s。随后将这一部分的预测再进行一次2倍的上采样并与从pool3层得到的预测融合起来,这一部分的网络被称为FCN-8s。 Skips结构将深层的全局信息与浅层的局部信息相结合。

四、 数据处理:

开始下面的操作之前,需要先下载Mindspore,还没有下载的宝子,可以回看我的昇思25天学习打卡营第1天|快速入门-CSDN博客。

1、数据集下载:

import time
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset_fcn8s.tar"download(url, "./dataset", kind="tar", replace=True)

2、数据集加载:

import numpy as np
import cv2
import mindspore.dataset as dsclass SegDataset:def __init__(self,image_mean,image_std,data_file='',batch_size=32,crop_size=512,max_scale=2.0,min_scale=0.5,ignore_label=255,num_classes=21,num_readers=2,num_parallel_calls=4):self.data_file = data_fileself.batch_size = batch_sizeself.crop_size = crop_sizeself.image_mean = np.array(image_mean, dtype=np.float32)self.image_std = np.array(image_std, dtype=np.float32)self.max_scale = max_scaleself.min_scale = min_scaleself.ignore_label = ignore_labelself.num_classes = num_classesself.num_readers = num_readersself.num_parallel_calls = num_parallel_callsmax_scale > min_scaledef preprocess_dataset(self, image, label):image_out = cv2.imdecode(np.frombuffer(image, dtype=np.uint8), cv2.IMREAD_COLOR)label_out = cv2.imdecode(np.frombuffer(label, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)sc = np.random.uniform(self.min_scale, self.max_scale)new_h, new_w = int(sc * image_out.shape[0]), int(sc * image_out.shape[1])image_out = cv2.resize(image_out, (new_w, new_h), interpolation=cv2.INTER_CUBIC)label_out = cv2.resize(label_out, (new_w, new_h), interpolation=cv2.INTER_NEAREST)image_out = (image_out - self.image_mean) / self.image_stdout_h, out_w = max(new_h, self.crop_size), max(new_w, self.crop_size)pad_h, pad_w = out_h - new_h, out_w - new_wif pad_h > 0 or pad_w > 0:image_out = cv2.copyMakeBorder(image_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)label_out = cv2.copyMakeBorder(label_out, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=self.ignore_label)offset_h = np.random.randint(0, out_h - self.crop_size + 1)offset_w = np.random.randint(0, out_w - self.crop_size + 1)image_out = image_out[offset_h: offset_h + self.crop_size, offset_w: offset_w + self.crop_size, :]label_out = label_out[offset_h: offset_h + self.crop_size, offset_w: offset_w+self.crop_size]if np.random.uniform(0.0, 1.0) > 0.5:image_out = image_out[:, ::-1, :]label_out = label_out[:, ::-1]image_out = image_out.transpose((2, 0, 1))image_out = image_out.copy()label_out = label_out.copy()label_out = label_out.astype("int32")return image_out, label_outdef get_dataset(self):ds.config.set_numa_enable(True)dataset = ds.MindDataset(self.data_file, columns_list=["data", "label"],shuffle=True, num_parallel_workers=self.num_readers)transforms_list = self.preprocess_datasetdataset = dataset.map(operations=transforms_list, input_columns=["data", "label"],output_columns=["data", "label"],num_parallel_workers=self.num_parallel_calls)dataset = dataset.shuffle(buffer_size=self.batch_size * 10)dataset = dataset.batch(self.batch_size, drop_remainder=True)return dataset# 定义创建数据集的参数
IMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"# 定义模型训练参数
train_batch_size = 4
crop_size = 512
min_scale = 0.5
max_scale = 2.0
ignore_label = 255
num_classes = 21# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,image_std=IMAGE_STD,data_file=DATA_FILE,batch_size=train_batch_size,crop_size=crop_size,max_scale=max_scale,min_scale=min_scale,ignore_label=ignore_label,num_classes=num_classes,num_readers=2,num_parallel_calls=4)dataset = dataset.get_dataset()

 3、数据集可视化:

import numpy as np
import matplotlib.pyplot as pltplt.figure(figsize=(16, 8))# 对训练集中的数据进行展示
for i in range(1, 9):plt.subplot(2, 4, i)show_data = next(dataset.create_dict_iterator())show_images = show_data["data"].asnumpy()show_images = np.clip(show_images, 0, 1)
# 将图片转换HWC格式后进行展示plt.imshow(show_images[0].transpose(1, 2, 0))plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0)
plt.show()print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), 'VertexGeek')

np.clip这个函数之前,我们没有介绍过,这里补充一下,函数是NumPy库的一部分,用于将数组中的元素限制在给定的最小值和最大值之间。如果一个元素的值小于最小值,它将被设置为最小值;如果它的值大于最大值,它将被设置为最大值;如果在最小值和最大值之间,它将保持不变。np.clip(show_images,0,1)方法确保图像数据的值在0到1的范围内,这样它们就可以正确地显示或者用于后续的图像处理步骤 。

 五、网络构建:

FCN网络的流程如下图所示:

  1. 输入图像image,经过pool1池化后,尺寸变为原始尺寸的1/2。
  2. 经过pool2池化,尺寸变为原始尺寸的1/4。
  3. 接着经过pool3、pool4、pool5池化,大小分别变为原始尺寸的1/8、1/16、1/32。
  4. 经过conv6-7卷积,输出的尺寸依然是原图的1/32。
  5. FCN-32s是最后使用反卷积,使得输出图像大小与输入图像相同。
  6. FCN-16s是将conv7的输出进行反卷积,使其尺寸扩大两倍至原图的1/16,并将其与pool4输出的特征图进行融合,后通过反卷积扩大到原始尺寸。
  7. FCN-8s是将conv7的输出进行反卷积扩大4倍,将pool4输出的特征图反卷积扩大2倍,并将pool3输出特征图拿出,三者融合后通反卷积扩大到原始尺寸。

 使用以下代码构建一个FCN-8s网络:

import mindspore.nn as nnclass FCN8s(nn.Cell):def __init__(self, n_class):super().__init__()self.n_class = n_classself.conv1 = nn.SequentialCell(nn.Conv2d(in_channels=3, out_channels=64,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(64),nn.ReLU(),nn.Conv2d(in_channels=64, out_channels=64,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(64),nn.ReLU())self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.SequentialCell(nn.Conv2d(in_channels=64, out_channels=128,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(128),nn.ReLU(),nn.Conv2d(in_channels=128, out_channels=128,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(128),nn.ReLU())self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv3 = nn.SequentialCell(nn.Conv2d(in_channels=128, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU(),nn.Conv2d(in_channels=256, out_channels=256,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(256),nn.ReLU())self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv4 = nn.SequentialCell(nn.Conv2d(in_channels=256, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU())self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv5 = nn.SequentialCell(nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU(),nn.Conv2d(in_channels=512, out_channels=512,kernel_size=3, weight_init='xavier_uniform'),nn.BatchNorm2d(512),nn.ReLU())self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv6 = nn.SequentialCell(nn.Conv2d(in_channels=512, out_channels=4096,kernel_size=7, weight_init='xavier_uniform'),nn.BatchNorm2d(4096),nn.ReLU(),)self.conv7 = nn.SequentialCell(nn.Conv2d(in_channels=4096, out_channels=4096,kernel_size=1, weight_init='xavier_uniform'),nn.BatchNorm2d(4096),nn.ReLU(),)self.score_fr = nn.Conv2d(in_channels=4096, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore2 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=4, stride=2, weight_init='xavier_uniform')self.score_pool4 = nn.Conv2d(in_channels=512, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore_pool4 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=4, stride=2, weight_init='xavier_uniform')self.score_pool3 = nn.Conv2d(in_channels=256, out_channels=self.n_class,kernel_size=1, weight_init='xavier_uniform')self.upscore8 = nn.Conv2dTranspose(in_channels=self.n_class, out_channels=self.n_class,kernel_size=16, stride=8, weight_init='xavier_uniform')def construct(self, x):x1 = self.conv1(x)p1 = self.pool1(x1)x2 = self.conv2(p1)p2 = self.pool2(x2)x3 = self.conv3(p2)p3 = self.pool3(x3)x4 = self.conv4(p3)p4 = self.pool4(x4)x5 = self.conv5(p4)p5 = self.pool5(x5)x6 = self.conv6(p5)x7 = self.conv7(x6)sf = self.score_fr(x7)u2 = self.upscore2(sf)s4 = self.score_pool4(p4)f4 = s4 + u2u4 = self.upscore_pool4(f4)s3 = self.score_pool3(p3)f3 = s3 + u4out = self.upscore8(f3)return out

六、训练准备:

1、VGG16权重导入:

FCN使用VGG-16作为骨干网络,用于实现图像编码。使用下面代码导入VGG-16预训练模型的部分预训练权重。

from download import download
from mindspore import load_checkpoint, load_param_into_neturl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/fcn8s_vgg16_pretrain.ckpt"
download(url, "fcn8s_vgg16_pretrain.ckpt", replace=True)
def load_vgg16():ckpt_vgg16 = "fcn8s_vgg16_pretrain.ckpt"param_vgg = load_checkpoint(ckpt_vgg16)load_param_into_net(net, param_vgg)print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())), "VertexGeek")

2、损失函数:

语义分割是对图像中每个像素点进行分类,仍是分类问题,故损失函数选择交叉熵损失函数来计算FCN网络输出与mask之间的交叉熵损失。这里我们使用的是mindspore.nn.CrossEntropyLoss()作为损失函数。 

from mindspore.nn import CrossEntropyLossloss = CrossEntropyLoss()

3、评估指标:

 这一段实在是不好打出来,这里我偷个懒,直接给大家上图片了(doge)

import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.train as trainclass PixelAccuracy(train.Metric):def __init__(self, num_class=21):super(PixelAccuracy, self).__init__()self.num_class = num_classdef _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrixdef clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)def update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)def eval(self):pixel_accuracy = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum()return pixel_accuracyclass PixelAccuracyClass(train.Metric):def __init__(self, num_class=21):super(PixelAccuracyClass, self).__init__()self.num_class = num_classdef _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrixdef update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)def eval(self):mean_pixel_accuracy = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1)mean_pixel_accuracy = np.nanmean(mean_pixel_accuracy)return mean_pixel_accuracyclass MeanIntersectionOverUnion(train.Metric):def __init__(self, num_class=21):super(MeanIntersectionOverUnion, self).__init__()self.num_class = num_classdef _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrixdef update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)def eval(self):mean_iou = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -np.diag(self.confusion_matrix))mean_iou = np.nanmean(mean_iou)return mean_iouclass FrequencyWeightedIntersectionOverUnion(train.Metric):def __init__(self, num_class=21):super(FrequencyWeightedIntersectionOverUnion, self).__init__()self.num_class = num_classdef _generate_matrix(self, gt_image, pre_image):mask = (gt_image >= 0) & (gt_image < self.num_class)label = self.num_class * gt_image[mask].astype('int') + pre_image[mask]count = np.bincount(label, minlength=self.num_class**2)confusion_matrix = count.reshape(self.num_class, self.num_class)return confusion_matrixdef update(self, *inputs):y_pred = inputs[0].asnumpy().argmax(axis=1)y = inputs[1].asnumpy().reshape(4, 512, 512)self.confusion_matrix += self._generate_matrix(y, y_pred)def clear(self):self.confusion_matrix = np.zeros((self.num_class,) * 2)def eval(self):freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)iu = np.diag(self.confusion_matrix) / (np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -np.diag(self.confusion_matrix))frequency_weighted_iou = (freq[freq > 0] * iu[freq > 0]).sum()return frequency_weighted_iou

等把课程所有的内容都过一遍,我再仔细介绍这个网络的结构(先挖个坑)

七、模型训练:

准备完成之后,我们就可以开始训练了:

import mindspore
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, Modeldevice_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)train_batch_size = 4
num_classes = 21
# 初始化模型结构
net = FCN8s(n_class=21)
# 导入vgg16预训练参数
load_vgg16()
# 计算学习率
min_lr = 0.0005
base_lr = 0.05
train_epochs = 1
iters_per_epoch = dataset.get_dataset_size()
total_step = iters_per_epoch * train_epochslr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,base_lr,total_step,iters_per_epoch,decay_epoch=2)
lr = Tensor(lr_scheduler[-1])# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)
# 初始化模型
if device_target == "Ascend":model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})# 设置ckpt文件保存的参数
time_callback = TimeMonitor(data_size=iters_per_epoch)
loss_callback = LossMonitor()
callbacks = [time_callback, loss_callback]
save_steps = 330
keep_checkpoint_max = 5
config_ckpt = CheckpointConfig(save_checkpoint_steps=10,keep_checkpoint_max=keep_checkpoint_max)
ckpt_callback = ModelCheckpoint(prefix="FCN8s",directory="./ckpt",config=config_ckpt)
callbacks.append(ckpt_callback)
model.train(train_epochs, dataset, callbacks=callbacks)print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), 'VertexGeek')

尝试了训练一把,时间太长了,我这里手动截断了:

 当然如果实在想体验一把完全体的训练,可以考虑手动调大batch_size和learning_rate。

八、模型评估:

import mindspore
from mindspore.nn import Adam
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor, ModelIMAGE_MEAN = [103.53, 116.28, 123.675]
IMAGE_STD = [57.375, 57.120, 58.395]
DATA_FILE = "dataset/dataset_fcn8s/mindname.mindrecord"# 下载已训练好的权重文件
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/FCN8s.ckpt"
download(url, "FCN8s.ckpt", replace=True)
net = FCN8s(n_class=num_classes)ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
device_target = "Ascend"
mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target=device_target)# 定义损失函数
loss = nn.CrossEntropyLoss(ignore_index=255)
# 定义优化器
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=0.005, momentum=0.9, weight_decay=0.0001)
# 定义loss_scale
scale_factor = 4
scale_window = 3000
loss_scale_manager = ms.amp.DynamicLossScaleManager(scale_factor, scale_window)if device_target == "Ascend":model = Model(net, loss_fn=loss, optimizer=optimizer, loss_scale_manager=loss_scale_manager, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})
else:model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"pixel accuracy": PixelAccuracy(), "mean pixel accuracy": PixelAccuracyClass(), "mean IoU": MeanIntersectionOverUnion(), "frequency weighted IoU": FrequencyWeightedIntersectionOverUnion()})# 实例化Dataset
dataset = SegDataset(image_mean=IMAGE_MEAN,image_std=IMAGE_STD,data_file=DATA_FILE,batch_size=train_batch_size,crop_size=crop_size,max_scale=max_scale,min_scale=min_scale,ignore_label=ignore_label,num_classes=num_classes,num_readers=2,num_parallel_calls=4)
dataset_eval = dataset.get_dataset()print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), 'VertexGeek')

不过在运行model.eval()的时候遇到了同步流的问题,有知道怎么解决的宝子,私信我哈。

 

九、模型推理:

import cv2
import matplotlib.pyplot as pltnet = FCN8s(n_class=num_classes)
# 设置超参
ckpt_file = "FCN8s.ckpt"
param_dict = load_checkpoint(ckpt_file)
load_param_into_net(net, param_dict)
eval_batch_size = 4
img_lst = []
mask_lst = []
res_lst = []
# 推理效果展示(上方为输入图片,下方为推理效果图片)
plt.figure(figsize=(8, 5))
show_data = next(dataset_eval.create_dict_iterator())
show_images = show_data["data"].asnumpy()
mask_images = show_data["label"].reshape([4, 512, 512])
show_images = np.clip(show_images, 0, 1)
for i in range(eval_batch_size):img_lst.append(show_images[i])mask_lst.append(mask_images[i])
res = net(show_data["data"]).asnumpy().argmax(axis=1)
for i in range(eval_batch_size):plt.subplot(2, 4, i + 1)plt.imshow(img_lst[i].transpose(1, 2, 0))plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 4, i + 5)plt.imshow(res[i])plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), 'VertexGeek')

模型推理也出现了这个问题(苦笑),有佬帮忙解决一下吗?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/38647.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于Java实现图像浏览器的设计与实现

图像浏览器的设计与实现 前言一、需求分析选题意义应用意义功能需求关键技术系统用例图设计JPG系统用例图图片查看系统用例图 二、概要设计JPG.javaPicture.java 三、详细设计类图JPG.java UML类图picture.java UML类图 界面设计JPG.javapicture.java 四、源代码JPG.javapictur…

深入理解pytest fixture:提升测试的灵活性和可维护性!

在现代软件开发中&#xff0c;测试是保证代码质量的重要环节。pytest作为一个强大的测试框架&#xff0c;以其灵活的fixture系统脱颖而出。本文将详细介绍pytest中的fixture概念&#xff0c;通过具体案例展示其应用&#xff0c;并说明如何利用fixture提高测试的灵活性和可维护性…

EKF+UKF+CKF+PF的效果对比|三维非线性滤波|MATLAB例程

前言 标题里的EKF、UKF、CKF、PF分别为&#xff1a;扩展卡尔曼滤波、无迹卡尔曼滤波、容积卡尔曼滤波、粒子滤波。 EKF是扩展卡尔曼滤波&#xff0c;计算快&#xff0c;最常用于非线性状态方程或观测方程下的卡尔曼滤波。 但是EKF应对强非线性的系统时&#xff0c;估计效果不如…

头文件没有string.h ----- 怎么统计字符串的长度?

字符串的逆序&#xff08;看收藏里面的题&#xff09; 第一种方式&#xff1a; #include <stdio.h> void f(char *p);int main() {char s[1000];gets(s);f(s);printf("%s",s);return 0; }void f(char *p) {int i0;int q,k0;while(p[i]!\0){i;}while(k<i){…

SaaS增长:小型SaaS企业可以使用推荐奖励计划吗

在SaaS&#xff08;软件即服务&#xff09;行业的激烈竞争中&#xff0c;如何快速有效地增长用户数量是每个企业都面临的挑战。对于小型SaaS企业来说&#xff0c;资源有限&#xff0c;如何最大化利用现有资源实现用户增长成为了一个重要议题。在这样的背景下&#xff0c;推荐奖…

git clone中的报错问题解决:git@github.com: Permission denied (publickey)

报错&#xff1a; Submodule path ‘kernels/3rdparty/llm-awq’: checked out ‘19a5a2c9db47f69a2851c83fea90f81ed49269ab’ Submodule path ‘kernels/3rdparty/nvbench’: checked out ‘75212298727e8f6e1df9215f2fcb47c8c721ffc9’ Submodule path ‘kernels/3rdparty/t…

自动点赞,自动评论,自动刷

最近周六日家里没事干了个自动程序。需要的找我&#xff01; 仅供学习&#xff01;&#xff01;&#xff01;&#xff01;目前实现的功能 1.自动打开痘印&#xff0c;头条等多个app 2.自动点赞&#xff0c;自动评论 3.自动养号 4.自动关注 后期逐步实现: 1.继续内容的自动…

阿里云:云通信号码认证服务,node.js+uniapp(vue),完整代码

api文档&#xff1a;云通信号码认证服务_云产品主页-阿里云OpenAPI开发者门户 (aliyun.com) reg.vue <template> <div> <input class"sl-input" v-model"phone" type"number" maxlength"11" placeholder"手机号…

TopK问题与如何在有限内存找出前几最大(小)项(纯c语言版)

目录 0.前言 1.知识准备 2.实现 1.首先是必要的HeapSort 2.造数据 其他注意事项 3.TopK的实现 0.前言 在我们的日常生活中总有排名系统&#xff0c;找出前第k个分数最高的人&#xff0c;而现在让我们用堆来在有限内存中进行实现 1.知识准备 想要实现topk问题首先我们要…

Linux运维:mysql高级查询语句(2)

目 录 一、创建数据库&#xff1a; 二、创建表结构&#xff1a;DDL 2.1 学生表s&#xff1a; 2.2 成绩表sc&#xff1a; 2.3 课程表c&#xff1a; 三、录入数据&#xff1a;DML 3.1 对学生表s的数据录入&#xff1a; 3.2 对成绩表sc的数据录入&#xff1a; 3.3 对课…

【Kaggle】Telco Customer Churn 电信用户流失预测案例

⭐️前言&#xff1a;案例学习说明与案例建模流程 我们将围绕Kaggle中的电信用户流失数据集&#xff08;Telco Customer Churn&#xff09;进行用户流失预测。在此过程中&#xff0c;将综合应用此前所介绍的各种方法与技巧&#xff0c;并在实践中提炼总结更多实用技巧。 ⭐️对…

期权交易指南:为什么要交易场外个股期权?

今天带你了解期权交易指南&#xff1a;为什么要交易场外个股期权&#xff1f;随着金融市场的发展和创新&#xff0c;投资者寻求更多的工具来管理风险和获得更高的回报。场外期权交易应运而生&#xff0c;成为一种重要的金融衍生品交易方式。 简单来说就是期权是一种合约&#…

Mysql 的账户管理,索引,存储引擎

目录 一.MySQL的账户管理 1.存放用户信息的表 2.查看当前使用的用户 3.新建用户 4.修改用户名称 5.删除用户 6.修改用户密码 7.破解密码 8. 远程登录 9.用户权限管理 9.1 权限类别 9.2 查看权限 9.3 授予权限 9.4 撤销权限 二.索引 1. 索引管理 1.1 查看索…

麒麟v10-sp3安装kkfileview

1、上传包到服务器 执行&#xff1a;/bin/startup.sh 会自动安装LibreOffice&#xff0c;因为/bin/install.sh判断了不是redhat-release就是ubuntu&#xff0c;导致麒麟系统会走ubuntu&#xff0c;所以会失败&#xff0c;这里改一下如果是麒麟也走install_redhat就可以了 也…

服务器日志事件ID4107:从自动更新 cab 中提取第三方的根目录列表失败,错误为: 已处理证书链,但是在不受信任提供程序信任的根证书中终止。

在查看Windows系统日志时&#xff0c;你是否有遇到过事件ID4107错误&#xff0c;来源CAPI2&#xff0c;详细信息在 http://www.download.windowsupdate.com/msdownload/update/v3/static/trustedr/en/authrootstl.cab 从自动更新 cab 中提取第三方的根目录列表失败&#xff0c;…

【存储】相关内容

【存储】相关内容 1. 存储类型1. 块存储2. 文件存储3. 对象存储4. 三种存储类型对比 2. 常见的存储分类1. DAS2. SAN3. NAS4. 存储分类分析比较 3. 一些存储的概念1. LUN2. volume3. HBA4. iSCSI 1. 存储类型 块存储和文件存储是我们比较熟悉的两种主流的存储类型&#xff0c;…

解决OneDrive “拒绝访问文件” 问题

问题描述&#xff1a; 在尝试将其他文件拖入oneDrive或是打开OneDrive中的文件时。出现如下报错&#xff1a; 拒绝访问文件 无法访问XXXXXXX中的文件。可能已移动或删除了此文件&#xff0c;或者受制于文件权限而不能访问。 ERR_ACCESS_DENIED 解决办法&#xff1a; 1. 找到O…

JavaEE—什么是服务器?以及Tomcat安装到如何集成到IDEA中?

目录 ▐ 前言 ▐ JavaEE是指什么? ▐ 什么是服务器&#xff1f; ▐ Tomcat安装教程 * 修改服务端口号 ▐ 将Tomcat集成到IDEA中 ▐ 测试 ▐ 结语 ▐ 前言 至此&#xff0c;这半年来我已经完成了JavaSE&#xff0c;Mysql数据库&#xff0c;以及Web前端知识的学习了&am…

notepad++安装并打开json文件

1、notepad安装 1、首先下载Notepad.exe 2、选择简体中文安装 点击下一步 点击“我接受” 选择安装目录&#xff0c;进行下一步安装 默认下一步 选择安装 等待安装完成 点击完成 2、保存json文件 复制返回结果 先把返回结果复制出来。保存到text里面 把文件另存为json格式 3、…

安卓手机的自带录屏在哪里找?5个软件帮助你快速给手机录屏

安卓手机的自带录屏在哪里找&#xff1f;5个软件帮助你快速给手机录屏 在安卓手机上进行屏幕录制是一项非常实用的功能&#xff0c;特别是对于需要录制游戏操作、制作教程或演示的用户来说。虽然部分安卓手机可能已经预装了屏幕录制功能&#xff0c;但有时候这些功能可能隐藏在…