基于Mindspore,通过Resnet50迁移学习实现猫十二分类

使用平台介绍

使用平台:启智AI协作平台
使用数据集:百度猫十二分类

数据集介绍

有cat_12_train和cat_12_test和train_list.txt
train_list.txt内有每张图片所对应的标签

Minspore部分操作科普

数据集加载

Mindspore加载图片数据集就直接调整成这种格式就行,然后可以用这个函数加载,自动生成两个列,一列是图片,一列是标签;ImageFolderDataset函数会自动读取和处理数据集,标签就是文件夹的名称
在这里插入图片描述
在这里插入图片描述

数据处理

map函数里可以一键进行处理和映射,定义好数据处理函数,直接把路径和标签Map处理,后面可以带上是训练集还是测试集的标签;,要处理图片就指定 input_columns 参数为image,这个是前面数据读取形成的;
本项目这里用的是ImageFolderDataset,可以自动生成图片对应数据的两个列,要是相对数据处理就设置为数据列名即可
在这里插入图片描述

数据批处理和重复

batch就是批处理,把数据分成指定数量的一个个批次
在这里插入图片描述
最后repeat对数据进行重复
在这里插入图片描述

整体数据处理流程

就是读取数据形成数据和标签对应的列(读取数据函数有很多),然后定义数据预处理函数,在map函数里一键映射,指定要处理的列一键处理,最后对数据进行批次划分,就拿到可以放进训练网络函数的规范数据集了。当然使用时候还要用create_tuple_iterator或者create_dict_iterator函数形成可迭代的数据集。

迁移学习

在迁移学习中,固定特征训练和模型微调都是常用的技术

固定特征训练

在源任务上训练一个模型,并将其应用到目标任务上。在这个过程中,模型的特征提取器是固定的,只对输出层进行调整。这种方法可以利用源任务中已经学习到的特征,从而减少目标任务的训练时间和数据需求。固定特征训练通常适用于目标任务和源任务具有相似的特征空间,并且目标任务的数据量较小的情况。

模型微调

使用预训练模型作为初始模型,并在目标任务的数据集上进行进步训练。在微调阶段,可以根据目标任务的数据和特定要求调整模型的参数使其适应目标任务。模型微调的主要目的是通过在目标任务上的有限训练来调整预训练模型,以取得更好的性能。

举例

以训练一个猫、狗分类器为例,固定特征训练是指在一个大型的猫狗图片数据集上训练一个通用的图像识别模型,然后将该模型应用于特定的猫、狗分类任务。在这个过程中,我们只需要调整模型的输出层,使其能够正确地对猫和狗进行分类。而模型微调则是指使用一个已经在大型数据集上训练好的通用图像识别模型,然后在特定的猫、狗图片数据集上进行进一步训练,以优化模型的性能。在这个过程中我们可以根据猫、狗图片的特点来调整模型的参数,使其能够更好地识别猫和狗。

数据处理过程

由于本项目采用的数据集是百度所提供的猫十二分类,要使用ImageFolderDataset的话,形式不太匹配

现有形式

cat_12_train和cat_12_test里面都是一张一张的图片
train_list.txt内有每张图片所对应的标签
在这里插入图片描述

处理后形式

train和val文件夹内分别有十二个子文件夹,代表12类猫,每个子文件夹内又有一张张的图片
在这里插入图片描述在这里插入图片描述

处理代码

这里有相关代码进行自动划分,但是对于训练集和测试集的划分,我直接采用了手动操作,也可以用代码来实现的;

# 处理异常图片
dir_lit = os.listdir('./work/cat_12_train/')
# dir_lit为一个列表,里面是一张张图片的名称
for list in dir_lit:
# list是图片名称,这里的操作是把这个图片形成一个个的路径img_path=os.path.join('./work/cat_12_train/',list)print(img_path)# 如果不是RGB那就转换为RGBimg=Image.open(img_path)if img.mode != 'RGB':img = img.convert('RGB')img.save(img_path)dir_lit = os.listdir('./work/cat_12_test/')
for list in dir_lit:img_path=os.path.join('./work/cat_12_test/',list)img=Image.open(img_path)if img.mode != 'RGB':img = img.convert('RGB')img.save(img_path)
# 整理数据格式
# 创建12个文件夹分别对应标签
path='./work/MyDataset/'
for i in range(12):if not os.path.exists(path+str(i)):os.mkdir(path+str(i))else:continue
#读取每一行
with open(f'./work/train_list.txt','r')as f:img_path=f.readlines()print(img_path)# 里面是一个个的'cat_12_train/8GOkTtqw7E6IHZx4olYnhzvXLCiRsUfM.jpg\t0\n'# 把对应文件放到对应标签文件夹下
for img in img_path:# 拿取每一张图片路径# print(img)img_src= img.split('\t')[0]# img_src为一个个图片路径rel_src= img_src.split('cat_12_train/')[1]# rel_src为图片名称img_label = img.split('\t')[1]img_label = img_label.split('\n')[0]# img_label为图片标签print(img_src)print(rel_src)print(img_label)# os.system(f'cp ./work/{img_src} ./work/MyDataset/{img_label}/{rel_src}')shutil.copy(f'./work/{img_src}',f'./work/MyDataset/{img_label}/{rel_src}')
print('图片处理完毕')

