昇思25天学习打卡营第11天 | mindspore 实现 ResNet 50 迁移学习

1. 背景:

使用 mindspore 学习神经网络,打卡第 11 天;主要内容也依据 mindspore 的学习记录。

2. 迁移学习介绍:

mindspore 实现 ResNet 50 迁移学习; 具体 ResNet 50 的模型原理以及实现,可以参考本博客的 ResNet50 分类;

  • 迁移学习背景:
    把已训练好的模型参数迁移到新的模型来帮助新模型训练;

  • 原因:
    a. 避免从 0 开始重复造轮子;
    b. 减少训练成本;如果采用导出特征向量的方法进行迁移学习,后期的训练成本非常低,用CPU都完全无压力,没有深度学习机器也可以做。
    c. 适用于小数据集:对于数据集本身很小(几千张图片)的情况,从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。这时候如果还想用上大型神经网络的超强特征提取能力,只能靠迁移学习。

  • 迁移学习的方式:
    a. 冻结预训练模型的全部卷积层,只训练自己定制的全连接层
    b. 冻结预训练模型的部分卷积层(通常是靠近输入的多数卷积层),训练剩下的卷积层(通常是靠近输出的部分卷积层)和全连接层。

3. 代码实现:

3.1 数据下载:

数据以 ImageNet 的狗和狼分类为例;

from download import downloaddataset_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/intermediate/Canidae_data.zip"download(dataset_url, "./datasets-Canidae", kind="zip", replace=True)

3.2 数据前处理:

加载数据集并进行图像增强操作

batch_size = 18                             # 批量大小
image_size = 224                            # 训练图像空间大小
num_epochs = 5                             # 训练周期数
lr = 0.001                                  # 学习率
momentum = 0.9                              # 动量
workers = 4                                 # 并行线程个数import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision# 数据集目录路径
data_path_train = "./datasets-Canidae/data/Canidae/train/"
data_path_val = "./datasets-Canidae/data/Canidae/val/"# 创建训练数据集def create_dataset_canidae(dataset_path, usage):"""数据加载"""data_set = ds.ImageFolderDataset(dataset_path,num_parallel_workers=workers,shuffle=True,)# 数据增强操作mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]scale = 32if usage == "train":# Define map operations for training datasettrans = [vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),vision.RandomHorizontalFlip(prob=0.5),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]else:# Define map operations for inference datasettrans = [vision.Decode(),vision.Resize(image_size + scale),vision.CenterCrop(image_size),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]# 数据映射操作data_set = data_set.map(operations=trans,input_columns='image',num_parallel_workers=workers)# 批量操作data_set = data_set.batch(batch_size)return data_setdataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()

3.3 构建 ResNet50 网络:

构建 ResNet50 网络:

from typing import Type, Union, List, Optional
from mindspore import nn, train
from mindspore.common.initializer import Normalweight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class ResidualBlockBase(nn.Cell):expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x  # shortcuts分支out = self.conv1(x)  # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return outclass ResidualBlock(nn.Cell):expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍def __init__(self, in_channel: int, out_channel: int,stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=1, weight_init=weight_init)self.norm1 = nn.BatchNorm2d(out_channel)self.conv2 = nn.Conv2d(out_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.norm2 = nn.BatchNorm2d(out_channel)self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,kernel_size=1, weight_init=weight_init)self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):identity = x  # shortscuts分支out = self.conv1(x)  # 主分支第一层:1*1卷积层out = self.norm1(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm2(out)out = self.relu(out)out = self.conv3(out)  # 主分支第三层:1*1卷积层out = self.norm3(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return outdef make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None  # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义,self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return xdef _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型download(url=model_url, path=pretrianed_ckpt, replace=True)param_dict = load_checkpoint(pretrianed_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"ResNet50模型"resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

3.4 固定特征进行训练:

使用固定特征进行训练的时候,需要冻结除最后一层之外的所有网络层。通过设置 requires_grad == False 冻结参数,以便不在反向传播中计算梯度。

import mindspore as ms
import matplotlib.pyplot as plt
import os
import timenet_work = resnet50(pretrained=True)# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 2)
# 重置全连接层
net_work.fc = head# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():if param.name not in ["fc.weight", "fc.bias"]:param.requires_grad = False# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = net_work(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

3.5 模型训练与评估:

开始训练模型,与没有预训练模型相比,将节约一大半时间,因为此时可以不用计算部分梯度

