深度学习损失计算

文章目录

  • 深度学习损失计算
    • 1.如何计算当前epoch的损失?
    • 2.为什么要计算样本平均损失,而不是计算批次平均损失?

深度学习损失计算

1.如何计算当前epoch的损失?

深度学习中的损失计算,通常为数据集的平均损失,即每个样本的平均损失值。计算步骤如下:

  • 计算单个批次的损失。每次迭代中,用当前模型预测值和真实值计算损失。假设 _loss 是这次迭代中计算得到的损失。
  • 转换为标量。利用item()方法将其转换为标量值。_loss.item()
  • 乘以批次大小。乘以批次大小的原因是,希望总损失是所有数据点的损失总和,而不是批次平均损失。
  • 累加损失loss += _loss.item() * batch_size 将当前批次的总损失累加到变量 loss 中。这样所有批次遍历结束后,就得到一个epoch的总损失。
  • 计算当前epoch的样本平均损失。通过总损失除以总的数据样本数,来得到平均损失。average_loss = loss/len(dataloader.dataset)【注意:除的是总的数据样本数(len(dataloader.dataset))!不是总的批次数(len(dataloader))!】

示例代码如下:

for epoch in total_epoch:  # epoch迭代total_loss = 0.0  # 初始化总损失for inputs, targets in dataloader:  # batch迭代outputs = model(inputs)  # 获取预测值_loss = criterion(outputs, targets)  # 计算当前批次损失,为批次平均损失batch_size = inputs.size(0)  # 获取批次大小total_loss += _loss.item() * batch_size  # 计算当前批次的总损失# 计算当前epoch的平均损失average_loss = total_loss / len(dataloader.dataset)  

2.为什么要计算样本平均损失,而不是计算批次平均损失?

由于每个批次的大小可能不一样,特别是在数据集的大小不是批次大小的整数倍时,所以使用 len(dataloader) 会导致错误的平均损失计算。

下面用一个简单的例子,解释这两种计算方式的不同:

假设数据集有 105 个样本,每个批次大小为 10,这样会有 11 个批次,其中最后一个批次只有 5 个样本。结合上面的伪代码,假设损失值 _loss.item() 是 1,对于 10 个批次的损失是 10,最后一个批次的损失是 5。那么:

  • t o t a l _ l o s s = ( 1 ∗ 10 ) ∗ 10 + ( 1 ∗ 5 ) ∗ 1 = 105 total\_loss = (1 * 10) * 10 + (1 * 5) * 1 = 105 total_loss=(110)10+(15)1=105
  • l e n ( d a t a l o a d e r . d a t a s e t ) = 105 len(dataloader.dataset) = 105 len(dataloader.dataset)=105
  • l e n ( d a t a l o a d e r ) = 11 len(dataloader) = 11 len(dataloader)=11

计算结果:

  • 样本平均损失计算:average_loss = total_loss / len(dataloader.dataset) 105 / 105 = 1 105/105 = 1 105/105=1
  • 批次平均损失计算:average_loss = total_loss / len(dataloader) 105 / 11 ≈ 9.545 105/11 \approx 9.545 105/119.545

显然,第一种方式是正确的,反映了每个样本的真实平均损失。

😃😃😃

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/872540.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CREC晶振产品分类

CREC晶振大类有石英晶体谐振器、石英晶体振荡器、石英晶体滤波器 其中石英晶体谐振器: KHZ石英谐振器 车规级32.768KHz石英谐振器 专为汽车RTC应用而设计,通过AECQ-200可靠性测试,满足汽车电子的高标准时频需求,为客户提供可靠…

完整的优化流程需要做什么工作

