Pytorch个人学习记录总结 09

目录

损失函数与反向传播

L1Loss

MSELOSS

CrossEntropyLoss


损失函数与反向传播

所需的Loss计算函数都在torch.nn的LossFunctions中,官方网址是:torch.nn — PyTorch 2.0 documentation。举例了L1Loss、MSELoss、CrossEntropyLoss。


在这些Loss函数的使用中,有以下注意的点:
(1) 参数reduction='mean',默认是'mean'表示对差值的和求均值,还可以是'sum'则不会求均值。
(2) 一定要注意Input和target的shape。 

L1Loss

创建一个标准,用于测量中每个元素之间的Input: x xx 和 target: y yy。

创建一个标准,用来测量Input: x xx 和 target: y yy 中的每个元素之间的平均绝对误差(MAE)(L 1 L_1L 1范数)。

 

 

Shape:

Input: (∗ *∗), where ∗ *∗ means any number of dimensions. 会对所有维度的loss求均值
Target: (∗ *∗), same shape as the input. 与Input的shape相同
Output: scalar.返回值是标量。

假设 a aa 是标量,则有:

type(a) = torch.Tensor
a.shape = torch.Size([])
a.dim = 0
 

MSELOSS

创建一个标准,用来测量Input: x xx 和 target: y yy 中的每个元素之间的均方误差(平方L2范数)。


 

Shape:

Input: (∗ *∗), where ∗ *∗ means any number of dimensions. 会对所有维度求loss
Target: (∗ *∗), same shape as the input. 与Input的shape相同
Output: scalar.返回值是标量。


CrossEntropyLoss

该标准计算 input 和 target 之间的交叉熵损失。

非常适用于当训练 C CC 类的分类问题(即多分类问题,若是二分类问题,可采用BCELoss)。如果要提供可选参数 w e i g h t weightweight ,那 w e i g h t weightweight 应设置为1维tensor去为每个类分配权重。这在训练集不平衡时特别有用。

期望的 input应包含每个类的原始的、未标准化的分数。input必须是大小为C CC(input未分批)、(m i n i b a t c h , C minibatch,Cminibatch,C) or (m i n i b a t c h , C , d 1 , d 2 , . . . d kminibatch,C,d_1,d_2,...d_kminibatch,C,d 1,d 2,...d k )的Tensor。最后一种方法适用于高维输入,例如计算2D图像的每像素交叉熵损失。

期望的 target应包含以下内容之一:

(1) (target包含了)在[ 0 , C ) [0,C)[0,C)区间的类别索引,C CC是类别总数量。如果指定了 ignore_index,则此损失也接受此类索引(此索引不一定在类别范围内)。reduction='none'情况下的loss为

注意:l o g loglog默认是以10为底的。
 

x是input,y yy是target,w ww是权重weight,C CC是类别数量,N NN涵盖minibatch维度且d 1 , d 2 . . . , d k d_1,d_2...,d_kd 1,d 2...,d k分别表示第k个维度。(N太难翻译了,总感觉没翻译对)如果reduction='mean'或'sum',

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torchvision.transforms import transformsdataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)class Model(nn.Module):def __init__(self):super(Model, self).__init__()self.model = Sequential(Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),MaxPool2d(kernel_size=2, stride=2),Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),MaxPool2d(kernel_size=2, stride=2),Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),MaxPool2d(kernel_size=2, stride=2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):  # 模型前向传播return self.model(x)model = Model()  # 定义模型
loss_cross = nn.CrossEntropyLoss()  # 定义损失函数for data in dataloader:imgs, targets = dataoutputs = model(imgs)# print(outputs)    # 先打印查看一下结果。outputs.shape=(2, 10) 即(N,C)# print(targets)    # target.shape=(2) 即(N)# 观察outputs和target的shape,然后选择使用哪个损失函数res_loss = loss_cross(outputs, targets)res_loss.backward()  # 损失反向传播print(res_loss)#
# inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
# targets = torch.tensor([1, 2, 5], dtype=torch.float32)
#
# inputs = torch.reshape(inputs, (1, 1, 1, 3))
# targets = torch.reshape(targets, (1, 1, 1, 3))
#
# # -------------L1Loss--------------- #
# loss = nn.L1Loss()
# res = loss(inputs, targets)  # 返回的是一个标量,ndim=0
# print(res)  # tensor(1.6667)
#
# # -------------MSELoss--------------- #
# loss_mse = nn.MSELoss()
# res_mse = loss_mse(inputs, targets)
# print(res_mse)
#
# # -------------CrossEntropyLoss--------------- #
# x = torch.tensor([0.1, 0.2, 0.3])  # (N,C)
# x = torch.reshape(x, (1, 3))
# y = torch.tensor([1])  # (N)
# loss_cross = nn.CrossEntropyLoss()
# res_cross = loss_cross(x, y)
# print(res_cross)


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/18536.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux系统将OpenSSH升级到最高版本