import mindspore as ms
import matplotlib.pyplot as plt
import os
import time
dataset_train = create_dataset_canidae(data_path_train, "train")
step_size_train = dataset_train.get_dataset_size()dataset_val = create_dataset_canidae(data_path_val, "val")
step_size_val = dataset_val.get_dataset_size()num_epochs = 5# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best-freezing-param.ckpt"import mindspore as ms
import matplotlib.pyplot as plt
import os
import time
# 开始循环训练
print("Start Training Loop ...")best_acc = 0for epoch in range(num_epochs):losses = []net_work.set_train()epoch_start = time.time()# 为每轮训练读入数据for i, (images, labels) in enumerate(data_loader_train):labels = labels.astype(ms.int32)loss = train_step(images, labels)losses.append(loss)# 每个epoch结束后,验证准确率acc = model1.eval(dataset_val)['Accuracy']epoch_end = time.time()epoch_seconds = (epoch_end - epoch_start) * 1000step_seconds = epoch_seconds/step_size_trainprint("-" * 20)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, sum(losses)/len(losses), acc))print("epoch time: %5.3f ms, per step time: %5.3f ms" % (epoch_seconds, step_seconds))if acc > best_acc:best_acc = accif not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)ms.save_checkpoint(net_work, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)

3.6 可视化模型预测:

使用固定特征得到的best.ckpt文件对对验证集的狼和狗图像数据进行预测。若预测字体为蓝色即为预测正确,若预测字体为红色则预测错误。

import matplotlib.pyplot as plt
import mindspore as msdef visualize_model(best_ckpt_path, val_ds):net = resnet50()# 全连接层输入层的大小in_channels = net.fc.in_channels# 输出通道数大小为狼狗分类数2head = nn.Dense(in_channels, 2)# 重置全连接层net.fc = head# 平均池化层kernel size为7avg_pool = nn.AvgPool2d(kernel_size=7)# 重置平均池化层net.avg_pool = avg_pool# 加载模型参数param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)model = train.Model(net)# 加载验证集的数据进行验证data = next(val_ds.create_dict_iterator())images = data["image"].asnumpy()labels = data["label"].asnumpy()class_name = {0: "dogs", 1: "wolves"}# 预测图像类别output = model.predict(ms.Tensor(data['image']))pred = np.argmax(output.asnumpy(), axis=1)# 显示图像及图像的预测值plt.figure(figsize=(5, 5))for i in range(4):plt.subplot(2, 2, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()

4. 相关链接:

  • https://xihe.mindspore.cn/events/mindspore-training-camp
  • mindspore - ResNet50 迁移学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/47092.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

谈谈我对李彦宏说的“不要卷模型,要卷应用”的理解

我理解李彦宏的发言是强调人工智能技术应该关注应用而不是仅仅关注模型的发展。他认为AI技术已经从辨别式(discriminative)转向了生成式(generative),意味着AI不再仅仅是识别和分类问题,而是能够创造新的内…

设计模式-领域逻辑模式-事务脚本(Transaction Script)

事务脚本的特点 多数应用可看成由多个事务组成事务脚本将多个业务逻辑组织成单个过程事务间相互修改各自产生的数据 事务脚本的运行机制 使用事务脚本时,领域逻辑主要通过系统所执行的事务来组织。例如:预定酒店过程。 事务脚本的组织 将整个事务脚本放…

【BUG】已解决:IndexError: list index out of range

已解决:IndexError: list index out of range 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司,热衷分享知识,武汉城市开发者社区…

微信小程序:2.全局开发

