LMFLOSS:专治解决不平衡医学图像分类的新型混合损失函数 (附代码)

 

论文地址:https://arxiv.org/pdf/2212.12741.pdf

代码地址:https://github.com/SanaNazari/LMFLoss

1.是什么?

LMFLOSS是一种用于不平衡医学图像分类的混合损失函数。它是由Focal Loss和LDAM Loss的线性组合构成的,旨在更好地处理不平衡数据集。Focal Loss通过强调难以分类的样本来提高模型的性能,而LDAM Loss则考虑了数据集的类别分布来调整权重。

2.为什么?

先来简单回顾下,对于类别不均衡问题,以往的方法是如何解决的。大体上主要有两种,即以数据为中心驱动和以算法为中心的解决方案。

数据策略

以数据为中心的类别不均衡解决方法主要有两种:过采样欠采样。过采样试图为少数类别生成人工数据点,而欠采样旨在消除多数类别的样本。

算法策略

算法层面的策略,特别是在深度学习领域,主要侧重于开发损失函数来应对类不平衡问题,而不是直接操纵数据。一种简单的方式便是为每个类别都设置相应的权重,以便与多数类别相比,少数类别样本的错误分类受到更严重的惩罚。另一种方法是为每个训练样本自适应地设置一个唯一的权重,以便硬样本获得更高的权重。

作者便提出了一种称为 Large Margin aware Focal (LMF) Loss 的新型损失函数,以缓解医学成像中的类不平衡问题。该损失函数动态地同时考虑硬样本和类分布。

3.怎么样

3.1 Focal Loss

说到类别不均衡的损失函数,不得不提的便是 Focal Loss。对于分类问题,大家常用的便是交叉熵损失 BCE Loss,该损失函数对所有类别均一视同仁,即赋予同等的权重学习。而 Focal Loss 主要就是交叉熵损失改进的,通过引入 \alpha 和 \gamma 两个调节因子来调整样本数量和样本难易程度,以便模型专注于学习少数类。具体公式如下:

3.2 LDAM Loss

《 Learning imbalanced datasets with label-distribution-aware margin loss 》 这篇文章中提出了另一项减轻类不平衡问题的工作,称为标签分布感知边距(LDAM)损失。作者建议对少数类引入比多数类更强的正则化,以减少它们的泛化误差。如此一来,损失函数保持了模型学习多数类并强调少数类的能力。LDAM 损失侧重于每个类的最小边际和获得每个类和统一标签测试错误,而不是鼓励大多数类训练样本与决策边界的大边距。换句话说,它只会鼓励少数群体获得相对较大的利润。此外,作者提出了用于获得多个类别 1、2、...、k 的类别相关边距的公式: \gamma _{j} = \frac{C}{n_{j}1/4^{}}.

这里 j∈1,...,k 表示特定类,n_{j}表示每个类别的样本数,C为固定的常数。现在,让我们定义出一个样本对 (x,y),x 为样本,y为对应的标签,同时给定一个模型 f。考虑下面这个函数映射:x_{y}=f(x)_{y};我们令 u=e^{z_{y}-p_{y}},这里对于每一个类别j∈1,...,k 都有 p_{j}=\frac{C}{n_{j}^{1/4}}。因此,LDAM 损失便可以定义为:

3.3 LMF Loss

Focal Loss 创建了一种机制,可以更加强调模型难以分类的样本;通常,来自少数群体的样本将属于这一类。相比之下,LDAM Loss 通过考虑数据集的类别分布来判断权重。我们假设与单独使用每个功能相比,同时利用这两个功能可以产生有效的结果。因此,作者提出的 Large Margin aware Focal (LMF) 损失是 Focal 损失和由两个超参数加权的 LDAM 的线性组合,公式如下:

这里,α 和 β 是常数,被认为是可以调整的超参数。 因此,本文提出的损失函数在单个框架中联合优化了两个独立的损失函数。通过反复试验,作者发现将相同的权重分配给两个组件会产生良好的结果。

3.4 代码实现

