MindSpore基础教程:LeNet-5 神经网络在MindSpore中的实现与训练

MindSpore基础教程:LeNet-5 神经网络在MindSpore中的实现与训练

官方文档教程使用已经弃用的MindVision模块,本文是对官方文档的更新
深度学习在图像识别领域取得了显著的成功,LeNet-5 作为卷积神经网络的经典之作,在诸多研究和应用中占有重要地位。本文将详细介绍如何使用 MindSpore 框架实现并训练一个 LeNet-5 神经网络,专注于处理MNIST手写数字数据集。

前言

MindSpore 是华为推出的一种新型深度学习框架,旨在为用户提供高效、易用的编程体验。接下来,我们将通过实例来展示如何在 MindSpore 中构建、训练和评估一个经典的 LeNet-5 神经网络。

环境配置

MindSpore官网

LeNet-5 网络结构简介

LeNet-5 是一个简单的卷积神经网络,包含两个卷积层和三个全连接层。它经常被用于图像识别任务,特别是在处理像 MNIST 这样的手写数字数据集时表现出色。

数据集准备与预处理

首先,我们需要准备并预处理数据集。在这个例子中,我们将使用 MNIST 数据集。以下函数 create_dataset 负责加载数据集,并进行必要的预处理:

def create_dataset(data_path, batch_size=32, repeat_size=1):"""创建用于训练的MNIST数据集。此函数负责加载MNIST数据集,对数据进行预处理和转换,以便它们可以用于训练神经网络。数据预处理包括调整图像大小、重新缩放和类型转换。参数:data_path (str): MNIST数据集的路径。这应该是包含MNIST数据文件的目录路径。batch_size (int, 可选): 每个数据批次的大小。默认值为32。repeat_size (int, 可选): 数据集重复的次数。这用于增加数据集的大小。默认值为1。步骤:1. 加载MNIST数据集。2. 对图像执行大小调整操作,将图像大小统一调整为32x32像素。3. 对图像进行重新缩放和标准化处理。先将像素值缩放到0-1之间,然后进行标准化。4. 将图像的格式从高宽通道(HWC)转换为通道高宽(CHW)。5. 对标签进行类型转换,将其转换为整型(int32)。6. 对数据集进行洗牌、批处理和重复操作,以准备训练过程。返回:返回一个处理过的MNIST数据集,可以直接用于模型训练。注意:- 数据集的预处理步骤对于训练深度学习模型来说是非常重要的,它们会影响训练的效果和速度。- 调整batch_size和repeat_size可以影响模型训练时的内存消耗和速度。"""mnist_dataset = ds.MnistDataset(data_path)resize_operation = vision.Resize((32, 32), interpolation=Inter.LINEAR)rescale_normalization_op = vision.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081)rescale_op = vision.Rescale(1.0 / 255.0, 0.0)hwc_to_chw_op = vision.HWC2CHW()type_cast_op = transforms.TypeCast(mstype.int32)mnist_dataset = mnist_dataset.map(input_columns="label", operations=type_cast_op)mnist_dataset = mnist_dataset.map(input_columns="image",operations=[resize_operation, rescale_op, rescale_normalization_op,hwc_to_chw_op])mnist_dataset = mnist_dataset.shuffle(buffer_size=10000)mnist_dataset = mnist_dataset.batch(batch_size, drop_remainder=True)mnist_dataset = mnist_dataset.repeat(repeat_size)return mnist_dataset

这个函数将数据集中的图像调整为统一的大小,并进行重新缩放和标准化。

构建 LeNet-5 模型

LeNet-5 模型的构建在 LeNet5 类中实现。此类定义了网络的各层及其排列:

class LeNet5(nn.Cell):"""LeNet-5 神经网络结构。这是一个经典的卷积神经网络,通常用于图像识别任务。它包含了两个卷积层和三个全连接层。参数:num_class (int): 输出层的类别数量。默认为10,适用于MNIST数据集。num_channel (int): 输入图像的通道数。对于灰度图像,此值为1。组件:- conv1: 第一个卷积层,使用有效填充。- conv2: 第二个卷积层,同样使用有效填充。- fc1: 第一个全连接层。- fc2: 第二个全连接层。- fc3: 第三个全连接层,输出层。- relu: 激活函数,使用ReLU。- max_pool2d: 最大池化层。- flatten: 扁平化层,用于全连接层之前的数据转换。方法:- construct(x): 定义了前向传播的过程。"""def __init__(self, num_class=10, num_channel=1):super(LeNet5, self).__init__()self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))self.relu = nn.ReLU()self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()def construct(self, x):x = self.conv1(x)x = self.relu(x)x = self.max_pool2d(x)x = self.conv2(x)x = self.relu(x)x = self.max_pool2d(x)x = self.flatten(x)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.relu(x)x = self.fc3(x)return x

