常用分类损失CE Loss、Focal Loss及GHMC Loss理解与总结

一、CE Loss

定义

交叉熵损失(Cross-Entropy Loss,CE Loss)能够衡量同一个随机变量中的两个不同概率分布的差异程度,当两个概率分布越接近时,交叉熵损失越小,表示模型预测结果越准确。

公式

二分类

二分类的CE Loss公式如下,

其中,M:正样本数量,N:负样本数量,y_{i}:真实值, p_{i}:预测值

多分类

在计算多分类的CE Loss时,首先需要对模型输出结果进行softmax处理。公式如下,

其中, output:模型输出,p:对模型输出进行softmax处理后的值, ​​​​​:真实值的one hot编码​(假设模型在做5分类,如果y_{i}=2,则=[0,0,1,0,0])

代码实现

二分类

import torch
import torch.nn as nn
import mathcriterion = nn.BCELoss()
output = torch.rand(1, requires_grad=True)
label = torch.randint(0, 1, (1,)).float()
loss = criterion(output, label)print("预测值:", output)
print("真实值:", label)
print("nn.BCELoss:", loss)for i in range(label.shape[0]):if label[i] == 0:res = -math.log(1-output[i])elif label[i] == 1:res = -math.log(output[i])
print("自己的计算结果", res)"""
预测值: tensor([0.7359], requires_grad=True)
真实值: tensor([0.])
nn.BCELoss: tensor(1.3315, grad_fn=<BinaryCrossEntropyBackward0>)
自己的计算结果 1.331509556677378
"""

多分类

import torch
import torch.nn as nn
import mathcriterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)print("预测值:", output)
print("真实值:", label)
print("nn.CrossEntropyLoss:", loss)output = torch.softmax(output, dim=1)
print("softmax后的预测值:", output)one_hot = torch.zeros_like(output).scatter_(1, label.view(-1, 1), 1)
print("真实值对应的one_hot编码", one_hot)res = (-torch.log(output) * one_hot).sum()
print("自己的计算结果", res)"""
预测值: tensor([[-0.7459, -0.3963, -1.8046,  0.6815,  0.2965]], requires_grad=True)
真实值: tensor([1])
nn.CrossEntropyLoss: tensor(1.9296, grad_fn=<NllLossBackward0>)
softmax后的预测值: tensor([[0.1024, 0.1452, 0.0355, 0.4266, 0.2903]], grad_fn=<SoftmaxBackward0>)
真实值对应的one_hot编码 tensor([[0., 1., 0., 0., 0.]])
自己的计算结果 tensor(1.9296, grad_fn=<SumBackward0>)
"""

二、Focal Loss

定义

虽然CE Loss能够衡量同一个随机变量中的两个不同概率分布的差异程度,但无法解决以下两个问题:1、正负样本数量不平衡的问题(如centernet的分类分支,它只将目标的中心点作为正样本,而把特征图上的其它像素点作为负样本,可想而知正负样本的数量差距之大);2、无法区分难易样本的问题(易分类的样本的分类错误的损失占了整体损失的绝大部分,并主导梯度)

为了解决以上问题,Focal Loss在CE Loss的基础上改进,引入了:1、正负样本数量调节因子以解决正负样本数量不平衡的问题;2、难易样本分类调节因子以聚焦难分类的样本

公式

二分类

公式如下,

 

​​​​​​​

其中,\alpha:正负样本数量调节因子,\gamma:难易样本分类调节因子

多分类

其中,\alpha _{y_{i}}y_{i}类别的权重

代码实现

二分类

def sigmoid_focal_loss(inputs: torch.Tensor,targets: torch.Tensor,alpha: float = -1,gamma: float = 2,reduction: str = "none",
) -> torch.Tensor:"""Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.Args:inputs: A float tensor of arbitrary shape.The predictions for each example.targets: A float tensor with the same shape as inputs. Stores the binaryclassification label for each element in inputs(0 for the negative class and 1 for the positive class).alpha: (optional) Weighting factor in range (0,1) to balancepositive vs negative examples. Default = -1 (no weighting).gamma: Exponent of the modulating factor (1 - p_t) tobalance easy vs hard examples.reduction: 'none' | 'mean' | 'sum''none': No reduction will be applied to the output.'mean': The output will be averaged.'sum': The output will be summed.Returns:Loss tensor with the reduction option applied."""inputs = inputs.float()targets = targets.float()p = torch.sigmoid(inputs)ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")p_t = p * targets + (1 - p) * (1 - targets)loss = ce_loss * ((1 - p_t) ** gamma)if alpha >= 0:alpha_t = alpha * targets + (1 - alpha) * (1 - targets)loss = alpha_t * lossif reduction == "mean":loss = loss.mean()elif reduction == "sum":loss = loss.sum()return loss

