学习pytorch14 损失函数与反向传播

神经网络-损失函数与反向传播

  • 官网
  • 损失函数
    • L1Loss MAE 平均
    • MSELoss 平方差
    • CROSSENTROPYLOSS 交叉熵损失
      • 注意
      • code
  • 反向传播
    • 在debug中的显示
      • code

B站小土堆pytorch视频学习

官网

https://pytorch.org/docs/stable/nn.html#loss-functions
在这里插入图片描述

损失函数

在这里插入图片描述

L1Loss MAE 平均

在这里插入图片描述
在这里插入图片描述

import torchinput = torch.tensor([1, 2, 3], dtype=float)
# target = torch.tensor([1, 2, 5], dtype=float)
target = torch.tensor([[[[1, 2, 5]]]], dtype=float) # shape [1, 1, 1, 3]
input = torch.reshape(input, (1,1,1,3))
# target = torch.reshape(target, (1,1,1,3))
print(input.shape)
print(target.shape)loss1 = torch.nn.L1Loss()
loss2 = torch.nn.L1Loss(reduction="sum")
result1 = loss1(input, target)
print(result1) # tensor(0.6667, dtype=torch.float64)
result2 = loss2(input, target)
print(result2) # tensor(2., dtype=torch.float64)

MSELoss 平方差

在这里插入图片描述
在这里插入图片描述

import torchinput = torch.tensor([1, 2, 3], dtype=float)
# target = torch.tensor([1, 2, 5], dtype=float)
target = torch.tensor([[[[1, 2, 5]]]], dtype=float) # shape [1, 1, 1, 3]
input = torch.reshape(input, (1,1,1,3))
# target = torch.reshape(target, (1,1,1,3))
print(input.shape)
print(target.shape)loss_mse = torch.nn.MSELoss(reduction='mean')
result_mse = loss_mse(input, target)
print(result_mse) # tensor(1.3333, dtype=torch.float64)
loss_mse2 = torch.nn.MSELoss(reduction='sum')
result_mse2 = loss_mse2(input, target)
print(result_mse2)   # tensor(4., dtype=torch.float64)

CROSSENTROPYLOSS 交叉熵损失

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
在这里插入图片描述
在这里插入图片描述
在神经网络中,默认log是以e为底的,所以也可以写成ln
在这里插入图片描述
在这里插入图片描述

注意

  1. 根据需求选择对应的loss函数
  2. 注意loss函数的输入输出shape

