BCELoss,BCEWithLogitsLoss和CrossEntropyLoss

目录

二分类

1. BCELoss

2. BCEWithLogitsLoss

多分类

1. CrossEntropyLoss

 举例


二分类

两个损失:BCELoss,BCEWithLogitsLoss

1. BCELoss

输入:([B,C], [B,C]),代表(prediction,target)的维度,其中,B是Batchsize,C为样本的class,即样本的类别数。

输出:一个标量

等价于:BCELoss + sigmoid

import torch
from torch import nninput = torch.randn(3) # (3,1) 随机生成一个输入,没有被sigmoid。
print(input)
print(input.shape)
target=torch.Tensor([0., 1., 1.])
loss1=nn.BCELoss()
print("BCELoss:",loss1(torch.sigmoid(input), target))#需要sigmod输出:
BCELoss: tensor(1.0053)

2. BCEWithLogitsLoss

输入:([B,C], [B,C]),输出:一个标量

import torch
from torch import nninput = torch.randn(3) # (3,1) 随机生成一个输入,没有被sigmoid。
print(input)
print(input.shape)
target=torch.Tensor([0., 1., 1.])
loss2=nn.BCEWithLogitsLoss()
print("BCEWithLogitsLoss:",loss2(input,target))#不需要sigmoid输出:
BCEWithLogitsLoss: tensor(1.0053)

多分类

1. CrossEntropyLoss

输入:([B,C], [B]) 输出:一个标量(这个minibatch的mean/sum的loss)

