昇思25天学习打卡营第9天|ResNet50图像分类

一、Resnet残差网络模型

  1. 构建残差网络结构;
  2. Building Block
  3. Bottleneck

残差结构由两个分支构成:一个主分支 𝐹(𝑥),一个shortcuts(图中弧线表示,𝑥)。 得到残差网络结构:𝐹(𝑥)+𝑥
通过Relu激活函数后即为残差网络最后的输出。
在这里插入图片描述
定义ResidualBlockBase类实现Building Block结构
在这里插入图片描述

from typing import Type, Union, List, Optional
import mindspore.nn as nn
from mindspore.common.initializer import Normal# 初始化卷积层与BatchNorm的参数
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)class ResidualBlockBase(nn.Cell):expansion: int = 1  # 最后一个卷积核数量与第一个卷积核数量相等def __init__(self, in_channel: int, out_channel: int,stride: int = 1, norm: Optional[nn.Cell] = None,down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlockBase, self).__init__()if not norm:self.norm = nn.BatchNorm2d(out_channel)else:self.norm = normself.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.conv2 = nn.Conv2d(in_channel, out_channel,kernel_size=3, weight_init=weight_init)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):"""ResidualBlockBase construct."""identity = x  # shortcuts分支out = self.conv1(x)  # 主分支第一层:3*3卷积层out = self.norm(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out

定义ResidualBlock类实现Bottleneck结构
在输入相同的情况下Bottleneck结构相对Building Block结构的参数数量更少,更适合层数较深的网络,ResNet50使用的残差结构就是Bottleneck。该结构的主分支有三层卷积结构,分别为1×1的卷积层、 3×3卷积层和 1×1的卷积层,其中 1×1 的卷积层分别起降维和升维的作用。
在这里插入图片描述

class ResidualBlock(nn.Cell):expansion = 4  # 最后一个卷积核的数量是第一个卷积核数量的4倍def __init__(self, in_channel: int, out_channel: int,stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=1, weight_init=weight_init)self.norm1 = nn.BatchNorm2d(out_channel)self.conv2 = nn.Conv2d(out_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)self.norm2 = nn.BatchNorm2d(out_channel)self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,kernel_size=1, weight_init=weight_init)self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)self.relu = nn.ReLU()self.down_sample = down_sampledef construct(self, x):identity = x  # shortscuts分支out = self.conv1(x)  # 主分支第一层:1*1卷积层out = self.norm1(out)out = self.relu(out)out = self.conv2(out)  # 主分支第二层:3*3卷积层out = self.norm2(out)out = self.relu(out)out = self.conv3(out)  # 主分支第三层:1*1卷积层out = self.norm3(out)if self.down_sample is not None:identity = self.down_sample(x)out += identity  # 输出为主分支与shortcuts之和out = self.relu(out)return out

二、Resnet50

堆叠残差网络来构建ResNet50网络:ResNet网络层结构如下图所示,以输入彩色图像 224×224 为例,首先通过数量64,卷积核大小为 7×7,stride为2的卷积层conv1,该层输出图片大小为 112×112 ,输出channel为64;然后通过一个 3×3 的最大下采样池化层,该层输出图片大小为 56×56,输出channel为64;再堆叠4个残差网络块(conv2_x、conv3_x、conv4_x和conv5_x),此时输出图片大小为 7×7 ,输出channel为2048;最后通过一个平均池化层、全连接层和softmax,得到分类概率。
在这里插入图片描述
对于每个残差网络块,以ResNet50网络中的conv2_x为例,其由3个Bottleneck结构堆叠而成,每个Bottleneck输入的channel为64,输出channel为256。
定义make_layer实现残差块的构建,其参数如下所示:
last_out_channel:上一个残差网络输出的通道数。
block:残差网络的类别,分别为ResidualBlockBase和ResidualBlock。
channel:残差网络输入的通道数。
block_nums:残差网络块堆叠的个数。
stride:卷积移动的步幅。