步骤1、首先对输入进行sigmoid处理,

p = torch.sigmoid(inputs)

步骤2、随后求出CE Loss,

ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

步骤3、定义p_{t}^{i},公式为:

p_t = p * targets + (1 - p) * (1 - targets)

步骤4、为CE Loss添加难易样本分类调节因子,

loss = ce_loss * ((1 - p_t) ** gamma)

步骤5、定义\alpha _{t}^{i},公式为:

alpha_t = alpha * targets + (1 - alpha) * (1 - targets)

步骤6、为步骤4的损失添加正负样本数量调节因子,

loss = alpha_t * loss

多分类

def multi_cls_focal_loss(inputs: torch.Tensor,targets: torch.Tensor,alpha: torch.Tensor,gamma: float = 2,reduction: str = "none",
) -> torch.Tensor:inputs = inputs.float()targets = targets.float()ce_loss = nn.CrossEntropyLoss()(inputs, targets, reduction="none")one_hot = torch.zeros_like(inputs).scatter_(1, targets.view(-1, 1), 1)p_t = inputs * one_hotloss = ce_loss * ((1 - p_t) ** gamma)if alpha >= 0:alpha_t = alpha * one_hotloss = alpha_t * lossreturn loss

三、GHMC Loss

定义

Focal Loss在CE Loss的基础上改进后,解决了正负样本不平衡以及无法区分难易样本的问题,但也会过分关注难分类的样本(离群点),导致模型学歪。为了解决这个问题,GHMC(Gradient Harmonizing Mechanism-C)定义了梯度模长,该梯度模长正比于分类的难易程度,目的是让模型不要关注那些容易学的样本,也不要关注那些特别难分的样本

公式

1、定义梯度模长

二分类的CE Loss公式如下,

假设x是模型的输出,假设p=sigmoid(x),求损失对x的偏导,

因此,定义梯度模长如下,

其中, p:预测值,p^{\ast }:真实值

梯度模长与样本数量的关系如下,

2、定义梯度密度(单位梯度模长g上的样本数量

  

其中,g_{k}:第k个样本的梯度模长,\delta _{\varepsilon }(g_{k},g)g_{k}(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})范围内的样本数量,l_{\varepsilon }(g):区间(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})的长度

3、定义梯度密度协调参数(gradient density harmonizing parameter)

其中,N:样本总数

 4、定义GHMC Loss

 

代码实现

def _expand_binary_labels(labels, label_weights, label_channels):bin_labels = labels.new_full((labels.size(0), label_channels), 0)inds = torch.nonzero(labels >= 1).squeeze()if inds.numel() > 0:bin_labels[inds, labels[inds] - 1] = 1bin_label_weights = label_weights.view(-1, 1).expand(label_weights.size(0), label_channels)return bin_labels, bin_label_weightsclass GHMC(nn.Module):def __init__(self,bins=10,momentum=0,use_sigmoid=True,loss_weight=1.0):super(GHMC, self).__init__()self.bins = binsself.momentum = momentumself.edges = [float(x) / bins for x in range(bins+1)]self.edges[-1] += 1e-6if momentum > 0:self.acc_sum = [0.0 for _ in range(bins)]self.use_sigmoid = use_sigmoidself.loss_weight = loss_weightdef forward(self, pred, target, label_weight, *args, **kwargs):""" Args:pred [batch_num, class_num]:The direct prediction of classification fc layer.target [batch_num, class_num]:Binary class target for each sample.label_weight [batch_num, class_num]:the value is 1 if the sample is valid and 0 if ignored."""if not self.use_sigmoid:raise NotImplementedError# the target should be binary class labelif pred.dim() != target.dim():target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))target, label_weight = target.float(), label_weight.float()edges = self.edgesmmt = self.momentumweights = torch.zeros_like(pred)# 计算梯度模长g = torch.abs(pred.sigmoid().detach() - target)valid = label_weight > 0tot = max(valid.float().sum().item(), 1.0)# 设置有效区间个数n = 0for i in range(self.bins):inds = (g >= edges[i]) & (g < edges[i+1]) & validnum_in_bin = inds.sum().item()if num_in_bin > 0:if mmt > 0:self.acc_sum[i] = mmt * self.acc_sum[i] \+ (1 - mmt) * num_in_binweights[inds] = tot / self.acc_sum[i]else:weights[inds] = tot / num_in_binn += 1if n > 0:weights = weights / nloss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / totreturn loss * self.loss_weight

步骤一、将梯度模长划分为bins(默认为10)个区域,

