昇思25天学习打卡营第15天|ResNet50图像分类

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

ResNet50图像分类

图像分类是最基础的计算机视觉应用,属于有监督学习类别,如给定一张图像(猫、狗、飞机、汽车等等),判断图像所属的类别。本章将介绍使用ResNet50网络对CIFAR-10数据集进行分类。

ResNet网络介绍

ResNet50网络是2015年由微软实验室的何恺明提出,获得ILSVRC2015图像分类竞赛第一名。在ResNet网络提出之前,传统的卷积神经网络都是将一系列的卷积层和池化层堆叠得到的,但当网络堆叠到一定深度时,就会出现退化问题。下图是在CIFAR-10数据集上使用56层网络与20层网络训练误差和测试误差图,由图中数据可以看出,56层网络比20层网络训练误差和测试误差更大,随着网络的加深,其误差并没有如预想的一样减小。

resnet-1

ResNet网络提出了残差网络结构(Residual Network)来减轻退化问题,使用ResNet网络可以实现搭建较深的网络结构(突破1000层)。论文中使用ResNet网络在CIFAR-10数据集上的训练误差与测试误差图如下图所示,图中虚线表示训练误差,实线表示测试误差。由图中数据可以看出,ResNet网络层数越深,其训练误差和测试误差越小。

resnet-4

了解ResNet网络更多详细内容,参见ResNet论文。

ImageNet 的示例网络架构。左:VGG-19 模型作为参考。中:一个具有 34 个参数层的普通网络。右:一个具有 34 个参数层的残差网络。虚线快捷连接(shortcut connections)用于增加维度。

数据集准备与加载

CIFAR-10数据集共有60000张32*32的彩色图像,分为10个类别,每类有6000张图,数据集一共有50000张训练图片和10000张评估图片。首先,如下示例使用download接口下载并解压,目前仅支持解析二进制版本的CIFAR-10文件(CIFAR-10 binary version)。

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.2.14
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/nginx/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"download(url, "./datasets-cifar10-bin", kind="tar.gz", replace=True)
Creating data folder...
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz (162.2 MB)file_sizes: 100%|█████████████████████████████| 170M/170M [00:00<00:00, 198MB/s]
Extracting tar.gz file...
Successfully downloaded / unzipped to ./datasets-cifar10-bin
'./datasets-cifar10-bin'

下载后的数据集目录结构如下:

datasets-cifar10-bin/cifar-10-batches-bin
├── batches.meta.text
├── data_batch_1.bin
├── data_batch_2.bin
├── data_batch_3.bin
├── data_batch_4.bin
├── data_batch_5.bin
├── readme.html
└── test_batch.bin

然后,使用mindspore.dataset.Cifar10Dataset接口来加载数据集,并进行相关图像增强操作。

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstypedata_dir = "./datasets-cifar10-bin/cifar-10-batches-bin"  # 数据集根目录
batch_size = 256  # 批量大小
image_size = 32  # 训练图像空间大小
workers = 4  # 并行线程个数
num_classes = 10  # 分类数量def create_dataset_cifar10(dataset_dir, usage, resize, batch_size, workers):data_set = ds.Cifar10Dataset(dataset_dir=dataset_dir,usage=usage,num_parallel_workers=workers,shuffle=True)trans = []if usage == "train":trans += [vision.RandomCrop((32, 32), (4, 4, 4, 4)),vision.RandomHorizontalFlip(prob=0.5)]trans += [vision.Resize(resize),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]target_trans = transforms.TypeCast(mstype.int32)# 数据映射操作data_set = data_set.map(operations=trans,input_columns='image',num_parallel_workers=workers)data_set = data_set.map(operations=target_trans,input_columns='label',num_parallel_workers=workers)# 批量操作data_set = data_set.batch(batch_size)return data_set# 获取处理后的训练与测试数据集dataset_train = create_dataset_cifar10(dataset_dir=data_dir,usage="train",resize=image_size,batch_size=batch_size,workers=workers)
step_size_train = dataset_train.get_dataset_size()dataset_val = create_dataset_cifar10(dataset_dir=data_dir,usage="test",resize=image_size,batch_size=batch_size,workers=workers)
step_size_val = dataset_val.get_dataset_size()