app实例 简介 app.js中注册小程序实例的方法App拥有生命周期回调函数、错误监听函数、页面不存在监听函数等 生命周期回调函数 onLaunch(options) {//监听小程序初始化 console.log("监听小程序初始化",options); }, onShow (options) {//监听小程序启动或切前台…

Ubuntu串口调试单片机

来自🥬🐶程序员 Truraly | 田园 的博客,最新文章首发于:田园幻想乡 | 原文链接 | github (欢迎关注) 文章目录 brltty 导致 USB 转串口连接失败串口调试工具直接操作串口puttySerial Monitor | vscode 插件minicom(推荐)舍友在搞 c-sky 的单片机,囚来了一块玩玩,尝…

【C++】C++中的getcwd函数详解

目录 一.getcwd函数是什么 二.getcwd函数怎么用 一.getcwd函数是什么 在C中&#xff0c; getcwd 是一个用于获取当前工作目录的函数&#xff0c;它是POSIX标准的一部分&#xff0c;定义在 <unistd.h> 头文件中&#xff08;在Windows上&#xff0c;它定义在 <direct…

对服务器进行基本了解(二)

目录 一. 云服务器数据库 1.查看MYSQL版本 2.查看mysql的运行状态 3.运行mysql 4. 进入mysql的用户 5. 更改用户密码 6. 查找mysql端口号 7. 创建一个数据库 8. 查看用户 9. 查看数据库 10. 显示数据库的表 11. 修改用户的host 12. 对用户赋权 13. 开放指定端…

python程序设定定时任务

在 Windows 系统上,您可以使用任务计划程序(Task Scheduler)来设置定时任务,执行 Python 文件。以下是具体步骤: 步骤 1:准备 Python 文件 假设有一个名为 script.py 的 Python 脚本。确保它可以在命令行中正确运行。 步骤2:找到Python可执行文件的位置 知道Python可…

【学习笔记】无人机系统(UAS)的连接、识别和跟踪(一)-3GPP TS 23.256 技术规范概述

3GPP TS 23.256 技术规范&#xff0c;主要定义了3GPP系统对无人机&#xff08;UAV&#xff09;的连接性、身份识别、跟踪及A2X&#xff08;Aircraft-to-Everything&#xff09;服务的支持。 3GPP TS 23.256 技术规范&#xff1a; 以下是文档的核心内容总结&#xff1a; UAV系…

nginx的access.log日志输出请求数

适用格式 #log_format main $remote_addr - $remote_user [$time_local] "$request" # $status $body_bytes_sent "$http_referer" # "$http_user_agent" "$http_x_forwarded_for"; 形如&#…

自然语言处理中的本体/分类/同义相似

本体相似、分类相似与同义相似这三个概念都是在信息检索和自然语言处理中经常用到的相似性度量方法&#xff0c;它们的区别如下&#xff1a; 本体相似&#xff1a;本体是领域内的概念及其关系的模型&#xff0c;本体相似性度量是指计算两个本体之间的相似程度。本体相似性度量通…

【JVM实战篇】内存调优:内存问题诊断+案例实战

文章目录 诊断内存快照在内存溢出时生成内存快照MAT分析内存快照MAT内存泄漏检测的原理支配树介绍如何在不内存溢出情况下生成堆内存快照&#xff1f;MAT查看支配树MAT如何根据支配树发现内存泄漏 运行程序的内存快照导出和分析快照**大文件的处理** 案例实战案例1&#xff1a;…

CH390H+STM32F1+LWIP

文章目录 1、CH390芯片介绍2、电路部分3、LWIP调试3.1修改点13.2 修改点2 4、结果展示参考 1、CH390芯片介绍 官网地址&#xff1a; 南京沁恒微电子股份有限公司 特点&#xff1a; 2、电路部分 CH390及接口&#xff1a; STM32F1引脚&#xff1a; 不含LWIP的demo及LWIP…

【WPF】图片剪裁-ImageCropping

【WPF】图片剪裁-ImageCropping 背景技术栈实现思路核心代码界面布局Style处理逻辑使用技巧预览下载背景 机缘巧合吧,当时在全网寻找图像剪裁工具,但大都不能满足需求,于是决定动手写。当然如果只是为了完成这么一个功能就没有必要记录了,主要是不依赖与第三方图像库,且实…

【数据结构】二叉树全攻略,从实现到应用详解

​ &#x1f48e;所属专栏&#xff1a;数据结构与算法学习 &#x1f48e; 欢迎大家互三&#xff1a;2的n次方_ ​ &#x1f341;1. 树形结构的介绍 树是一种非线性的数据结构&#xff0c;它是由n&#xff08;n>0&#xff09;个有限结点组成一个具有层次关系的集合。把它叫做…

Docker Machine 深入解析

Docker Machine 深入解析 引言 Docker Machine 是 Docker 生态系统中的一个重要工具,它简化了 Docker 容器环境的配置和管理过程。本文将深入探讨 Docker Machine 的概念、功能、使用场景以及如何在实际环境中高效利用它。 什么是 Docker Machine? Docker Machine 是一个…

Qt纯代码绘制一个等待提示Ui控件

等待样式控件是我们在做UI时出场率还挺高的控件之一&#xff0c;通常情况下有如下的几种实现方式&#xff1a;1、自定义绘图&#xff0c;然后重写paintEvent函数&#xff0c;在paintEvent中绘制等待图标&#xff0c;通过QTimer更新绘制达到转圈圈的效果。2、 获取一张gif的资源…

SpringBoot下的定时魔法:揭秘@Scheduled注解的无限可能

在这个快节奏的时代&#xff0c;自动化与定时任务成为了提升效率的不二法门。而在Java的Spring Boot框架中&#xff0c;Scheduled注解就像是一位精通时间魔法的巫师&#xff0c;悄无声息地让你的应用按部就班地执行着各种定时任务。今天&#xff0c;就让我们一起揭开它的神秘面…

Ubuntu上安装配置samba服务

Ubuntu上安装配置samba服务 在Ubuntu中安装配置samba共享服务&#xff0c;可以让你在网络上共享文件和打印机。以下是一个相对详细的步骤指南&#xff0c;介绍如何在Ubuntu上安装和配置Samba。 1. 安装Samba 首先&#xff0c;需要安装Samba软件包。打开终端并运行以下命令&a…

Gocator Acquisition for Cognex VisionPro(LMI相机图像获取)

概述 VisionPro 是个很强大的视觉软件, 我们很乐意我们的客户在VisionPro 环境中使用Gocator产品。 实现方法 在 VisionPro 环境下配置 Gocator 产品两种方法: ● 方法一: 创建一个 QuickBuild Job,在 Job 编辑器添加 Job Script,插入 Gocator 的 SDK,编辑简 单脚本就 OK。 …