self.edges = [float(x) / bins for x in range(bins+1)]
"""
[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000, 1.0000]
"""

步骤二、计算梯度模长

g = torch.abs(pred.sigmoid().detach() - target)

步骤三、计算落入不同bin区间的梯度模长数量

valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0
for i in range(self.bins):inds = (g >= edges[i]) & (g < edges[i+1]) & validnum_in_bin = inds.sum().item()if num_in_bin > 0:if mmt > 0:self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_binweights[inds] = tot / self.acc_sum[i]else:weights[inds] = tot / num_in_binn += 1
if n > 0:weights = weights / n

步骤四、计算GHMC Loss

loss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / tot * self.loss_weight

【参考文章】

Focal Loss的理解以及在多分类任务上的使用(Pytorch)_focal loss 多分类_GHZhao_GIS_RS的博客-CSDN博客

focal loss 通俗讲解 - 知乎

Focal Loss损失函数(超级详细的解读)_BigHao688的博客-CSDN博客

5分钟理解Focal Loss与GHM——解决样本不平衡利器 - 知乎 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/51.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Zookeeper 分布式锁案例

一、概述&#xff1a; Zookeeper 是一个开源的分布式协调服务&#xff0c;可以用于维护分布式系统中的一致性、顺序性和命名等。其中&#xff0c;Zookeeper 的分布式锁机制可以用于实现分布式系统中的互斥访问&#xff0c;确保在多个节点上对共享资源进行同步访问。 Zookeepe…

深入理解预训练(pre-learning)、微调(fine-tuning)、迁移学习(transfer learning)三者的联系与区别

1. 什么是预训练和微调 你需要搭建一个网络模型来完成一个特定的图像分类的任务。首先&#xff0c;你需要随机初始化参数&#xff0c;然后开始训练网络&#xff0c;不断调整参数&#xff0c;直到网络的损失越来越小。在训练的过程中&#xff0c;一开始初始化的参数会不断变化。…

小程序:页面跳转闪屏

自己的笔记&#xff0c;随手记录。扛精走开。 1、问题描述 进入页面&#xff0c;是一个组件&#xff0c;通过路由传参判断是由哪个页面进入&#xff0c;不同的页面拿的已选值不一样&#xff0c;需要回显值&#xff0c;在编辑数据。此时会出现一个问题&#xff0c;A页面中进来…

Serverless和EDA是绝配,亚马逊云科技CTO Werner表示需要用开放心态来重新审视架构

前一段有个很火的博客&#xff0c;讲的是一家全球流媒体企业的监测系统从Serverless微服务改成了单体&#xff0c;成本居然降低了90%&#xff01;这一下子可在网上炸锅了&#xff0c;特别是一些看不惯微服务的、单体应用的拥趸&#xff0c;更是坐不住了。但这并不像吃瓜群众看到…

0基础学C#笔记01:数组复制

文章目录 前言一、错误的复制方式二、正确的复制方式1.复制方式 一2.复制方式 二3.复制方式 三4.复制方式 四三、验证总结前言 C# 数组复制的四种方式,简单明了实用 我们先创建一个数组: string[] names = {"张三", "李四",

【ECharts系列】ECharts 图表渲染问题解决方案

1 问题描述 echats 渲染&#xff0c;第一次的时候只出现Y轴数值&#xff0c;不出现X轴数值&#xff0c;切换下页面&#xff0c;X轴数值就能出现。 2 原因分析 如果在使用ECharts渲染时&#xff0c;X轴数值只在切换页面后才出现&#xff0c;可能是因为ECharts在初始化时没有正确…

光速吟唱,Clibor ,批量多次复制依次粘贴工具 快捷输入软件教程

批量多次复制依次粘贴工具 批量复制粘贴工具0.81.exe https://www.aliyundrive.com/s/3sbBaGmHkb8 点击链接保存&#xff0c;或者复制本段内容&#xff0c;打开「阿里云盘」APP &#xff0c;无需下载极速在线查看&#xff0c;视频原画倍速播放。 青县solidworks钣金设计培训 …

Redis进阶 - Redis哨兵

原文首更地址&#xff0c;阅读效果更佳&#xff01; Redis进阶 - Redis哨兵 | CoderMast编程桅杆https://www.codermast.com/database/redis/redis-advance-sentinel.html 思考 slave 节点宕机恢复以后可以找 master 节点同步数据&#xff0c;那么 master 节点宕机怎么办&am…

MFC 利用多态的特性实现子窗口同时存在一个

多个子窗口的类都继承同一父类 CDialogEx 于是在主窗口我声明一个CDialogEx指针 通过判断该指针是否为空 不为空则视为有一子窗口存在 注意这里介绍的是 非模态化窗口的关闭 你可以在任何时候调用DestroyWindow()以达到彻底销毁自身对象的作用。&#xff08;DestroyWindow()的…

