昇思25天学习打卡营第2天|快速入门

快速入门

  • 操作步骤
    • 1.引入依赖包
    • 2.下载Mnist数据集
    • 3.划分训练集和测试集
    • 4.数据预处理
    • 5.网络构建
    • 6.模型训练
    • 7.保存模型
    • 8.加载模型
    • 9.模型预测

今天通过昇思大模型平台AI实验室提供的在线Jupyter工具,快速入门MindSpore。
目标:通过MindSpore的API快速实现一个简单的深度学习模型。

操作步骤

1.引入依赖包

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

2.下载Mnist数据集

from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

Mnist数据集目录结构如下:
MNIST_Data
└── train
├── train-images-idx3-ubyte (60000个训练图片)
├── train-labels-idx1-ubyte (60000个训练标签)
└── test
├── t10k-images-idx3-ubyte (10000个测试图片)
├── t10k-labels-idx1-ubyte (10000个测试标签)

3.划分训练集和测试集

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

4.数据预处理

使用dataset模块的map操作对图像数据及标签进行变换处理,然后将处理好的数据集打包为大小为64的batch。

def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

5.网络构建

继承nn.Cell类,并重写__init__方法和construct方法。__init__包含所有网络层的定义,construct中包含数据(Tensor)的变换过程。

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

6.模型训练

一个完整的训练过程(step)需要实现以下三步:

  1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
  2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
  3. 参数优化:将梯度更新到参数上。
# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 定义正向计算函数
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 使用value_and_grad通过函数变换获得梯度计算函数
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return loss# 定义测试函数,用来评估模型的性能
def train(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")# 预测,并输出每一轮的loss值和预测准确率(Accuracy)
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")

训练结果:
训练结果

7.保存模型

# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

8.加载模型

# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

9.模型预测

model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break

打印结果:Predicted: “[7 1 9 8 3 8 7 7 7 9]”, Actual: “[7 1 9 8 3 8 7 7 7 9]”

截图时间
截图时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/35130.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

云计算 | 期末梳理(下)

1.模运算 2. 拓展欧几里得算法 3.扩散和混淆、攻击的分类 香农的贡献:定义了理论安全性,提出扩散和混淆原则,奠定了密码学的理论基础。扩散:将每一位明文尽可能地散布到多个输出密文中去,以更隐蔽明文数字的统计特性。混淆:使密文的统计特性与明文密钥之间的关系尽量复杂…

深入解析直播带货系统源码:短视频商城APP开发全攻略

本篇文章,小编将深入解析直播带货系统的源码,并为开发短视频商城APP提供全攻略,助力开发者打造高效、稳定的带货平台。 一、直播带货系统概述 直播带货系统主要由直播模块、商品管理模块、订单处理模块、用户管理模块、以及支付模块等组成。…

Ubuntu20.04使用Samba

目录 一、Samba介绍 Samba 的主要功能 二、启动samba 三、主机操作 四、Ubuntu与windows系统中文件互联 五、修改samba路径 一、Samba介绍 Samba 是一个开源软件套件,用于在 Linux 和 Unix 系统上实现 SMB(Server Message Block)协议…

速卖通自养号测评:安全高效的推广手段

在速卖通平台上,卖家们常常寻求各种方法来提升商品的曝光、转化率和店铺权重。其中,自养号测评作为一种低成本、高回报的推广方式,备受关注。然而,若操作不当,也可能带来风险。以下是如何安全有效地进行自养号测评的指…

VS Code 使用 Makefile 运行 CPP项目

Installing the MinGW-w64 toolchainCMake Toolsmakelist.txt报错 1报错 2报错 3生成了 Makefile ,如何使用 make 命令 Installing the MinGW-w64 toolchain 参见文档 将 GCC 与 MinGW 结合使用 CMake Tools 参见文档 Linux 上的 CMake 工具入门 CMake 的使用 …

关于Pycharm右下角不显示解释器interpreter的问题解决

关于Pycharm右下角不显示解释器interpreter的问题 在安装新的Pycharm后,发现右下角的 interpreter 的选型消失了: 觉得还挺不习惯的,于是网上找解决办法,无果。 自己摸索了一番后,发现解决办法如下: 勾…

37岁,被裁员,失业三个月,被面试官嫌弃“太水”:就这也叫10年以上工作经验?

今年部门要招两个自动化测试,这几个月我面试了几十位候选人。发现一个很奇怪的现象,面试中一问到元素定位、框架api、脚本编写之类的,很多候选人都对答如流。但是一问到实际项目,比如“项目中UI自动化和接口自动化如何搭配使用&am…

电商平台家电以旧换新销售额增长超80%

记者近日从国家发展和改革委员会举办的新闻发布会上获悉,今年1—5月份,主要电商平台家电以旧换新销售额增长超过80%,以旧换新成为推动家电消费增长的重要因素。 今年3月,国务院印发了《推动大规模设备更新和消费品以旧换新行动方…

天花板国际幼儿园是怎样的?一起来听听天津惠灵顿幼儿园园长分享

上周,天津惠灵顿幼儿园举行了精彩的毕业典礼。一如往常,这是一个回顾过去、展望未来的机会。这届毕业班有一些孩子是四年前园长加入惠灵顿学校的时入园的。他们从小小班启航,在这所天津国际幼儿园开始了他们的惠灵顿之旅。四年来,…

java基于ssm+jsp 班级同学录网站

1前台首页功能模块 班级同学录网站,在前台首页可以查看首页、公告信息、校友风采、论坛信息、我的、跳转到后台、客服等内容,如图1所示。 图1前台首页功能界面图 用户注册,在用户注册页面可以填写用户名、姓名、头像、性别、手机号码、邮箱等…

[leetcode]unique-paths 不同路径

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:int uniquePaths(int m, int n) {vector<vector<int>> f(m, vector<int>(n));for (int i 0; i < m; i) {f[i][0] 1;}for (int j 0; j < n; j) {f[0][j] 1;}for (int i 1; i &l…

站在巨人的肩膀上 C语言理解和简单练习(包含指针前的简单内容)

1.格式化的输入/输出 1.1printf函数 printf函数你需要了解的就是转换说明&#xff0c;转换说明的作用是将内存中的二进制转换成你所需要的格式入%d就是将内存中存储的变量的二进制转化为十进制并打印出来&#xff0c;同时我们可以在%X的转换说明对精度和最小字段宽度的指定&a…

ORA-6544[pevm_peruws_callback-1][604] is caused (Doc ID 2638095.1)

ORA-6544[pevm_peruws_callback-1][604] is caused (Doc ID 2638095.1)​编辑To Bottom In this Document Symptoms Cause Solution References Applies to: Oracle Database - Enterprise Edition - Version 12.2.0.1 and later Information in this document applies to an…

C++并发之环形队列(ring,queue)

目录 1 概述2 实现3 测试4 运行 1 概述 最近研究了C11的并发编程的线程/互斥/锁/条件变量&#xff0c;利用互斥/锁/条件变量实现一个支持多线程并发的环形队列&#xff0c;队列大小通过模板参数传递。 环形队列是一个模板类&#xff0c;有两个模块参数&#xff0c;参数1是元素…

[学习笔记] 禹神:一小时快速上手Electron笔记,附代码

课程地址 禹神&#xff1a;一小时快速上手Electron&#xff0c;前端Electron开发教程_哔哩哔哩_bilibili 笔记地址 https://github.com/sui5yue6/my-electron-app 进程通信 桌面软件 跨平台的桌面应用程序 chromium nodejs native api 流程模型 main主进程 .js文件 node…

Verilog HDL语法入门系列(二):Verilog的语言文字规则

目录 1 空白符和注释2 整数常量和实数常量3 整数常量和实数常量4 字符串&#xff08;string)5 格式符与转义符6 标识符(identifiers) 微信公众号获取更多FPGA相关源码&#xff1a; 1 空白符和注释 2 整数常量和实数常量 Verilog中&#xff0c;常量(literals)可是整数也可以是…

照片放大工具Topaz Gigapixel AI for Mac v7.1.2

Topaz Gigapixel AI软件是一款相当高效的PC端图像大小调整工具&#xff0c;更是一款能够为摄影师、设计师以及图像处理爱好者带来革命性体验的强大软件。它凭借先进的深度学习技术&#xff0c;打破了传统图像大小调整的限制&#xff0c;实现了真正意义上的无损放大和图像恢复。…

服务器硬件及RAID配置

目录 一、RAID磁盘阵列 1.概念 2.RAID 0 3.RAID 1 4.RAID 5 5.RAID 6 6.RAID 10 二、阵列卡 1.简介 2.缓存 三、创建 1.创建RAID 0 2.创建RAID 1 3.创建RAID 5 4.创建RAID 10 四、模拟故障 一、RAID磁盘阵列 1.概念 &#xff08;1&#xff09;是Redundant Array …

游戏服务器研究二:大世界的 scale 问题

这是一个非常陈旧的话题了&#xff0c;没什么新鲜的&#xff0c;但本人对 scale 比较感兴趣&#xff0c;所以研究得比较多。 本文不会探讨 MMO 类的网游提升单服承载人数有没有意义&#xff0c;只单纯讨论技术上如何实现。 像 moba、fps、棋牌、体育竞技等 “开房间类型的游戏…

调幅信号AM的原理与matlab实现

平台&#xff1a;matlab r2021b 本文知识内容摘自《软件无线电原理和应用》 调幅就是使载波的振幅随调制信号的变化规律而变化。用音频信号进行调幅时&#xff0c;其数学表达式可以写为: 式中&#xff0c;为调制音频信号&#xff0c;为调制指数&#xff0c;它的范围在(0&…