【昇思大模型平台体验篇】day1快速入门

早闻毕晟、昇思等平台,今日有机会能参加入门课程,非视频课程算是我第一次看,也算是对我自己的一个锻炼,之前也没有系统学习模型之类,每天抽出一点点时间来学习一下也是不错的

MindSpore

看来是和torch类似的结构

处理数据集

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

很方便的就下载了MNIST数据集

train_dataset = MnistDataset('MNIST_Data/train')
test_dataset = MnistDataset('MNIST_Data/test')

读取训练集和测试集,查看数据names print(train_dataset.get_col_names())

MindSpore的dataset使用数据处理流水线(Data Processing Pipeline),需指定map、batch、shuffle等操作。这里我们使用map对图像数据及标签进行变换处理,然后将处理好的数据集打包为大小为64的batch。

def datapipe(dataset, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return dataset
# Map vision transforms and batch dataset
train_dataset = datapipe(train_dataset, 64)
test_dataset = datapipe(test_dataset, 64)

网络构建

mindspore.nn类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承nn.Cell类,并重写__init__方法和construct方法。__init__包含所有网络层的定义,construct中包含数据(Tensor)的变换过程。

# Define model
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
print(model)

模型训练

  1. 正向计算:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。
  2. 反向传播:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。
  3. 参数优化:将梯度更新到参数上。

MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:
4. 定义正向计算函数。
5. 使用value_and_grad通过函数变换获得梯度计算函数。
6. 定义训练函数,使用set_train设置为训练模式,执行正向计算、反向传播和参数优化。

# Instantiate loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), 1e-2)# 1. Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# 2. Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# 3. Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
def test(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epochs = 3
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(model, train_dataset)test(model, test_dataset, loss_fn)
print("Done!")

保存模型

# Save checkpoint
mindspore.save_checkpoint(model, "model.ckpt")
print("Saved Model to model.ckpt")

加载模型

加载保存的权重分为两步:

  1. 重新实例化模型对象,构造模型。
  2. 加载模型参数,并将其加载至模型上。
# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

直接推理

model.set_train(False)
for data, label in test_dataset:pred = model(data)predicted = pred.argmax(1)print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/861626.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GPT5将引领第四次工业革命:人工智能、物联网、大数据、生物技术、量子计算等的综合体GPT大模型将改变很多现在的工作方式和生活方式,人人必读,人人必用

2024年6月22日,美国达特茅斯工程学院的一场采访引起了全球科技界的广泛关注。OpenAI首席技术官米拉穆拉蒂在采访中确认,备受期待的GPT-5将在一年半后发布。 这一消息不仅激起了科技界的热烈讨论,也让人们对人工智能(AI&#xff09…

LongRAG:利用长上下文大语言模型提升检索生成效果

一、前言 前面我们已经介绍了多种检索增强生成 (RAG) 技术,基本上在保证数据质量的前提下,检索增强生成(RAG)技术能够有效提高检索效率和质量,相对于大模型微调技术,其最大的短板还是在于有限的上下文窗口…

幽默证明题!高考成绩公布后,妈妈连夜写了一封信:孩子,这就是我不让你玩手机的原因——早读(逆天打工人爬取热门微信文章解读)

毛毛雨,五分钟结束,怎么证明今天早上有下雨呢? 引言Python 代码第一篇 洞见 高考成绩公布后,妈妈连夜写了一封信:孩子,这就是我不让你玩手机的原因第二篇 视频新闻结尾 引言 今天睡眠质量不错 发现一个问题…

10分钟微调专属于自己的大模型_10分钟微调大模型

1.环境安装 # 设置pip全局镜像 (加速下载) pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/ # 安装ms-swift pip install ms-swift[llm] -U# 环境对齐 (通常不需要运行. 如果你运行错误, 可以跑下面的代码, 仓库使用最新环境测试) pip install -r r…

面试专区|【88道Vue高频题整理(附答案背诵版)】

1、请简述Vue插件和组件的区别 ? Vue的插件(Plugin)和组件(Component)是Vue.js中非常重要的两个概念,它们在功能上有着明显的差异。 Vue组件(Component): Vue组件是Vue…

vb6多线程异步,VB.NET 全用API实现:CreateThread创建多线程,等待线程完成任务

在VB.NET中,你可以使用API函数来创建多线程并等待线程完成任务。以下是一个示例代码,展示如何使用API函数来实现这个功能: Imports System.Runtime.InteropServices Imports System.ThreadingPublic Class Form1Private Delegate Sub ThreadC…

大模型+多模态合规分析平台,筑牢金融服务安全屏障

随着金融市场的快速发展,金融产品和服务日趋多样化,消费者面临的风险也逐渐增加。 为保护消费者权益,促进金融市场长期健康稳定发展,国家监管机构不断加强金融监管,出台了一系列法律法规和政策文件。对于金融从业机构…

【DC-DC升压电推剪方案】FP6277,FP6296电源升压芯片在电推剪中扮演着一个怎样的角色?带你深入了解电推剪的功能和应用及工作原理

随着人们对个人形象要求的不断提高,理发器作为一个必备的家居用品,也在不断进行技术升级。而其中的核心装备之一,电推剪理发器升压芯片FP6277、FP6296,正在引领着现代理发技术的突破。本文将给大家带来的是电推剪在传统意义上运用…

keil仿真,查看函数执行时间和执行次数

Execution Profiler执行档案器 The Execution Profiler records timing and execution statistics about instructions for the complete program code. To view the values in the Editor or Disassembly Window, use Show Time or Show Calls from the menu Debug — Executi…

[vue]打包后发布相关问题总结

问题1: 1. history模式下,发布到nginx后,刷新页面或者地址栏回车404问题。 打开nginx的nginx.conf文件,添加注释【2】对应行的代码,可以解决这个问题。 events {worker_connections 1024; }http {include mime.typ…

6.18-6.26 旧c语言

第一章 概述 32关键字 9种控制语句 优点:能直接访问物理地址,位操作,代码质量高,执行效率高 可移植性好 面向过程:以事件为中心 面向对象:以实物为中心 printf:系统定义的标准函数 #include&l…

Java并发编程中的线程局部变量ThreadLocal variables的关键点是什么?

ThreadLocal变量的关键点 ThreadLocal变量是Java中用于实现线程局部存储的一种机制,它允许每个线程拥有自己的变量副本,从而避免了多线程环境下变量共享导致的并发问题。以下是ThreadLocal变量的几个关键点: 线程隔离:ThreadLoc…

探索 JQuery EasyUI:构建简单易用的前端页面

介绍 当我们站在网页开发的浩瀚世界中,眼花缭乱的选择让我们难以抉择。而就在这纷繁复杂的技术海洋中,JQuery EasyUI 如一位指路明灯,为我们提供了一条清晰的航线。 1.1 什么是 JQuery EasyUI? JQuery EasyUI,简单来…

DM达梦数据库转换、条件函数整理

💝💝💝首先,欢迎各位来到我的博客,很高兴能够在这里和您见面!希望您在这里不仅可以有所收获,同时也能感受到一份轻松欢乐的氛围,祝你生活愉快! 💝&#x1f49…

分享AI学习笔记之Python

当你说"抓取网站数据"时,通常指的是网络爬虫(web scraping)或网络抓取(web crawling)。Python提供了很多库可以帮助你实现这个功能,其中最常见的有requests(用于发送HTTP请求&#xf…

【乐吾乐2D可视化组态编辑器】画布

5.1 设置画布属性 默认颜色:预先设置默认颜色,拖拽到画布的节点(基础图形、文字、icon)自动统一默认颜色。 画笔填充颜色:预先设置画笔填充颜色,拖拽到画布的节点(基础图形)自动统…

QT自定义信号和槽函数

在QT中最重要也是必须要掌握的机制,就是信号与槽机制,在MFC上也就是类型的机制就是消息与响应函数机制 在QT中我们不仅要学会如何使用信号与槽机制,还要会自定义信号与槽函数,要自定义的原因是系统提供的信号,在一些情…

免费录制视频软件推荐,这3款软件超实用!

随着网络技术的发展,录制视频已经成为人们日常生活中的一个重要需求。无论是教学、会议、游戏还是娱乐,视频录制都为我们提供了极大的便利。然而,市场上的视频录制软件琳琅满目,如何选择一款适合自己的免费录制视频软件成为了一个…

Java基础知识-Map、HashMap、HashTable和TreeMap

1、HashMap 和 Hashtable 的区别? HashMap 和 Hashtable是Map接口的实现类,它们大体有一下几个区别: 1. 继承的父类不同。HashMap是继承自AbstractMap类,而HashTable是继承自Dictionary类。 2. 线程安全性不同。Hashtable 中的方…

Bootstrap 5 Flex

Bootstrap 5 Flex 简介 Bootstrap 5 是一个流行的前端框架,用于快速开发响应式和移动设备优先的网页。Flexbox 是 Bootstrap 5 中用于布局的强大工具,它提供了一种更加灵活和高效的方式来对齐和分布容器内的元素。在本篇文章中,我们将深入探讨 Bootstrap 5 中的 Flexbox 功…