PyTorch nn.MSELoss() 均方误差损失函数详解和要点提醒

文章目录

  • nn.MSELoss() 均方误差损失函数
    • 参数
    • 数学公式
      • 元素版本
    • 要点
    • 附录
  • 参考链接

nn.MSELoss() 均方误差损失函数

torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input x x x and target y y y.

计算输入和目标之间每个元素的均方误差(平方 L2 范数)。

参数

  • size_average (bool, 可选):
    • 已弃用。请参阅 reduction 参数。
    • 默认情况下,损失在批次中的每个损失元素上取平均(True);否则(False),在每个小批次中对损失求和。
    • reduceFalse 时忽略该参数。
    • 默认值是 True
  • reduce (bool, 可选):
    • 已弃用。请参阅 reduction 参数。
    • 默认情况下,损失根据 size_average 参数进行平均或求和。
    • reduceFalse 时,返回每个批次元素的损失,并忽略 size_average 参数。
    • 默认值是 True
  • reduction (str, 可选):
    • 指定应用于输出的归约方式。
    • 可选值为 'none''mean''sum'
      • 'none':不进行归约。
      • 'mean':输出的和除以输出的元素总数。
      • 'sum':输出的元素求和。
    • 注意:size_averagereduce 参数正在被弃用,同时指定这些参数中的任何一个都会覆盖 reduction 参数。
    • 默认值是 'mean'

数学公式

附录部分会验证下述公式和代码的一致性。

假设有 N N N 个样本,每个样本的输入为 x n x_n xn,目标为 y n y_n yn。均方误差损失的计算步骤如下:

  1. 单个样本的损失
    计算每个样本的均方误差:
    l n = ( x n − y n ) 2 l_n = (x_n - y_n)^2 ln=(xnyn)2
    其中 l n l_n ln 是第 n n n 个样本的损失。
  2. 总损失
    计算所有样本的平均损失(reduction 参数默认为 'mean'):
    L = 1 N ∑ n = 1 N l n = 1 N ∑ n = 1 N ( x n − y n ) 2 \mathcal{L} = \frac{1}{N} \sum_{n=1}^{N} l_n = \frac{1}{N} \sum_{n=1}^{N} (x_n - y_n)^2 L=N1n=1Nln=N1n=1N(xnyn)2
    如果 reduction 参数为 'sum',总损失为所有样本损失的和:
    L = ∑ n = 1 N l n = ∑ n = 1 N ( x n − y n ) 2 \mathcal{L} = \sum_{n=1}^{N} l_n = \sum_{n=1}^{N} (x_n - y_n)^2 L=n=1Nln=n=1N(xnyn)2
    如果 reduction 参数为 'none',则返回每个样本的损失 l n l_n ln 组成的张量:
    L = [ l 1 , l 2 , … , l N ] = [ ( x 1 − y 1 ) 2 , ( x 2 − y 2 ) 2 , … , ( x N − y N ) 2 ] \mathcal{L} = [l_1, l_2, \ldots, l_N] = [(x_1 - y_1)^2, (x_2 - y_2)^2, \ldots, (x_N - y_N)^2] L=[l1,l2,,lN]=[(x1y1)2,(x2y2)2,,(xNyN)2]

元素版本

假设输入张量 x \mathbf{x} x 和目标张量 y \mathbf{y} y 具有相同的形状,每个张量包含 N N N 个元素。均方误差损失的计算步骤如下:

  1. 单个元素的损失
    计算每个元素的均方误差:
    l i j = ( x i j − y i j ) 2 l_{ij} = (x_{ij} - y_{ij})^2 lij=(xijyij)2
    其中 l i j l_{ij} lij 是输入张量和目标张量在位置 ( i , j ) (i, j) (i,j) 的元素损失。
  2. 总损失
    计算所有元素的平均损失(reduction 参数默认为 'mean'):
    L = 1 N ∑ i , j l i j = 1 N ∑ i , j ( x i j − y i j ) 2 \mathcal{L} = \frac{1}{N} \sum_{i,j} l_{ij} = \frac{1}{N} \sum_{i,j} (x_{ij} - y_{ij})^2 L=N1i,jlij=N1i,j(xijyij)2
    如果 reduction 参数为 'sum',总损失为所有元素损失的和:
    L = ∑ i , j l i j = ∑ i , j ( x i j − y i j ) 2 \mathcal{L} = \sum_{i,j} l_{ij} = \sum_{i,j} (x_{ij} - y_{ij})^2 L=i,jlij=i,j(xijyij)2
    如果 reduction 参数为 'none',则返回每个元素的损失 l i j l_{ij} lij 组成的张量:
    L = { l i j } = { ( x i j − y i j ) 2 } \mathcal{L} = \{l_{ij}\} = \{(x_{ij} - y_{ij})^2 \} L={lij}={(xijyij)2}

