神经网络 torch.nn---损失函数与反向传播

torch.nn - PyTorch中文文档 (pytorch-cn.readthedocs.io)

torch.nn — PyTorch 2.3 documentation

Loss Function的作用

  • 每次训练神经网络的时候都会有一个目标,也会有一个输出。目标和输出之间的误差,就是用Loss Function来衡量的。所以,误差Loss是越小越好的。

  • 此外,我们可以根据误差Loss,指导输出output接近目标target。即我们可以以target为依据,不断训练神经网络,优化神经网络中各个模块,从而优化output

Loss Function的作用

  1. 计算实际输出和目标之间的差距
  2. 为我们更新输出提供一定的依据,这个提供依据的过程也叫反向传播

nn.L1Loss

创建一个衡量输入x(模型预测输出)和目标y之间差的绝对值的平均值的标准。

class torch.nn.L1Loss(size_average=None, reduce=None, reduction='mean')

参数说明:

  • reduction:默认为 ‘mean’ ,可选meansum

  • reduction='mean'时,计算误差采用公式:

  • reduction='sum'时,计算误差采用公式:

需要注意的是,计算的数据必须为浮点数

程序代码:

import torch
from torch.nn import L1Lossinput=torch.tensor([1,2,3],dtype=torch.float32)
target=torch.tensor([1,2,5],dtype=torch.float32)input=torch.reshape(input,(1,1,1,3))
target=torch.reshape(target,(1,1,1,3))loss1=L1Loss()  #reduction='mean'
loss2=L1Loss(reduction='sum')  
result1=loss1(input,target)
result2=loss2(input,target)print(result1,result2)

输出:

nn.MSELoss

创建一个衡量输入x(模型预测输出)和目标y之间均方误差标准。

  • x 和 y 可以是任意形状,每个包含n个元素。

  • n个元素对应的差值的绝对值求和,得出来的结果除以n

  • 如果在创建MSELoss实例的时候在构造函数中传入size_average=False,那么求出来的平方和将不会除以n

class torch.nn.MSELoss(size_average=None, reduce=None, reduction='mean')

参数说明:

reduction:默认为 ‘mean’ ,可选meansum

  • reduction='mean'时,计算误差采用公式:

  • reduction='sum'时,计算误差采用公式:

程序代码:

import torch
from torch.nn import L1Loss,MSELossinput = torch.tensor([1,2,3],dtype=torch.float32)
target = torch.tensor([1,2,5],dtype=torch.float32)input = torch.reshape(input,(1,1,1,3))
target = torch.reshape(target,(1,1,1,3))loss_mse1 = MSELoss()  #reduction='mean'
loss_mse2 = MSELoss(reduction='sum')
result_mse1 = loss_mse1(input, target)
result_mse2 = loss_mse2(input, target)print(result_mse1, result_mse2)

输出:

nn.CrossEntropyLoss(交叉熵)

当训练一个分类问题的时候,假设这个分类问题有C个类别,那么有:

 当weight参数被指定的时候,loss的计算公式变为:

计算出的lossmini-batch的大小取了平均。

形状(shape):

  • Input: (N,C)    其中N代表batch_size,C 是类别的数量即数据要分成几类(或有几个标签)。

  • Target: (N)     Nmini-batch的大小,0 <= targets[i] <= C-1

举个例子:

  • 我们对包含了人、狗、猫的图片进行分类,其标签的索引分别为0、1、2。这时候将一张的图片输入神经网络,即目标(target)为1(对应标签索引)。输出结果为[0.1,0.2,0.3],该列表中的数字分别代表分类标签对应的概率。

  • 根据上述分类结果,图片为的概率更大,即0.3。对于该分类的Loss Function,我们可以通过交叉熵去计算,即:

那么如何验证这个公式的合理性呢?根据上面的例子,分类结果越准确,Loss应该越小。这条公式由两个部分组成:

  • 1、log(∑jexp(x[j])

log(∑jexp(x[j])主要作用是控制或限制预测结果的概率分布。比如说,预测出来的人、狗、猫的概率均为0.9,每个结果概率都很高,这显然是不合理的。此时 log(∑jexp(x[j]) 的值会变大,误差loss(x,class)也会随之变大。同时该指标也可以作为分类器性能评判标准。

  • 2、−x[class]:在已知图片类别的情况下,预测出来对应该类别的概率x[class]越高,其预测结果误差越小。

程序代码:

import torch
from torch import nn
from torch.nn import L1Lossinputs = torch.tensor([1, 2, 3], dtype=torch.float)
targets = torch.tensor([1, 2, 5], dtype=torch.float)inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3))x = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float)
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))loss_cross_1 = nn.CrossEntropyLoss(reduction='mean')
result_cross_1 = loss_cross_1(x, y)
loss_cross_2 = nn.CrossEntropyLoss(reduction='sum')
result_cross_2 = loss_cross_2(x, y)
print(result_cross_1, result_cross_2)

