常用分类损失CE Loss、Focal Loss及GHMC Loss理解与总结

一、CE Loss

定义

交叉熵损失(Cross-Entropy Loss,CE Loss)能够衡量同一个随机变量中的两个不同概率分布的差异程度,当两个概率分布越接近时,交叉熵损失越小,表示模型预测结果越准确。

公式

二分类

二分类的CE Loss公式如下,

其中,M:正样本数量,N:负样本数量,y_{i}:真实值, p_{i}:预测值

多分类

在计算多分类的CE Loss时,首先需要对模型输出结果进行softmax处理。公式如下,

其中, output:模型输出,p:对模型输出进行softmax处理后的值, ​​​​​:真实值的one hot编码​(假设模型在做5分类,如果y_{i}=2,则=[0,0,1,0,0])

代码实现

二分类

import torch
import torch.nn as nn
import mathcriterion = nn.BCELoss()
output = torch.rand(1, requires_grad=True)
label = torch.randint(0, 1, (1,)).float()
loss = criterion(output, label)print("预测值:", output)
print("真实值:", label)
print("nn.BCELoss:", loss)for i in range(label.shape[0]):if label[i] == 0:res = -math.log(1-output[i])elif label[i] == 1:res = -math.log(output[i])
print("自己的计算结果", res)"""
预测值: tensor([0.7359], requires_grad=True)
真实值: tensor([0.])
nn.BCELoss: tensor(1.3315, grad_fn=<BinaryCrossEntropyBackward0>)
自己的计算结果 1.331509556677378
"""

多分类

import torch
import torch.nn as nn
import mathcriterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)print("预测值:", output)
print("真实值:", label)
print("nn.CrossEntropyLoss:", loss)output = torch.softmax(output, dim=1)
print("softmax后的预测值:", output)one_hot = torch.zeros_like(output).scatter_(1, label.view(-1, 1), 1)
print("真实值对应的one_hot编码", one_hot)res = (-torch.log(output) * one_hot).sum()
print("自己的计算结果", res)"""
预测值: tensor([[-0.7459, -0.3963, -1.8046,  0.6815,  0.2965]], requires_grad=True)
真实值: tensor([1])
nn.CrossEntropyLoss: tensor(1.9296, grad_fn=<NllLossBackward0>)
softmax后的预测值: tensor([[0.1024, 0.1452, 0.0355, 0.4266, 0.2903]], grad_fn=<SoftmaxBackward0>)
真实值对应的one_hot编码 tensor([[0., 1., 0., 0., 0.]])
自己的计算结果 tensor(1.9296, grad_fn=<SumBackward0>)
"""

二、Focal Loss

定义

虽然CE Loss能够衡量同一个随机变量中的两个不同概率分布的差异程度,但无法解决以下两个问题:1、正负样本数量不平衡的问题(如centernet的分类分支,它只将目标的中心点作为正样本,而把特征图上的其它像素点作为负样本,可想而知正负样本的数量差距之大);2、无法区分难易样本的问题(易分类的样本的分类错误的损失占了整体损失的绝大部分,并主导梯度)

为了解决以上问题,Focal Loss在CE Loss的基础上改进,引入了:1、正负样本数量调节因子以解决正负样本数量不平衡的问题;2、难易样本分类调节因子以聚焦难分类的样本

公式

二分类

公式如下,

 

​​​​​​​

其中,\alpha:正负样本数量调节因子,\gamma:难易样本分类调节因子

多分类

其中,\alpha _{y_{i}}y_{i}类别的权重

代码实现

二分类

def sigmoid_focal_loss(inputs: torch.Tensor,targets: torch.Tensor,alpha: float = -1,gamma: float = 2,reduction: str = "none",
) -> torch.Tensor:"""Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.Args:inputs: A float tensor of arbitrary shape.The predictions for each example.targets: A float tensor with the same shape as inputs. Stores the binaryclassification label for each element in inputs(0 for the negative class and 1 for the positive class).alpha: (optional) Weighting factor in range (0,1) to balancepositive vs negative examples. Default = -1 (no weighting).gamma: Exponent of the modulating factor (1 - p_t) tobalance easy vs hard examples.reduction: 'none' | 'mean' | 'sum''none': No reduction will be applied to the output.'mean': The output will be averaged.'sum': The output will be summed.Returns:Loss tensor with the reduction option applied."""inputs = inputs.float()targets = targets.float()p = torch.sigmoid(inputs)ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")p_t = p * targets + (1 - p) * (1 - targets)loss = ce_loss * ((1 - p_t) ** gamma)if alpha >= 0:alpha_t = alpha * targets + (1 - alpha) * (1 - targets)loss = alpha_t * lossif reduction == "mean":loss = loss.mean()elif reduction == "sum":loss = loss.sum()return loss