# -*- coding: utf-8 -*-
"""
Created on Wed May 24 17:03:06 2023@author: Sana
"""import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from ..builder import LOSSESclass FocalLoss(nn.Module):def __init__(self, alpha, gamma=2):super().__init__()self.alpha = alphaself.gamma = gammadef forward(self, output, target):num_classes = output.size(1)assert len(self.alpha) == num_classes, \'Length of weight tensor must match the number of classes'logp = F.cross_entropy(output, target, self.alpha)p = torch.exp(-logp)focal_loss = (1 - p) ** self.gamma * logpreturn torch.mean(focal_loss)class LDAMLoss(nn.Module):def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):"""max_m: The appropriate value for max_m depends on the specific dataset and the severity of the class imbalance.You can start with a small value and gradually increase it to observe the impact on the model's performance.If the model struggles with class separation or experiences underfitting, increasing max_m might help. However,be cautious not to set it too high, as it can cause overfitting or make the model too conservative.s: The choice of s depends on the desired scale of the logits and the specific requirements of your problem.It can be used to adjust the balance between the margin and the original logits. A larger s value amplifiesthe impact of the logits and can be useful when dealing with highly imbalanced datasets.You can experiment with different values of s to find the one that works best for your dataset and model."""super(LDAMLoss, self).__init__()m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))m_list = m_list * (max_m / np.max(m_list))m_list = torch.cuda.FloatTensor(m_list)self.m_list = m_listassert s > 0self.s = sself.weight = weightdef forward(self, x, target):index = torch.zeros_like(x, dtype=torch.uint8)index.scatter_(1, target.data.view(-1, 1), 1)index_float = index.type(torch.cuda.FloatTensor)batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))batch_m = batch_m.view((-1, 1))x_m = x - batch_moutput = torch.where(index, x_m, x)return F.cross_entropy(self.s * output, target, weight=self.weight)@LOSSES.register_module()
class LMFLoss(nn.Module):def __init__(self, cls_num_list, weight, alpha=1, beta=1, gamma=2, max_m=0.5, s=30):super().__init__()self.focal_loss = FocalLoss(weight, gamma)self.ldam_loss = LDAMLoss(cls_num_list, max_m, weight, s)self.alpha = alphaself.beta = betadef forward(self, output, target):focal_loss_output = self.focal_loss(output, target)ldam_loss_output = self.ldam_loss(output, target)total_loss = self.alpha * focal_loss_output + self.beta * ldam_loss_outputreturn total_loss

参考:Focal Loss 后继之秀 | LMFLOSS:专治解决不平衡医学图像分类的新型混合损失函数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/121307.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringSecurity 认证实战

一. 项目数据准备 1.1 添加依赖 <dependencies><!--spring security--><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-security</artifactId></dependency><!--web起步依赖-…

语雀故障事件——P0级别事故启示录 发生肾么事了? 怎么回事?

前言 最近&#xff0c;阿里系的语雀出了一个大瓜&#xff0c;知名在线文档编辑与协同工具语雀发生故障&#xff0c;崩溃近10小时。。。。最后&#xff0c;官方发布了一则公告&#xff0c;我们一起来看看这篇公告&#xff0c;能不能有所启发。 目录 前言引出一、语雀P0故障回顾…

重复控制器的性能优化

前言 重复控制器在控制系统中是比较优秀的控制器&#xff0c;在整流逆变等周期性输入信号时&#xff0c;会有很好的跟随行&#xff0c;通常可以单独使用&#xff0c;也可以与其他补偿器串联并联使用。 这里我来分析一下重复控制器的重复控制器的应用工况以及其的优缺点。 分析…

Mybatis-Plus(企业实际开发应用)

一、Mybatis-Plus简介 MyBatis-Plus是MyBatis框架的一个增强工具&#xff0c;可以简化持久层代码开发MyBatis-Plus&#xff08;简称 MP&#xff09;是一个 MyBatis 的增强工具&#xff0c;在 MyBatis 的基础上只做增强不做改变&#xff0c;为简化开发、提高效率而生。 官网&a…

Python深度学习实战-基于class类搭建BP神经网络实现分类任务(附源码和实现效果)

实现功能 上篇文章介绍了用Squential搭建BP神经网络&#xff0c;Squential可以搭建出上层输出就是下层输入的顺序神经网络结构&#xff0c;无法搭出一些带有跳连的非顺序网络结构&#xff0c;这个时候我们可以选择类class搭建封装神经网络结构。 第一步&#xff1a;import ten…

基于情感词典的情感分析方法

计算用户情绪强弱性&#xff0c;对于每一个文本都可以得到一个情感分值&#xff0c;以情感分值的正负性表示情感极性&#xff0c;大于0为积极情绪&#xff0c;小于0反之&#xff0c;绝对值越大情绪越强烈。 基于情感词典的情感分析方法主要思路&#xff1a; 1、对文本进行分词…

影响光源的因素

影响光源的因素 对比度 1.对比度 均匀性 2.均匀性 色彩还原性 3.色彩还原性 其他因素&#xff1a; 4. 亮度 &#xff1a; 光源 亮度是光源选择时的重要参考&#xff0c;尽量选择亮度高的光源。 5. 鲁棒性 &#xff1a; 鲁棒性是指光源是否对部件的位置敏感度最小 。 6. 光…