训练模型

接下来,我们定义 train_network 函数来训练模型。此函数接受模型实例、数据集路径和其他训练参数:

def train_network(model, epoch_size, data_path, repeat_size, checkpoint_callback):"""训练神经网络模型。此函数负责初始化数据集,然后使用指定的模型进行训练。在训练过程中,它将记录损失并保存模型的检查点。参数:model (Model): 要训练的神经网络模型。epoch_size (int): 训练过程中遍历数据集的次数。data_path (str): 训练数据集的路径。repeat_size (int): 数据集的重复次数,用于扩充数据集。checkpoint_callback (Callback): 用于保存模型检查点的回调函数。过程:- 使用 `create_dataset` 函数创建训练数据集。- 调用模型的 `train` 方法进行训练。- 在训练过程中,会通过回调函数记录损失和保存检查点。注意:- 确保提供的 `data_path` 包含适当格式的数据。"""print("============== 开始训练 ==============")ds_train = create_dataset(data_path, 32, repeat_size)model.train(epoch_size, ds_train, callbacks=[checkpoint_callback, LossMonitor(), TimeMonitor()],dataset_sink_mode=False)print("============== 训练结束 ==============")

主函数

最后,我们通过 train 函数和 parse_arguments 函数将所有步骤串联起来。train 函数负责初始化模型、损失函数、优化器和检查点回调,然后调用 train_network 进行训练:

def train(args):"""初始化并训练LeNet-5神经网络模型。此函数设置了网络模型、损失函数、优化器,并定义了模型检查点。然后,使用指定的参数调用 `train_network` 函数来进行模型的训练。参数:args (Namespace): 一个包含训练参数的命名空间对象。此对象应该包含以下属性:- epochs (int): 模型训练的迭代次数。- data_url (str): 训练数据集的路径。- output_path (str): 保存模型检查点的路径。过程:1. 创建 LeNet-5 网络实例。2. 定义损失函数为 Softmax Cross-Entropy。3. 定义优化器为 Momentum 优化器。4. 创建模型实例,并指定网络、损失函数、优化器和评估指标。5. 设置模型检查点配置。6. 初始化模型检查点回调函数。7. 调用 `train_network` 函数进行训练。注意:- 确保 `args` 对象包含正确和完整的训练参数。- 调整优化器和损失函数的参数可以对训练结果产生影响。- 模型检查点将保存在 `args.output_path` 指定的路径中。"""net = LeNet5()net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")net_opt = nn.Momentum(net.trainable_params(), 0.01, 0.9)model = Model(net, net_loss, net_opt, metrics={"Accuracy": nn.Accuracy()})config_checkpoint = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10)checkpoint_callback = ModelCheckpoint(prefix="checkpoint_lenet", directory=args.output_path,config=config_checkpoint)train_network(model, args.epochs, args.data_url, 1, checkpoint_callback)

推理

# 加载网络
param_dict = load_checkpoint("/root/MyCode/pycharm/lenet5/ckpt/checkpoint_lenet-19_1884.ckpt")
network = LeNet5(num_class=NUM_CLASS, num_channel=1)  # 用您定义的LeNet5类创建模型实例
load_param_into_net(network, param_dict)  # 将参数加载到网络中
model = Model(network)def predict_digit(img):# 图像预处理img = cv2.resize(img, (32, 32))  # 调整图像大小为32x32img = np.array(img, dtype=np.float32)  # 转换图像数据类型img = (img - 0.1307) / 0.3081  # 对图像进行标准化处理img = img[np.newaxis, np.newaxis, :, :]  # 改变图像形状以符合网络输入要求(1, 1, 32, 32)# 将图像数据转换为MindSpore张量img_tensor = Tensor(img)# 使用模型进行预测output = model.predict(img_tensor)# 将输出转换为概率分布probabilities = Softmax()(output)# 获取每个类别的概率probabilities_np = probabilities.asnumpy()[0]# 将概率转换为字典格式labels = [str(i) for i in range(10)]  # 类别标签,例如"0", "1", "2", ..., "9"probabilities_dict = {label: prob for label, prob in zip(labels, probabilities_np)}return probabilities_dictgr.Interface(fn=predict_digit,inputs=gr.Image(image_mode='L'),outputs=gr.Label(num_top_classes=NUM_CLASS),live=False,css=".footer {display:none !important}",title="0-9数字画板",description="画0-9数字",thumbnail="https://raw.githubusercontent.com/gradio-app/real-time-mnist/master/thumbnail2.png"
).launch()