def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],channel: int, block_nums: int, stride: int = 1):down_sample = None  # shortcuts分支if stride != 1 or last_out_channel != channel * block.expansion:down_sample = nn.SequentialCell([nn.Conv2d(last_out_channel, channel * block.expansion,kernel_size=1, stride=stride, weight_init=weight_init),nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)])layers = []layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))in_channel = channel * block.expansion# 堆叠残差网络for _ in range(1, block_nums):layers.append(block(in_channel, channel))return nn.SequentialCell(layers)

在这里插入图片描述

from mindspore import load_checkpoint, load_param_into_netclass ResNet(nn.Cell):def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],layer_nums: List[int], num_classes: int, input_channel: int) -> None:super(ResNet, self).__init__()self.relu = nn.ReLU()# 第一个卷积层,输入channel为3(彩色图像),输出channel为64self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)self.norm = nn.BatchNorm2d(64)# 最大池化层,缩小图片的尺寸self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')# 各个残差网络结构块定义self.layer1 = make_layer(64, block, 64, layer_nums[0])self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)# 平均池化层self.avg_pool = nn.AvgPool2d()# flattern层self.flatten = nn.Flatten()# 全连接层self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)def construct(self, x):x = self.conv1(x)x = self.norm(x)x = self.relu(x)x = self.max_pool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avg_pool(x)x = self.flatten(x)x = self.fc(x)return x

网络模型构建函数

def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,input_channel: int):model = ResNet(block, layers, num_classes, input_channel)if pretrained:# 加载预训练模型download(url=model_url, path=pretrained_ckpt, replace=True)param_dict = load_checkpoint(pretrained_ckpt)load_param_into_net(model, param_dict)return modeldef resnet50(num_classes: int = 1000, pretrained: bool = False):"""ResNet50模型"""resnet50_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt"resnet50_ckpt = "./LoadPretrainedModel/resnet50_224_new.ckpt"return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,pretrained, resnet50_ckpt, 2048)

三、 CIFAR10数据集

CIFAR10数据集共有60000个样本(了50000个训练样本和10000个测试样本):每个样本图都是32x32像素,RGB图像-分为3个通道。用于监督学习模型训练。CIFAR10数据集标签有10种物体类别(标签值:0~9):飞机( airplane )、汽车( automobile )、鸟( bird )、猫( cat )、鹿( deer )、狗( dog )、青蛙( frog )、马( horse )、船( ship )和卡车( truck )。

3.1数据集加载

数据下载

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"download(url, "./datasets-cifar10-bin", kind="tar.gz", replace=True)

数据集加载与数据增强

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstypedata_dir = "./datasets-cifar10-bin/cifar-10-batches-bin"  # 数据集根目录
batch_size = 256  # 批量大小
image_size = 32  # 训练图像空间大小
workers = 4  # 并行线程个数
num_classes = 10  # 分类数量def create_dataset_cifar10(dataset_dir, usage, resize, batch_size, workers):data_set = ds.Cifar10Dataset(dataset_dir=dataset_dir,usage=usage,num_parallel_workers=workers,shuffle=True)trans = []if usage == "train":trans += [vision.RandomCrop((32, 32), (4, 4, 4, 4)),vision.RandomHorizontalFlip(prob=0.5)]trans += [vision.Resize(resize),vision.Rescale(1.0 / 255.0, 0.0),vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),vision.HWC2CHW()]target_trans = transforms.TypeCast(mstype.int32)# 数据映射操作data_set = data_set.map(operations=trans,input_columns='image',num_parallel_workers=workers)data_set = data_set.map(operations=target_trans,input_columns='label',num_parallel_workers=workers)# 批量操作data_set = data_set.batch(batch_size)return data_set# 获取处理后的训练与测试数据集dataset_train = create_dataset_cifar10(dataset_dir=data_dir,usage="train",resize=image_size,batch_size=batch_size,workers=workers)
step_size_train = dataset_train.get_dataset_size()dataset_val = create_dataset_cifar10(dataset_dir=data_dir,usage="test",resize=image_size,batch_size=batch_size,workers=workers)
step_size_val = dataset_val.get_dataset_size()

3.2图像数据可视化展示

