学习pytorch14 损失函数与反向传播

神经网络-损失函数与反向传播

  • 官网
  • 损失函数
    • L1Loss MAE 平均
    • MSELoss 平方差
    • CROSSENTROPYLOSS 交叉熵损失
      • 注意
      • code
  • 反向传播
    • 在debug中的显示
      • code

B站小土堆pytorch视频学习

官网

https://pytorch.org/docs/stable/nn.html#loss-functions
在这里插入图片描述

损失函数

在这里插入图片描述

L1Loss MAE 平均

在这里插入图片描述
在这里插入图片描述

import torchinput = torch.tensor([1, 2, 3], dtype=float)
# target = torch.tensor([1, 2, 5], dtype=float)
target = torch.tensor([[[[1, 2, 5]]]], dtype=float) # shape [1, 1, 1, 3]
input = torch.reshape(input, (1,1,1,3))
# target = torch.reshape(target, (1,1,1,3))
print(input.shape)
print(target.shape)loss1 = torch.nn.L1Loss()
loss2 = torch.nn.L1Loss(reduction="sum")
result1 = loss1(input, target)
print(result1) # tensor(0.6667, dtype=torch.float64)
result2 = loss2(input, target)
print(result2) # tensor(2., dtype=torch.float64)

MSELoss 平方差

在这里插入图片描述
在这里插入图片描述

import torchinput = torch.tensor([1, 2, 3], dtype=float)
# target = torch.tensor([1, 2, 5], dtype=float)
target = torch.tensor([[[[1, 2, 5]]]], dtype=float) # shape [1, 1, 1, 3]
input = torch.reshape(input, (1,1,1,3))
# target = torch.reshape(target, (1,1,1,3))
print(input.shape)
print(target.shape)loss_mse = torch.nn.MSELoss(reduction='mean')
result_mse = loss_mse(input, target)
print(result_mse) # tensor(1.3333, dtype=torch.float64)
loss_mse2 = torch.nn.MSELoss(reduction='sum')
result_mse2 = loss_mse2(input, target)
print(result_mse2)   # tensor(4., dtype=torch.float64)

CROSSENTROPYLOSS 交叉熵损失

https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss
在这里插入图片描述
在这里插入图片描述
在神经网络中,默认log是以e为底的,所以也可以写成ln
在这里插入图片描述
在这里插入图片描述

注意

  1. 根据需求选择对应的loss函数
  2. 注意loss函数的输入输出shape