下载CIFAR-10数据集及数据增强操作,如随机裁剪、水平翻转、调整大小、归一化等,增加数据的多样性,提高了模型的泛化能力。

对CIFAR-10训练数据集进行可视化。

import matplotlib.pyplot as plt
import numpy as npdata_iter = next(dataset_train.create_dict_iterator())images = data_iter["image"].asnumpy()
labels = data_iter["label"].asnumpy()
print(f"Image shape: {images.shape}, Label shape: {labels.shape}")# 训练数据集中,前六张图片所对应的标签
print(f"Labels: {labels[:6]}")classes = []with open(data_dir + "/batches.meta.txt", "r") as f:for line in f:line = line.rstrip()if line:classes.append(line)# 训练数据集的前六张图片
plt.figure()
for i in range(6):plt.subplot(2, 3, i + 1)image_trans = np.transpose(images[i], (1, 2, 0))mean = np.array([0.4914, 0.4822, 0.4465])std = np.array([0.2023, 0.1994, 0.2010])image_trans = std * image_trans + meanimage_trans = np.clip(image_trans, 0, 1)plt.title(f"{classes[labels[i]]}")plt.imshow(image_trans)plt.axis("off")
plt.show()
Image shape: (256, 3, 32, 32), Label shape: (256,)
Labels: [1 1 2 9 4 0]

展示训练数据集的前六张图片。

构建网络

残差网络结构(Residual Network)是ResNet网络的主要亮点,ResNet使用残差网络结构后可有效地减轻退化问题,实现更深的网络结构设计,提高网络的训练精度。本节首先讲述如何构建残差网络结构,然后通过堆叠残差网络来构建ResNet50网络。

构建残差网络结构

残差网络结构图如下图所示,残差网络由两个分支构成:一个主分支,一个shortcuts(图中弧线表示)。主分支通过堆叠一系列的卷积操作得到,shortcuts从输入直接到输出,主分支输出的特征矩阵𝐹(𝑥)加上shortcuts输出的特征矩阵𝑥𝑥得到𝐹(𝑥)+𝑥,通过Relu激活函数后即为残差网络最后的输出。

residual

残差网络结构主要由两种,一种是Building Block,适用于较浅的ResNet网络,如ResNet18和ResNet34;另一种是Bottleneck,适用于层数较深的ResNet网络,如ResNet50、ResNet101和ResNet152。

Building Block

Building Block结构图如下图所示,主分支有两层卷积网络结构:

  • 主分支第一层网络以输入channel为64为例,首先通过一个3×3的卷积层,然后通过Batch Normalization层,最后通过Relu激活函数层,输出channel为64;
  • 主分支第二层网络的输入channel为64,首先通过一个3×3的卷积层,然后通过Batch Normalization层,输出channel为64。

最后将主分支输出的特征矩阵与shortcuts输出的特征矩阵相加,通过Relu激活函数即为Building Block最后的输出。

building-block-5

主分支与shortcuts输出的特征矩阵相加时,需要保证主分支与shortcuts输出的特征矩阵shape相同。如果主分支与shortcuts输出的特征矩阵shape不相同,如输出channel是输入channel的一倍时,shortcuts上需要使用数量与输出channel相等,大小为1×1的卷积核进行卷积操作;若输出的图像较输入图像缩小一倍,则要设置shortcuts中卷积操作中的stride为2,主分支第一层卷积操作的stride也需设置为2。

如下代码定义ResidualBlockBase类实现Building Block结构。

from typing import Type, Union, List, Optional
import mindspore.nn as nn
from mindspore.common.initializer import Normal# 初始化卷积层与BatchNorm的参数
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class ResidualBlockBase(nn.Cell):expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x  # shortcuts分支out = self.conv1(x)  # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out
Bottleneck