一、背景: 公司安全扫描到主机的OpenSSH安全漏洞,由于是虚拟机只能由自己修复,很多OpenSSH的漏洞厂商都没有提供补丁,只能通过禁用scp或者端口的方式临时解决,但是后面使用就不方便了,而且也不安全&#x…

这所211考数一英二,学硕降分33分,十分罕见!

一、学校及专业介绍 合肥工业大学(Hefei University of Technology),简称“合工大”,校本部位于安徽省合肥市,是中华人民共和国教育部直属的全国重点大学,是国家“双一流”建设高校, 国家“211工…

sftp和scp协议,哪个传大文件到服务器传输速率快?

环境: 1.Win scp 6.1.1 2.XFTP 7 3.9.6G压缩文件 4.Centos 7 5.联想E14笔记本Win10 6.HW-S1730S-S48T4S-A交换机 问题描述: sftp和scp协议,哪个传大文件到服务器速度快? 1.SFTP 基于SSH加密传输文件,可靠性高&am…

【论文阅读24】Better Few-Shot Text Classification with Pre-trained Language Model

论文相关 论文标题:Label prompt for multi-label text classification(基于预训练模型对少样本进行文本分类) 发表时间:2021 领域:多标签文本分类 发表期刊:ICANN(顶级会议) 相关代…

基于opencv的几种图像滤波

一、介绍 盒式滤波、均值滤波、高斯滤波、中值滤波、双边滤波、导向滤波。 boxFilter() blur() GaussianBlur() medianBlur() bilateralFilter() 二、代码 #include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> …

JDBC的书写

文章目录 基本概念操作数据库方式一&#xff08;不建议使用这种查询&#xff0c;可以sql注入&#xff09;读取properties文件 事务转账示例 获取id连接池 基本概念 持久化:把数据放在磁盘上&#xff0c;断电后还是有数据。使用execute 执行增删改返回false,查返回true 操作数…

AI 绘画Stable Diffusion 研究(三)sd模型种类介绍及安装使用详解

本文使用工具&#xff0c;作者:秋葉aaaki 免责声明: 工具免费提供 无任何盈利目的 大家好&#xff0c;我是风雨无阻。 今天为大家带来的是 AI 绘画Stable Diffusion 研究&#xff08;三&#xff09;sd模型种类介绍及安装使用详解。 目前&#xff0c;AI 绘画Stable Diffusion的…

css3 hover border 流动效果

