昇思25天学习打卡营第08天|模型训练

模型训练

模型训练一般分为四个步骤:

  1. 构建数据集。
  2. 定义神经网络模型。
  3. 定义超参、损失函数及优化器。
  4. 输入数据集进行训练与评估。

现在我们有了数据集和模型后,可以进行模型的训练与评估。

ps:这里的训练和Stable Diffusion中的训练是一样的吗?欢迎评论。

超参

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。

损失函数

损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差

优化器

模型优化(Optimization)是在每个训练步骤中调整模型参数以减少模型误差的过程。

训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,以检查模型性能是否提升。
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
import time# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)epochs = 3
batch_size = 64
learning_rate = 1e-2def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logitsloss_fn = nn.CrossEntropyLoss()optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'skywp')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/41585.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解C#中的文件系统I/O操作

文件系统I/O操作是任何编程语言中的重要组成部分,C#也不例外。无论是读写文件、操作目录,还是处理文件流,C#都提供了丰富且强大的类库来实现这些功能。本文将详细介绍C#中的文件系统I/O操作,并通过代码示例展示如何高效地处理文件…

进程的概念

一.进程和程序的理解 首先抛出结论:进程是动态的,暂时存在于内存中,进程是程序的一次执行,而进程总是对应至少一个特定的程序。 程序是静态的,永久的存在于磁盘中。 程序是什么呢?程序其实就是存放在我们…

图像分类-数据驱动方法

K近邻算法(K-Nearest Neighbors,简称KNN) KNN算法通过比较新样本与训练集中的样本的距离,然后根据最近的K个邻居的投票结果来决定新样本的分类。 如图所示,K越大的边界会更加平滑,本质上是根据某一样本最近…

红薯小眼睛接口分析与Python脚本实现

文章目录 1. 写在前面2. 接口分析3. 算法脚本实现 【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、CSDN博客专家、阿里云博客专家、华为云享专家。一路走来长期坚守并致力于Py…

04.27 - 05.18_111期_Linux_进程间通信

命名管道通信 特点,写端在没有往管道里面写内容时,读端会处于阻塞状态 共享内存 特点,读端在什么时候都可以进行读操作 ,拷贝次数少,通信次数快 makefile 中使用g进行编译 要实现将上述两个特点进行融合&#xff…

Microsoft Copilot Studio:定制 AI 解决方案的未来

微软最近为其生成式 AI 和大型语言模型工具套件添加了一项创新功能,即 Copilot Studio。这款新产品在 Microsoft Ignite 2023 大会上亮相,将彻底改变组织与 AI 助手的互动方式。 为每个用户提供定制能力 Copilot Studio 是一款出色的用户友好型平台&am…

【面试题】Reactor模型

Reactor模型 定义 Reactor模型是一种事件驱动的设计模式,用于处理服务请求。它通过将事件处理逻辑与事件分发机制解耦,实现高性能、可扩展的并发处理。Reactor模型适用于高并发、事件驱动的程序设计,如网络服务器等。 特点 事件驱动&#…

递归(三)—— 初识暴力递归之“字符串的全部子序列”

题目1 : 打印一个字符串的全部子序列 题目分析: 解法1:非递归方法 我们通过一个实例来理解题意,假设字符串str “abc”,那么它的子序列都有那些呢?" ", “a”, “b”,…

Vue的民族民俗文化分享平台-计算机毕业设计源码22552

基于Vue的民族民俗文化分享平台设计与实现 摘 要 本文介绍了一种基于Vue.js前端框架和Express后端框架的民族民俗文化分享平台的设计和实现。该平台旨在通过线上方式,促进民族民俗文化的传播与分享,增强公众对多元文化的了解和认同。 平台为普通用户提供…

图论·题解1

原题地址 P3371 【模板】单源最短路径(标准版) 注意的点: 边有重复,选择最小边!对于SPFA算法容易出现重大BUG,没有负权值的边时不要使用!!! 70分代码 朴素板dijsktra…

前后端的导入、导出、模板下载等写法

导入,导出、模板下载等的前后端写法 文章目录 导入,导出、模板下载等的前后端写法一、导入实现1.1 后端的导入1.2 前端的导入 二、基础的模板下载2.1 后端的模板下载-若依基础版本2.2 前端的模板下载2.3 后端的模板下载 - 基于资源文件读取2.4 excel制作…

JVM8为什么要增加元空间 ?

持久代 持久代的大小 为什么移除持久代 ? 元空间 元空间的特点: 持久代 持久代中包含了虚拟机中所有可通过反射获取到的数据,比如Class和Method对象。不同的Java虚拟机之间可能会进行类共享,因此持久代又分为只读区和读写区。…

24西安电子科技大学马克思主义学院—考研录取情况

01、马克思主义学院各个方向 02、24马克思主义学院近三年复试分数线对比 PS:马院24年院线相对于23年院线增加15分,反映了大家对于马克思主义理论学习与研究的热情高涨,也彰显了学院在人才培养、学科建设及学术研究等方面的不断进步与成就。 6…

直接更新flowable数据库的流程定义信息的一种方法

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 http://218.75.87.38:9666/ 更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码: h…

文章解读与仿真程序复现思路——太阳能学报EI\CSCD\北大核心《绿电交易场景下计及温控负荷的高铁站两阶段调度策略》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

【第21章】MyBatis-Plus多数据源支持

文章目录 前言一、dynamic-datasource1. 特性2. 约定3. 使用方法3.1 引入依赖3.2 配置数据源3.3 使用 DS 切换数据源 二、mybatis-mate1.特性2.使用方法2.1 配置数据源2.2 使用 Sharding 切换数据源2.3 切换指定数据库节点 三、实战1. 引入库2. 配置3. 使用 DS 切换数据源4. 测…

开关电源详解

一、开关电源的概述 1. 定义与简介 开关电源(Switched-Mode Power Supply,SMPS)是一种通过高频开关器件(如晶体管、MOSFET)进行电能转换的电源装置。与传统的线性电源相比,开关电源具有转换效率高、体积小…

vue项目打包部署后 浏览器自动清除缓存问题(解决方法)

vue打包部署后 浏览器缓存问题,导致控制台报错ChunkLoadError: Loading chunk failed的解决方案 一、报错如下: 每次build打包部署到服务器上时,偶尔会出现前端资源文件不能及时更新到最新,浏览器存在缓存问题,这时在…

Pandas数据可视化详解:大案例解析(第27天)

系列文章目录 Pandas数据可视化解决不显示中文和负号问题matplotlib数据可视化seaborn数据可视化pyecharts数据可视化优衣库数据分析案例 文章目录 系列文章目录前言1. Pandas数据可视化1.1 案例解析:代码实现 2. 解决不显示中文和负号问题3. matplotlib数据可视化…

ListBox自动滚动并限制显示条数

1、实现功能 限制ListBox显示的最大条数; ListBox自动滚动,显示最新行; 2、C#代码 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.IO; using Syst…