Bottleneck结构图如下图所示,在输入相同的情况下Bottleneck结构相对Building Block结构的参数数量更少,更适合层数较深的网络,ResNet50使用的残差结构就是Bottleneck。该结构的主分支有三层卷积结构,分别为1×1的卷积层、3×3卷积层和1×1的卷积层,其中1×1的卷积层分别起降维和升维的作用。

  • 主分支第一层网络以输入channel为256为例,首先通过数量为64,大小为1×1的卷积核进行降维,然后通过Batch Normalization层,最后通过Relu激活函数层,其输出channel为64;
  • 主分支第二层网络通过数量为64,大小为3×3的卷积核提取特征,然后通过Batch Normalization层,最后通过Relu激活函数层,其输出channel为64;
  • 主分支第三层通过数量为256,大小1×1的卷积核进行升维,然后通过Batch Normalization层,其输出channel为256。

最后将主分支输出的特征矩阵与shortcuts输出的特征矩阵相加,通过Relu激活函数即为Bottleneck最后的输出。

building-block-6

主分支与shortcuts输出的特征矩阵相加时,需要保证主分支与shortcuts输出的特征矩阵shape相同。如果主分支与shortcuts输出的特征矩阵shape不相同,如输出channel是输入channel的一倍时,shortcuts上需要使用数量与输出channel相等,大小为1×1的卷积核进行卷积操作;若输出的图像较输入图像缩小一倍,则要设置shortcuts中卷积操作中的stride为2,主分支第二层卷积操作的stride也需设置为2。

如下代码定义ResidualBlock类实现Bottleneck结构。

class ResidualBlock(nn.Cell):expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍def __init__(self, in_channel: int, out_channel: int,stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=1, weight_init=weight_init)self.norm1 = nn.BatchNorm2d(out_channel)self.conv2 = nn.Conv2d(out_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.norm2 = nn.BatchNorm2d(out_channel)self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,kernel_size=1, weight_init=weight_init)self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):identity = x  # shortscuts分支out = self.conv1(x)  # 主分支第一层:1*1卷积层out = self.norm1(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm2(out)out = self.relu(out)out = self.conv3(out)  # 主分支第三层:1*1卷积层out = self.norm3(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out
构建ResNet50网络

ResNet网络层结构如下图所示,以输入彩色图像224×224为例,首先通过数量64,卷积核大小为7×7,stride为2的卷积层conv1,该层输出图片大小为112×112,输出channel为64;然后通过一个3×3的最大下采样池化层,该层输出图片大小为56×56,输出channel为64;再堆叠4个残差网络块(conv2_x、conv3_x、conv4_x和conv5_x),此时输出图片大小为7×7,输出channel为2048;最后通过一个平均池化层、全连接层和softmax,得到分类概率。

resnet-layer

对于每个残差网络块,以ResNet50网络中的conv2_x为例,其由3个Bottleneck结构堆叠而成,每个Bottleneck输入的channel为64,输出channel为256。

如下示例定义make_layer实现残差块的构建,其参数如下所示:

  • last_out_channel:上一个残差网络输出的通道数。
  • block:残差网络的类别,分别为ResidualBlockBaseResidualBlock
  • channel:残差网络输入的通道数。
  • block_nums:残差网络块堆叠的个数。
  • stride:卷积移动的步幅。
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None  # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)

ResNet50网络共有5个卷积结构,一个平均池化层,一个全连接层,以CIFAR-10数据集为例:

  • conv1:输入图片大小为32×32,输入channel为3。首先经过一个卷积核数量为64,卷积核大小为7×7,stride为2的卷积层;然后通过一个Batch Normalization层;最后通过Reul激活函数。该层输出feature map大小为16×16,输出channel为64。
  • conv2_x:输入feature map大小为16×16,输入channel为64。首先经过一个卷积核大小为3×3,stride为2的最大下采样池化操作;然后堆叠3个[1×1,64;3×3,64;1×1,256]结构的Bottleneck。该层输出feature map大小为8×8,输出channel为256。
  • conv3_x:输入feature map大小为8×8,输入channel为256。该层堆叠4个[1×1,128;3×3,128;1×1,512]结构的Bottleneck。该层输出feature map大小为4×4,输出channel为512。
  • conv4_x:输入feature map大小为4×4,输入channel为512。该层堆叠6个[1×1,256;3×3,256;1×1,1024]结构的Bottleneck。该层输出feature map大小为2×2,输出channel为1024。
  • conv5_x:输入feature map大小为2×2,输入channel为1024。该层堆叠3个[1×1,512;3×3,512;1×1,2048]结构的Bottleneck。该层输出feature map大小为1×1,输出channel为2048。
  • average pool & fc:输入channel为2048,输出channel为分类的类别数。