import matplotlib.pyplot as plt
import numpy as npdata_iter = next(dataset_train.create_dict_iterator())images = data_iter["image"].asnumpy()
labels = data_iter["label"].asnumpy()
print(f"Image shape: {images.shape}, Label shape: {labels.shape}")# 训练数据集中,前六张图片所对应的标签
print(f"Labels: {labels[:6]}")classes = []with open(data_dir + "/batches.meta.txt", "r") as f:for line in f:line = line.rstrip()if line:classes.append(line)# 训练数据集的前六张图片
plt.figure()
for i in range(6):plt.subplot(2, 3, i + 1)image_trans = np.transpose(images[i], (1, 2, 0))mean = np.array([0.4914, 0.4822, 0.4465])std = np.array([0.2023, 0.1994, 0.2010])image_trans = std * image_trans + meanimage_trans = np.clip(image_trans, 0, 1)plt.title(f"{classes[labels[i]]}")plt.imshow(image_trans)plt.axis("off")
plt.show()

在这里插入图片描述

四、加载数据集-开展Resnet50模型训练

构建网络模型,调整输出

# 定义ResNet50网络
network = resnet50(pretrained=True)# 全连接层输入层的大小
in_channel = network.fc.in_channels
fc = nn.Dense(in_channels=in_channel, out_channels=10)
# 重置全连接层
network.fc = fc

设置学习率

# 设置学习率
num_epochs = 10  # 设置训练batch
lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,step_per_epoch=step_size_train, decay_epoch=num_epochs)
# 定义优化器和损失函数
opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)
loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')def forward_fn(inputs, targets):logits = network(inputs)loss = loss_fn(logits, targets)return lossgrad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)def train_step(inputs, targets):loss, grads = grad_fn(inputs, targets)opt(grads)return loss

定义训练、验证函数

import mindspore.ops as opsdef train(data_loader, epoch):"""模型训练"""losses = []network.set_train(True)for i, (images, labels) in enumerate(data_loader):loss = train_step(images, labels)if i % 100 == 0 or i == step_size_train - 1:print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' %(epoch + 1, num_epochs, i + 1, step_size_train, loss))losses.append(loss)return sum(losses) / len(losses)def evaluate(data_loader):"""模型验证"""network.set_train(False)correct_num = 0.0  # 预测正确个数total_num = 0.0  # 预测总数for images, labels in data_loader:logits = network(images)pred = logits.argmax(axis=1)  # 预测结果correct = ops.equal(pred, labels).reshape((-1, ))correct_num += correct.sum().asnumpy()total_num += correct.shape[0]acc = correct_num / total_num  # 准确率return acc

读数据集,开始训练数据

import os# 创建迭代器
data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)
data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)# 最佳模型存储路径
best_acc = 0
best_ckpt_dir = "./BestCheckpoint"
best_ckpt_path = "./BestCheckpoint/resnet50-best.ckpt"if not os.path.exists(best_ckpt_dir):os.mkdir(best_ckpt_dir)# 开始循环训练
print("Start Training Loop ...")for epoch in range(num_epochs):curr_loss = train(data_loader_train, epoch)curr_acc = evaluate(data_loader_val)print("-" * 50)print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (epoch+1, num_epochs, curr_loss, curr_acc))print("-" * 50)# 保存当前预测准确率最高的模型if curr_acc > best_acc:best_acc = curr_accms.save_checkpoint(network, best_ckpt_path)print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "f"save the best ckpt file in {best_ckpt_path}", flush=True)

在这里插入图片描述

五、可视化模型预测

import matplotlib.pyplot as pltdef visualize_model(best_ckpt_path, dataset_val):num_class = 10  # 对狼和狗图像进行二分类net = resnet50(num_class)# 加载模型参数param_dict = ms.load_checkpoint(best_ckpt_path)ms.load_param_into_net(net, param_dict)# 加载验证集的数据进行验证data = next(dataset_val.create_dict_iterator())images = data["image"]labels = data["label"]# 预测图像类别output = net(data['image'])pred = np.argmax(output.asnumpy(), axis=1)# 图像分类classes = []with open(data_dir + "/batches.meta.txt", "r") as f:for line in f:line = line.rstrip()if line:classes.append(line)# 显示图像及图像的预测值plt.figure()for i in range(6):plt.subplot(2, 3, i + 1)# 若预测正确,显示为蓝色;若预测错误,显示为红色color = 'blue' if pred[i] == labels.asnumpy()[i] else 'red'plt.title('predict:{}'.format(classes[pred[i]]), color=color)picture_show = np.transpose(images.asnumpy()[i], (1, 2, 0))mean = np.array([0.4914, 0.4822, 0.4465])std = np.array([0.2023, 0.1994, 0.2010])picture_show = std * picture_show + meanpicture_show = np.clip(picture_show, 0, 1)plt.imshow(picture_show)plt.axis('off')plt.show()# 使用测试数据集进行验证
visualize_model(best_ckpt_path=best_ckpt_path, dataset_val=dataset_val)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/42559.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringMVC常见的注解