结论

通过本文的指南,您可以在 MindSpore 框架中实现并训练一个经典的 LeNet-5 神经网络。LeNet-5 在图像识别任务中展现了卓越的性能,而 MindSpore 的高效和易用性使得深度学习研究和开发更加便捷。您可以根据本文的指导进行实验,并根据需要调整网络结构和训练参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/165273.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux | 从虚拟地址到物理地址

前言 本章主要讲解虚拟地址是怎么转化成物理地址的,以及页表相关知识;本文环境默认为32位机器下;如果你连什么是虚拟地址都不知道可以先看看下面这篇文章; Linux | 进程地址空间-CSDN博客 一、概念补充 页表:是一种数据…

【性能优化】CPU利用率飙高与内存飙高问题

📫作者简介:小明java问道之路,2022年度博客之星全国TOP3,专注于后端、中间件、计算机底层、架构设计演进与稳定性建设优化,文章内容兼具广度、深度、大厂技术方案,对待技术喜欢推理加验证,就职于…

2023APMCM亚太杯数学建模选题建议及初步思路

大家好呀,亚太杯数学建模开始了,来说一下初步的选题建议吧: 首先定下主基调,本次亚太杯推荐选择B题。 C题如果想做好,搜集数据难度并不低,并且模型比较简单,此外目前选择的人数过多&#xff0c…

java项目之消防物资存储系统(ssm+vue)

项目简介 消防物资存储系统实现了以下功能: 管理员功能: 管理员登陆后,主要模块包括首页,个人中心,用户管理,仓库管理,物资入库管理,物资出库管理,仓库管理,物资详情管…

23年下半年软考成绩查询时间是什么时候?

一、成绩查询时间 2023年下半年软考成绩查询时间预计2023年12月份公布,成绩查询入口为计算机技术职业资格网(全国统一成绩查询时间,统一查询入口)。 二、成绩查询方法 登陆中国计算机技术职业资格网,点击“成绩查询”…

7-9 jmu-python-班级人员信息统计

7-9 jmu-python-班级人员信息统计 分数 15 作者 郑如滨 单位 集美大学 输入a,b班的名单,并进行如下统计。 输入格式: 第1行::a班名单,一串字符串,每个字符代表一个学生,无空格,可能有重复字符。 第2行:&am…

WPF实战项目十六(客户端):备忘录接口

1、新增IMemoService接口&#xff0c;继承IBaseService接口 public interface IMemoService : IBaseService<MemoDto>{} 2、新增MemoService类&#xff0c;继承BaseService和IMemoService接口 public class MemoService : BaseService<MemoDto>, IMemoService{pub…

DRF-通用分页器(PageNumberPagination):ListModelMixin可以使用的通用分页器

一、ListModelMixin 和GenericAPIView源码 ListModelMixin 是一个单一功能类&#xff0c;必须配合GenericAPIView&#xff08;或其子类&#xff09;来一起使用&#xff0c;才能完成其视图的功能 class ListModelMixin:"""List a queryset."""d…

腾讯云点播小程序端上传 SDK

云点播是专门应对上传大视频文件的。 腾讯云点播文档&#xff1a;https://cloud.tencent.com/document/product/266/18177 这个文档比较简单&#xff0c;实在不行&#xff0c;把demo下载下来&#xff0c;一看就明白了&#xff0c;然后再揉一下挪到自己的项目里。完事。 getSign…

芯知识 | 混音播报语音芯片的优势:革新音频应用的新力量

