昇思25天学习打卡营第1天 | 快速入门

内容介绍:通过MindSpore的API来快速实现一个简单的深度学习模型。

具体内容:

1. 导包

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

2. 处理数据

from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

3. 获取数据对象

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

4. 数据处理

def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

5. 使用 create_dict_iterator或create_dict_iterator对数据集进行迭代访问

for image, label in test_dataset.create_tuple_iterator():print(f"Shape of image [N, C, H, W]: {image.shape} {image.dtype}")print(f"Shape of label: {label.shape} {label.dtype}")breakfor data in test_dataset.create_dict_iterator():print(f"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}")print(f"Shape of label: {data['label'].shape} {data['label'].dtype}")break

6. 网络构建

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

7. 模型训练

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

8. 测试函数

def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

9. 训练过程

epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")

通过训练可以看出loss不断降低,Accuracy不断升高,可以通过调参到达更好的效果。

10. 保存模型

mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

11. 加载模型

model = Network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/29986.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何快速使用向量检索服务DashVector?

免费体验阿里云高性能向量检索服务:https://www.aliyun.com/product/ai/dashvector 本文将介绍如何快速上手使用向量检索服务DashVector。 前提条件 已创建Cluster:创建Cluster。 已获得API-KEY:API-KEY管理。 已安装最新版SDK&#xff1a…

【网络安全学习】漏洞扫描:-01- 漏洞数据库searchsploit的使用

漏洞数据库是收集和存储各种软件漏洞信息的资源库。 漏洞数据库通常包含漏洞的名称、编号、描述、影响范围、危害等级、解决方案等信息,有些还提供漏洞的分析报告、演示视频、利用代码等内容。 1.常用的在线漏洞库: 国家信息安全漏洞共享平台 https:/…

Unity 天空盒制作使用教程

文章目录 1.概念2.制作天空盒3.使用天空盒3.1 为场景添加3.2 为相机添加 1.概念 天空盒是包裹整个场景的环境效果。 2.制作天空盒 1、创建材质球。 2、设置材质球Shader为SkyBox/6 Sided,将六张贴图放到对应位置。 3.使用天空盒 3.1 为场景添加 方法一、直接…

STM32F103ZET6_移植uC/OS_HAL