一、Spring MVC Spring Web MVC是基于ServletAPI构建的原始web 框架,一开始就包含在Spring 框架中,通常被称为“Spring MVC”。 1.MVC 是什么? MVC(Model、View、Controller)是软件工程中的一种软件架构设计模型。它把软件系统分…

STM32-输入捕获IC和编码器接口

本内容基于江协科技STM32视频学习之后整理而得。 文章目录 1. 输入捕获IC1.1 输入捕获IC简介1.2 频率测量1.3 输入捕获通道1.4 主从触发模式1.5 输入捕获基本结构1.6 PWMI基本结构 2. 输入捕获库函数及代码2.1 输入捕获库函数2.2 6-6 输入捕获模式测频率2.2.1 硬件连接2.2.2 硬…

2024暑假集训

Day1——枚举 Day2——测试 Day3——贪心 Day4、5——测试 ——————————————————————————————————————————— Day3T7&Day5T7:没思路 Day3T8:不知道怎么排序筛选 Day5T5:没有算法难度,但是不知道怎么处理2队奶牛的情…

什么牌子的头戴式蓝牙耳机好性价比高?

说起性价比高的头戴式蓝牙耳机,就不得不提倍思H1s,作为倍思最新推出的新款,在各项功能上都实现了不错的升级,二字开头的价格,配置却毫不含糊, 倍思H1s的音质表现堪称一流。它采用了40mm天然生物纤维振膜,这种振膜柔韧而有弹性,能够显著提升低音的量感。无论是深沉的低音还是清…

数据跨境法案:美国篇上

近年来随着全球数字化的加速发展,数据已成为国家竞争力的重要基石。在这样的背景下,中国软件和技术出海的场景日益丰富。本系列邀请到在跨境数据方面的研究人员针对海外的数据跨境政策进行解读。 本期将针对美国对数据跨境流动的态度和政策进行阐释。过…

非比较排序 计数排序

1.核心思路 首先要找出max 和 min,最大值 - 最小值 1,就可以计算出数据在什么范围然后创建计数数组大小,a[i] - min 在数组的相对位置计数 通过自然序列排序然后把计数好的值,按照顺序依次放回原数组即可 动图解释,其…

Linux—网络设置

目录 一、ifconfig——查看网络配置 1、查看网络接口信息 1.1、查看所有网络接口 1.2、查看具体的网络接口 2、修改网络配置 3、添加网络接口 4、禁用/激活网卡 二、hostname——查看主机名称 1、查看主机名称 2、临时修改主机名称 3、永久修改主机名称 4、查看本…

Java引用的4种类型:强、软、弱、虚

在Java中,引用的概念不仅限于强引用,还包括软引用、弱引用和虚引用(也称为幻影引用)。这些引用类型主要用于不同的内存管理策略,尤其是在垃圾收集过程中。以下是对这四种引用类型的详细解释: 1. 强引用&am…

algorithm算法库学习之——不修改序列的操作

algorithm此头文件是算法库的一部分。本篇介绍不修改序列的操作函数。 不修改序列的操作 all_ofany_ofnone_of (C11)(C11)(C11) 检查谓词是否对范围中所有、任一或无元素为 true (函数模板) for_each 应用函数到范围中的元素 (函数模板) for_each_n (C17) 应用一个函数对象到序…

一.7.(2)基本运算电路,包括比例运算电路、加减运算电路、积分运算电路、微分电路等常见电路的分析、计算及应用;(未完待续)