输出:

反向传播

如何根据Loss Function为更新神经网络数据提供依据?

  • 对于每个卷积核当中的参数,设置一个grad(梯度)。

  • 当我们进行反向传播的时候,对每一个节点的参数都会求出一个对应的梯度。之后我们根据梯度对每一个参数进行优化,最终达到降低Loss的一个目的。比较典型的一个方法——梯度下降法

代码举例:

 result_loss = loss(outputs, targets)result_loss.backward()
  • 上面就是反向传播的使用方法,它的主要作用是计算一个grad。使用debug功能并删掉上面这行代码,会发现单纯由result_loss=loss(output,targets)计算出来的结果,是没有grad这个参数的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/23625.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt图像处理技术十一:得到QImage图像的马赛克图像

效果图 指数5 指数15 指数40 原理 马赛克的原理很简单&#xff0c;就是取一个值&#xff0c;让这个值作为一个方格子的长宽&#xff0c;如40*40px的格子&#xff0c;取这个区域的平均R G B值&#xff0c;然后这个区域的所有像素点都是这个RGB值即可 源码 QImage applyM…

力扣2968.执行操作使频率分数最大

力扣2968.执行操作使频率分数最大 方法一&#xff1a;滑窗 前缀和 求前缀和数组s 求一个数组补齐到中位数的差值 枚举右端点 class Solution {public:int maxFrequencyScore(vector<int>& nums, long long k) {int res0,n nums.size();sort(nums.begin(),nums…

27-unittest之断言(assert)

在测试方法中需要判断结果是pass还是fail&#xff0c;自动化测试脚本里面一般把这种生成测试结果的方法称为断言&#xff08;assert&#xff09;。 使用unittest测试框架时&#xff0c;有很多的断言方法&#xff0c;下面介绍几种常用的断言方法&#xff1a;assertEqual、assert…

Vue可视化表单设计 FcDesigner v3.1.0 发布,新增 12 个组件,支持事件配置等

FcDesigner 是一款可视化表单设计器组件。可以通过拖拽的方式快速创建表单&#xff0c;提高开发者对表单的开发效率&#xff0c;节省开发者的时间。 本项目采用 Vue 和 ElementPlus 进行页面构建&#xff0c;内置多语言解决方案&#xff0c;支持二次扩展开发&#xff0c;支持自…

【cmake】cmake cache

cmake cache是什么 cmake cache是cmake在配置好后生成的一个CMakeCache.txt的文件&#xff0c;里面存储了一堆变量&#xff0c;这些变量一般都是关于项目的配置和环境的。 比如你用的什么编译器&#xff0c;编译器选项&#xff0c;还有项目目录。 例如&#xff08;在cmakelist…

前端工程化工具系列(九)—— mddir(v1.1.1):自动生成文件目录结构工具

mddir 是一个基于项目目录结构动态生成 Markdown 格式目录结构的工具&#xff0c;方便开发者在文档中展示文件和文件夹的组织结构。 1. 安装 全局安装改工具&#xff0c;方便用于各个项目。 pnpm i -g mddir2. 使用 在想要生成目录接口的项目内打开命令行工具&#xff0c;输…

太阳能航空障碍灯在航空安全发挥什么作用_鼎跃安全

随着我国经济的快速发展&#xff0c;空域已经成为经济发展的重要领域。航空运输、空中旅游、无人机物流、飞行汽车等经济活动为空域经济发展提供了巨大潜力。然而&#xff0c;空域安全作为空域经济发展的关键因素&#xff0c;受到了广泛关注。 随着空域经济活动的多样化和密集…

Waymo视角革新!MoST:编码视觉世界,刷新轨迹预测SOTA!

论文标题&#xff1a; MoST: Multi-modality Scene Tokenization for Motion Prediction 论文作者&#xff1a; Norman Mu, Jingwei Ji, Zhenpei Yang, Nate Harada, Haotian Tang, Kan Chen, Charles R. Qi, Runzhou Ge, Kratarth Goel, Zoey Yang, Scott Ettinger, Rami A…

锁存器(Latch)的产生与特点

Latch 是什么 Latch 其实就是锁存器&#xff0c;是一种在异步电路系统中&#xff0c;对输入信号电平敏感的单元&#xff0c;用来存储信息。锁存器在数据未锁存时&#xff0c;输出端的信号随输入信号变化&#xff0c;就像信号通过一个缓冲器&#xff0c;一旦锁存信号有效&#…