如下示例代码实现ResNet50模型的构建,通过用调函数resnet50即可构建ResNet50模型,函数resnet50参数如下:

  • num_classes:分类的类别数,默认类别数为1000。
  • pretrained:下载对应的训练模型,并加载预训练模型中的参数到网络中。
from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return x

def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型download(url=model_url, path=pretrained_ckpt, replace=True)param_dict = load_checkpoint(pretrained_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"""ResNet50模型"""resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

残差网络通过跳跃连接(shortcuts)将输入直接添加到输出残差网络结构主要由两种,一种是Building Block,适用于较浅的ResNet网络;另一种是Bottleneck,适用于层数较深的ResNet网络。ResNet50模型由多个残差块(Residual Block)组成,每个残差块包含多个卷积层和批归一化层。堆叠不同数量的残差块,可以构建不同深度的ResNet模型。

模型训练与评估

本节使用ResNet50预训练模型进行微调。调用resnet50构造ResNet50模型,并设置pretrained参数为True,将会自动下载ResNet50预训练模型,并加载预训练模型中的参数到网络中。然后定义优化器和损失函数,逐个epoch打印训练的损失值和评估精度,并保存评估精度最高的ckpt文件(resnet50-best.ckpt)到当前路径的./BestCheckPoint下。

由于预训练模型全连接层(fc)的输出大小(对应参数num_classes)为1000, 为了成功加载预训练权重,我们将模型的全连接输出大小设置为默认的1000。CIFAR10数据集共有10个分类,在使用该数据集进行训练时,需要将加载好预训练权重的模型全连接层输出大小重置为10。

此处我们展示了5个epochs的训练过程,如果想要达到理想的训练效果,建议训练80个epochs。

# 定义ResNet50网络
network = resnet50(pretrained=True)# 全连接层输入层的大小
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=10)
# 重置全连接层
network.fc = fc
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt (97.7 MB)file_sizes: 100%|█████████████████████████████| 102M/102M [00:00<00:00, 131MB/s]
Successfully downloaded file to ./LoadPretrainedModel/resnet50_224_new.ckpt
# 设置学习率
num_epochs = 5
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,step_per_epoch=step_size_train, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = network(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss

import os# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)# 最佳模型存储路径
best_acc = 0
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"if not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)

import mindspore.ops as opsdef train(data_loader, epoch):"""模型训练"""losses = []network.set_train(True)for i, (images, labels) in enumerate(data_loader):loss = train_step(images, labels)if i % 100 == 0 or i == step_size_train - 1:print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' %(epoch + 1, num_epochs, i + 1, step_size_train, loss))losses.append(loss)return sum(losses) / len(losses)def evaluate(data_loader):"""模型验证"""network.set_train(False)correct_num = 0.0  # 预测正确个数total_num = 0.0  # 预测总数for images, labels in data_loader:logits = network(images)pred = logits.argmax(axis=1)  # 预测结果correct = ops.equal(pred, labels).reshape((-1, ))correct_num += correct.sum().asnumpy()total_num += correct.shape[0]acc = correct_num / total_num  # 准确率return acc

# 开始循环训练
print("Start Training Loop ...")for epoch in range(num_epochs):curr_loss = train(data_loader_train, epoch)curr_acc = evaluate(data_loader_val)print("-" * 50)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, curr_loss, curr_acc))print("-" * 50)# 保存当前预测准确率最高的模型if curr_acc > best_acc:best_acc = curr_accms.save_checkpoint(network, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)
Start Training Loop ...
Epoch: [  1/  5], Steps: [  1/196], Train Loss: [2.378]
Epoch: [  1/  5], Steps: [101/196], Train Loss: [1.535]
Epoch: [  1/  5], Steps: [196/196], Train Loss: [1.096]
--------------------------------------------------
Epoch: [  1/  5], Average Train Loss: [1.614], Accuracy: [0.598]
--------------------------------------------------
Epoch: [  2/  5], Steps: [  1/196], Train Loss: [0.990]
Epoch: [  2/  5], Steps: [101/196], Train Loss: [0.947]
Epoch: [  2/  5], Steps: [196/196], Train Loss: [0.964]
--------------------------------------------------
Epoch: [  2/  5], Average Train Loss: [1.006], Accuracy: [0.684]
--------------------------------------------------
Epoch: [  3/  5], Steps: [  1/196], Train Loss: [0.825]
Epoch: [  3/  5], Steps: [101/196], Train Loss: [0.843]
Epoch: [  3/  5], Steps: [196/196], Train Loss: [0.822]
--------------------------------------------------
Epoch: [  3/  5], Average Train Loss: [0.845], Accuracy: [0.721]
--------------------------------------------------
Epoch: [  4/  5], Steps: [  1/196], Train Loss: [0.713]
Epoch: [  4/  5], Steps: [101/196], Train Loss: [0.792]
Epoch: [  4/  5], Steps: [196/196], Train Loss: [0.772]
--------------------------------------------------
Epoch: [  4/  5], Average Train Loss: [0.774], Accuracy: [0.732]
--------------------------------------------------
Epoch: [  5/  5], Steps: [  1/196], Train Loss: [0.720]
Epoch: [  5/  5], Steps: [101/196], Train Loss: [0.790]
Epoch: [  5/  5], Steps: [196/196], Train Loss: [0.731]
--------------------------------------------------
Epoch: [  5/  5], Average Train Loss: [0.742], Accuracy: [0.736]
--------------------------------------------------
================================================================================
End of validation the best Accuracy is:  0.736, save the best ckpt file in ./BestCheckpoint/resnet50-best.ckpt