[Docker] Docker镜像管理和操作实践(二) 文末送书

前言&#xff1a; Docker镜像是容器化应用程序的打包和分发单元&#xff0c;包含了应用程序及其所有依赖项&#xff0c;实现了应用程序的可移植性和一致性。 文章目录 使用Dockerfile创建自定义镜像实践练手1. 创建基于ubuntu的自定义镜像&#xff0c;并安装nginx2. 配置Redis容…

MySQL数据库高级查询语句

MySQL数据库高级查询语句 一、语句SELECT ----显示表格中一个或数个字段的所有数据记录DISTINCT ----不显示重复的数据记录WHERE ----有条件查询AND OR ----且 或IN ----显示已知的值的数据记录BETWEEN ----显示两个值范围内的数据记录通配符 ----通常通配符都是跟 LIKE 一起使…

【云原生】二进制k8s集群(下)部署高可用master节点

本次部署说明 在上一篇文章中&#xff0c;就已经完成了二进制k8s集群部署的搭建&#xff0c;但是单机master并不适用于企业的实际运用&#xff08;因为单机master中&#xff0c;仅仅只有一台master作为节点服务器的调度指挥&#xff0c;一旦宕机。就意味着整个集群的瘫痪&#…

input.unsqueeze(0)的作用

input.unsqueeze(0) 是 PyTorch 张量&#xff08;Tensor&#xff09;的方法之一&#xff0c;用于增加张量的维度。具体来说&#xff0c;它会在索引为 0 的位置上插入一个维度。 假设 input 是一个形状为 (n,) 的一维张量&#xff0c;其中 n 是任意长度。调用 unsqueeze(0) 后&…

【Web3】 Web3JS Pay Api

Web3Network.eth.sendSignedTransaction(serializedTx) 参数&#xff1a; from- String|Number&#xff1a;发送帐户的地址。如果未指定&#xff0c;则使用web3.eth.defaultAccount属性。或web3.eth.accounts.wallet中本地的地址。 to- String:(可选&#xff09;消息的目标地址…

宝塔Linux面板安装Composer依赖管理工具与PHP依赖包的方法

最近看见腾讯云有一个AI绘画还挺有意思&#xff0c;想搞来写个接口玩 但是Composer一直运行不成功 提示xdebug什么的 最后经过搜索 发现 需要删除你宝塔里所有php中禁用的putenv函数 然后重启php就可以了&#xff01; 然后就可以运行这个命令了 出现这种情况 还需要删除所有…

Linux常用命令——emacs命令

在线Linux命令查询工具 emacs 功能强大的全屏文本编辑器 补充说明 emacs命令是由GNU组织的创始人Richard Stallman开发的一个功能强大的全屏文本编辑器&#xff0c;它支持多种编程语言&#xff0c;具有很多优良的特性。有众多的系统管理员和软件开发者使用emacs。 语法 e…

ubuntu20.04上linux内核开发环境搭建(qemu+gdb+vscode)

qemugdbvscode环境搭建 1. 环境准备1.1 安装基础软件1.2 安装开发软件1.3 安装qemu 2. 制作rootfs2.1 buildroot制作rootfs2.2 busybox制作rootfs 3. kernel编译及运行3.1 编译内核3.2 启动内核3.3 qemu参数说明 4. vscode 图形化调试环境配置4.1 vscode配置4.2 qemu配置 1. 环…

服务器技术(三)--Nginx

Nginx介绍 Nginx是什么、适用场景 Nginx是一个高性能的HTTP和反向代理服务器&#xff0c;特点是占有内存少&#xff0c;并发能力强&#xff0c;事实上nginx的并发能力确实在同类型的网页服务器中表现较好。 Nginx专为性能优化而开发&#xff0c;性能是其最重要的考量&#xf…

【Nginx】rewrite简单使用

前言 没有对正式的rewrite进行了解&#xff0c;为了能快速了解它是干嘛怎么用&#xff0c;找了一些有例子的博客进行简单学习了一下&#xff1b;由于每次看的间隔有点大&#xff0c;老忘记&#xff0c;这回专门写个超级快速理解的例子。 PS&#xff1a;下面的解释可能会不太对…

SMALE周报_20230714

目录标题 1. 上周回顾2. 本周计划3. 完成情况4. 存在的主要问题5. 下一步工作 1. 上周回顾 总结不确定性在神经网络中的运用。跳转链接 2. 本周计划 通过阅读论文《Semi-Supervised Deep Regression with Uncertainty Consistency and Variational Model Ensembling》&#…