要点

  1. nn.MSELoss() 接受的输入和目标应具有相同的形状和类型。
    使用示例
    import torch
    import torch.nn as nn# 定义输入和目标张量
    input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 使用 nn.MSELoss 计算损失
    criterion = nn.MSELoss()
    loss = criterion(input, target)print(f"Loss using nn.MSELoss: {loss.item()}")
    
    >>> Loss using nn.MSELoss: 0.25
    
  2. nn.MSELoss()reduction 参数指定了如何归约输出损失。默认值是 'mean',计算的是所有样本的平均损失。
    • 如果 reduction 参数为 'mean',损失是所有样本损失的平均值。
    • 如果 reduction 参数为 'sum',损失是所有样本损失的和。
    • 如果 reduction 参数为 'none',则返回每个样本的损失组成的张量。
      代码示例
    import torch
    import torch.nn as nn# 定义输入和目标张量
    input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 使用 nn.MSELoss 计算损失(reduction='mean')
    criterion_mean = nn.MSELoss(reduction='mean')
    loss_mean = criterion_mean(input, target)
    print(f"Loss with reduction='mean': {loss_mean.item()}")# 使用 nn.MSELoss 计算损失(reduction='sum')
    criterion_sum = nn.MSELoss(reduction='sum')
    loss_sum = criterion_sum(input, target)
    print(f"Loss with reduction='sum': {loss_sum.item()}")# 使用 nn.MSELoss 计算损失(reduction='none')
    criterion_none = nn.MSELoss(reduction='none')
    loss_none = criterion_none(input, target)
    print(f"Loss with reduction='none': {loss_none}")
    
    >>> Loss with reduction='mean': 0.25
    >>> Loss with reduction='sum': 1.0
    >>> Loss with reduction='none': tensor([[0.2500, 0.2500],[0.2500, 0.2500]], grad_fn=<MseLossBackward0>)
    

附录

用于验证数学公式和函数实际运行的一致性

import torch
import torch.nn.functional as F# 假设有两个样本,每个样本有两个维度
input = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
target = torch.tensor([[1.5, 2.5], [3.5, 4.5]])# 根据公式实现均方误差损失
def mse_loss(input, target):return ((input - target) ** 2).mean()# 使用 nn.MSELoss 计算损失
criterion = torch.nn.MSELoss(reduction='mean')
loss_torch = criterion(input, target)# 使用根据公式实现的均方误差损失
loss_custom = mse_loss(input, target)# 打印结果
print("PyTorch 计算的均方误差损失:", loss_torch.item())
print("根据公式实现的均方误差损失:", loss_custom.item())# 验证结果是否相等
assert torch.isclose(loss_torch, loss_custom), "数学公式验证失败"
>>> PyTorch 计算的均方误差损失: 0.25
>>> 根据公式实现的均方误差损失: 0.25

输出没有抛出 AssertionError,验证通过。

参考链接

MSELoss - Docs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/861673.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python25 Numpy基础

1.什么是Numpy NumPy&#xff08;Numerical Python 的简称&#xff09;是 Python 语言的一个扩展程序库&#xff0c;支持大量的维度数组与矩阵运算&#xff0c;此外也针对数组运算提供大量的数学函数库。NumPy 的前身是 Numeric&#xff0c;这是一个由 Jim Hugunin 等人开发的…

SAP ALV 负号提前

FUNCTION CONVERSION_EXIT_ZSIGN_OUTPUT. *"---------------------------------------------------------------------- *"*"本地接口&#xff1a; *" IMPORTING *" REFERENCE(INPUT) *" EXPORTING *" REFERENCE(OUTPUT) *"…

PNAS|这样也可以?拿别人数据发自己Paper?速围观!

还在为数据量小&#xff0c;说服力不足发愁&#xff1f; 想研究脱颖而出、眼前一亮&#xff1f; 想从更高层次的探索微生物的奥秘&#xff0c;发出一篇好文章&#xff1f; 近期&#xff0c;有一篇发表在PNAS(IF11.1)的文章“Deforestation impacts soil biodiversity and ecos…

量子计算与AI融合:IBM引领未来计算新纪元

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

docker-本地部署-后端

前置条件 后端文件 这边是一个简单项目的后端文件目录 docker服务 镜像文件打包 #命令行 docker build -t author/chatgpt-ai-app:1.0 -f ./Dockerfile .红框是docker所在文件夹 author&#xff1a;docker用户名chatgpt-ai-app&#xff1a;打包的镜像文件名字:1.0 &#…

YOLOv10改进 | 卷积模块 | 将Conv替换为轻量化的GSConv【轻量又涨点】

秋招面试专栏推荐 &#xff1a;深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转 &#x1f4a1;&#x1f4a1;&#x1f4a1;本专栏所有程序均经过测试&#xff0c;可成功执行&#x1f4a1;&#x1f4a1;&#x1f4a1; 专栏目录&#xff1a;《YOLOv8改进有效…

Spring Boot中如何集成ElasticSearch进行全文搜索

Spring Boot中如何集成ElasticSearch进行全文搜索 大家好&#xff0c;我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天&#xff0c;我们将探讨如何在Spring Boot应用中集成Elas…

技术流 | 运维平台大型“生产事故”录播和实战重现