使用预训练的ResNet50模型进行微调,加快训练速度并提高模型性能。定义优化器、损失函数和训练循环,对模型进行训练,在验证集上评估模型性能。

可视化模型预测

定义visualize_model函数,使用上述验证精度最高的模型对CIFAR-10测试数据集进行预测,并将预测结果可视化。若预测字体颜色为蓝色表示为预测正确,预测字体颜色为红色则表示预测错误。

由上面的结果可知,5个epochs下模型在验证数据集的预测准确率在70%左右,即一般情况下,6张图片中会有2张预测失败。如果想要达到理想的训练效果,建议训练80个epochs。

import matplotlib.pyplot as pltdef visualize_model(best_ckpt_path, dataset_val):num_class = 10  # 对狼和狗图像进行二分类net = resnet50(num_class)# 加载模型参数param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)# 加载验证集的数据进行验证data = next(dataset_val.create_dict_iterator())images = data["image"]labels = data["label"]# 预测图像类别output = net(data['image'])pred = np.argmax(output.asnumpy(), axis=1)# 图像分类classes = []with open(data_dir + "/batches.meta.txt", "r") as f:for line in f:line = line.rstrip()if line:classes.append(line)# 显示图像及图像的预测值plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels.asnumpy()[i] else 'red'plt.title('predict:{}'.format(classes[pred[i]]), color=color)picture_show = np.transpose(images.asnumpy()[i], (1, 2, 0))mean = np.array([0.4914, 0.4822, 0.4465])std = np.array([0.2023, 0.1994, 0.2010])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()# 使用测试数据集进行验证
visualize_model(best_ckpt_path=best_ckpt_path, dataset_val=dataset_val)