深入解析Java中volatile关键字

前言 我们都听说过volatile关键字&#xff0c;也许大家都知道它在Java多线程编程编程中可以减少很多的线程安全问题&#xff0c;但是会用或者用好volatile关键字的开发者可能少之又少&#xff0c;包括我自己。通常在遇到同步问题时&#xff0c;首先想到的一定是加锁&#xff0…

移动web性能测试工具有哪些呢?

摘要&#xff1a;本文将介绍一系列移动Web性能测试工具&#xff0c;以帮助开发人员评估和优化移动网站和应用程序的性能。我们将从基本概念开始&#xff0c;逐步深入&#xff0c;详细介绍每种工具的特点、用途和使用方法。 1. 概述 1.1 移动Web性能测试的重要性 1.2 测试工具…

微信小程序-wx.showToast超长文字展示不全

wx.showToast超长文字展示不全 问题解决方法1 问题 根据官方文档&#xff0c;iconnone&#xff0c;最多显示两行文字。所以如果提示信息较多&#xff0c;超过两行&#xff0c;就需要用其他方式解决。 解决方法1 使用vant组件里面的tost 根据官方例子使用&#xff1a; 1、在…

【Python报错】已解决ModuleNotFoundError: No module named ‘packaging’

成功解决“ModuleNotFoundError: No module named ‘packaging’”错误的全面指南 在Python编程中&#xff0c;遇到ModuleNotFoundError: No module named packaging这样的错误&#xff0c;通常意味着你的Python环境中缺少名为packaging的模块&#xff0c;或者该模块没有被正确…

YOLOV5 图像分割:利用yolov5进行图像分割

1、介绍 本章将介绍yolov5的分割部分,其他的yolov5分类、检测项目参考之前的博文 分类:YOLOV5 分类:利用yolov5进行图像分类_yolov5 图像分类-CSDN博客 检测:YOLOV5 初体验:简单猫和老鼠数据集模型训练-CSDN博客 yolov5的分割和常规的分割项目有所区别,这里分割的结果…

网络编程(UPD和TCP)

//发送数据 //UDP协议发送数据 package com.example.mysocketnet.a02UDPdemo;import java.io.IOException; import java.net.*;public class SendMessageDemo {public static void main(String[] args) throws IOException {//发送数据//1.创建DatagramSocket对象(快递公司)//…

【Linux】线程安全的艺术:解锁互斥量在并发编程中的应用

文章目录 前言&#xff1a;1. 进程线程间的互斥相关背景概念1.1. 操作共享变量会有问题的售票系统代码&#xff1a; 2. 互斥量的接口2.1. 解决方案2.1.1. 使用全局的锁&#xff1a;2.1.2. 使用局部的锁&#xff1a;2.1.3. 封装为RAII风格的加锁和解锁&#xff1a;2.1.4. C 11 中…

Liunx音频

一. echo -e "\a" echo 通过向控制台喇叭设备发送字符来发声&#xff1a; echo -e "\a"&#xff08;这里的 -e 选项允许解释反斜杠转义的字符&#xff0c;而 \a 是一个响铃(bell)字符&#xff09; 二. beep 下载对应的包 yum -y install beep 发声命令 be…

YashanDB携手宏杉科技助力国产软件生态发展

近日&#xff0c;深圳计算科学研究院崖山数据库系统YashanDB与宏杉科技系列存储、系列服务器与数据库一体机等多款产品顺利完成兼容性互认证。经严格测试&#xff0c;双方产品完全兼容&#xff0c;稳定运行&#xff0c;共同提供高效、稳定、安全的国产软硬件一体化解决方案&…

《精通ChatGPT:从入门到大师的Prompt指南》大纲目录

第一部分&#xff1a;入门指南 第1章&#xff1a;认识ChatGPT 1.1 ChatGPT是什么 1.2 ChatGPT的应用领域 1.3 为什么需要了解Prompt 第2章&#xff1a;Prompt的基本概念 2.1 什么是Prompt 2.2 好Prompt的特征 2.3 常见的Prompt类型 第二部分&#xff1a;Prompt设计技巧 第…

解决 iOS 端小程序「saveVideoToPhotosAlbum:fail invalid video」问题

场景复现&#xff1a; const url https://mobvoi-digitalhuman-video-public.weta365.com/1788148372310446080.mp4uni.downloadFile({url,success: (res) > {uni.saveVideoToPhotosAlbum({filePath: res.tempFilePath,success: (res) > {console.log("res > &…