【昇思初学入门】第七天打卡-模型训练

训练模型

学习心得

  1. 构建数据集。这通常包括训练集、验证集(可选)和测试集。训练集用于训练模型,验证集用于调整超参数和监控过拟合,测试集用于评估模型的泛化能力。
    (mindspore提供数据集https://www.mindspore.cn/docs/zh-CN/r2.3.0rc2/api_python/mindspore.dataset.html)
  2. 定义神经网络模型。这通常涉及到选择适当的网络架构(如卷积神经网络CNN、循环神经网络RNN、全连接网络等)和激活函数。
    创建模型类:使用mindspore.nn.Cell作为基类,创建一个自定义的神经网络模型类。
    义网络层:定义所需的网络,如卷积层、全连接层、激活函数和池化层等
    实现construct方法:在construct方法中,使用定义好的网络层构建前向网络
  3. 定义超参、损失函数和优化器。
    设置超参数:设置超参数,如学习率、批次大小、训练轮数等。
    定义损失函数:选择适当的损失函数,如均方误差(MSE)用于回归问题,交叉熵损失(Cross-Entropy Loss)用于分类问题等。
    设置优化器:选择合适的优化器,如随机梯度下降(SGD)、Adam等,用于根据损失函数的梯度更新模型参数。
  4. 训练和评估。
    循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:
    训练:迭代训练数据集,并尝试收敛到最佳参数。
    验证/测试:迭代测试数据集,以检查模型性能是否提升。

笔记

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()epochs = 3
batch_size = 64
learning_rate = 1e-2loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")

结果
训练结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/860555.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Python和NLTK进行NLP分析的高级指南

在本文中,将利用数据集来比较和分析自然语言。 本文涵盖的基本构建块是: WordNet和同义词集相似度比较树和树岸命名实体识别 WordNet和同义词集 WordNet是NLTK中的大型词汇数据库语料库。WordNet维护与名词,动词,形容词&#…

Unity 弧形图片位置和背景裁剪

目录 关键说明 Unity 设置如下 代码如下 生成和部分数值生成 角度转向量 计算背景范围 关键说明 效果图如下 来自红警ol游戏内的截图 思路:确定中心点为圆的中心点 然后 计算每个的弧度和距离 Unity 设置如下 没什么可以说的主要是背景图设置 代码如下 …

攻克PS之路——Day1(A1-A8)

#暑假到了,作为可能是最后一个快乐的暑假,我打算学点技能来傍身,首先,开始PS之旅 这个帖子作为我跟着B站up主学习PS的记录吧,希望我可以坚持下去! 学习的链接在这里:A02-PS软件安装&#xff0…

基于SSM+VUE的网上订餐系统(带1w+文档)

基于SSMVUE的网上订餐系统(带1w文档) 网上订餐系统的数据库里面存储的各种动态信息,也为上层管理人员作出重大决策提供了大量的事实依据。总之,网上订餐系统是一款可以真正提升管理者的办公效率的软件系统。 项目简介 基于SSMVUE的网上订餐系统(带1w文档…

亚马逊云科技官方活动:一个月拿下助理架构师SAA+云从业者考试认证(送半价折扣券)

为了帮助大家考取AWS SAA和AWS云从业者认证,小李哥争取到了大量考试半价50%折扣券,使用折扣券考试最多可省75刀(545元人民币)。 领取折扣券需要加入云师兄必过班群,在群中免费领取。目前必过班群招募到了超过200名小伙伴,名额有限…

从0到1使用vite搭建react项目保姆级教程(持续更新中)

一、vite创建react项目 要使用Vite创建一个React项目,你需要按照以下步骤操作: 1、确保你已经安装了Node.js(建议使用最新的稳定版本)。 2、 使用npm命令安装Vite CLI工具,再来创建项目 npm create vitelatest my-vi…

解决ChatGPT遇到“抱歉,我无法完成你的请求”问题

在使用ChatGPT时,可能会遇到这样的问题:当多次重复输入相同的内容时,系统会返回 抱歉,我无法完成你的请求 。本文将解释为什么会出现这种情况,并提供一些避免这种情况的解决方法。 为什么会出现“抱歉,我…

TSLANet:时间序列模型的新构思

实时了解业内动态,论文是最好的桥梁,专栏精选论文重点解读热点论文,围绕着行业实践和工程量产。若在某个环节出现卡点,可以回到大模型必备腔调或者LLM背后的基础模型重新阅读。而最新科技(Mamba,xLSTM,KAN)…

2024-6-20 Windows AndroidStudio SDK(首次加载)基础配置,SDK选项无法勾选,以及下载失败的一些解决方法

2024-6-20 Windows AndroidStudio SDK(首次加载)基础配置,SDK选项无法勾选,以及下载失败的一些解决方法 注意:仅仅是SDK这种刚安装时的配置的下载,不要和开源库的镜像源扯到一起!!!! 最近想玩AndroidStudio的JNI开发, 想着安装后…

Java三层框架的解析

引言:欢迎各位点击收看本篇博客,在历经很多的艰辛,我也是成功由小白浅浅进入了入门行列,也是收货到很多的知识,每次看黑马的JavaWeb课程视频,才使一个小菜鸡见识到了Java前后端是如何进行交互访问的&#x…

项目实训-vue(十二)

项目实训-vue(十二) 文章目录 项目实训-vue(十二)1.概述2.处理进度可视化 1.概述 本篇博客将记录我在图片上传页面中的工作。 2.处理进度可视化 除了导航栏之外,我们还需要对上传图片以及图片处理的过程以及流程进行…

数据结构-----【链表:刷题】

-------------------------------------------基础题参照leetcode---------------------------------------------------------------------------------------------------------- 【2】两数相加 /*** Definition for singly-linked list.* struct ListNode {* int val;…

浦语·灵笔2 模型部署图片理解实战

效果图镇楼 1、使用 huggingface_hub 下载模型中的部分文件(演示练习与模型实战无关) 使用 Hugging Face 官方提供的 huggingface-cli 命令行工具。安装依赖: pip install -U huggingface_hub 然后新建 python 文件,填入以下代码&#xf…

upload-labs第14关

upload-labs第14关 第十四关一、源代码分析代码审计 二、绕过分析a. 制作图片码首先需要一个照片,然后其次需要一个eval.php。 b.上传图片码上传成功 c.结合文件包含漏洞进行访问访问:http://192.168.1.110/upload-labs-master/include.php?filehttp://…

封装了一个iOS联动滚动效果

效果图 实现逻辑和原理 就是在 didEndDisplayingCell 方法中通过indexPathsForVisibleItems 接口获取当前可见的cell对应的indexPath, 然后获取到item最小的那一个,即可,同时,还要在 willDisplayCell 方法中直接设置标题的选中属…

cropperjs 裁剪/框选图片

1.效果 2.使用组件 <!-- 父级 --><Cropper ref"cropperRef" :imgUrl"url" searchImg"searchImg"></Cropper>3.封装组件 <template><el-dialog :title"title" :visible.sync"dialogVisible" wi…

Steam怎么卸载DLC Steam怎么只卸载DLC不卸载游戏教程

我们玩家在steam中玩游戏&#xff0c;有一个功能特别重要&#xff0c;那就是DLC&#xff0c;其实也就是一款游戏的扩展&#xff0c;很多游戏都有DLC&#xff0c;让游戏玩法特别丰富&#xff0c;比如都市天际线的DLC&#xff0c;给城市中就增加了很多建筑&#xff0c;或者更便捷…

web前端——CSS

目录 一、css概述 二、基本语法 1.行内样式表 2.内嵌样式表 3.外部样式表 4.三者对比 三、选择器 1.常用的选择器 2. 选择器优先级 3.由高到低优先级排序 四、文本,背景,列表,伪类,透明 1.文本 2.背景 3.列表 4.伪类 5.透明 五、块级,行级,行级块标签, dis…