可视化模型的预测结果,直观查看模型的预测,包括预测正确的样本和预测错误的样本。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/41500.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JAVA入门】Day13 - 代码块

【JAVA入门】Day13 - 代码块 文章目录 【JAVA入门】Day13 - 代码块一、局部代码块二、构造代码块三、静态代码块 在 Java 中&#xff0c;两个大括号 { } 中间的部分叫一个代码块&#xff0c;代码块又分为&#xff1a;局部代码块、构造代码块、静态代码块三种。 一、局部代码块…

c++11新特性-3-自动类型推导

文章目录 自动类型推导1.auto1.1 const修饰1.2 auto不能使用的场景1.3 auto应用场景 2.decltype1.1 基本语法 自动类型推导 1.auto 注意&#xff0c;auto必须进行初始化 auto i 10; //int类型auto k 3.14; //double类型auto db; //错误1.1 const修饰 当const修改指针或者…

C++:构造函数是什么东西

一、构造函数是什么 在C中&#xff0c;构造函数是一种特殊成员函数&#xff0c;它有一下几个明显的特征&#xff1a; 1、它自动在创建新对象时被调用。 2、其名称与类名相同&#xff0c; 3、没有返回类型&#xff0c; 4、通常没有参数&#xff08;除了默认情况下的隐式thi…

跟《经济学人》学英文:2024年06月01日这期 The side-effects of the TikTok tussle

The side-effects of the TikTok tussle tussle&#xff1a;美 [ˈtəsəl] 激烈扭打&#xff1b;争夺 注意发音 side-effects&#xff1a;副作用&#xff1b;&#xff08;side-effect的复数&#xff09; As the app’s future hangs in the balance, the ramifications of …

MySQL的并发控制、事务、日志

目录 一.并发控制 1.锁机制 2.加锁与释放锁 二.事务&#xff08;transactions&#xff09; 1.事物的概念 2.ACID特性 3.事务隔离级别 三.日志 1.事务日志 2.错误日志 3.通用日志 4.慢查询日志 5.二进制日志 备份 一.并发控制 在 MySQL 中&#xff0c;并发控制是确…

都有哪些离线翻译器软件?没网就用这4个

经历完痛苦的期末考&#xff0c;可算是千盼万盼等来了日思夜想的暑假&#xff01;趁着这大好时光&#xff0c;怎么能不来一场出国游呢~ 不知道有多少小伙伴和我一样&#xff0c;出国玩最怕的就是语言不通&#xff0c;不管是吃饭还是游玩体验感都会大受影响~好在多出国玩了几趟…

ES6模块化学习

1. 回顾&#xff1a;node.js 中如何实现模块化 node.js 遵循了 CommonJS 的模块化规范。其中&#xff1a; 导入其它模块使用 require() 方法 模块对外共享成员使用 module.exports 对象 模块化的好处&#xff1a; 大家都遵守同样的模块化规范写代码&#xff…

Linux 时区文件格式【man 5 tzfile】

时区文件格式标准&#xff1a;https://datatracker.ietf.org/doc/html/rfc8536 1. NAME&#xff08;名&#xff09; tzfile - 时区文件。&#xff08;非文本文件&#xff09; 2. DESCRIPTION&#xff08;描述&#xff09; 本页介绍被 tzset(3) 函数使用的时区文件的结构。这…

006 线程安全

文章目录 临界资源线程安全基本概念*何谓竞态条件**何谓线程安全* 对象的安全局部基本类型变量局部的对象引用对象成员(成员变量) 不可变性 临界资源 临界资源是一次仅允许一个进程使用的共享资源。各进程采取互斥的方式&#xff0c;实现共享的资源称作临界资源。属于临界资源…

如何使用GPT进行科研:详细指令指南

如何使用GPT进行科研&#xff1a;详细指令指南 随着GPT模型的流行&#xff0c;越来越多的科研人员开始利用这项技术来辅助科学研究&#xff0c;特别是在文本处理任务如论文翻译、文本润色和降低抄袭率方面。本文将提供详细的指令&#xff0c;帮助科研人员有效地使用GPT进行科研…