nn.CrossEntropyLoss计算过程: 
input: logits(未经过softmax的模型的"输出”)

  •  softmax(input)
  • -log(softmax(input))
  • 用target做选择提取(关于logsoftmax)· mean

等价于:nn.CrossEntropyLoss = nn.NLLLoss(nn.LogSoftmax)
 

import torch
from torch import nnloss2 = nn.CrossEntropyLoss(reduction="none")
target2 = torch.tensor([0, 1, 2])
predict2 = torch.tensor([[0.9, 0.2, 0.8], [0.5, 0.2, 0.4], [0.4, 0.2, 0.9]])
print(predict2.shape) # torch.Size([3, 3])
print(target2.shape) # torch.Size([3])
print(loss2(predict2, target2))# #结果计算为:
# tensor([0.8761, 1.2729, 0.7434])

 举例

1. BCEWithLogitsLoss计算ACC和Loss:

参考:https://github.com/Loche2/IMDB_RNN/blob/master/training.py

criterion = nn.BCEWithLogitsLoss()
# 计算准确率
def binary_accuracy(predicts, y):rounded_predicts = torch.round(torch.sigmoid(predicts))correct = (rounded_predicts == y).float()accuracy = correct.sum() / len(correct)return accuracy# 训练
def train(model, iterator, optimizer, criterion):model.train()epoch_loss = 0epoch_accuracy = 0for batch in tqdm(iterator, desc=f'Epoch [{epoch + 1}/{EPOCHS}]', delay=0.1):optimizer.zero_grad()predictions = model(batch.text[0]).squeeze(1)loss = criterion(predictions, batch.label)accuracy = binary_accuracy(predictions, batch.label)loss.backward()optimizer.step()epoch_loss += loss.item()epoch_accuracy += accuracy.item()return epoch_loss / len(iterator), epoch_accuracy / len(iterator)

2. 计算ACC和Loss

# 截取情感分析部分代码 criterion = nn.CrossEntropyLoss()total_loss = 0.0correct_predictions = 0total_predictions = 0for batch in train_loader:input_ids = batch['input_ids'].to(device)labels = batch['label'].to(device)optimizer.zero_grad()logits = model(input_ids)loss_sentiment = criterion(logits, labels.long())loss_sentiment.backward()optimizer.step()total_loss += loss_sentiment.item()# get sentiment accuracypredicted_labels = torch.argmax(logits, dim=1)correct_predictions += torch.sum(predicted_labels == labels).item()total_predictions += labels.size(0)accuracy = correct_predictions / total_predictionsloss = total_loss / len(train_loader)

也可以直接看github上别人写的例子:https://github.com/songyouwei/ABSA-PyTorch/blob/master/train.py

参考:

深刻剖析与实战BCELoss详解(主)和BCEWithLogitsLoss(次)以及与普通CrossEntropyLoss的区别(次)-CSDN博客

另外提出一个问题:

二分类必须用BCEWithLogitsLoss吗,也可以用CrossEntropyLoss吧?

(1)如果用CrossEntropyLoss的话,只要让网络的fc层为nn.Linear(hidden_size, 2)就行,这样就和多分类一样算。另外CrossEntropyLoss里面包含了softmax,所以在计算loss的时候也不需要过softmax再算loss.

(2)如果用BCEWithLogitsLoss的话,就按照上面举例中BCEWithLogitsLoss计算Loss,只是如上面代码可是,再计算Acc的时候将predict使用sigimoid缩放到0,1来计算预测正确的个数

注:仅供学习记录,理解或者学习有误请与我联系

参考问题:二分类问题,应该选择sigmoid还是softmax? - 知乎 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/581658.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp项目如何引用安卓原生aar插件(避坑指南三)

官方文档说明:uni小程序SDK 【彩带- 避坑知识点】 如果引用原生aar插件,都配置好之后,云打包,报不包含此插件,除了检查以下步骤流程外,还要检查一下是否上打包的原生插件aar流程有问题。 1.第一步在uniapp项…

关于“Python”的核心知识点整理大全44

目录 ​编辑 15.3.4 模拟多次随机漫步 rw_visual.py 注意 15.3.5 设置随机漫步图的样式 15.3.6 给点着色 rw_visual.py 15.3.7 重新绘制起点和终点 rw_visual.py 15.3.8 隐藏坐标轴 rw_visual.py 15.3.9 增加点数 rw_visual.py 15.3.10 调整尺寸以适合屏幕 rw_vi…

Linux磁盘与文件管理

目录 一、磁盘介绍 1. 磁盘数据结构 2. 磁盘的接口类型 3. 磁盘在Linux上的表现形式 二、磁盘分区与MBR 1. 分区优缺点 2. 分区方式 3. MBR分区 4. GPT分区 三、文件系统 1. 文件系统的组成 2. 默认的文件系统 3. 文件系统的作用 4. 模拟破坏文件与修复文件 4…

C语言二维数值数组常用算法------------(C每日一编程)

--主、次对角线求和 --上、下三角求和 --N*N方阵转置 --杨辉三角 正文开始&#xff1a; 主对角线&#xff1a; 用两个双重for循环 int a[3][3], i, j, s 0; for (i 0; i < 3; i)for (j 0; j < 3; j)if (i j)s s a[i][j]; 次对角线&#xff1a; 用两个双重…

Next Station of Flink CDC

摘要&#xff1a;本文整理自阿里云智能 Flink SQL、Flink CDC 负责人伍翀&#xff08;花名&#xff1a;云邪&#xff09;&#xff0c;在 Flink Forward Asia 2023 主会场的分享。Flink CDC 是一款基于 Flink 打造一系列数据库的连接器。本次分享主要介绍 Flink CDC 开源社区在过…

STM32基础概念

1 什么是STM32 ST 是意法半导体&#xff0c;为公司名称&#xff0c;是SOC厂商。 M 是Microelectronics 的缩写。 32 表示32 位。 STM32 就是指ST 公司开发的32 位微控制器。 2 功能 自带了各种常用通信接口&#xff0c;比如USART、I2C、SPI 等&#xff0c;可接非常多的传感器…

uniapp实现前端银行卡隐藏中间的数字,及隐藏姓名后两位

Vue 实现前端银行卡隐藏中间的数字 主要应用了 filters过滤器 来实现效果 实现效果&#xff0c;如图&#xff1a; <template><div><div style"background-color: #f4f4f4;margin:50px 0 0 460px;width:900px;height:300px;"><p>原来&#…

python之Selenium WebDriver安装与使用

首先把python下载安装后&#xff0c;再添加到环境变量中&#xff0c;再打开控制台输入: pip install selenium 正常情况下是安装好的&#xff0c;检查一下“pip show selenium”命令&#xff0c;出现版本号就说明安装好了。 1&#xff1a;如果出现安装错误&#xff1a; 那就用“…

C++ 返回当前EXE所在的绝对路径和文件夹路径

目录 一、代码示例二、运行结果在代码里打印当前EXE所在的绝对路径和文件夹路径,以便调用该可执行程序时我可以知道当前执行程序的路径,以方便后续我使用别的文件夹和文件。 一、代码示例 #include<iostream> #include<string> #include<Windows.h> using…

EasyExcel简单合并单元格数据工具类

代码&#xff1a; package com.ly.cloud.util;import cn.hutool.core.collection.CollUtil; import com.alibaba.excel.metadata.Head; import com.alibaba.excel.write.merge.AbstractMergeStrategy; import org.apache.poi.ss.usermodel.Cell; import org.apache.poi.ss.use…

Alibaba Cloud Linux 3.2104 LTS 64位镜像兼容CentOS吗?

Alibaba Cloud Linux 3.2104 LTS 64位镜像兼容CentOS吗&#xff1f;完全兼容RHEL/CentOS生态和操作方式。 阿里云Alibaba Cloud Linux 3.2104 LTS 64位镜像是可以选择的&#xff0c;它阿里云打造的Linux服务器操作系统发行版&#xff0c;针对云服务器ECS做了大量深度优化&…

Windows 源码编译 MariaDB

环境 Win11, vs2022, git, cmake, Bison from GnuWin32, perl, Gnu Diff. 默认都安装好。 perl 看之前博客教程。perl Bison from GnuWin32 默认安装到 C:\GnuWin32 Add C:\GnuWin32\bin to your system PATH after installation. 下载mariadb源码 地址&#xff1a;MariaD…

【maven】pom.xml 文件详解

有关 maven 其他配置讲解参考 maven 配置文件 setting.xml 详解 pom.xml 文件是 Maven 项目的核心配置文件&#xff0c;其中包含了项目的元数据、构建配置、依赖管理等信息。以下是一个 pom.xml 文件的主要部分&#xff1a; <?xml version"1.0" encoding"U…

测试:抓包工具

抓包工具是网络安全和软件测试领域中非常重要的工具&#xff0c;它能够帮助用户捕获、分析和修改网络数据包。这些工具对于开发人员、测试人员以及安全研究人员来说都非常实用&#xff0c;因为它们可以用来监测网络流量、定位问题、分析协议以及进行安全评估。 Fiddler Fiddl…

代码随想录 Leetcode27. 移除元素

题目&#xff1a; 代码(首刷看解析 2023年12月28日)&#xff1a; class Solution { public:int removeElement(vector<int>& nums, int val) {int n nums.size();int slowIndex 0;for(int fastIndex 0; fastIndex < n; fastIndex){if(val ! nums[fastIndex])…

电影“AI化”已成定局,华为、小米转战入局又将带来什么?

从华为、Pika、小米等联合打造电影工业化实验室、到Pika爆火&#xff0c;再到国内首部AI全流程制作《愚公移山》开机……业内频繁的新动态似乎都在预示着2023年国内电影开始加速进入新的制片阶段&#xff0c;国内AI电影热潮即将来袭。 此时以华为为首的底层技术科技企业加入赛…

leaflet学习笔记-地图缩略图(鹰眼)的添加(三)

介绍 地图缩略图控件有助于用户了解主窗口显示的地图区域在全球、全国、全省、全市等范围内的相对位置&#xff0c;也称为鹰眼图。Leaflet提供了好几种地图缩略图控件&#xff0c;本文介绍其中一个最常用控件&#xff0c;即插件Leaflet.MiniMap。 依赖添加 这些地图控件都可以…

清除conda和pip缓存的方法

conda 清除conda缓存中的所有包、索引和临时文件&#xff0c; conda clean --all 只清除conda缓存中的包&#xff0c;而不清除索引和临时文件 &#xff0c; conda clean --packages pip 清除pip缓存中的所有包和索引文件&#xff0c; pip cache purge

使用机器学习进行语法错误检测/纠正

francescofranco_39234 一、说明 一般的学习&#xff0c;特别是深度学习&#xff0c;促进了自然语言处理。各种模型使人们能够执行机器翻译、文本摘要和情感分析——仅举几个用例。今天&#xff0c;我们将研究另一个流行的用途&#xff1a;我们将使用Gramformer构建一个用于机器…

时间序列系列03-统计模型

时间序列统计模型是用来描述和预测时间序列数据的数学模型。这些模型通常基于过去的观测值&#xff0c;并假设时间序列的行为是可预测的。以下是一些常见的时间序列统计模型&#xff1a; 1. 自回归移动平均模型&#xff08;ARMA&#xff09;&#xff1a; ARMA 模型是由自回归…