code

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWritertest_set = torchvision.datasets.CIFAR10("./dataset", train=False, transform=torchvision.transforms.ToTensor(),download=True)dataloader = DataLoader(test_set, batch_size=1)class MySeq(nn.Module):def __init__(self):super(MySeq, self).__init__()self.model1 = Sequential(Conv2d(3, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 32, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Conv2d(32, 64, kernel_size=5, stride=1, padding=2),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = nn.CrossEntropyLoss()
myseq = MySeq()
print(myseq)
for data in dataloader:imgs, targets = dataprint(imgs.shape)output = myseq(imgs)result = loss(output, targets)print(result)

反向传播

在debug中的显示

显示在网络结构中,每一层的保护属性中,都有weight属性,梯度属性在weitht属性里面
先找模型结构 在找每一层 在找weight权重,梯度在weight权重里面

在这里插入图片描述

code

核心代码:result_loss.backward() # 要在最后获取 backward函数要挂在通过loss函数计算后的结果上。

# 模型定义、数据加载 同上个代码
for data in dataloader:imgs, targets = dataprint(imgs.shape)output = myseq(imgs)result_loss= loss(output, targets)result_loss.backward()  # 要在最后获取print(result_loss)print(result_loss.grad)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/115836.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux使用find命令查找文件

find命令 简介语法格式基本参数 参考实例根目录下文件名称的例子指定路径下特定类型的例子指定路径、文件类型特定文件名称的例子指定路径、文件类型特定文件大小的例子指定路径、文件类型 查找近期修改时间的例子指定路径、文件类型 查找空文件或目录的例子指定路径、文件类型…

nginx安装详细步骤和使用说明

下载地址: https://download.csdn.net/download/jinhuding/88463932 详细说明和使用参考: 地址:http://www.gxcode.top/code 一 nginx安装步骤: 1.nginx安装与运行 官网 http://nginx.org/1.1安装gcc环境 # yum install gcc-c…

RustDay06------Exercise[71-80]

71.box的使用 说实话这题没太看懂.敲了个模板跟着提示就过了 // box1.rs // // At compile time, Rust needs to know how much space a type takes up. This // becomes problematic for recursive types, where a value can have as part of // itself another value of th…

2022最新版-李宏毅机器学习深度学习课程-P26 Recurrent Neural Network

RNN 应用场景:填满信息 把每个单词表示成一个向量的方法:独热向量 还有其他方法,比如:Word hashing 单词哈希 输入:单词输出:该单词属于哪一类的概率分布 由于输入是文字序列,这就产生了一个问…

如何能够获取到本行业的能力架构图去了解自己的能力缺陷与短板,从而能清晰的去弥补差距?

如何能够获取到本行业的能力架构图去了解自己的能力缺陷与短板,从而能清晰的去弥补差距? 获取并利用能力架构图(Competency Model)来了解自己在特定行业或职位中的能力缺陷和短板,并据此弥补差距,是一个非常…

【PyTorch实战演练】自调整学习率实例应用(附代码)

目录 0. 前言 1. 自调整学习率的常用方法 1.1 ExponentialLR 指数衰减方法 1.2 CosineAnnealingLR 余弦退火方法 1.3 ChainedScheduler 链式方法 2. 实例说明 3. 结果说明 3.1 余弦退火法训练过程 3.2 指数衰减法训练过程 3.3 恒定学习率训练过程 3.4 结果解读 4. …

软件工程第七周

内聚 耦合 (Coupling): 描述的是两个模块之间的相互依赖程度。控制耦合是耦合度的一种,表示一个模块控制另一个模块的流程。高度的耦合会导致软件维护困难,因为改变一个模块可能会对其他模块产生意外的影响。 内聚 (Cohesion): 描述的是模块内部各个元素…

虚拟机weblogic服务搭建及访问(物理机 )

第一、安装环境: weblogic10.3.6.jar, jdk1.6.bin(开始安装jdk1.8后,安装域的时候报错 ,版本很重要) centos7虚拟机(VMware9) 本机系统windows7 以上安装包如果需要可以私信我,上传资源提示…

yolov8x-p2 实现 tensorrt 推理

简述 在最开始的yolov8提供的不同size的版本,包括n、s、m、l、x(模型规模依次增大,通过depth, width, max_channels控制大小),这些都是通过P3、P4和P5提取图片特征; 正常的yolov8对象检测模型输出层是P3、…

【WCA-KELM预测】基于水循环算法优化核极限学习机回归预测研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

springboot实现消息通知需求

springboot实现消息通知需求 参考: Springboot整合Websocket(推送消息通知) SpringBoot使用SSE进行实时通知前端 vuespringbootwebsocket实现消息通知,含应用场景

CNN记录】pytorch中flatten函数

pytorch原型 torch.flatten(input, start_dim0, end_dim- 1) 作用:将连续的维度范围展平维张量,一般写再某个nn后用于对输出处理, 参数: start_dim:开始的维度 end_dim:终止的维度,-1为最后…

Python实现一个简单的http服务,Url传参输出html页面

摘要 要实现一个可以接收参数的HTTP服务器,您可以使用Python标准库中的http.server模块。该模块提供了一个简单的HTTP服务器,可以用于开发和测试Web应用程序。 下面是一个示例代码,它实现了一个可以接收参数的HTTP服务器: 代码…

关于单机流程编排技术——docker compose安装使用的问题

最近在学习docker相关的东西,当我在docker上部署了一个nest应用,其中该应用中依赖了一个基于mysql镜像的容器,一个基于redis镜像的容器。那我,当我进行部署上线时,在启动nest容器时,必须保证redis容器和mys…

华为OD 完全二叉树非叶子部分后序遍历(200分)【java】A卷+B卷

华为OD统一考试A卷+B卷 新题库说明 你收到的链接上面会标注A卷还是B卷。目前大部分收到的都是B卷。 B卷对应往年部分考题以及新出的题目,A卷对应的是新出的题目。 我将持续更新最新题目 获取更多免费题目可前往夸克网盘下载,请点击以下链接进入: 我用夸克网盘分享了「华为OD…

跨境商城源码可以定制开发吗?

跨境电商已经成为了一个全球性的趋势,而跨境商城源码定制开发是否可行,一直是广大电商从业者心中的疑问。跨境商城源码定制开发是指在已有的商城源码的基础上,进行个性化需求的修改和开发,以满足商家在跨境电商中的特定需求。下面…

mybatis plus中json格式实战

1.pom.xml <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0…

一、XSS加解密编码解码工具

一、XSS加解密编码解码工具 解释&#xff1a;使用大佬开发的工具&#xff0c;地址&#xff1a;https://github.com/Leon406/ToolsFx/blob/dev/README-zh.md 在线下载地址&#xff1a; https://leon.lanzoui.com/b0d9av2kb(提取码&#xff1a;52pj)&#xff08;建议下载jdk8-w…

kubesphere 一键部署K8Sv1.21.5版本

1. 在centos上的安装流程 1.1 安装需要的环境 yum install -y socat conntrack ebtables ipset curl1.2 下载KubeKey #电脑必须可以访问github&#xff0c;很重要。不然安装过程会出问题 curl -sfL https://get-kk.kubesphere.io | VERSIONv1.2.1 sh - chmod x kk1.3 开始安…

mysql 优化 聚簇索引=主键索引吗

在 InnoDB 引擎中&#xff0c;每张表都会有一个特殊的索引“聚簇索引”&#xff0c;也被称之为聚集索引&#xff0c;它是用来存储行数据的。一般情况下&#xff0c;聚簇索引等同于主键索引&#xff0c;但这里有一个前提条件&#xff0c;那就是这张表需要有主键&#xff0c;只有…