步骤1、首先对输入进行sigmoid处理,

p = torch.sigmoid(inputs)

步骤2、随后求出CE Loss,

ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

步骤3、定义p_{t}^{i},公式为:

p_t = p * targets + (1 - p) * (1 - targets)

步骤4、为CE Loss添加难易样本分类调节因子,

loss = ce_loss * ((1 - p_t) ** gamma)

步骤5、定义\alpha _{t}^{i},公式为:

alpha_t = alpha * targets + (1 - alpha) * (1 - targets)

步骤6、为步骤4的损失添加正负样本数量调节因子,

loss = alpha_t * loss

多分类

def multi_cls_focal_loss(inputs: torch.Tensor,targets: torch.Tensor,alpha: torch.Tensor,gamma: float = 2,reduction: str = "none",
) -> torch.Tensor:inputs = inputs.float()targets = targets.float()ce_loss = nn.CrossEntropyLoss()(inputs, targets, reduction="none")one_hot = torch.zeros_like(inputs).scatter_(1, targets.view(-1, 1), 1)p_t = inputs * one_hotloss = ce_loss * ((1 - p_t) ** gamma)if alpha >= 0:alpha_t = alpha * one_hotloss = alpha_t * lossreturn loss

三、GHMC Loss

定义

Focal Loss在CE Loss的基础上改进后,解决了正负样本不平衡以及无法区分难易样本的问题,但也会过分关注难分类的样本(离群点),导致模型学歪。为了解决这个问题,GHMC(Gradient Harmonizing Mechanism-C)定义了梯度模长,该梯度模长正比于分类的难易程度,目的是让模型不要关注那些容易学的样本,也不要关注那些特别难分的样本

公式

1、定义梯度模长

二分类的CE Loss公式如下,

假设x是模型的输出,假设p=sigmoid(x),求损失对x的偏导,

因此,定义梯度模长如下,

其中, p:预测值,p^{\ast }:真实值

梯度模长与样本数量的关系如下,

2、定义梯度密度(单位梯度模长g上的样本数量

  

其中,g_{k}:第k个样本的梯度模长,\delta _{\varepsilon }(g_{k},g)g_{k}(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})范围内的样本数量,l_{\varepsilon }(g):区间(g-\frac{\varepsilon }{2},g+\frac{\varepsilon }{2})的长度

3、定义梯度密度协调参数(gradient density harmonizing parameter)

其中,N:样本总数

 4、定义GHMC Loss

 

代码实现

def _expand_binary_labels(labels, label_weights, label_channels):bin_labels = labels.new_full((labels.size(0), label_channels), 0)inds = torch.nonzero(labels >= 1).squeeze()if inds.numel() > 0:bin_labels[inds, labels[inds] - 1] = 1bin_label_weights = label_weights.view(-1, 1).expand(label_weights.size(0), label_channels)return bin_labels, bin_label_weightsclass GHMC(nn.Module):def __init__(self,bins=10,momentum=0,use_sigmoid=True,loss_weight=1.0):super(GHMC, self).__init__()self.bins = binsself.momentum = momentumself.edges = [float(x) / bins for x in range(bins+1)]self.edges[-1] += 1e-6if momentum > 0:self.acc_sum = [0.0 for _ in range(bins)]self.use_sigmoid = use_sigmoidself.loss_weight = loss_weightdef forward(self, pred, target, label_weight, *args, **kwargs):""" Args:pred [batch_num, class_num]:The direct prediction of classification fc layer.target [batch_num, class_num]:Binary class target for each sample.label_weight [batch_num, class_num]:the value is 1 if the sample is valid and 0 if ignored."""if not self.use_sigmoid:raise NotImplementedError# the target should be binary class labelif pred.dim() != target.dim():target, label_weight = _expand_binary_labels(target, label_weight, pred.size(-1))target, label_weight = target.float(), label_weight.float()edges = self.edgesmmt = self.momentumweights = torch.zeros_like(pred)# 计算梯度模长g = torch.abs(pred.sigmoid().detach() - target)valid = label_weight > 0tot = max(valid.float().sum().item(), 1.0)# 设置有效区间个数n = 0for i in range(self.bins):inds = (g >= edges[i]) & (g < edges[i+1]) & validnum_in_bin = inds.sum().item()if num_in_bin > 0:if mmt > 0:self.acc_sum[i] = mmt * self.acc_sum[i] \+ (1 - mmt) * num_in_binweights[inds] = tot / self.acc_sum[i]else:weights[inds] = tot / num_in_binn += 1if n > 0:weights = weights / nloss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / totreturn loss * self.loss_weight

步骤一、将梯度模长划分为bins(默认为10)个区域,

self.edges = [float(x) / bins for x in range(bins+1)]
"""
[0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 0.5000, 0.6000, 0.7000, 0.8000, 0.9000, 1.0000]
"""