随着科技的进步&#xff0c;语音芯片在各个领域的应用越来越广泛。而在众多语音芯片中&#xff0c;混音播报语音芯片以其独特的优势&#xff0c;正逐渐成为音频应用领域的翘楚。本文将重点探讨混音播报语音芯片的优势及其在现代科技应用中的价值。 一、混音播报语音芯片概述 …

element-vue实现网页锁屏功能

1.写一个锁屏页面&#xff0c;这里比较简单&#xff0c;自己定义一下,需要放到底层HTML中哦&#xff0c;比如index.html <div id"appIndex"><el-dialog title"请输入密码解锁屏幕" :visible.sync"lockScreenFlag" :close-on-click-mod…

力扣236. 二叉树的最近公共祖先(java DFS解法)

Problem: 236. 二叉树的最近公共祖先 文章目录 题目描述思路解题方法复杂度Code 题目描述 给定一个二叉树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为&#xff1a;“对于有根树 T 的两个节点 p、q&#xff0c;最近公共祖先表示为一个节点 x&am…

Android逆向一-frida操作

系列文章目录 第一章 frida操作 文章目录 系列文章目录前言一、两种模式二、frida命令行执行及参数三、frida使用python执行四、动静态域调用1. 静态域调用2.动态域调用 五. 远程rpc调用六. 补充总结 前言 熟悉frida操作&#xff0c;hook手机app的关键位置进行逆向操作 一、…

芯知识 | Flash可更换声音语音芯片—引领音频IC技术革新的新篇章

随着科技的飞速发展&#xff0c;人们对于电子产品的音频性能要求越来越高。在这种背景下&#xff0c;Flash可更换声音语音芯片应运而生&#xff0c;成为音频技术领域的一颗璀璨明星。本文将详细介绍Flash可更换声音语音芯片的特点、优势以及应用场景&#xff0c;展望其在未来科…

【Docker】从零开始:10.registry搭建私有仓库

【Docker】从零开始&#xff1a;10.registry搭建私有仓库 为什么要使用私有仓库关于Docker Registry基于容器搭建registry私有仓库1.下载镜像2. 启动镜像3.修改系统配置文件4.下载ubuntu镜像&#xff0c;修改名称3.提交镜像4.查看镜像 本地搭建私有仓库(目前编译报错找不到包&a…

【管理运筹学】背诵手册(五)| 动态规划

五、动态规划 基本概念 阶段&#xff08;Stage&#xff09;&#xff1a;将所给问题的过程&#xff0c;按时间或空间特征分解成若干相互联系的阶段&#xff0c;以便按次序去求解每阶段的解&#xff0c;常用字母 k k k 表示。 状态&#xff08;State&#xff09;&#xff1a;…

java实现连接linux(上传文件,执行shell命令等)

1 导入pom <dependency><groupId>com.jcraft</groupId><artifactId>jsch</artifactId><version>0.1.55</version></dependency> 2 编写配置类 package com.budwk.app.atest;import com.budwk.app.common.config.AppExceptio…

计算机网络之网络层

一、概述 主要任务是实现网络互连&#xff0c;进而实现数据包在各网络之间的传输 1.1网络引入的目的 从7层结构上看&#xff0c;网络层下是数据链路层 从4层结构上看&#xff0c;网络层下面是网络接口层 至少我们看到的网络层下面是以太网 以太网解决了什么问题&#xff1f; 答…

【Python 千题 —— 基础篇】删除列表值

题目描述 题目描述 删除列表的指定值。有一个列表 [1, 3, 5, 2, 44, 1, 9, 10, 32] &#xff0c;请使用 for 循环删除该列表中与 [44, 1, 9] 列表相同的值&#xff0c;并输出该列表。 输入描述 无输入。 输出描述 输出操作后的列表。 示例 示例 ① 输出&#xff1a; …

记录:通过day.js获取两个日期相差的时间,并转化为年月日的格式

day.js这个日期库真的是很不错的日期库&#xff0c;足够满足日常的开发需求。 Day.js中文网 (fenxianglu.cn) 需求&#xff1a;获取两个日期相差的时间&#xff0c;转化为年月日的形式&#xff1b;话不多少&#xff0c;直接放代码 import dayjs from "dayjs"; imp…