昇思25天学习打卡营第08天 | 模型训练

昇思25天学习打卡营第08天 | 模型训练

文章目录

  • 昇思25天学习打卡营第08天 | 模型训练
    • 超参数
    • 损失函数
    • 优化器
      • 优化过程
    • 训练与评估
    • 总结
    • 打卡

模型训练一般遵循四个步骤:

  1. 构建数据集
  2. 定义神经网络模型
  3. 定义超参数、损失函数和优化器
  4. 输入数据集进行训练和评估

构建数据集和网络模型在之前的内容在已经涉及,不再赘述。

超参数

超参数(Hyperparameters)是可以调整的参数,可以控制模型训练的过程。

深度学习模型多采用随机梯度下降算法SGD进行优化:
w t + 1 = w t − η 1 n ∑ x ∈ B ∇ l ( x , w t ) w_{t+1}=w_t- \eta\frac1n\sum_{x\in B}\nabla l(x,w_t) wt+1=wtηn1xBl(x,wt)
其中, η \eta η是学习率, n n n是batch大小,都是超参数,这两个参数是直接影响模型性能收敛的重要参数。
一般会定义三个超参数:

  • epoch:遍历数据集的次数
  • batch size:每个批次数据的大小。size 过小导致花费时间多,梯度震荡严重,不利于收敛;size 过大容易陷入局部极小值。
  • learning rate:学习率国小会导致收敛速度变慢;过大则可能会导致训练不收敛。

损失函数

损失函数用于评估模型预测值和目标值之间的误差。
常见的损失函数包括:

  • nn.MSELoss:均方误差,用于回归
  • nn.NLLLoss:负对数似然,用于分类
  • nn.CrossEntropyLoss:结合了nn.LogSoftmaxnn.NLLLoss,可以对logits进行归一化并计算预测误差
loss_fn = nn.CrossEntropyLoss()

优化器

优化器内部定义了模型参数的优化过程,所有的优化逻辑都封装在优化器对象中。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

优化过程

通过自动微分获得的微分函数,计算参数对应的梯度,并传入优化器中,即可实现参数优化。

grads = grad_fn(inputs)
optimizer(grads)

训练与评估

遍历一次数据集被称为一轮(epoch),每轮执行训练时包含两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,检查模型性能是否提升。
# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练过程一般为:

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)

总结

这一节的内容对深度学习模型训练的一般过程进行了详细的介绍,从数据集构建到模型定义,接着定义超参数并选择合适的值,创建损失函数和优化器对象完成训练前的准备。通过封装一个模型调用和loss计算的前向计算函数并自动微分,在每个epoch中计算loss并优化参数,从而完成模型的训练。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/42166.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp 去掉小数末尾多余的0

文章目录 在uniapp或者一般的JavaScript环境中,要去掉小数末尾的0,可以使用以下几种方法: 使用parseFloat()函数 let num 123.4500; let result parseFloat(num); console.log(result); // 输出: 123.45字符串处理 将数字转换为字符串&am…

24年沈阳教师招聘报名全流程(电脑版)

公开招聘1876名事业编制工作人员 报名时间:7月8日10:00至7月11日16:00; 报名网站::沈阳市考试院 报名流程: 1.报名 ①登录沈阳市考试院填写报名信息,上传照片 本人近期免冠2寸正面电子证jian照片&#xff0…

接口依赖-动态参数+数据依赖的代码

怎么写测试用例 接口名:被依赖的返回值的jsonpath表达式,有几个依赖往后写就可以 代码 如果caseinfo里的getisdep等于yes,并且,depkey不为空 用正则解析 在testrun里增加判断 判断的内容是正则表达式,先去匹配,再去…

2024年世界人工智能大会(WAIC)各大佬的精彩发言

2024年世界人工智能大会(WAIC)在上海举行,受到了广泛关注和参与。以下是大会首日的主要观点和议题的总结: AI 应用落地:大会讨论了AI应用如何落地,即如何在当前阶段利用大模型技术实现实际应用。 AI 安全&…

flask缓存、信号的使用

【 一 】flask-ache ​ 它为 Flask 应用程序提供了缓存支持。缓存是 Web 应用程序中非常常见的做法,用于存储频繁访问但不太可能经常更改的数据,以减少对数据库或其他慢速存储系统的访问,从而提高应用程序的性能和响应速度。 ​ Flask-Cach…

041基于SSM+Jsp的高校校园点餐系统

开发语言:Java框架:ssm技术:JSPJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包…

昇思MindSpore学习笔记4-05生成式--Pix2Pix实现图像转换

摘要: 记录昇思MindSpore AI框架使用Pix2Pix模型生成图像、判断图像真实概率的原理和实际使用方法、步骤。包括环境准备、下载数据集、数据加载和处理、创建cGAN神经网络生成器和判别器、模型训练、模型推理等。 一、概念 1.Pix2Pix模型 条件生成对抗网络&#x…