步骤二、计算梯度模长

g = torch.abs(pred.sigmoid().detach() - target)

步骤三、计算落入不同bin区间的梯度模长数量

valid = label_weight > 0
tot = max(valid.float().sum().item(), 1.0)
n = 0
for i in range(self.bins):inds = (g >= edges[i]) & (g < edges[i+1]) & validnum_in_bin = inds.sum().item()if num_in_bin > 0:if mmt > 0:self.acc_sum[i] = mmt * self.acc_sum[i] + (1 - mmt) * num_in_binweights[inds] = tot / self.acc_sum[i]else:weights[inds] = tot / num_in_binn += 1
if n > 0:weights = weights / n

步骤四、计算GHMC Loss

loss = F.binary_cross_entropy_with_logits(pred, target, weights, reduction='sum') / tot * self.loss_weight

【参考文章】

Focal Loss的理解以及在多分类任务上的使用(Pytorch)_focal loss 多分类_GHZhao_GIS_RS的博客-CSDN博客

focal loss 通俗讲解 - 知乎

Focal Loss损失函数(超级详细的解读)_BigHao688的博客-CSDN博客

5分钟理解Focal Loss与GHM——解决样本不平衡利器 - 知乎 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/51.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解预训练(pre-learning)、微调(fine-tuning)、迁移学习(transfer learning)三者的联系与区别

1. 什么是预训练和微调 你需要搭建一个网络模型来完成一个特定的图像分类的任务。首先&#xff0c;你需要随机初始化参数&#xff0c;然后开始训练网络&#xff0c;不断调整参数&#xff0c;直到网络的损失越来越小。在训练的过程中&#xff0c;一开始初始化的参数会不断变化。…

小程序:页面跳转闪屏

自己的笔记&#xff0c;随手记录。扛精走开。 1、问题描述 进入页面&#xff0c;是一个组件&#xff0c;通过路由传参判断是由哪个页面进入&#xff0c;不同的页面拿的已选值不一样&#xff0c;需要回显值&#xff0c;在编辑数据。此时会出现一个问题&#xff0c;A页面中进来…

Serverless和EDA是绝配,亚马逊云科技CTO Werner表示需要用开放心态来重新审视架构

前一段有个很火的博客&#xff0c;讲的是一家全球流媒体企业的监测系统从Serverless微服务改成了单体&#xff0c;成本居然降低了90%&#xff01;这一下子可在网上炸锅了&#xff0c;特别是一些看不惯微服务的、单体应用的拥趸&#xff0c;更是坐不住了。但这并不像吃瓜群众看到…

【ECharts系列】ECharts 图表渲染问题解决方案

1 问题描述 echats 渲染&#xff0c;第一次的时候只出现Y轴数值&#xff0c;不出现X轴数值&#xff0c;切换下页面&#xff0c;X轴数值就能出现。 2 原因分析 如果在使用ECharts渲染时&#xff0c;X轴数值只在切换页面后才出现&#xff0c;可能是因为ECharts在初始化时没有正确…

光速吟唱,Clibor ,批量多次复制依次粘贴工具 快捷输入软件教程

批量多次复制依次粘贴工具 批量复制粘贴工具0.81.exe https://www.aliyundrive.com/s/3sbBaGmHkb8 点击链接保存&#xff0c;或者复制本段内容&#xff0c;打开「阿里云盘」APP &#xff0c;无需下载极速在线查看&#xff0c;视频原画倍速播放。 青县solidworks钣金设计培训 …

Redis进阶 - Redis哨兵

原文首更地址&#xff0c;阅读效果更佳&#xff01; Redis进阶 - Redis哨兵 | CoderMast编程桅杆https://www.codermast.com/database/redis/redis-advance-sentinel.html 思考 slave 节点宕机恢复以后可以找 master 节点同步数据&#xff0c;那么 master 节点宕机怎么办&am…

[Docker] Docker镜像管理和操作实践(二) 文末送书

前言&#xff1a; Docker镜像是容器化应用程序的打包和分发单元&#xff0c;包含了应用程序及其所有依赖项&#xff0c;实现了应用程序的可移植性和一致性。 文章目录 使用Dockerfile创建自定义镜像实践练手1. 创建基于ubuntu的自定义镜像&#xff0c;并安装nginx2. 配置Redis容…

MySQL数据库高级查询语句

MySQL数据库高级查询语句 一、语句SELECT ----显示表格中一个或数个字段的所有数据记录DISTINCT ----不显示重复的数据记录WHERE ----有条件查询AND OR ----且 或IN ----显示已知的值的数据记录BETWEEN ----显示两个值范围内的数据记录通配符 ----通常通配符都是跟 LIKE 一起使…

【云原生】二进制k8s集群(下)部署高可用master节点