整体代码

# 解压上传的数据集压缩包并查看数据集结构
!unzip MyDataset.zip -d data/
import os
print(os.listdir("data"))
print(os.listdir("data/train"))
print(os.listdir('data/train/1'))

输出

['val', 'train']
['9', '10', '5', '8', '11', '2', '7', '3', '0', '1', '6', '4']
['DKkQylbgdrWRjYap63MCJe0UBLhcHXPm.jpg', 'k9HWNaG2Z1wUKAOYdSDu7vRr4xBqmTCV.jpg', 'LfIoOrSNvKHQzsGtm5eMZc0lRuBXhTP6.jpg', '0esFjXNqc5xbMmUaJkRVPwQorWlu3LvA.jpg', 'PMoFIabq0W9U2wETZr7yf4JLYdBxv6hQ.jpg', 'QlUX4zHfPZ3LxRDswqm5FeMbnTWNaj6g.jpg', '7E9oOUcQjkLMvpAtNymHCRSqFfdGVDK4.jpg', 'pznq7EivBH9LwrNysIWxgGeomTlOP8cZ.jpg', '3Ndv9X6uTgzFtnoA01VECIBPj7xqlewG.jpg', 'o1g6adKmS4lBDw2F5buAYnetUWh7xXGz.jpg', '7QZTYlspK2fqdJUwjC0HDmOFrM5W4PX9.jpg', 'qgimvDE8Zaf4PJ32dkNhwVy5nxATOtrX.jpg', 'RJnThakSOGUzeFBdigXAm2NsL8jyYvu3.jpg', 'Ig61xq3ME78fdCRTDWhaKkcyuOQj24PG.jpg', 'tAdqSefI0DohNuU6wgVyPca7Qz5lYTOH.jpg', 'pTe31kYFqwyOGmV50sbhgoLQ9KcjJaxd.jpg', 'HUuwb4gRqoPWD3Lvrsa9hVcQ7FSfOT8t.jpg', 'cqkJDEpWiwS69UxFtKMPRgb4mXhj1LAs.jpg', '2pjBVbqF30cUTvRIYtsCGfgwPKOJz4ua.jpg', 'e7f1iucltpVQTXroFR96xawm2BDZYnNG.jpg', 'tWcMpXTe78zo2ikhbUOqPud6VJ5RfSEw.jpg', 'gcTxA5NwLztvWr7YPCMnDjdFyfqoa2uK.jpg', 'qbKjsR05lrFVYfLChtMGD7im36cUgAnE.jpg', 'S4hfUR5kOj7CXr826Gxa9t3bEBPioJq1.jpg', 'A8PMtHzoyFE0WgjpZ2qUYbduL4T9arxN.jpg', '8E0bSi3h1aVy2cNWpgOsKvxZCQtzqkLU.jpg', 'oDc9XxipzfBjUAEl0hOmyd4PNr5v8IsG.jpg', 'AVGJoCPsX3L6I5Y2M94kEO7vNHmt1Ure.jpg', 'SqyF8c0Rak1NedXpYvlI9TsVwzOhtGZJ.jpg', '04Iv3QNKtu2DAfRTgs9XZwBMb1Cm7l6P.jpg', '94NwrzYLo8iMtagTR3SfPGHmWvZXbyUx.jpg', 'oiU4YjnhNpI3JWagx8SuTCktA6qZXGRH.jpg', 'e416wAERYOQ7NutUJDcIVFk2oPWpC3q8.jpg', '9I3enpUrZ5xD2TvRAOFt4S7lBfVMdqsJ.jpg', 'oLNGFnUPmQhxOkdbv37HwSj8uql4z1sp.jpg', 'nHfDoId8SXKzMt1weR7bJlaWPcNx32yv.jpg', 'dzV1Psxncp6H8g4KWhX3mbrTfqwuLaNv.jpg', 'OW6e1GbpNsfmxFvLQKMnIByX4hDcS2io.jpg', 'MOGw0PDqjmnLdViez26b7WY3hU85vatH.jpg', '3R5BWakTdG2hKjJoiNxg0pr61LsDqSuM.jpg', '41OaVziAEuenpKqv7LCYMPsGH6BkQotD.jpg', 'fRbdkW0GDAhBjpTVeonPItycEia45Ns9.jpg', 'RsYG3VJi7NTXptoPQvWKhcFaDqIe4EkC.jpg', 'jWPXtA07yYrcRxNBUkwC9dpuSs2M463e.jpg', 'bqRVATEuI4x3kJSO9DitWjYms8KoG2Lz.jpg', 'OCXPGzodQsZHRnfMFaBkqW9hKYxA4glr.jpg', 'hAPzcCeE04sSadDIFB627ipyOWgjX8mq.jpg', 'PIbktpOd2DqHwVceLzUyE7CmohMjNA91.jpg', 'adAjP14SXL6vVJ95TRrMIYDiHl7BUqbF.jpg', 'hTxYnXrQ35vwKL1NEMSIot02djHy8i6D.jpg', 'kwuLVmg7n9I4iEOzMQC1NxJfvX6Bhoqs.jpg', '78WTn5auMmQshIZi2qDAYdR1oKcwEzfk.jpg', 'hOEPm8o26CBptkD5yT3fsgbM4dRuVZzw.jpg', 'PIpXbRiu68dm2s1DfxHJGAYLegOUMzoB.jpg', '8NxvitwMaCpsuEQT17nDXzFR5gAZ4rfc.jpg', 'JgxcdpvW7f6lKMShPjHFeZ2RDX13UCiN.jpg', 'BhHLRN0QTWOwl3UEG9J17XScni2P5gVe.jpg', '7Gw2o8LJTF4ecZI6nl1WuDrsAOSfQPaq.jpg', 't6xZhQkD2jWCOi1r3fK7T9slGHVbwgNd.jpg', 'gbfyYtlWaAO4iUCPK2cFVkoQX9MmwJTI.jpg', '9PQic6o4VyZ13pLAYu2avFWSbJRz57fC.jpg', 'Cf2Z3j6hYliVOduEvK5NJp4yba0wSGcA.jpg', '2QvYgMIzELXH4Fy8GNDBaPS3W0tVZ5xq.jpg', 'oFXrWl80gRMenqPbG2uZv5wk4KmHjaNd.jpg', 'l1NsjeJKdvFimRgh6IEZuqfxCw5An8o4.jpg', 'tvWgSwN9m3BZ5qOXjK7LexVCIn6F4AHk.jpg', 'ahcTZUbOmJsloVt8vGMjwPqIXd0x9iy1.jpg', 'hp0nNWXar81lB3eSYE9kGcdDL4tJfuwC.jpg', 'HNehmorRIS6M2iDj48gZL9OKva7Xck1n.jpg', '0GX4YKdcwBi15lTpR7ExWO2ZagseoVNI.jpg', 'b57SiGKYPaE1DrfxJtVeQdlOAUojLZhR.jpg', 'f4gLhHjyKdxTumna7F5pGWPqVIRY01et.jpg', 'AHxQ1GFgRLs802diT3VIlwOoqkW9Sar4.jpg', '09i4DcyrWktZb1naHFEpL5elhG3CvYxu.jpg', 'AieqOGKQC7fgDayW9kLuJ5mUHx4XP2I6.jpg', 'MwZIekE7oxPtRpTHQVf4l6qA2iC1zWLh.jpg', '6M8xAZBdQLkuDcF14HEz53J0IboiPfUa.jpg', 'eD4gfaQTFdoWUCnhPj6YmIZBl5AMxNik.jpg', 'PwBJK7rZHDhq46ynYoj9Saxip0IldMV2.jpg', 'neHaTbwPkdVmoOA3JyWxR7Lh92NDpzEf.jpg', 'E38k6xhQFYKALn4tDlwiPBfpdCygeSNs.jpg', 'N1VpzqjoRmPZCQ5KasAv72TwtMDFrG0d.jpg', 'RvXKbfDuF4W6exgVcInE3SktJa9LzBj8.jpg', 'aMoxSymjdiUwbJ6k5NzGR09uILQ4sEc3.jpg', 'l1WfIcvOZk9jAn2xQwtSCEgY5XhRyoFB.jpg', '0IWfLUGk53iHt9SElNzKsBCDwuMjpPbR.jpg', '80vskwDtCRAz9iWYjnrhIGfeXdUZx16b.jpg', '0WglfKCD5Gu2LqI1msTSZa6orBO7XAz9.jpg', 'os5kaDubPM1hY7f6gRrSOZqNQFEAU89v.jpg', 'm2azqs0NGPDdjR8rUTxWF4covLE5ikQZ.jpg', 'siKAzUrV9eykjlCQ0odZEhnW7FIgTuLm.jpg', 'jZWaf0ne2R1pDo9hTBkCbA8YOq3LlQ5x.jpg', 'Kny8zFiIt4vxNSO6g7Lu9kGfdVJoqPC3.jpg', 'LTMkHx9w2nfsRiZec3bEVtmujpv7qS1y.jpg', 'oZin4PuwTet39xWCYhUBfvlzGyISb5DV.jpg', 'oJ4HWQkZDvta8rUyinRu9fVNs3BX1Kj7.jpg', 'mUp6082yMQghXY13OtvxabTrNSEeiu4B.jpg', 'RZWpn9jGxcKSUb3Y56fVMQHlJhEIeNiA.jpg', 'L86JQlekn9Ko01TbXHYMFImdZ2upxg5h.jpg', 'I0YcgXB97QL6MtHlU3p1znqWdCGOD4mo.jpg', 'IG9NFCfMybKYiQhquOd5H3DjwlSakW6U.jpg', 'SZYosxl4cHRWyT3h5JFqNjGdnIag6907.jpg', '4dMVtGvRJbrjK6X17STZ9Lx3kgeEioqp.jpg', 'q9YrDFK73Jfv15SHpTWelGAIwnBxt0iE.jpg', 'spNU7J8uk6BXiAyQErHegYMzjOaFR2qV.jpg', 'p1ji352o6vhd8l0Q4uNVRZrIgkSLnfBq.jpg', '7WQ0ByMPtJAdZ8h5OkveLi3ScuU6bIY1.jpg', 'hajCi0GDlVP2ONg6FeSWrvubQ34ozwkx.jpg', 'RKLDkUwmFg0Oj5tPeIs31y7J8AQZ9dni.jpg', 'hWOAp1EV6nJzYxaHt03T8GPNe7ujUiF9.jpg', 'B0a2VHnwQv1byMDTlEiOJXxI7Scs9Zjf.jpg', 'gMzOoyTrGniBj1vxN0AeD9VQsFHU7aKR.jpg', 'CiBq0GVawv1rdYyLDjcWoIXP6SKbzH8F.jpg', 'G71cYNEBD6shJLkgVzwb52m8oRluKUHS.jpg', '7IdLnFCb3a25cKNV6tXuYi1fe0hJQMOU.jpg', 'jsThJuVYQxUKSz3btXdA5q8M1O9Cioaw.jpg', '2OpyK1cm85obujwEMqGWNv9V7PnQfJ3U.jpg', 'DLIZr5TjPepd8csioJXMbYHk94RmKx6v.jpg', 'Q80DEFkGlxJj2qR37t4ZKpY6zMdvuIyr.jpg', 'EFxXsVJ0qHkomcBhnLfY96W5U4yOliQG.jpg', 'okw9N05dAnsxgW2IuQy7eGhz1iLOqVrJ.jpg', 'obRL95fxtP1uCNBwQiTjWsdqUvgp2Z43.jpg', '6jTZ5sfCpGwJWIK3DaYQvixLbNt48nHr.jpg', 'gLqBoG3ah0AHXIYWS7dFTt6pxDw8snQv.jpg', '1lzs3kM8NiILvcgDYtn6fdCoSeXauJ5P.jpg', 'FEuyDnKSIJ1a6UtY5LB2rGpRqOm3xP9Z.jpg', 'OKvn2uJmWQi4R9Cs8B7fxbkZtoczq31Y.jpg', 'ByYKkZHb8omRPcvfe5GzXsxOQ3DlLuUq.jpg', 'SKoiaj8C3UGyvJQXh5zWwrxNmYkdEHqn.jpg', '2cKUvXCjm5HNWksY1b4ioIgdSFqyMtEJ.jpg', 'kJK9OA3hXpMWeUY7cifvrz0BItn1VS2T.jpg', 't1DnLxSZXwWTgeJsyE02lrjHfdM35po8.jpg', '5C76eISyb3vmPZuMYcARHU8aFQrBWf1k.jpg', 'puBcg8Fh6tXs27doz1aAIl4L0iVYC3wE.jpg', 'MV5C7YmuzG1LyZplFXvqOQkW4JStjcNP.jpg', 'ruleKNQvzwqmy5sn9MDd7I2RUJjVCWh8.jpg', 'l9Z3gPwjC5HbhINcfVO8dnz1qAxBrJkU.jpg', 'H9BcFOo8UI3jX2CyW0mzxn7agJNAsZQS.jpg', 'jHUJE37YZOGAXInPmyCSp9f0o4uvRe5W.jpg', 'I8jNkAVgZ1yqDw5K9b0Wm4rETfiGBcUF.jpg', 'kKzQrE6GjfpeFhsXx2Ddu9YaTHc3PUbB.jpg', 'jla5O2TkVhefr07XDLpMEonuG6yJWgYd.jpg', 'Km9BZsaSUoxQ4VArcXYyHThIDRbq2t7l.jpg', 'fBp0Yor4EQtWkM7I3TsnNHLXuvCFacjS.jpg']