计算机相关专业入门,高考假期预习指南

一&#xff1a;学习资源推荐 跟着b站的“黑马程序员”学c&#xff0c;黑马程序员匠心之作|C教程从0到1入门编程,学习编程不再难_哔哩哔哩_bilibili&#xff0c;把这个编程语言基础打好&#xff0c;然后看“蓝桥杯算法”&#xff0c;到了大一直接就能打蓝桥杯比赛了 看完上面的 …

TRILL简介

介绍TRILL的定义及目的。 定义 TRILL(Transparent Interconnection of Lots of Links)是一种把三层链路状态路由技术应用于二层网络的协议。TRILL通过扩展IS-IS路由协议实现二层路由&#xff0c;可以很好地满足数据中心大二层组网需求&#xff0c;为数据中心业务提供解决方案…

用数组手搓一个小顶堆

堆默认从数组下标为1开始存储。 const int N201000; int heap[N]; int len; 插入操作&#xff1a; 将元素插入到堆的末尾位置向上调整。 void up(int k){while(k>1&&heap[k/2]>heap[k]){swap(heap[k],heap[k/2]);k/2;} } //len为当前存在元素长度 void Inser…

水利水库大坝结构安全自动化监测主要测哪些内容?

在大坝安全自动化监测系统建设中&#xff0c;应根据坝型、坝体结构和地质条件等因素选定监测项目&#xff1b;主要监测对象包括坝体、坝基及有关的各种主要水工建筑物、大坝附近的不稳定岸坡和大坝周边的气象环境。深圳安锐科技建议参考下列表格适当调整。 &#xff08;一&am…

计算机网络(2

计算机网络续 一. 网络编程 网络编程, 指网络上的主机, 通过不同的进程, 以编程的方式实现网络通信(或网络数据传输). 即便是同一个主机, 只要不同进程, 基于网络来传输数据, 也属于网络编程. 二. 网络编程套接字(socket) socket: 操作系统提供的网络编程的 API 称作 “soc…

(0)2024年基于财务的数据科学项目Python编程基础(Jupyter Notebooks)

目录 前言学习目标&#xff1a;学习内容&#xff1a;大纲 前言 随着数据科学的迅猛发展&#xff0c;其在财务领域的应用也日益广泛。财务数据的分析和预测对于企业的决策过程至关重要。 本专栏旨在通过Jupyter Notebooks这一强大的交互式计算工具&#xff0c;介绍基于财务的数…

【车载开发系列】常见单片机调试接口的区别

【车载开发系列】常见单片机调试接口的区别 【车载开发系列】常见单片机调试接口的区别 【车载开发系列】常见单片机调试接口的区别一. JTAG协议二. SWD接口三. RDI接口四. 仿真器1&#xff09;J-Link仿真器2&#xff09;ULink仿真器3&#xff09;ST-LINK仿真器 五. SWD / JTAG…

Day05-组织架构-角色管理

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 1.组织架构-编辑部门-弹出层获取数据2.组织架构-编辑部门-编辑表单校验3.组织架构-编辑部门-确认取消4.组织架构-删除部门5.角色管理-搭建页面结构6.角色管理-获取数…

MySQL中的DDL语句

第一题 输入密码登录mysql&#xff0c;创建数据库zoo&#xff0c;转换到zoo数据库&#xff0c; mysql> create database zoo character set gbk; mysql> use zoo查看创建数据库zoo信息 mysql> show create database zoo;删除数据库zoo mysql> drop database zo…

【后端面试题】【中间件】【NoSQL】MongoDB查询优化2(优化排序、mongos优化)

优化排序 在MongoDB里面&#xff0c;如果能够利用索引来排序的话&#xff0c;直接按照索引顺序加载数据就可以了。如果不能利用索引来排序的话&#xff0c;就必须在加载了数据之后&#xff0c;再次进行排序&#xff0c;也就是进行内存排序。 可想而知&#xff0c;如果内存排序…