1下载源码 网址 GitHub - weston-embedded/uC-OS2: C/OS-II is a preemptive, highly portable, and scalable real-time kernels. Designed for ease of use on a huge number of CPU architectures. 需要下载三个文件 1看你使用是ucos2还是3(第一个文件&#…

【Python】类和对象高级特性

目录 前言 类变量与实例变量 类方法 静态方法 私有属性和方法 多重继承 元类 描述符 总结 前言 在前一篇文章中,我们讨论了 Python 类和对象的基本概念。本文将深入探讨一些高级特性,这些特性可以帮助你更有效地使用 Python 进行面向对象编程。…

Next.js开发中使用useRouter实现点击返回到上一页

在使用Next.js框架做前端页面开发时,如果想返回到上一页,可以利用useRouter钩子提供的back()方法,可以这样做: import {useRouter} from "next/navigation"; import {Space} from "antd"; import {ArrowLeftOutlined} f…

Mendix 创客访谈录|医疗设备领域的数字化转型利器

本期创客 尚衍亮 爱德亚(北京)医疗科技有限公司 应用开发和数字化事业部开发经理 大家好,我叫尚衍亮。毕业于软件工程专业,有6年的软件开发经验。从2021年开始,我在爱德亚(北京)医疗科技有限公司…

智能合约开发的过程

智能合约是一种运行在区块链上的程序,可以自动执行预先设定的条款和条件。智能合约具有去中心化、透明、不可篡改等特点,因此被广泛应用于金融、供应链、物联网等领域。北京木奇移动技术有限公司,专业的软件外包开发公司,欢迎交流…

Spring Boot集成Minio插件快速入门

1 Minio介绍 MinIO 是一个基于 Apache License v2.0 开源协议的对象存储服务。它兼容亚马逊 S3 云存储服务接口,非常适合于存储大容量非结构化的数据,例如图片、视频、日志文件、备份数据和容器/虚拟机镜像等,而一个对象文件可以是任意大小&…

LSM-Tree数据结构原理

LSM-Tree树原理 什么是LSM-Tree LSM-Tree 即 Log Structrued Merge Tree,这是一种分层有序,硬盘友好的数据结构。核心思想是利用磁盘顺序写性能远高于随机写。 LSM-Tree 并不是一种严格的树结构,而是一种内存磁盘的多层存储结构。HBase、L…

基于Baichuan2的新冠流感中医自我诊断治疗(大模型微调+Gradio)

一、项目说明 项目使用paddleNLP提供的大模型套件对Baichuan2-7b/13b进行微调,使用《中医治疗新冠流感支原体感染等有效病历集》进行Lora训练,使大模型具备使用中医方案诊断和治疗新冠、流感等上呼吸道感染的能力。 二、PaddleNLP PaddleNLP提供的飞桨…

css 文字两端对齐

<body><div class"box"><p>姓名</p><p>性与别</p><p>家庭住址</p><p>how are you</p><p>hello</p><p>1234</p><p>1 2 3 4</p></div> </body> text-a…

Ubuntu-24.04-live-server-amd64启用ssh

系列文章目录 Ubuntu-24.04-live-server-amd64安装界面中文版 Ubuntu安装qemu-guest-agent Ubuntu乌班图安装VIM文本编辑器工具 文章目录 系列文章目录前言一、输入安装命令二、使用私钥登录&#xff08;可选&#xff09;1.创建私钥2.生成三个文件说明3.将公钥复制到服务器 三…

面向对象进阶--继承(Java继承(超详解))

目录 1. 继承 1.1 继承概述 1.2 继承特点 1.3练习 1.4继承父类的内容 构造方法是否被子类继承 成员变量是否被子类继承 成员方法是否被子类继承 1.5总结 继承中&#xff1a;成员变量的访问特点 继承中&#xff1a;成员方法的访问特点 方法重写概述 方法重写的本质 …

飞睿智能LR-WIFI无线数据采集模块,6公里视频图传,安防监控、工业传输数据更高效

在数字化浪潮席卷全球的今天&#xff0c;无线数据采集技术已经成为推动社会进步的重要力量。特别是在安防监控和工业领域&#xff0c;高效、稳定的数据传输成为了实现智能化、自动化的关键。飞睿智能LR-WiFi无线数据采集模块不仅具备可靠的传输性能&#xff0c;还能在复杂环境下…

尚硅谷爬虫学习第一天(3) 请求对象定制

#url的组成 #协议 http&#xff0c;https&#xff0c;一个安全&#xff0c;一个不安全。 #主机&#xff0c; 端口号 学过java 的肯定知道 沃日&#xff0c;以前面试运维的时候&#xff0c;问到主机地址&#xff0c;我懵逼了下&#xff0c;回了个8080 # 主机地址 80 # …

关于微信小程序(必看)

前言 为规范开发者的用户个人信息处理行为&#xff0c;保障用户的合法权益&#xff0c;自2023年9月15日起&#xff0c;对于涉及处理用户个人信息的小程序开发者&#xff0c;微信要求&#xff0c;仅当开发者主动向平台同步用户已阅读并同意了小程序的隐私保护指引等信息处理规则…

Datacom HCIE实验考试通过率90%!深圳智汇云校传来5月捷报!

坚持不懈地努力&#xff0c;才能取得成功的果实 这是不变的真理 深圳云校传来5月捷报 在Datacom HCIE实验考试中 共有10名学员应战 其中9名学员凭借出色的表现 一次性通过了考试 展现出了扎实的技术能力 通过率高达90% &#xff08;华为历年考试平均通过率约60%&#…

超级棒的时钟屏保 芝麻时钟颜值高 屏保界的天花板

太酷了&#xff01;这个时钟屏保太有个性了 屏保时钟软件推荐&#xff01;超级棒的时钟屏保 芝麻时钟颜值高 屏保界的天花板&#xff0c;今天小编给大家分享一个非常实用好看的时钟屏保&#xff08;芝麻时钟&#xff09;&#xff0c;从美观、功能、效果、操作方面去评估&#x…

【机器学习】机器学习重要方法——无监督学习:理论、算法与实践

文章目录 引言第一章 无监督学习的基本概念1.1 什么是无监督学习1.2 无监督学习的主要任务 第二章 无监督学习的核心算法2.1 聚类算法2.1.1 K均值聚类2.1.2 层次聚类2.1.3 DBSCAN聚类 2.2 降维算法2.2.1 主成分分析&#xff08;PCA&#xff09;2.2.2 t-SNE 2.3 异常检测算法2.3…