昇思25天学习打卡营第13天|应用实践之ResNet50迁移学习

基本介绍

        今日的应用实践的模型是计算机实践领域中十分出名的模型----ResNet模型。ResNet是一种残差网络结构,它通过引入“残差学习”的概念来解决随着网络深度增加时训练困难的问题,从而能够训练更深的网络结构。现很多网络极深的模型或多或少都受此影响。今日的主要任务是使用ResNet50进行迁移学习,所谓的迁移学习是不从头到尾训练一个模型,而是在一个特别大的数据集上面训练得到一个预训练模型,然后使用该模型的权重作为初始化参数,最后应用于特定的任务。对特定的任务来说也可以对模型进行训练,但是需要冻结模型的某些参数,一般只训练模型的分类器部分,不会训练特征提取部分。下面我们详细讲讲今日的应用实践。

数据集准备

        本次使用的数据集是来自ImageNet数据集中抽取出来的狼狗数据集,每个类别大概有120张训练图像与30张验证图像,数据集可通过华为云OBS和相关API直接下载,然后进行加载。可借助MindSpore提供的API对数据集进行加载,同时,因为数据集在送入模型之前需做些处理,这些可封装到一个函数内,具体代码如下:

def create_dataset_canidae(dataset_path, usage):"""数据加载"""data_set = ds.ImageFolderDataset(dataset_path,num_parallel_workers=workers,shuffle=True,)# 数据增强操作mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]std = [0.229 * 255, 0.224 * 255, 0.225 * 255]scale = 32if usage == "train":# Define map operations for training datasettrans = [vision.RandomCropDecodeResize(size=image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),vision.RandomHorizontalFlip(prob=0.5),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]else:# Define map operations for inference datasettrans = [vision.Decode(),vision.Resize(image_size + scale),vision.CenterCrop(image_size),vision.Normalize(mean=mean, std=std),vision.HWC2CHW()]# 数据映射操作data_set = data_set.map(operations=trans,input_columns='image',num_parallel_workers=workers)# 批量操作data_set = data_set.batch(batch_size)return data_set

通过上述操作后,我们可以很方便的调用数据集,无论是进行训练还是可视化,都可以。

模型搭建

        准备好数据集后,自然就是进行模型搭建,ResNet50是一个非常常见的模型,具体模型结构和不同AI框架下的代码都很容易获取,MindSpore官方也有相关的实现代码,我们直接使用,模型的代码如下:

class ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义,self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return xdef _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrianed_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型download(url=model_url, path=pretrianed_ckpt, replace=True)param_dict = load_checkpoint(pretrianed_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"ResNet50模型"resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

训练模型

        我们采用定特征进行训练,需要冻结除最后一层之外的所有网络层,最后一层其实就是一个分类层,前面是特征提取层,MindSpore可以通过设置 requires_grad == False 冻结参数,以便不在反向传播中计算梯度,具体代码如下:

import mindspore as ms
import matplotlib.pyplot as plt
import os
import timenet_work = resnet50(pretrained=True)# 全连接层输入层的大小
in_channels = net_work.fc.in_channels
# 输出通道数大小为狼狗分类数2
head = nn.Dense(in_channels, 2)
# 重置全连接层
net_work.fc = head# 平均池化层kernel size为7
avg_pool = nn.AvgPool2d(kernel_size=7)
# 重置平均池化层
net_work.avg_pool = avg_pool# 冻结除最后一层外的所有参数
for param in net_work.get_parameters():if param.name not in ["fc.weight", "fc.bias"]:param.requires_grad = False# 定义优化器和损失函数
opt = nn.Momentum(params=net_work.trainable_params(), learning_rate=lr, momentum=0.5)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = net_work(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss# 实例化模型
model1 = train.Model(net_work, loss_fn, opt, metrics={"Accuracy": train.Accuracy()})

一切准备妥当后,便可以进行训练,由于有预训练模型,模型的训练速度非常快,比从头到尾训练快了好几倍。训练了5轮,效果就非常好了:

模型可视化预测

        有了训练好的模型,自然要看看训练得好不好,除了看评价指标,最直观的就是实际使用一下模型,由于这是图像分类任务,所以将其可视化。这一整个流程的代码如下:

def visualize_model(best_ckpt_path, val_ds):net = resnet50()# 全连接层输入层的大小in_channels = net.fc.in_channels# 输出通道数大小为狼狗分类数2head = nn.Dense(in_channels, 2)# 重置全连接层net.fc = head# 平均池化层kernel size为7avg_pool = nn.AvgPool2d(kernel_size=7)# 重置平均池化层net.avg_pool = avg_pool# 加载模型参数param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)model = train.Model(net)# 加载验证集的数据进行验证data = next(val_ds.create_dict_iterator())images = data["image"].asnumpy()labels = data["label"].asnumpy()class_name = {0: "dogs", 1: "wolves"}# 预测图像类别output = model.predict(ms.Tensor(data['image']))pred = np.argmax(output.asnumpy(), axis=1)# 显示图像及图像的预测值plt.figure(figsize=(5, 5))for i in range(4):plt.subplot(2, 2, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels[i] else 'red'plt.title('predict:{}'.format(class_name[pred[i]]), color=color)picture_show = np.transpose(images[i], (1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()

调用该函数后,可视化结果如下:可以看出还是很准的

Jupyter运行情况

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/43404.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据链路层(超详细)

引言 数据链路层是计算机网络协议栈中的第二层,位于物理层之上,负责在相邻节点之间的可靠数据传输。数据链路层使用的信道主要有两种类型:点对点信道和广播信道。点对点信道是指一对一的通信方式,而广播信道则是一对多的通信方式…

风险评估:Tomcat的安全配置,Tomcat安全基线检查加固

「作者简介」:冬奥会网络安全中国代表队,CSDN Top100,就职奇安信多年,以实战工作为基础著作 《网络安全自学教程》,适合基础薄弱的同学系统化的学习网络安全,用最短的时间掌握最核心的技术。 这一章节我们需…

grafana数据展示

目录 一、安装步骤 二、如何添加喜欢的界面 三、自动添加注册客户端主机 一、安装步骤 启动成功后 可以查看端口3000是否启动 如果启动了就在浏览器输入IP地址:3000 账号密码默认是admin 然后点击 log in 第一次会让你修改密码 根据自定义密码然后就能登录到界面…

高职物联网实训室

一、高职物联网实训室建设背景 随着《中华人民共和国国民经济和社会发展第十四个五年规划和2035年远景目标纲要》的发布,中国正式步入加速数字化转型的新时代。在数字化浪潮中,物联网技术作为连接物理世界与数字世界的桥梁,其重要性日益凸显…

Golang | Leetcode Golang题解之第224题基本计算器

题目&#xff1a; 题解&#xff1a; func calculate(s string) (ans int) {ops : []int{1}sign : 1n : len(s)for i : 0; i < n; {switch s[i] {case :icase :sign ops[len(ops)-1]icase -:sign -ops[len(ops)-1]icase (:ops append(ops, sign)icase ):ops ops[:len(o…

2024年有多少程序员转行了?

疫情后大环境下行&#xff0c;各行各业的就业情况都是一言难尽。互联网行业更是极不稳定&#xff0c;频频爆出裁员的消息。大家都说2024年程序员的就业很难&#xff0c;都很焦虑。 在许多人眼里&#xff0c;程序员可能是一群背着电脑、进入高大上写字楼的职业&#xff0c;他们…

Datawhale AI 夏令营 机器学习挑战赛

一、赛事背景 在当今科技日新月异的时代&#xff0c;人工智能&#xff08;AI&#xff09;技术正以前所未有的深度和广度渗透到科研领域&#xff0c;特别是在化学及药物研发中展现出了巨大潜力。精准预测分子性质有助于高效筛选出具有优异性能的候选药物。以PROTACs为例&#x…

Hi3861 OpenHarmony嵌入式应用入门--MQTT

MQTT 是机器对机器(M2M)/物联网(IoT)连接协议。它被设计为一个极其轻量级的发布/订阅消息传输 协议。对于需要较小代码占用空间和/或网络带宽非常宝贵的远程连接非常有用&#xff0c;是专为受限设备和低带宽、 高延迟或不可靠的网络而设计。这些原则也使该协议成为新兴的“机器…

AutoMQ 生态集成 Kafdrop-ui

Kafdrop [1] 是一个为 Kafka 设计的简洁、直观且功能强大的Web UI 工具。它允许开发者和管理员轻松地查看和管理 Kafka 集群的关键元数据&#xff0c;包括主题、分区、消费者组以及他们的偏移量等。通过提供一个用户友好的界面&#xff0c;Kafdrop 大大简化了 Kafka 集群的监控…

量产工具一一UI系统(四)

目录 前言 一、按钮数据结构抽象 1.ui.h 二、按键处理 1.button.c 2.disp_manager.c 3.disp_manager.h 三、单元测试 1.ui_test.c 2.上机测试 前言 前面我们实现了显示系统框架&#xff0c;输入系统框架和文字系统框架&#xff0c;链接&#xff1a; 量产工具一一显…

集控中心操作台材质选择如何选择

作为集控中心的核心组成部分&#xff0c;操作台不仅承载着各种设备和工具&#xff0c;更是工作人员进行监控、操作和管理的重要平台。因此&#xff0c;选择适合的集控中心操作台材质显得尤为重要。 一、材质选择的考量因素 在选择集控中心操作台材质时&#xff0c;我们需要综合…

SpringCloud跨微服务的远程调用,如何发起网络请求,RestTemplate

在我们的业务流程之中不一定都会是自己模块查询自己模块的信息&#xff0c;有些时候就需要去结合其他模块的信息来进行一些查询完成相应的业务流程&#xff0c;但是在SpringCloud每个模块都相对独立&#xff0c;数据库也有数据隔离。所以当我们需要其他微服务模块的信息的时候&…

前端javascript中的排序算法之选择排序

选择排序&#xff08;Selection Sort&#xff09;基本思想&#xff1a; 是一种原址排序法&#xff1b; 将数组分为两个区间&#xff1a;左侧为已排序区间&#xff0c;右侧为未排序区间。每趟从未排序区间中选择一个值最小的元素&#xff0c;放到已排序区间的末尾&#xff0c;从…

小米采取措施禁止国行版设备安装国际版系统 刷机后将报错无法进入系统

据知名官改版系统 Xiaomi.EU 测试者 Kacper Skrzypek 发布的消息&#xff0c;小米目前已经在开机引导中新增区域检测机制&#xff0c;该机制将识别硬件所属的市场版本&#xff0c;例如中国大陆市场销售的小米即将在安装国际版系统后将无法正常启动。 测试显示该检测机制是在开…

1.DDR3 SO-DIMM 内存条硬件总结

最近在使用fpga读写DDR3&#xff0c;板子上的DDR3有两种形式与fpga相连&#xff0c;一种是直接用ddr3内存颗粒&#xff0c;另一种是通过内存条的形式与fpga相连。这里我们正好记录下和ddr3相关的知识&#xff0c;先从DDR3 SO-DIMM 内存条开始。 1.先看内存条的版本 从JEDEC下载…

《算法笔记》总结No.5——递归

一.分而治之 将原问题划分为若干个规模较小而结构与原问题相同或相似的子问题&#xff0c;然后分别解决这些子问题&#xff0c;最后合并子问题的解&#xff0c;即可得到原问题的解&#xff0c;步骤抽象如下&#xff1a; 分解&#xff1a;将原问题分解为若干子问题解决&#x…

用VLM训练实时计算机视觉模型

经过数十亿个参数训练的 AI 模型非常强大&#xff0c;但并不总是适合实时使用。但是&#xff0c;它们可以通过自动监督快速专用模型的标注来减少人力投入。 ‍ 如果你曾经构建过计算机视觉模型&#xff0c;就就会知道监督需要大量工作——人类花时间&#xff08;数小时或数天&a…

自动化测试全攻略:从入门到精通!

1、自动化测试专栏 随着技术的发展和工作需求的增长&#xff0c;自动化测试已成为软件质量保障体系中不可或缺的一环。 为了帮助广大测试工程师、开发者和对自动化测试感兴趣的读者们更好地掌握这一技能&#xff0c;今年特别推出了全新的《自动化测试全攻略&#xff1a;从入门…

scratch绘制四个三角形 2024年6月中国电子学会 图形化编程 scratch编程等级考试二级真题和答案解析

scratch绘制四个三角形 一、题目要求 2024年6月电子学会图形化编程Scratch等级考试二级真题 1、准备工作 1.保留默认角色小猫; 2.添加背景Stars。 2、功能实现 1 .隐藏角色小猫&#xff0c;设置画笔裙始位置为(0,0)&#xff0c;画笔颜色为黄色&#xff0c;画笔的粗细为5…

NET Core 中的空对象设计模式

介绍 一种称为“空对象模式”的行为设计模式提供了一个对象来表示接口缺少的对象。在空对象会导致空引用异常的情况下&#xff0c;这是一种提供替代行为的方法。在本文中&#xff0c;我们将深入探讨 C# 空对象模式&#xff0c;并逐步解决更复杂的情况。 空对象设计模式它是什…