本次部署说明 在上一篇文章中&#xff0c;就已经完成了二进制k8s集群部署的搭建&#xff0c;但是单机master并不适用于企业的实际运用&#xff08;因为单机master中&#xff0c;仅仅只有一台master作为节点服务器的调度指挥&#xff0c;一旦宕机。就意味着整个集群的瘫痪&#…

宝塔Linux面板安装Composer依赖管理工具与PHP依赖包的方法

最近看见腾讯云有一个AI绘画还挺有意思&#xff0c;想搞来写个接口玩 但是Composer一直运行不成功 提示xdebug什么的 最后经过搜索 发现 需要删除你宝塔里所有php中禁用的putenv函数 然后重启php就可以了&#xff01; 然后就可以运行这个命令了 出现这种情况 还需要删除所有…

Linux常用命令——emacs命令

在线Linux命令查询工具 emacs 功能强大的全屏文本编辑器 补充说明 emacs命令是由GNU组织的创始人Richard Stallman开发的一个功能强大的全屏文本编辑器&#xff0c;它支持多种编程语言&#xff0c;具有很多优良的特性。有众多的系统管理员和软件开发者使用emacs。 语法 e…

服务器技术(三)--Nginx

Nginx介绍 Nginx是什么、适用场景 Nginx是一个高性能的HTTP和反向代理服务器&#xff0c;特点是占有内存少&#xff0c;并发能力强&#xff0c;事实上nginx的并发能力确实在同类型的网页服务器中表现较好。 Nginx专为性能优化而开发&#xff0c;性能是其最重要的考量&#xf…

【Nginx】rewrite简单使用

前言 没有对正式的rewrite进行了解&#xff0c;为了能快速了解它是干嘛怎么用&#xff0c;找了一些有例子的博客进行简单学习了一下&#xff1b;由于每次看的间隔有点大&#xff0c;老忘记&#xff0c;这回专门写个超级快速理解的例子。 PS&#xff1a;下面的解释可能会不太对…

LAXCUS分布式操作系统存在的意义和价值

总有一些新用户不能理解LAXCUS分布式操作系统&#xff0c;以及它存在的意义和价值&#xff0c;我这样说吧。 下图是一个图形桌面&#xff08;LAXCUS的图形桌面&#xff0c;不是Windows、也不是Macintosh&#xff09;&#xff0c;在它后面&#xff0c;连着一个计算机集群&#…

数据库技术与应用——目录篇

数据库技术与应用目录 文章目录 第1章 数据库基础知识数据库技术的概念数据管理的发展数据库的体系结构数据库管理系统常用的数据库管理系统介绍 第2章 信息得三种世界与数据模型信息的三种世界及其描述数据模型 第3章 关系模型关系模型的由来关系数据库的结构关系代数关系演算…

Linux命令----modprobe命令详解

【原文链接】Linux命令----modprobe命令详解 一、modprobe命令的作用 加载内核模块&#xff1a; 使用modprobe命令可以加载指定的内核模块到运行中的内核中。加载内核模块可以在运行时添加新的功能、驱动程序或修改内核行为。 解决模块依赖关系&#xff1a; modprobe命令可以…

基于OpenCV的人脸对齐步骤详解及源码实现

目录 1. 前言2. 人脸对齐基本原理与步骤3. 人脸对齐代码实现 1. 前言 在做人脸识别的时候&#xff0c;前期的数据处理过程通常会遇到一个问题&#xff0c;需要将各种人脸从不同尺寸的图像中截取出来&#xff0c;再进行人脸对齐操作&#xff1a;即将人脸截取出来并将倾斜的人脸…

【图像处理】经营您的第一个U-Net以进行图像分割

一、说明 AI厨师们&#xff0c;今天您将学习如何准备计算机视觉中最重要的食谱之一&#xff1a;U-Net。本文将叙述&#xff1a;1语义与实例分割&#xff0c;2 图像分割中还使用了其他损失&#xff0c;例如Jaccard损失&#xff0c;焦点损失&#xff1b;3 如果2D图像分割对您来说…

spring 详解二 IOC(Bean xml配置及DI)

配置列表 Xml配置 功能描述 <bean id"" class""></bean> Bean的id&#xff0c;配置id会转为Bean名称和不配就是全限定类名 <bean name"" ></bean> Bean的别名配置&#xff0c;存储在Factory的aliasMap中通过别名也…

【SpringBoot】SpringBoot的自动配置源码解析

文章目录 1. SpringBoot的自动配置概念2. SpringBoot自动配置的原理3. EnableAutoConfiguration4. 常用的Conditional注解 1. SpringBoot的自动配置概念 SpringBoot相对于SSM来说&#xff0c;主要的优点就是简化了配置&#xff0c;不再需要像SSM哪有写一堆的XML配置&#xff0…