/* Hover 边线流动 */.hoverDrawLine {border: 0 !important;position: relative;border-radius: 5px;--border-color: #60daaa; } .hoverDrawLine::before, .hoverDrawLine::after {box-sizing: border-box;content: ;position: absolute;border: 2px solid transparent;borde…

生成对抗网络DCGAN学习实践

在AI内容生成领域&#xff0c;有三种常见的AI模型技术&#xff1a;GAN、VAE、Diffusion。其中&#xff0c;Diffusion是较新的技术&#xff0c;相关资料较为稀缺。VAE通常更多用于压缩任务&#xff0c;而GAN由于其问世较早&#xff0c;相关的开源项目和科普文章也更加全面&#…

【机器学习】Gradient Descent

Gradient Descent for Linear Regression 1、梯度下降2、梯度下降算法的实现(1) 计算梯度(2) 梯度下降(3) 梯度下降的cost与迭代次数(4) 预测 3、绘图4、学习率 首先导入所需的库&#xff1a; import math, copy import numpy as np import matplotlib.pyplot as plt plt.styl…

Devops系统中jira平台迁移

需求:把aws中的devops系统迁移到华为云中,其中主要是jira系统中的数据迁移,主要方法为在华为云中建立一套 与aws相同的devops平台,再把数据库和文件系统中的数据迁移,最后进行测试。 主要涉及到的服务集群CCE、数据库mysql、弹性文件服务SFS、数据复制DRS、弹性负载均衡ELB。 迁…

问道管理:补仓什么意思?怎么补仓可以降低成本?

补仓这个术语我们在理财出资中经常听到&#xff0c;例如基金补仓&#xff0c;股票补仓。那么&#xff0c;补仓什么意思&#xff1f;怎样补仓能够降低成本&#xff1f;问道管理为我们预备了相关内容&#xff0c;以供参阅。 补仓什么意思&#xff1f; 股票补仓是指出资者在某一只…

Debian 12.1 “书虫 “发布,包含 89 个错误修复和 26 个安全更新

导读Debian 项目今天宣布&#xff0c;作为最新 Debian GNU/Linux 12 “书虫 “操作系统系列的首个 ISO 更新&#xff0c;Debian 12.1 正式发布并全面上市。 Debian 12.1 是在 Debian GNU/Linux 12 “书虫 “发布六周后推出的&#xff0c;目的是为那些希望在新硬件上部署操作系统…

Vivado进行自定义IP封装

一. 简介 本篇文章将介绍如何使用Vivado来对上篇文章(FPGA驱动SPI屏幕)中的代码进行一个IP封装&#xff0c;Vivado自带的IP核应该都使用过&#xff0c;非常方便。 这里将其封装成IP核的目的主要是为了后续项目的调用&#xff0c;否则当我新建一个项目的时候&#xff0c;我需要将…

VirtualBox Ubuntu无法安装增强功能以及无法复制粘贴踩坑记录

在VirtualBox安装增强功能想要和主机双向复制粘贴&#xff0c;中间查了很多资料&#xff0c;终于是弄好了。记录一下过程&#xff0c;可能对后来人也有帮助&#xff0c;我把我参考的几篇主要的博客都贴上来了&#xff0c;如果觉得我哪里讲得不清楚的&#xff0c;可以去对应的博…

Shell脚本学习-Shell函数

函数的作用就是将程序里多次被调用的相同代码组合起来&#xff08;函数体&#xff09;&#xff0c;并为其取一个名字&#xff0c;即函数名。其他所有想重复调用这部分代码的地方都只需要调用这个名字就可以了。当需要修改这部分代码时候&#xff0c;只需要修改函数体内的这部分…

【简单认识GFS分布式文件系统】

文章目录 一.GlusterFS 概述1.GlusterFS简介2.特点3.GlusterFS 术语4.模块化堆栈式架构5.GlusterFS 的工作流程6.GlusterFS的卷类型1、**分布式卷&#xff08;Distribute volume&#xff09;**2、条带卷&#xff08;Stripe volume&#xff09;3、复制卷&#xff08;Replica vol…

Web后端基本设计思想

JavaWeb应用的后端一般基于MVC和三层架构思想实现。 MVC是一种设计模式&#xff0c;用于开发用户界面和交互式应用程序。M即Model&#xff0c;业务模型&#xff0c;负责处理应用程序的业务逻辑和数据&#xff1b;V即View&#xff0c;视图&#xff0c;负责给用户展示界面和数据&…

快速创建vue3+vite+ts项目

安装nodejs 创建项目 npm init vitelatest 默认之后回车 选择项目名字my-vue-project 选择vue框架 选择ts 运行项目 cd my-vue-project npm install --registryhttps://registry.npm.taobao.org npm run dev

Vue2 第十二节 Vue组件化编程(一)

1.模块与组件&#xff0c;模块化与组件化概念 2. 非单文件组件 3. 组件编写注意事项 4. 组件的嵌套 一. 模块与组件&#xff0c;模块化与组件化 传统方式编写存在的问题 &#xff08;1&#xff09;依赖关系混乱&#xff0c;不好维护 &#xff08;2&#xff09;代码的复用…