超参数设置

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 10                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数

数据预处理

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision# 数据集目录路径
data_path_train = "data/train/"
data_path_val = "data/val/"# 创建训练数据集def create_dataset_canidae(dataset_path, usage):"""数据加载"""data_set = ds.ImageFolderDataset(dataset_path,num_parallel_workers=workers,shuffle=True,)# 数据增强操作mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]scale = 32if usage == "train":# Define map operations for training datasettrans = [vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),vision.RandomHorizontalFlip(prob=0.5),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]else:# Define map operations for inference datasettrans = [vision.Decode(),vision.Resize(image_size + scale),vision.CenterCrop(image_size),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]# 数据映射操作data_set = data_set.map(operations=trans,input_columns='image',num_parallel_workers=workers)# 批量操作data_set = data_set.batch(batch_size)return data_setdataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
print(step_size_train)
print(step_size_val)
data = next(dataset_val.create_dict_iterator())
images = data["image"]
labels = data["label"]
print("Tensor of image", images.shape)
print("Labels:", labels)

输出

96
24
Tensor of image (18, 3, 224, 224)
Labels: [ 1  2  4  3  0 10  5  4 11  9 11  6  7  1 11  5  1  3]

数据集可视化查看

import matplotlib.pyplot as plt
import numpy as np# class_name对应label,按文件夹字符串从小到大的顺序标记label
class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}plt.figure(figsize=(5, 5))
for i in range(4):# 获取图像及其对应的labeldata_image = images[i].asnumpy()data_label = labels[i]# 处理图像供展示使用data_image = np.transpose(data_image, (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])data_image = std * data_image + meandata_image = np.clip(data_image, 0, 1)# 显示图像plt.subplot(2, 2, i+1)plt.imshow(data_image)plt.title(class_name[int(labels[i].asnumpy())])plt.axis("off")plt.show()