【本文作者&#xff1a;擎创科技 资深专家禹鼎侯】 本文写于2021年&#xff0c;最近重读觉得特别有现场感。这也是运维人面对生产环境时遇到的各种惊心动魄的事件之一。惊险&#xff0c;但又顺利解决。是最好的结果。 事情是酱紫的。 那天上午&#xff0c;轻轻松松完成了一个新…

昇思MindSpore基本介绍

昇思MindSpore是一个全场景深度学习框架&#xff0c;旨在实现易开发、高效执行、全场景统一部署三大目标。 其中&#xff0c;易开发表现为API友好、调试难度低&#xff1b;高效执行包括计算效率、数据预处理效率和分布式训练效率&#xff1b;全场景则指框架同时支持云、边缘以…

C语言之进程学习

进程打开的文件列表&#xff1a;就是0 1 2 stdin stdout stderro等 类似于任务管理器是动态分ps是静态的 Zombie状态&#xff1a; 在Linux进程的状态中&#xff0c;僵尸进程是非常特殊的一种&#xff0c;它是已经结束了的进程&#xff0c;但是没有从进程表中删除。太多了会导…

轻量级仿 SpringBoot 程序

但凡 Java 程序&#xff0c;想必就是 Spring 程序&#xff1b;但凡 Spring 程序&#xff0c;想必就是 SpringBoot 程序——且慢&#xff0c;当今尚有不是 SpringBoot 即 SpringMVC 的程序不&#xff1f;有——老旧的遗留系统不就是嘛~——不&#xff0c;其实只要稍加“调教”&a…

TikTok网页版使用指南:如何登录TikTok网页版?

海外版抖音TikTok&#xff0c;已成为连接全球观众的重要平台。据统计&#xff0c;在美国&#xff0c;TikTok的用户数量已达到近1.3亿&#xff0c;并且在国外的95后用户群体中很受欢迎。 TikTok网页版也提供了一个广阔的平台&#xff0c;让品牌和创作者在电脑端与全球观众互动&…

智能语音抽油烟机:置入WTK6900L离线语音识别芯片 掌控厨房新风尚

一、抽油烟机语音识别芯片开发背景 在繁忙的现代生活中&#xff0c;人们对于家居生活的便捷性和舒适性要求越来越高。传统的抽油烟机操作方式往往需要用户手动调节风速、开关等功能&#xff0c;不仅操作繁琐&#xff0c;而且在烹饪过程中容易分散注意力&#xff0c;增加安全隐…

单点登录方法

一、父域cookie:两个有相同父域名的二级域名之间可以跨域传递cookie //注意该接口的地址也是baidu.com下属的二级域名:a.baidu.com //全部接口地址为:a.baidu.com/dev-api/system/ecdWeb/login。如果不是a.baidu.com那么根本带不过去 //其实可以理解为通过该方法将cookie传给…

获取股票列表关键信息

获取股票列表的关键信息通常包括以下几个方面: 1. **股票代码**:股票的唯一标识符,通常由字母和数字组成,如"AAPL"代表苹果公司的股票。 2. **公司名称**:股票所代表的公司全称。 3. **行业板块**:股票所属的行业领域,如科技、金融、医疗等。 4. **市场类…

大数据处理引擎选型之 Hadoop vs Spark vs Flink

随着大数据时代的到来&#xff0c;处理海量数据成为了各个领域的关键挑战之一。为了应对这一挑战&#xff0c;多个大数据处理框架被开发出来&#xff0c;其中最知名的包括Hadoop、Spark和Flink。本文将对这三个大数据处理框架进行比较&#xff0c;以及在不同场景下的选择考虑。…

Linux内存管理(七十三):cgroup v2 简介

版本基于: Linux-6.6 约定: 芯片架构:ARM64内存架构:UMACONFIG_ARM64_VA_BITS:39CONFIG_ARM64_PAGE_SHIFT:12CONFIG_PGTABLE_LEVELS :31. cgroup 简介 术语: cgroup:control group 的缩写,永不大写(never capitalized); 单数形式的 cgroup 用于指定整个特性,也用…

ubuntu篇---添加环境变量并且在pycharm中使用

ubuntu篇—添加环境变量并且在pycharm中使用 一. 添加环境变量 vim ~/.bashrc 在文件末尾加上 保存退出 source ~/.bashrc二. 在pycharm中添加环境变量 1.打开pycharm&#xff0c;并打开你的项目 2.点击菜单栏中的“Run”&#xff0c; 选择“Edit Configurations” 3.在弹…

pytorch为自己的extension backend添加profiler功能

pytorch为自己的extension backend添加profiler功能 1.参考文档2.your-extension-for-pytorch需要增加的代码3.pytorch demo及如何调整chrome trace json文件4.[可视化](https://ui.perfetto.dev/) 本文演示了pytorch如何为自己的extension backend添加profiler功能 背景介绍 …

Taro +vue3 中的微信小程序中的分享

微信小程序 右上角分享 的触发 以及配 useShareAppMessage(() > {return {title: "电影属全国通兑券",page: /pages/home/index,imageUrl: "http:///chuanshuo.jpg",};}); 置 就是Taro框架中提供的一个分享Api 封装好的