👽System.out.println(“👋🏼嗨,大家好,我是代码不会敲的小符,目前工作于上海某电商服务公司…”); 📚System.out.println(“🎈如果文章中有错误的地方,恳请大家指正&…

三生随记——空调的诅咒

在一个炎热的夏日,小镇上的居民们都在忍受着高温的煎熬。阳光无情地炙烤着大地,空气仿佛凝固了一般,让人喘不过气来。 杰克和艾米是一对年轻的夫妻,他们刚刚搬进了这座小镇边缘的一座古老房子。这座房子虽然宽敞,但却透…

前后端,数据库以及分布式系统

1. 前端(Frontend) 定义: 前端是用户直接与之交互的部分,通常在浏览器中运行。它负责呈现和展示数据,与用户进行交互。 关键点: HTML/CSS/JavaScript: HTML定义了页面结构,CSS负责…

Interface中的方法被default修饰

Interface中的方法被default修饰 在Java 8中,引入了default方法的概念,使得接口可以包含具体的方法实现。被default修饰的方法称为默认方法。这意味着接口不仅可以声明方法,还可以提供方法的默认实现。 默认方法的主要作用包括:…

工业网络通信教学平台-工业互联网综合教学的实验平台-工业互联网应用实训

工业互联网(Industrial Internet),也称为工业物联网或IIoT,是一个开放的、全球化的工业网络,将人、数据和机器进行连接,将工业、技术和互联网深度融合。 工业互联网产业发展离不开信息技术产业人才&#xf…

Elasticsearch 聚合查询简介

Elasticsearch 聚合查询简介 在 Elasticsearch 中,聚合(Aggregations)是一种强大的数据分析工具,可以让你从数据中提取有意义的信息。通过聚合查询,可以对数据进行分类、统计、过滤和分组等操作,从而帮助用…

Android TEE SE

在Android平台上,Trusted Execution Environment (TEE) 和 Secure Element (SE) 是用来增强设备安全性的关键技术。TEE提供了隔离的执行环境,可以执行敏感的安全操作,而SE则是一个独立的、高度安全的微控制器,用于存储和处理非常敏…

Log4j的原理及应用详解(五)

本系列文章简介: 在软件开发的广阔领域中,日志记录是一项至关重要的活动。它不仅帮助开发者追踪程序的执行流程,还在问题排查、性能监控以及用户行为分析等方面发挥着不可替代的作用。随着软件系统的日益复杂,对日志管理的需求也日…

Qt Creator的好用的功能

(1)ctrlf: 在当前文档进行查询操作 (2)f3: 找到后,按f3,查找下一个 (3)shiftf3: 查找上一个 右键菜单: (4)f4:在…

LabVIEW异步和同步通信详细分析及比较

1. 基本原理 异步通信: 原理:异步通信(Asynchronous Communication)是一种数据传输方式,其中数据发送和接收操作在独立的时间进行,不需要在特定时刻对齐。发送方在任何时刻可以发送数据,而接收…

GitHub+Picgo图片上传

Picgo下载,修改安装路径,其他一路下一步! 地址 注册GitHub,注册过程不详细展开,不会的百度一下 地址 新建GitHub仓库存放图片 ——————————————————————————————————————————…

第二十章 Nest 大文件分片上传

在前端的文件上传功能中,只要请求头里定义 content-type 为 multipart/form-data,内容就会以下面形式传递到服务端,接着服务器再按照multipart/form-data的格式去提取数据 获取文件数据但是当文件体积很大时 就会出现一个问题 文件越大 请求的…

08-8.6.2 败者树

👋 Hi, I’m Beast Cheng 👀 I’m interested in photography, hiking, landscape… 🌱 I’m currently learning python, javascript, kotlin… 📫 How to reach me --> 458290771qq.com 喜欢《数据结构》部分笔记的小伙伴可以…

QT使用QPainter绘制多边形维度图

多边形统计维度图是一种用于展示多个维度的数据的图表。它通过将各个维度表示为图表中的多边形的边,根据数据的大小和比例来确定各个维度的长度。 一、简述 本示例实现六边形战力统计维度图,一种将六个维度的战力统计以六边形图形展示的方法。六个维度是…

六、元组、字典、集合

文章目录 学习目标一、元组的使用二、字典的基本使用2.1 字典使用注意事项2.2 字典的增删改查2.3 update方法的使用2.4 合并多个dict2.5 字典的遍历2.6 字典的推导式三、集合的使用3.1 set的使用3.2 set的操作3.3 集合的高级用法四、eval与json五、可迭代对象通用方法5.1 运算符…

五、python列表

文章目录 学习目标一、列表的基本使用二、列表的遍历2.1 while循环遍历2.2 for...in 循环遍历三、列表的排序3.1 交换两个变量的值3.2 冒泡排序3.3 列表的排序与反转方法四、列表的复制4.1 可变数据类型与不可变数据类型4.2 列表的复制五、列表的嵌套(略)六、列表推导式学习目…

基于形状匹配原始版放出来(给有用的人参考2)

我们仍然讲学习。 昨天已经把80万像素1024*768的图像变成256*192图像了,并且使用iir高斯平滑保留了特征。 下面做的就是用roi把特征图扣出来,也就是所谓的模板,你在原图中的roi假定是200*200,那么在256*192中,就变成…

怎样在 PostgreSQL 中优化对复合索引的选择性?

🍅关注博主🎗️ 带你畅游技术世界,不错过每一次成长机会!📚领书:PostgreSQL 入门到精通.pdf 文章目录 怎样在 PostgreSQL 中优化对复合索引的选择性一、理解复合索引的概念二、选择性的重要性三、优化复合索…

ollama编译安装@focal jammy Ubuntu @FreeBSD jail

Ollama是一个用于在本地运行大型语言模型(LLM)的开源框架。它支持多种操作系统,但是唯独不支持FreeBSD,于是尝试在FreeBSD的jail里安装Ubuntu,Ubuntu里再安装ollama。 先上结论,好像focal jail无法编译成功…