在这里插入图片描述

网络结构搭建

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
class ResidualBlockBase(nn.Cell):expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x  # shortcuts分支out = self.conv1(x)  # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out
class ResidualBlock(nn.Cell):expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍def __init__(self, in_channel: int, out_channel: int,stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=1, weight_init=weight_init)self.norm1 = nn.BatchNorm2d(out_channel)self.conv2 = nn.Conv2d(out_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.norm2 = nn.BatchNorm2d(out_channel)self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,kernel_size=1, weight_init=weight_init)self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):identity = x  # shortscuts分支out = self.conv1(x)  # 主分支第一层:1*1卷积层out = self.norm1(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm2(out)out = self.relu(out)out = self.conv3(out)  # 主分支第三层:1*1卷积层out = self.norm3(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out
def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None  # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)
from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义,self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return xdef _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型# download(url=model_url, path=pretrianed_ckpt, replace=True)param_dict = load_checkpoint(pretrianed_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"ResNet50模型"resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

形式一:模型微调

模型训练
from mindspore import nn, train
from mindspore.nn import Loss, Accuracy
!pip install download
import mindspore as ms
from download import download
network = resnet50(pretrained=True)# 全连接层输入层的大小
in_channels = network.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 12)
# 重置全连接层
network.fc = head# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
network.avg_pool = avg_poolimport mindspore as ms
import mindspore# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=momentum)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')# 实例化模型
model = train.Model(network, loss_fn, opt, metrics={"Accuracy": Accuracy()})def forward_fn(inputs, targets):logits = network(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = mindspore.ops.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
# 最佳模型保存路径
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"
import os
import time# 开始循环训练
print("Start Training Loop ...")best_acc = 0for epoch in range(num_epochs):losses = []network.set_train()epoch_start = time.time()# 为每轮训练读入数据for i, (images, labels) in enumerate(data_loader_train):labels = labels.astype(ms.int32)loss = train_step(images, labels)losses.append(loss)# 每个epoch结束后,验证准确率acc = model.eval(dataset_val)['Accuracy']epoch_end = time.time()epoch_seconds = (epoch_end - epoch_start) * 1000step_seconds = epoch_seconds/step_size_trainprint("-" * 20)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, sum(losses)/len(losses), acc))print("epoch time: %5.3f ms, per step time: %5.3f ms" % (epoch_seconds, step_seconds))if acc > best_acc:best_acc = accif not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)ms.save_checkpoint(network, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)

输出

Start Training Loop ...
--------------------
Epoch: [  1/ 10], Average Train Loss: [1.774], Accuracy: [0.838]
epoch time: 60892.337 ms, per step time: 634.295 ms
--------------------
Epoch: [  2/ 10], Average Train Loss: [0.762], Accuracy: [0.905]
epoch time: 8745.406 ms, per step time: 91.098 ms
--------------------
Epoch: [  3/ 10], Average Train Loss: [0.568], Accuracy: [0.921]
epoch time: 8449.129 ms, per step time: 88.012 ms
--------------------
Epoch: [  4/ 10], Average Train Loss: [0.508], Accuracy: [0.910]
epoch time: 8199.763 ms, per step time: 85.414 ms
--------------------
Epoch: [  5/ 10], Average Train Loss: [0.459], Accuracy: [0.900]
epoch time: 7856.060 ms, per step time: 81.834 ms
--------------------
Epoch: [  6/ 10], Average Train Loss: [0.405], Accuracy: [0.931]
epoch time: 8138.927 ms, per step time: 84.780 ms
--------------------
Epoch: [  7/ 10], Average Train Loss: [0.368], Accuracy: [0.919]
epoch time: 8333.523 ms, per step time: 86.808 ms
--------------------
Epoch: [  8/ 10], Average Train Loss: [0.354], Accuracy: [0.912]
epoch time: 8271.008 ms, per step time: 86.156 ms
--------------------
Epoch: [  9/ 10], Average Train Loss: [0.338], Accuracy: [0.928]
epoch time: 8457.969 ms, per step time: 88.104 ms
--------------------
Epoch: [ 10/ 10], Average Train Loss: [0.338], Accuracy: [0.907]
epoch time: 8183.743 ms, per step time: 85.247 ms
================================================================================
End of validation the best Accuracy is:  0.931, save the best ckpt file in ./BestCheckpoint/resnet50-best.ckpt
模型评估
import matplotlib.pyplot as plt
import mindspore as msdef visualize_model(best_ckpt_path, val_ds):net = resnet50()# 全连接层输入层的大小in_channels = net.fc.in_channels# 输出通道数大小为分类数12head = nn.Dense(in_channels, 12)# 重置全连接层net.fc = head# 平均池化层kernel size为7avg_pool = nn.AvgPool2d(kernel_size=7)# 重置平均池化层net.avg_pool = avg_pool# 加载模型参数param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)model = train.Model(net)#print(net)# 加载验证集的数据进行验证data = next(val_ds.create_dict_iterator())images = data["image"].asnumpy()print(type(images))print(images.shape)#print(images)labels = data["label"].asnumpy()#print(labels)class_name = {0: "0", 1: "1",2: "2", 3: "3",4: "4", 5: "5",6: "6", 7: "7",8: "8", 9: "9",10: "10", 11: "11",12: "12"}# 预测图像类别data_pre=ms.Tensor(data["image"])print(data_pre.shape)print(type(data_pre))output = model.predict(data_pre)#print(output)pred = np.argmax(output.asnumpy(), axis=1)# 显示图像及图像的预测值plt.figure(figsize=(5, 5))for i in range(4):plt.subplot(2, 2, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()
visualize_model('BestCheckpoint/resnet50-best.ckpt', dataset_val)

输出
在这里插入图片描述

形式二:固定特征训练

模型训练
net_work = resnet50(pretrained=True)
# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为分类数12
head = nn.Dense(in_channels, 12)
# 重置全连接层
net_work.fc = head
# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool
# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():if param.name not in ["fc.weight", "fc.bias"]:param.requires_grad = False
# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
def forward_fn(inputs, targets):logits = net_work(inputs)loss = loss_fn(logits, targets)return loss
grad_fn = ms.ops.value_and_grad(forward_fn, None, opt.parameters)
def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss
# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": Accuracy()})
dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()
dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()
num_epochs = 10
# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best-freezing-param.ckpt"
# 开始循环训练
print("Start Training Loop ...")
best_acc = 0
for epoch in range(num_epochs):losses = []net_work.set_train()epoch_start = time.time()# 为每轮训练读入数据for i, (images, labels) in enumerate(data_loader_train):labels = labels.astype(ms.int32)loss = train_step(images, labels)losses.append(loss)# 每个epoch结束后,验证准确率acc = model1.eval(dataset_val)['Accuracy']epoch_end = time.time()epoch_seconds = (epoch_end - epoch_start) * 1000step_seconds = epoch_seconds/step_size_trainprint("-" * 20)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, sum(losses)/len(losses), acc))print("epoch time: %5.3f ms, per step time: %5.3f ms" % (epoch_seconds, step_seconds))if acc > best_acc:best_acc = accif not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)ms.save_checkpoint(net_work, best_ckpt_path)
print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)
模型评估
visualize_model(best_ckpt_path, dataset_val)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/726685.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么?!你居然连个内存泄漏都排查不出来