不同设备的请求头信息UserAgent,Headers

一、电脑端 【设备名称】&#xff1a;电脑 Win10 【应用名称】&#xff1a;win10 Edge 【浏览器信息】&#xff1a;名称:(Chrome)&#xff1b;版本:(70.0) 【请求头信息】&#xff1a;Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Ch…

企业如何安全跨国传输30T文件数据

对于一些对数据敏感性比较高的企业&#xff0c;如IT企业和国企等&#xff0c;跨国数据传输是当今企业面临的一个重要挑战&#xff0c;尤其是当数据量达到30T这样的规模时&#xff0c;如何保证数据的速度、安全和合规性&#xff0c;就成为了企业必须考虑的问题。本文将从以下几个…

【Java题】输出基本数据类型的最大值和最小值,以及float和double的正无穷大值和负无穷大值

一&#xff1a;代码 public class Test {public static void main(String[] args) {//输出byte型的最大值与最小值System.out.println(Byte.MAX_VALUE);System.out.println(Byte.MIN_VALUE);//输出short型的最大值与最小值System.out.println(Short.MAX_VALUE);System.out.pri…

RTE(Runtime Environment)

RTE&#xff08;Runtime Environment&#xff09;是一个运行时环境&#xff0c;在这个环境里&#xff0c;你可以实现的功能是&#xff1a; 作为一个缓冲buffer给应用层和BSW层的接口&#xff08;例如COM&#xff09;用来存储数据&#xff0c;也就是说定义一个全局变量供上层和下…

Django实战项目-学习任务系统-任务管理

接着上期代码框架&#xff0c;开发第3个功能&#xff0c;任务管理&#xff0c;再增加一个学习任务表&#xff0c;用来记录发布的学习任务的标题和内容&#xff0c;预计完成天数&#xff0c;奖励积分和任务状态等信息。 第一步&#xff1a;编写第三个功能-任务管理 1&#xff0…

二、BurpSuite Decoder解码器

一、编码解码 解释&#xff1a;BurpSuite 可以用这个模块来轻松进行编码解码&#xff0c;下面是支持的类型 URL HTML Base64 ASCIIhex Hex Octal Binary Gzip 注意&#xff1a;特别注意的是URL编码&#xff0c;一般的在线网站都无法对比如‘abc’的文本编码&#xff0c;burps…

目标检测及锚框、IoU

1. 目标检测 物体检测&#xff08;目标检测&#xff09;是计算机视觉和数字图像处理的热门方向&#xff0c;意在判断一幅图像上是否存在感兴趣物体&#xff0c;并给出物体分类及位置等&#xff08;What and Where&#xff09;。本文主要进行物体检测研究背景、发展脉络、相关算…

禁止chrome浏览器更新方式

1、禁用更新服务 WinR调出运行&#xff0c;输入services.msc&#xff0c;进入服务。 在服务中有两个带有Google Update字样&#xff0c;双击打开后禁用&#xff0c;并把恢复选项设置为无操作。 2、删除计划任务 运行taskschd.msc&#xff0c;打开计划任务程序库&#xff0c;在…

SDRAM学习笔记(MT48LC16M16A2,w9812g6kh)

一、基本知识 SDRAM : 即同步动态随机存储器&#xff08;Synchronous Dynamic Random Access Memory&#xff09;, 同步是指其时钟频率与对应控制器&#xff08;CPU/FPGA&#xff09;的系统时钟频率相同&#xff0c;并且内部命令 的发送与数据传输都是以该时钟为基准&#xff…

【C#】LIMS实验室信息管理系统源码

一、系统概述 LIMS(Laboratory Information Management System)即实验室信息管理系统,是通过对样品检验流程、分析数据及报告、实验室资源和客户信息等要素的综合管理,按照标准化实验室管理规范,建立符合实验室业务流程的质量体系,实现实验室信息化管理。是实验室提高分析水平…

CSS 滚动驱动动画与 @keyframes 新语法

CSS 滚动驱动动画与 keyframes 在 CSS 滚动驱动动画相关的属性出来之后, keyframes 也迎来变化. 以前, keyframes 的值可以是 from, to, 或者百分数. 现在它多了一种属性的值 <timeline-range-name> <percentage> 建议先了解 animation-range 不然你会对 timeli…

[RISC-V]verilog

小明教IC-1天学会verilog(7)_哔哩哔哩_bilibili task不可综合&#xff0c;function可以综合