昇思25天学习打卡营第08天|模型训练

模型训练

模型训练一般分为四个步骤:

  1. 构建数据集。
  2. 定义神经网络模型。
  3. 定义超参、损失函数及优化器。
  4. 输入数据集进行训练与评估。

现在我们有了数据集和模型后,可以进行模型的训练与评估。

ps:这里的训练和Stable Diffusion中的训练是一样的吗?欢迎评论。

超参

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。

损失函数

损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差

优化器

模型优化(Optimization)是在每个训练步骤中调整模型参数以减少模型误差的过程。

训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,以检查模型性能是否提升。
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
import time# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)epochs = 3
batch_size = 64
learning_rate = 1e-2def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logitsloss_fn = nn.CrossEntropyLoss()optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'skywp')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/41585.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

进程的概念

一.进程和程序的理解 首先抛出结论:进程是动态的,暂时存在于内存中,进程是程序的一次执行,而进程总是对应至少一个特定的程序。 程序是静态的,永久的存在于磁盘中。 程序是什么呢?程序其实就是存放在我们…

图像分类-数据驱动方法

K近邻算法(K-Nearest Neighbors,简称KNN) KNN算法通过比较新样本与训练集中的样本的距离,然后根据最近的K个邻居的投票结果来决定新样本的分类。 如图所示,K越大的边界会更加平滑,本质上是根据某一样本最近…

红薯小眼睛接口分析与Python脚本实现

文章目录 1. 写在前面2. 接口分析3. 算法脚本实现 【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、CSDN博客专家、阿里云博客专家、华为云享专家。一路走来长期坚守并致力于Py…

递归(三)—— 初识暴力递归之“字符串的全部子序列”

题目1 : 打印一个字符串的全部子序列 题目分析: 解法1:非递归方法 我们通过一个实例来理解题意,假设字符串str “abc”,那么它的子序列都有那些呢?" ", “a”, “b”,…

Vue的民族民俗文化分享平台-计算机毕业设计源码22552

基于Vue的民族民俗文化分享平台设计与实现 摘 要 本文介绍了一种基于Vue.js前端框架和Express后端框架的民族民俗文化分享平台的设计和实现。该平台旨在通过线上方式,促进民族民俗文化的传播与分享,增强公众对多元文化的了解和认同。 平台为普通用户提供…

前后端的导入、导出、模板下载等写法

导入,导出、模板下载等的前后端写法 文章目录 导入,导出、模板下载等的前后端写法一、导入实现1.1 后端的导入1.2 前端的导入 二、基础的模板下载2.1 后端的模板下载-若依基础版本2.2 前端的模板下载2.3 后端的模板下载 - 基于资源文件读取2.4 excel制作…

24西安电子科技大学马克思主义学院—考研录取情况

01、马克思主义学院各个方向 02、24马克思主义学院近三年复试分数线对比 PS:马院24年院线相对于23年院线增加15分,反映了大家对于马克思主义理论学习与研究的热情高涨,也彰显了学院在人才培养、学科建设及学术研究等方面的不断进步与成就。 6…

直接更新flowable数据库的流程定义信息的一种方法

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 http://218.75.87.38:9666/ 更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码: h…

文章解读与仿真程序复现思路——太阳能学报EI\CSCD\北大核心《绿电交易场景下计及温控负荷的高铁站两阶段调度策略》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

【第21章】MyBatis-Plus多数据源支持

文章目录 前言一、dynamic-datasource1. 特性2. 约定3. 使用方法3.1 引入依赖3.2 配置数据源3.3 使用 DS 切换数据源 二、mybatis-mate1.特性2.使用方法2.1 配置数据源2.2 使用 Sharding 切换数据源2.3 切换指定数据库节点 三、实战1. 引入库2. 配置3. 使用 DS 切换数据源4. 测…

vue项目打包部署后 浏览器自动清除缓存问题(解决方法)

vue打包部署后 浏览器缓存问题,导致控制台报错ChunkLoadError: Loading chunk failed的解决方案 一、报错如下: 每次build打包部署到服务器上时,偶尔会出现前端资源文件不能及时更新到最新,浏览器存在缓存问题,这时在…

Pandas数据可视化详解:大案例解析(第27天)

系列文章目录 Pandas数据可视化解决不显示中文和负号问题matplotlib数据可视化seaborn数据可视化pyecharts数据可视化优衣库数据分析案例 文章目录 系列文章目录前言1. Pandas数据可视化1.1 案例解析:代码实现 2. 解决不显示中文和负号问题3. matplotlib数据可视化…

ListBox自动滚动并限制显示条数

1、实现功能 限制ListBox显示的最大条数; ListBox自动滚动,显示最新行; 2、C#代码 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.IO; using Syst…

JUC并发编程基础(包含线程概念,状态等具体实现)

一.JUC并发编程基础 1. 并行与并发 1.1 并发: 是在同一实体上的多个事件是在一台处理器上"同时处理多个任务"同一时刻,其实是只有一个事件在发生. 即多个线程抢占同一个资源. 1.2 并行 是在不同实体上的多个事件是在多台处理器上同时处理多个任务同一时刻,大家…

【C++】main函数及返回值深度解析

一.main函数介绍 1.main函数怎么写 #include <iostream>int main() {// 程序的代码放在这里std::cout << "Hello, World!" << std::endl;return 0; }在这个例子中&#xff1a; #include <iostream> 是预处理指令&#xff0c;它告诉编译器…

在昇腾服务器上使用llama-factory对baichuan2-13b模型进行lora微调

什么是lora微调 LoRA 提出在预训练模型的参数矩阵上添加低秩分解矩阵来近似每层的参数更新&#xff0c;从而减少适配下游任务所需要训练的参数。 环境准备 这次使用到的微调框架是llama-factory。这个框架集成了对多种模型进行各种训练的代码&#xff0c;少量修改就可使用。 …

小红书矩阵系统源码:赋能内容创作与电商营销的创新工具

在内容驱动的电商时代&#xff0c;小红书凭借其独特的社区氛围和用户基础&#xff0c;成为品牌营销和个人创作者不可忽视的平台。小红书矩阵系统源码&#xff0c;作为支撑这一平台的核心技术&#xff0c;提供了一系列的功能和优势&#xff0c;助力用户在小红书生态中实现更高效…

Windows 安装hadoop 3.4

目录 安装 下载 设置环境变量 配置 修改&#xff1a;hadoop-env.cmd 修改&#xff1a;core-sit.xml 修改&#xff1a;hdfs-site.xml 修改&#xff1a;mapred-site.xml 修改&#xff1a;yarn-site.xml 运行 格式化HDFS文件系统 启动&#xff1a;hadoop 启动&#xf…

python-21-零基础自学python 写了一个彩票 发现买彩票中了真的是天选

学习内容&#xff1a;《python编程&#xff1a;从入门到实践》第二版 知识点&#xff1a; from random import choice、choice&#xff08;&#xff09;函数用法、while循环 练习内容&#xff1a; 练习9-14&#xff1a;彩票 创建一个列表或元组&#xff0c;其中包含10个数…

JAVA基础知识(上)

# 一、说说&和&&的区别? 作为运算符&#xff1a;& 将二进制的每一位进行与运算 作为逻辑运算符&#xff1a;两者都是与&#xff0c;&& 如果左边为假则终止右边运算&#xff0c;即短路运算。& 则需要把两边的比较执行完 # 二、int和Integer的区…