公众号:程序员白特,欢迎一起交流学习~ 在日常的业务开发中,偶尔会出现内存泄漏的情况,那么我们该怎么排查呢?现在跟着文章一起学习下吧~ 使用Chrome devTools查看内存情况 打开Chrome的无痕模式,以屏蔽Ch…

k个一组反转链表

题目 题目链接 . - 力扣&#xff08;LeetCode&#xff09; 题目描述 代码实现 class Solution { public:ListNode* reverseKGroup(ListNode* head, int k) {if(k 1) return head;//特殊情况ListNode *cur head;for(int i 1; i < k; i){if(cur nullptr ||cur->nex…

Seurat 中的数据可视化方法

本文[1]将使用从 2,700 PBMC 教程计算的 Seurat 对象来演示 Seurat 中的可视化技术。您可以从 SeuratData[2] 下载此数据集。 SeuratData::InstallData("pbmc3k")library(Seurat)library(SeuratData)library(ggplot2)library(patchwork)pbmc3k.final <- LoadData(…

【wine】解决 0024:fixme:msctf:KeystrokeMgr_TestKeyUp STUB:(00A3D508)

故障日志 0024:fixme:msctf:KeystrokeMgr_TestKeyUp STUB:(00A3D508) AI分析 这些消息表示Wine对IE内核组件以及IME&#xff08;Input Method Editor&#xff0c;输入法编辑器&#xff09;的支持不完全。特别是涉及文本输入、拖放事件、属性变化通知等功能。 解决 winetrick…

【论文阅读】单词级文本攻击TAAD2.2

TAAD2.2论文概览 0.前言1-101.Bridge the Gap Between CV and NLP! A Gradient-based Textual Adversarial Attack Frameworka. 背景b. 方法c. 结果d. 论文及代码 2.TextHacker: Learning based Hybrid Local Search Algorithm for Text Hard-label Adversarial Attacka. 背景b…

python爬虫(一)

一、python中的NumPy模块&#xff08;数据的存储和处理&#xff09; 这里是下载完成之后的表现 &#xff08;1&#xff09;创建数组 1、使用array&#xff08;&#xff09;函数创建数组 使用array函数可以创建任意维度的的数组 下面是一个创建二维数组的代码示例 下面是代码…

java集合(泛型数据结构)

1.泛型 1.1泛型概述 泛型的介绍 泛型是JDK5中引入的特性&#xff0c;它提供了编译时类型安全检测机制 泛型的好处 把运行时期的问题提前到了编译期间 避免了强制类型转换 泛型的定义格式 <类型>: 指定一种类型的格式.尖括号里面可以任意书写,一般只写一个字母.例如: …

【力扣 - 三数之和】

题目描述 给你一个整数数组 nums &#xff0c;判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k &#xff0c;同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三元组。 注意&#xff1a;答案中不可以包含重复的三元组。…

PostgreSQL开发与实战(6.2)体系结构2

作者&#xff1a;太阳 二、逻辑架构 graph TD A[database] -->B(schema) B -->C[表] B -->D[视图] B -->E[触发器] C -->F[索引] tablespace 三、内存结构 Postgres内存结构主要分为 共享内存 与 本地内存 两部分。共享内存为所有的 background process提供内…

excel中去除公式,仅保留值

1.单个单元格去除公式 双击单元格&#xff0c;按F9. 2.批量去除公式 选中列然后复制&#xff0c;选择性粘贴&#xff0c;选值粘贴

windows server 2019 激活系统时点击“更改产品密钥”无反应的解决方案

一、问题现象 点击“更改产品密钥”没反应。 二、解决方案 使用slmgr命令&#xff1a; 打开命令提示符&#xff08;管理员&#xff09;&#xff0c;然后尝试使用slmgr命令来手动输入密钥和激活Windows。例如&#xff1a; slmgr.vbs /ipk <您的产品密钥>slmgr.vbs /ato 备…

软件测试技术分享 | 测试环境搭建

被测系统的环境搭建&#xff0c;是我们作为软件测试人员需要掌握的技能。 被测系统AUT (Application Under Test) 常见的被测系统即需要被测试的 app&#xff0c;网页和后端服务。大致分为两个方面移动端测试和服务端测试&#xff0c;如下图所示&#xff1a; 常见的被测系统类…

3、Redis Cluster集群运维与核心原理剖析

Redis集群方案比较 哨兵模式 在redis3.0以前的版本要实现集群一般是借助哨兵sentinel工具来监控master节点的状态&#xff0c;如果master节点异常&#xff0c;则会做主从切换&#xff0c;将某一台slave作为master&#xff0c;哨兵的配置略微复杂&#xff0c;并且性能和高可用性…

【C语言】冒泡排序

概念 冒泡排序&#xff08;Bubble Sort&#xff09;是一种简单的排序算法&#xff0c;它重复地遍历要排序的列表&#xff0c;一次比较两个元素&#xff0c;并且如果它们的顺序错误就把它们交换过来。通过多次的遍历和比较&#xff0c;最大&#xff08;或最小&#xff09;的元素…

数智化转型的新篇章:企业如何在「数据飞轮」理念中寻求增长?_光点科技

在当今的数字化浪潮中&#xff0c;企业对数据的渴求与日俱增。数据不再仅是辅助决策的工具&#xff0c;而是成为推动业务增长的核心动力。自从「数据中台」概念降温后&#xff0c;企业纷纷探寻新的数智化路径。在这个过程中&#xff0c;「数据飞轮」作为一种新兴的理念&#xf…

Blazor系统教程(.net8)

Blazor系统教程 1.认识 Blazor 简单来讲&#xff0c;Blazor旨在使用C#来替代JavaScript的Web应用程序的UI框架。其主要优势有&#xff1a; 使用C#编写代码&#xff0c;这可提高应用开发和维护的效率利用现有的NET库生态系统受益于NET的性能、可靠性和安全性与新式托管平台(如…

第三方软件测试报告有效期是多久?专业软件测试报告获取

第三方软件测试报告是在软件开发过程中&#xff0c;由独立的第三方机构对软件进行全面测试和评估后发布的报告。这些第三方机构通常是与软件开发商和用户无关的专业技术机构&#xff0c;具备丰富的测试经验和专业知识。    第三方测试报告具有以下几个好处&#xff1a;   …

阿里云Linux系统MySQL8忘记密码修改密码

相关版本 操作系统&#xff1a;Alibaba Cloud Linux 3.2104 LTS 64位MySQL&#xff1a;mysql Ver 8.0.34 for Linux on x86_64 (Source distribution) MySQL版本可通过下方命令查询 mysql --version一、修改my.cnf文件 文件位置&#xff1a;etc/my.cnf进入远程连接后可以打…

落地灯哪个牌子好?实机测评喜爱度爆表的五款落地灯!

近些年来&#xff0c;由于使用电子产品以及学习压力大的人越来越多&#xff0c;而且越加年轻化&#xff0c;而平时用眼时的不良光线影响着人们的视力健康&#xff0c;不少眼科专家都推荐使用能够带来更好光线效果的落地灯&#xff0c;对此&#xff0c;作为专业的电器测评员&…

Pygame教程05:帧动画原理+边界值检测,让小球来回上下运动

------------★Pygame系列教程★------------ Pygame教程01&#xff1a;初识pygame游戏模块 Pygame教程02&#xff1a;图片的加载缩放旋转显示操作 Pygame教程03&#xff1a;文本显示字体加载transform方法 Pygame教程04&#xff1a;draw方法绘制矩形、多边形、圆、椭圆、弧…