昇思25天学习打卡营第2天|快速入门

快速入门

  • 操作步骤
    • 1.引入依赖包
    • 2.下载Mnist数据集
    • 3.划分训练集和测试集
    • 4.数据预处理
    • 5.网络构建
    • 6.模型训练
    • 7.保存模型
    • 8.加载模型
    • 9.模型预测

今天通过昇思大模型平台AI实验室提供的在线Jupyter工具,快速入门MindSpore。
目标:通过MindSpore的API快速实现一个简单的深度学习模型。

操作步骤

1.引入依赖包

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

2.下载Mnist数据集

from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

Mnist数据集目录结构如下:
MNIST_Data
└── train
├── train-images-idx3-ubyte (60000个训练图片)
├── train-labels-idx1-ubyte (60000个训练标签)
└── test
├── t10k-images-idx3-ubyte (10000个测试图片)
├── t10k-labels-idx1-ubyte (10000个测试标签)

3.划分训练集和测试集

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

4.数据预处理

使用dataset模块的map操作对图像数据及标签进行变换处理,然后将处理好的数据集打包为大小为64的batch。

def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

5.网络构建

继承nn.Cell类,并重写__init__方法和construct方法。__init__包含所有网络层的定义,construct中包含数据(Tensor)的变换过程。

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

6.模型训练

一个完整的训练过程(step)需要实现以下三步:

  1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
  2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
  3. 参数优化:将梯度更新到参数上。
# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 定义正向计算函数
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 使用value_and_grad通过函数变换获得梯度计算函数
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return loss# 定义测试函数,用来评估模型的性能
def train(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")# 预测,并输出每一轮的loss值和预测准确率(Accuracy)
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")

训练结果:
训练结果

7.保存模型

# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

8.加载模型

# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

9.模型预测

model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break

打印结果:Predicted: “[7 1 9 8 3 8 7 7 7 9]”, Actual: “[7 1 9 8 3 8 7 7 7 9]”

截图时间
截图时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/35130.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《昇思 25 天学习打卡营第 6 天 | 函数式自动微分 》

《昇思 25 天学习打卡营第 6 天 | 函数式自动微分 》 活动地址:https://xihe.mindspore.cn/events/mindspore-training-camp 签名:Sam9029 函数式自动微分 自动微分是深度学习中的一个核心概念,它允许我们自动计算模型参数的梯度&#xff0c…

云计算 | 期末梳理(下)

1.模运算 2. 拓展欧几里得算法 3.扩散和混淆、攻击的分类 香农的贡献:定义了理论安全性,提出扩散和混淆原则,奠定了密码学的理论基础。扩散:将每一位明文尽可能地散布到多个输出密文中去,以更隐蔽明文数字的统计特性。混淆:使密文的统计特性与明文密钥之间的关系尽量复杂…

深入解析直播带货系统源码:短视频商城APP开发全攻略

本篇文章,小编将深入解析直播带货系统的源码,并为开发短视频商城APP提供全攻略,助力开发者打造高效、稳定的带货平台。 一、直播带货系统概述 直播带货系统主要由直播模块、商品管理模块、订单处理模块、用户管理模块、以及支付模块等组成。…

Ubuntu20.04使用Samba

目录 一、Samba介绍 Samba 的主要功能 二、启动samba 三、主机操作 四、Ubuntu与windows系统中文件互联 五、修改samba路径 一、Samba介绍 Samba 是一个开源软件套件,用于在 Linux 和 Unix 系统上实现 SMB(Server Message Block)协议…

速卖通自养号测评:安全高效的推广手段

在速卖通平台上,卖家们常常寻求各种方法来提升商品的曝光、转化率和店铺权重。其中,自养号测评作为一种低成本、高回报的推广方式,备受关注。然而,若操作不当,也可能带来风险。以下是如何安全有效地进行自养号测评的指…

VS Code 使用 Makefile 运行 CPP项目

Installing the MinGW-w64 toolchainCMake Toolsmakelist.txt报错 1报错 2报错 3生成了 Makefile ,如何使用 make 命令 Installing the MinGW-w64 toolchain 参见文档 将 GCC 与 MinGW 结合使用 CMake Tools 参见文档 Linux 上的 CMake 工具入门 CMake 的使用 …

关于Pycharm右下角不显示解释器interpreter的问题解决

关于Pycharm右下角不显示解释器interpreter的问题 在安装新的Pycharm后,发现右下角的 interpreter 的选型消失了: 觉得还挺不习惯的,于是网上找解决办法,无果。 自己摸索了一番后,发现解决办法如下: 勾…

37岁,被裁员,失业三个月,被面试官嫌弃“太水”:就这也叫10年以上工作经验?

今年部门要招两个自动化测试,这几个月我面试了几十位候选人。发现一个很奇怪的现象,面试中一问到元素定位、框架api、脚本编写之类的,很多候选人都对答如流。但是一问到实际项目,比如“项目中UI自动化和接口自动化如何搭配使用&am…

电商平台家电以旧换新销售额增长超80%

记者近日从国家发展和改革委员会举办的新闻发布会上获悉,今年1—5月份,主要电商平台家电以旧换新销售额增长超过80%,以旧换新成为推动家电消费增长的重要因素。 今年3月,国务院印发了《推动大规模设备更新和消费品以旧换新行动方…

天花板国际幼儿园是怎样的?一起来听听天津惠灵顿幼儿园园长分享

上周,天津惠灵顿幼儿园举行了精彩的毕业典礼。一如往常,这是一个回顾过去、展望未来的机会。这届毕业班有一些孩子是四年前园长加入惠灵顿学校的时入园的。他们从小小班启航,在这所天津国际幼儿园开始了他们的惠灵顿之旅。四年来,…

java基于ssm+jsp 班级同学录网站

1前台首页功能模块 班级同学录网站,在前台首页可以查看首页、公告信息、校友风采、论坛信息、我的、跳转到后台、客服等内容,如图1所示。 图1前台首页功能界面图 用户注册,在用户注册页面可以填写用户名、姓名、头像、性别、手机号码、邮箱等…

[leetcode]unique-paths 不同路径

. - 力扣&#xff08;LeetCode&#xff09; class Solution { public:int uniquePaths(int m, int n) {vector<vector<int>> f(m, vector<int>(n));for (int i 0; i < m; i) {f[i][0] 1;}for (int j 0; j < n; j) {f[0][j] 1;}for (int i 1; i &l…

站在巨人的肩膀上 C语言理解和简单练习(包含指针前的简单内容)

1.格式化的输入/输出 1.1printf函数 printf函数你需要了解的就是转换说明&#xff0c;转换说明的作用是将内存中的二进制转换成你所需要的格式入%d就是将内存中存储的变量的二进制转化为十进制并打印出来&#xff0c;同时我们可以在%X的转换说明对精度和最小字段宽度的指定&a…

根据模型log文件画loss曲线

根据模型log文件画loss曲线 思想&#xff1a;使用Python的matplotlib库来绘制loss曲线。首先需要解析log文件&#xff0c;提取出每个epoch对应的loss值&#xff0c;然后再进行绘制。 import re import matplotlib.pyplot as plt# 初始化数据列表 epochs [] losses []# 读取…

1000. 合并石头的最低成本

Problem: 1000. 合并石头的最低成本 文章目录 思路解题方法复杂度Code 思路 这道题目的核心在于理解合并石头的过程和寻找最优策略。给定一个数组 stones 表示石堆&#xff0c;以及一个整数 k 表示每次可以合并的石堆数量&#xff0c;目标是找到将所有石堆合并成一个石堆的最小…

ORA-6544[pevm_peruws_callback-1][604] is caused (Doc ID 2638095.1)

ORA-6544[pevm_peruws_callback-1][604] is caused (Doc ID 2638095.1)​编辑To Bottom In this Document Symptoms Cause Solution References Applies to: Oracle Database - Enterprise Edition - Version 12.2.0.1 and later Information in this document applies to an…

嵌入式工具:VI、GCC、GDB、makefile

目录 1.Vi 2.gcc 3 gdb 4.makefile 参考 1.Vi VI 是一种编辑器,有很多版本。它有三种工作模式:编辑模式(启动时默认的模式)、插入模式(按下i或a键即可进入)、最后一行模式(又叫命令模式)。 1.1 最后一行模式常用命令 W:保存文件; q:退出; q!:强行退出不保…

C++并发之环形队列(ring,queue)

目录 1 概述2 实现3 测试4 运行 1 概述 最近研究了C11的并发编程的线程/互斥/锁/条件变量&#xff0c;利用互斥/锁/条件变量实现一个支持多线程并发的环形队列&#xff0c;队列大小通过模板参数传递。 环形队列是一个模板类&#xff0c;有两个模块参数&#xff0c;参数1是元素…

[学习笔记] 禹神:一小时快速上手Electron笔记,附代码

课程地址 禹神&#xff1a;一小时快速上手Electron&#xff0c;前端Electron开发教程_哔哩哔哩_bilibili 笔记地址 https://github.com/sui5yue6/my-electron-app 进程通信 桌面软件 跨平台的桌面应用程序 chromium nodejs native api 流程模型 main主进程 .js文件 node…

Verilog HDL语法入门系列(二):Verilog的语言文字规则

目录 1 空白符和注释2 整数常量和实数常量3 整数常量和实数常量4 字符串&#xff08;string)5 格式符与转义符6 标识符(identifiers) 微信公众号获取更多FPGA相关源码&#xff1a; 1 空白符和注释 2 整数常量和实数常量 Verilog中&#xff0c;常量(literals)可是整数也可以是…