what id the 虚短虚断虚地? 虚短:运放的正相输入端和反相输入端貌似连在一起了,所以两端的电压相等,即UU- 虚断:输入端输入阻抗无穷大 虚地:运放正相输入端接地,导致U=U-=0。 虚…

对话大模型Prompt是否需要礼貌点?

大模型相关目录 大模型,包括部署微调prompt/Agent应用开发、知识库增强、数据库增强、知识图谱增强、自然语言处理、多模态等大模型应用开发内容 从0起步,扬帆起航。 基于Dify的QA数据集构建(附代码)Qwen-2-7B和GLM-4-9B&#x…

秋招突击——7/5——复习{}——新作{跳跃游戏II、划分字母区间、数组中的第K个大的元素(模板题,重要)、前K个高频元素}

文章目录 引言正文贪心——45 跳跃游戏II个人实现参考实现 划分字母区间个人实现参考实现 数组中的第K个最大元素个人实现参考做法 前K个高频元素个人实现参考实现 总结 引言 今天就开始的蛮早的,现在是九点多,刚好开始做算法,今天有希望能够…

【leetcode周赛记录——405】

405周赛记录 #1.leetcode100339_找出加密后的字符串2.leetcode100328_生成不含相邻零的二进制字符串3.leetcode100359_统计X和Y频数相等的子矩阵数量4.leetcode100350_最小代价构造字符串 刷了一段时间算法了,打打周赛看看什么水平了 #1.leetcode100339_找出加密后的…

【UML用户指南】-30-对体系结构建模-模式和框架

目录 1、机制 2、框架 3、常用建模技术 3.1、对设计模式建模 3.2、对体系结构模式建模 用模式来详述形成系统体系结构的机制和框架。通过清晰地标识模式的槽、标签、按钮和刻度盘 在UML中, 对设计模式(也叫做机制)建模,将它…

【web前端HTML+CSS+JS】--- CSS学习笔记02

一、CSS(层叠样式表)介绍 1.优势 2.定义解释 如果有多个选择器共同作用的话,只有优先级最高那层样式决定最终的效果 二、无语义化标签 div和span:只起到描述的作用,不带任何样式 三、标签选择器 1.标签/元素选择器…

【算法笔记自学】第 8 章 提高篇(2)——搜索专题

8.1深度优先搜索&#xff08;DFS&#xff09; #include <cstdio>const int MAXN 5; int n, m, maze[MAXN][MAXN]; bool visited[MAXN][MAXN] {false}; int counter 0;const int MAXD 4; int dx[MAXD] {0, 0, 1, -1}; int dy[MAXD] {1, -1, 0, 0};bool isValid(int …

【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【21】【购物车】

持续学习&持续更新中… 守破离 【雷丰阳-谷粒商城 】【分布式高级篇-微服务架构篇】【21】【购物车】 购物车需求描述购物车数据结构数据Model抽取实现流程&#xff08;参照京东&#xff09;代码实现参考 购物车需求描述 用户可以在登录状态下将商品添加到购物车【用户购物…

CSS技巧:纯CSS实现文字渐变动画效果

文字渐变动画&#xff0c;可以实现的有两种&#xff1a;一种是一行文字整体变化颜色&#xff1b;另一种一行文字依次变化颜色。接下来&#xff0c;我就介绍一下这两种文字渐变的实现过程。 布局代码&#xff1a; <div class"con"><div class"animate…

7.pwn 工具安装和使用

关闭保护的方法 pie: -no-pie Canary:-fno-stack-protector aslr:查看:cat /proc/sys/kernel/randomize_va_space 2表示打开 关闭:echo 0>/proc/sys/kernel/randomize_va_space NX:-z execstack gdb使用以及插件安装 是GNU软件系统中的标准调试工具&#xff0c;此外GD…

electron 初始使用

electron electron文档地址deno下载地址安装命令 yarn config set electron_mirror https://cdn.npm.taobao.org/dist/electron/ npm install下载文件 文件下载完成后&#xff0c;新建dist目录&#xff0c;解压到list目录下&#xff1b;path文件中写入electron.exe 运行命令 …