科研绘图系列:R语言两组数据散点分布图(scatter plot)

介绍 展示两组数据的散点分布图是一种图形化表示方法,用于显示两个变量之间的关系。在散点图中,每个点代表一个数据点,其x坐标对应于第一组数据的值,y坐标对应于第二组数据的值。以下是散点图可以展示的一些结果: 线性关系:如果两组数据之间存在线性关系,散点图将显示出…

HUAWEI VRRP 实验

实验要求:在汇聚交换机上SW1和SW2中实施VRRP以保证终端网关的高可靠性(当某一个网关设备失效时,其他网关设备依旧可以实现业务数据的转发。) 1.在SW1和SW2之间配置链路聚合,以提高带宽速度。 2.PC1 访问远端网络8.8.8.8 ,优先走…

视频汇聚/安防监控/GB28181国标EasyCVR视频综合管理平台出现串流的原因排查及解决

安防视频监控系统/视频汇聚EasyCVR视频综合管理平台,采用了开放式的网络结构,能在复杂的网络环境中(专网、局域网、广域网、VPN、公网等)将前端海量的设备进行统一集中接入与视频汇聚管理,视频汇聚EasyCVR平台支持设备…

Git-Unity项目版本管理

目录 准备GitHub新建项目并添加ssh密钥Unity文件夹 本文记录如何用git对unity 项目进行版本管理,并可传至GitHub远端。 准备 名称版本windows11Unity2202.3.9.f1gitN.A.githubN.A. GitHub新建项目并添加ssh密钥 GitHub新建一个repositorywindows11 生成ssh-key&…

【python】python母婴数据分析模型预测可视化(数据集+论文+PPT+源码)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉公众号👈:测试开发自动化【获取源码商业合作】 👉荣__誉👈:阿里云博客专家博主、5…

(阿里云在线播放)基于SpringBoot+Vue前后端分离的在线教育平台项目

💗博主介绍💗:✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示:文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

阿里通义音频生成大模型 FunAudioLLM 开源!

01 导读 人类对自身的研究和模仿由来已久,在我国2000多年前的《列子汤问》里就描述了有能工巧匠制作出会说话会舞动的类人机器人的故事。声音包含丰富的个体特征及情感情绪信息,对话作为人类最常使用亲切自然的交互模式,是连接人与智能世界…

uniapp报错--app.json: 在项目根目录未找到 app.json

【问题】 刚创建好的uni-app项目,运行微信小程序控制台报错如下: 【解决方案】 1. 程序根目录打开project.config.json文件 2. 配置miniprogramRoot,指定小程序代码的根目录 我的小程序代码编译后的工程文件目录为:dist/dev/mp…

Java | Leetcode Java题解之第220题存在重复元素III

题目&#xff1a; 题解&#xff1a; class Solution {public boolean containsNearbyAlmostDuplicate(int[] nums, int k, int t) {int n nums.length;Map<Long, Long> map new HashMap<Long, Long>();long w (long) t 1;for (int i 0; i < n; i) {long i…

CANoe的capl调用Qt制作的dll

闲谈 因为Qt封装了很多个人感觉很好用的库&#xff0c;所以一直想通过CAPL去调用Qt实现一些功能&#xff0c;但是一直没机会&#xff08;网络上也没看到这方面的教程&#xff09;&#xff0c;这次自己用了两天&#xff0c;踩了很多坑&#xff0c;终于是做成了一个初步的调用方…

文件系统技术架构分析

一文读懂&#xff1a;什么是文件系统 &#xff0c;有哪几类&#xff1f; ▉ 什么是文件系统&#xff1f; 技术大拿眉头皱了皱&#xff0c;忍住快要爆发的情绪。解释到&#xff1a; 数据以二进制形式存储于介质&#xff0c;但高低电平含义难解。文件系统揭秘这些二进制背后的意…

运维Tips | Ubuntu 24.04 安装配置 xrdp 远程桌面服务

[ 知识是人生的灯塔&#xff0c;只有不断学习&#xff0c;才能照亮前行的道路 ] Ubuntu 24.04 Desktop 安装配置 xrdp 远程桌面服务 描述&#xff1a;Xrdp是一个微软远程桌面协议&#xff08;RDP&#xff09;的开源实现&#xff0c;它允许我们通过图形界面控制远程系统。这里使…

前端面试题(CSS篇四)

一、CSS 优化、提高性能的方法有哪些&#xff1f; 加载性能&#xff1a; &#xff08;1&#xff09;css压缩&#xff1a;将写好的css进行打包压缩&#xff0c;可以减少很多的体积。 &#xff08;2&#xff09;css单一样式&#xff1a;当需要下边距和左边距的时候&#xff0c;很…