code

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(test_set, batch_size=1)class MySeq(nn.Module):def __init__(self):super(MySeq, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
myseq = MySeq()
print(myseq)
for data in dataloader:imgs, targets = dataprint(imgs.shape)output = myseq(imgs)result = loss(output, targets)print(result)

反向传播

在debug中的显示

显示在网络结构中,每一层的保护属性中,都有weight属性,梯度属性在weitht属性里面
先找模型结构 在找每一层 在找weight权重,梯度在weight权重里面

在这里插入图片描述

code

核心代码:result_loss.backward() # 要在最后获取 backward函数要挂在通过loss函数计算后的结果上。

# 模型定义、数据加载 同上个代码
for data in dataloader:imgs, targets = dataprint(imgs.shape)output = myseq(imgs)result_loss= loss(output, targets)result_loss.backward()  # 要在最后获取print(result_loss)print(result_loss.grad)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/115836.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nginx安装详细步骤和使用说明

下载地址: https://download.csdn.net/download/jinhuding/88463932 详细说明和使用参考: 地址:http://www.gxcode.top/code 一 nginx安装步骤: 1.nginx安装与运行 官网 http://nginx.org/1.1安装gcc环境 # yum install gcc-c…

2022最新版-李宏毅机器学习深度学习课程-P26 Recurrent Neural Network

RNN 应用场景:填满信息 把每个单词表示成一个向量的方法:独热向量 还有其他方法,比如:Word hashing 单词哈希 输入:单词输出:该单词属于哪一类的概率分布 由于输入是文字序列,这就产生了一个问…

如何能够获取到本行业的能力架构图去了解自己的能力缺陷与短板,从而能清晰的去弥补差距?

如何能够获取到本行业的能力架构图去了解自己的能力缺陷与短板,从而能清晰的去弥补差距? 获取并利用能力架构图(Competency Model)来了解自己在特定行业或职位中的能力缺陷和短板,并据此弥补差距,是一个非常…

【PyTorch实战演练】自调整学习率实例应用(附代码)

目录 0. 前言 1. 自调整学习率的常用方法 1.1 ExponentialLR 指数衰减方法 1.2 CosineAnnealingLR 余弦退火方法 1.3 ChainedScheduler 链式方法 2. 实例说明 3. 结果说明 3.1 余弦退火法训练过程 3.2 指数衰减法训练过程 3.3 恒定学习率训练过程 3.4 结果解读 4. …

软件工程第七周

内聚 耦合 (Coupling): 描述的是两个模块之间的相互依赖程度。控制耦合是耦合度的一种,表示一个模块控制另一个模块的流程。高度的耦合会导致软件维护困难,因为改变一个模块可能会对其他模块产生意外的影响。 内聚 (Cohesion): 描述的是模块内部各个元素…

虚拟机weblogic服务搭建及访问(物理机 )

第一、安装环境: weblogic10.3.6.jar, jdk1.6.bin(开始安装jdk1.8后,安装域的时候报错 ,版本很重要) centos7虚拟机(VMware9) 本机系统windows7 以上安装包如果需要可以私信我,上传资源提示…

yolov8x-p2 实现 tensorrt 推理

简述 在最开始的yolov8提供的不同size的版本,包括n、s、m、l、x(模型规模依次增大,通过depth, width, max_channels控制大小),这些都是通过P3、P4和P5提取图片特征; 正常的yolov8对象检测模型输出层是P3、…

【WCA-KELM预测】基于水循环算法优化核极限学习机回归预测研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

Python实现一个简单的http服务,Url传参输出html页面

摘要 要实现一个可以接收参数的HTTP服务器,您可以使用Python标准库中的http.server模块。该模块提供了一个简单的HTTP服务器,可以用于开发和测试Web应用程序。 下面是一个示例代码,它实现了一个可以接收参数的HTTP服务器: 代码…

跨境商城源码可以定制开发吗?

跨境电商已经成为了一个全球性的趋势,而跨境商城源码定制开发是否可行,一直是广大电商从业者心中的疑问。跨境商城源码定制开发是指在已有的商城源码的基础上,进行个性化需求的修改和开发,以满足商家在跨境电商中的特定需求。下面…

一、XSS加解密编码解码工具

一、XSS加解密编码解码工具 解释:使用大佬开发的工具,地址:https://github.com/Leon406/ToolsFx/blob/dev/README-zh.md 在线下载地址: https://leon.lanzoui.com/b0d9av2kb(提取码:52pj)(建议下载jdk8-w…

javaEE -6(10000详解文件操作)

一:认识文件 我们先来认识狭义上的文件(file)。针对硬盘这种持久化存储的I/O设备,当我们想要进行数据保存时,往往不是保存成一个整体,而是独立成一个个的单位进行保存,这个独立的单位就被抽象成文件的概念&#xff0c…

Linux:firewalld防火墙-基础使用(2)

上一章 Linux:firewalld防火墙-介绍(1)-CSDN博客https://blog.csdn.net/w14768855/article/details/133960695?spm1001.2014.3001.5501 我使用的系统为centos7 firewalld启动停止等操作 systemctl start firewalld 开启防火墙 systemct…

文件的基本操作(创建文件,删除文件,读写文件,打开文件,关闭文件)

1.创建文件(create系统调用) 1.进行Create系统调用时, 需要提供的几个主要参数: 1.所需的外存空间大小(如:一个盘块,即1KB) 2.文件存放路径(“D:/Demo”) 3.文件名(这个地方默认为“新建文本文档.txt”) …

linux进程管理,一个进程的一生(喂饭级教学)

这篇文章谈谈linux中的进程管理。 一周爆肝,创作不易,望支持! 希望对大家有所帮助!记得收藏! 要理解进程管理,重要的是周边问题,一定要知其然,知其所以然。看下方目录就知道都是干货…

MD5生成和校验

MD5生成和校验 2021年8月19日席锦 任何类型的一个文件,它都只有一个MD5值,并且如果这个文件被修改过或者篡改过,它的MD5值也将改变。因此,我们会对比文件的MD5值,来校验文件是否是有被恶意篡改过。 什么是MD5&#xff…

cmd命令快速打开MATLAB

文章目录 复制快捷方式添加 -nojvm打开 复制快捷方式 添加 -nojvm 打开 唯一的缺点是无法使用plot,这一点比不上linux系统,不过打开速度还是挺快的。

Java面试(基础篇)——解构Java常见的基础面试题 结合Java源码分析

fail-safe 和fail-fast机制分别有什么作用? Fail-fast:快速失败 Fail-fast : 表示快速失败,在集合遍历过程中,一旦发现容器中的数据被修改了,会立刻抛出ConcurrentModificationException 异常&#xff0c…

如何学会从产品经理角度去思考问题?

如何学会从产品经理角度去思考问题? 从产品经理的角度思考问题意味着你需要关注产品从构思到上市全过程中的各个方面,包括用户需求、市场趋势、设计、开发、测试、上市后的用户反馈等。以下是一些策略和方法,帮助你培养从产品经理角度思考问…

Python爬虫:ad广告引擎的模拟登录

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ 🐴作者:秋无之地 🐴简介:CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作,主要擅长领域有:爬虫、后端、大数据…