F.binary_cross_entropy、nn.BCELoss、nn.BCEWithLogitsLoss与F.kl_div函数详细解读

提示:有关loss损失函数详细解读,并附源码!!!

文章目录

  • 前言
  • 一、F.binary_cross_entropy()函数解读
    • 1.函数表达
    • 2.函数运用
  • 二、nn.BCELoss()函数解读
    • 1.函数表达
    • 2.函数运用
  • 三、nn.BCEWithLogitsLoss()函数解读
    • 1.函数表达
    • 2.函数运用(logit探索)
    • 3.函数运用(pred探索)
  • 四、F.kl_div()函数解读


前言

最近我在构建蒸馏相关模型,我重温了一下交叉熵相关内容,也使用pytorch相关函数接口调用,我将对F.binary_cross_entropy()、nn.BCELoss()与nn.BCEWithLogitsLoss()函数做一个说明,同时也简单介绍相对熵的蒸馏F.kl_div()函数做一个介绍。

一、F.binary_cross_entropy()函数解读

1.函数表达

F.binary_cross_entropy(input: Tensor,  # 预测输入target: Tensor, # 标签weight: Optional[Tensor] = None, # 权重可选项size_average: Optional[bool] = None,  # 可选项,快被弃用了reduce: Optional[bool] = None,reduction: str = "mean",  # 默认均值或求和等形式
) -> Tensor:

该函数实际是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def binary_cross_entropy():input = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# s = nn.Sigmoid()# pred = s(input)target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])weight = torch.tensor([[0.1, 0.9, 0.1],[0.1, 0.1, 0.9]])output_weight = F.binary_cross_entropy(input, target,weight=weight)  # input取值范围[0,1]output = F.binary_cross_entropy(input, target)  # input取值范围[0,1]print('预测数据:',input)print('标签数据:',target)print('\nbinary_cross_entropy-有权重:{}\t无权重:{}\n'.format(output_weight, output))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])binary_cross_entropy-有权重:0.12723307311534882	无权重:0.5912299752235413

二、nn.BCELoss()函数解读

1.函数表达

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction='mean')

参数说明:
weight :用于样本加权的权重张量。如果给定,则必须是一维张量,大小等于输入张量的大小。默认值为 None。
reduction :指定如何计算损失值。可选值为 ‘none’、‘mean’ 或 ‘sum’。默认值为 ‘mean’。

此为类,是对F.binary_cross_entropy()函数的调用,也是交叉熵运算方式,其中input、target与权重有相同维度(batch,),其中表示可以是任何维度。同时,input为模型预测其每个元素取值范围在[0,1]间。

2.函数运用

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bceloss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])# pred = s(pred)  # 一般会经过sigmoid或softmax方式将其预测转为[0,1]范围的值target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:',pred)print('标签数据:',target)print('\nbceloss:{}\n'.format(b))

结果如下:

预测数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.5912299752235413

可以看出该函数与上面无权重运行结果一致,实际是对上一个函数进行了类包装,其计算方式和上面函数完全一样。

三、nn.BCEWithLogitsLoss()函数解读

1.函数表达

torch.nn.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)

参数说明:
weight:用于对每个样本的损失值进行加权。默认值为 None。
reduction:指定如何对每个 batch 的损失值进行降维。可选值为 ‘none’、‘mean’ 和 ‘sum’。默认值为 ‘mean’。
pos_weight:用于对正样本的损失值进行加权。可以用于处理样本不平衡的问题。例如,如果正样本比负样本少很多,可以设置 pos_weight 为一个较大的值,以提高正样本的权重。默认值为 None。

2.函数运用(logit探索)

假设输入input经过sigmoid或softmax等方式将其值转为[0,1]范围预测,target为one-hot标签(也可是教师的软标签形式),其应用代码如下:

import torch
import torch.nn.functional as F
def bce_logit_loss():s = nn.Sigmoid()  # 输出是pred = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错pred = s(pred)# nn.BCELoss输入的pred与target的形状必须相同,实际是交叉熵计算,target没有限制bce = nn.BCELoss(reduction='mean')  # size_average参数将被遗弃,使用reduction决定后续操作,有mean sumb = bce(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print('\nbceloss:{}\t bce_with_logit:{} \n'.format(b, b_logit))

结果如下:

预测数据: tensor([[0.6225, 0.7311, 0.6900],[0.5498, 0.5987, 0.6457]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bceloss:0.7678468823432922	 bce_with_logit:0.7678468823432922

可以看出,nn.BCELoss只需多一个nn.Sigmoid()得到的结果和nn.BCEWithLogitsLoss是一致的,说明该类只是多了一个logit过程。

3.函数运用(pred探索)

import torch
import torch.nn.functional as F
def bce_logit_loss():pred = torch.tensor([[5, 1, 8.0], [2, 4, 6.0]])target = torch.tensor([[0, 1.0, 0], [0, 0, 1.0]])bce_logit = nn.BCEWithLogitsLoss(reduction='mean')b_logit = bce_logit(pred, target)  # pred元素取值范围是[0,1]之间,否则会报错print('预测数据:', pred)print('标签数据:', target)print(' bce_with_logit:{} \n'.format( b_logit))

结果如下:

预测数据: tensor([[5., 1., 8.],[2., 4., 6.]])
标签数据: tensor([[0., 1., 0.],[0., 0., 1.]])bce_with_logit:3.2446444034576416 

可以看出nn.BCEWithLogitsLoss的输入是可以为实数,它先进行sigmoid处理,将其输入变为[0,1]范围,在进行交叉熵运算,然上面nn.BCELoss与F.binary_cross_entropy则不行。

四、F.kl_div()函数解读

该函数为蒸馏模型使用的函数,我直接给出示列,如下:

def kl_func():logits = torch.tensor([[0.5, 1.0, 0.8], [0.2, 0.4, 0.6]])probs = torch.nn.functional.softmax(logits, dim=1)  # 预测学生模型target_probs = torch.tensor([[0.3, 0.4, 0.3], [0.1, 0.5, 0.4]])  # 教师模型loss = F.kl_div(torch.log(probs), target_probs, reduction='batchmean')print('模型输出数据:', logits)print('预测数据:',probs)print('标签数据:',target_probs)print('\nkl_loss:{}\n'.format(loss))

输出结果:

模型输出数据: tensor([[0.5000, 1.0000, 0.8000],[0.2000, 0.4000, 0.6000]])
预测数据: tensor([[0.2501, 0.4123, 0.3376],[0.2693, 0.3289, 0.4018]])
标签数据: tensor([[0.3000, 0.4000, 0.3000],[0.1000, 0.5000, 0.4000]])kl_loss:0.057796258479356766

参考文章:点击这里

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/144433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue dev-tools的安装

安装 Vue 开发者工具,装插件调试Vue应用 1.通过谷歌应用商店来进行安装(国外网站) 2.极简插件: 搜索 Vue -> 下载解压 -> 浏览器扩展模式打开,开发者模式 -> 将解压的CRX文件拖拽安装 -> 插件详情 &…

vscode Prettier配置

常用配置项: .prettierrc.json 是 Prettier 格式化工具的配置文件 {"printWidth": 200, // 指定行的最大长度"tabWidth": 2, // 指定缩进的空格数"useTabs": false, // 是否使用制表符进行缩进,默认为 false"singl…

华为Matebook X Pro 2022款 i7 集显(MRG-W76)原装出厂Windows11预装系统21H2

下载链接:https://pan.baidu.com/s/12ru9lUeQ7mWd5u1KLCM0Pg?pwdc7pi 提取码:c7pi 原厂系统自带指纹、面部识别、声卡、网卡、显卡等所有驱动、出厂主题壁纸、Office办公软件、华为电脑管家等预装程序,如图 由于时间关系,绝大部分资料没…

照亮夜晚的台灯:户外空间的闪亮之选

户外台灯是家庭和社交空间的重要元素,它们不仅提供照明,还可以为您的户外区域增添美感,以及创造一个温馨的社交氛围。以下是一些关于户外台灯的信息,以帮助您更好地了解它们的多功能性和用途。 1、照明的重要性:户外台…

工作中积累的对K8s的就绪和存活探针的一些认识

首先,我的项目是基于 Spring Boot 2.3.5 的,并依赖 spring-boot-starter-actuator 提供的 endpoints 来实现就绪和存活探针,POM 文件如下图: 下面,再让我们来看下与该项目对应的Deployment的YAML文件,如下…

ES的索引概念

1. 概念:Elasticsearch(ES)是一个开源的全文搜索引擎,可以快速地存储、搜索和分析大量的结构化和非结构化数据。 2. 索引的作用:ES索引是将数据存储在Elasticsearch中的基本方式。它用于存储、搜索、分析和查询数据。…

Mac代码文本编辑器Sublime Text 4

Sublime Text 4 for Mac拥有快速响应的功能,可以快速加载文件和执行命令,并提供多种语言支持,包括C 、Java、Python、HTML、CSS等。此外,该编辑器还支持LaTeX、Markdown、JSON、XML等技术领域。 Sublime Text 4 for Mac的插件丰富…

【纯干货】医疗视觉大模型2023年进展简述|Medical Vision-language Models (VLM)

写在前面——本篇为原创内容,如转载/引用请务必注明出处!!(最后更新于2023年11月16日) 如有错误,欢迎评论区指出!!不胜感激!! 点赞三连谢谢!!! 如有 Medical…

【SA8295P 源码分析 (三)】125 - MAX96712 解串器 start_stream、stop_stream 寄存器配置 过程详细解析

【SA8295P 源码分析】125 - MAX96712 解串器 start_stream、stop_stream 寄存器配置 过程详细解析 一、sensor_detect_device():MAX96712 检测解串器芯片是否存在,获取chip_id、device_revision二、sensor_detect_device_channels() :MAX96712 解串器 寄存器初始化 及 detec…

K8s Pod 创建埋点处理(Mutating Admission Webhook)

写在前面 工作中涉及到相关的知识在实际的生产中,我们可以有需求对 创建的资源做一些类似 埋点 相关的操作,比如添加一些 Pod 创建时的自定义验证逻辑,类似表单提交验证那样,或者希望对创建的资源对象进行加工,在比如给…

一个怪异的笔记本重启死机问题分析

疫情期间买了个国产的海鲅笔记本,八代i5处理器8269u,显卡是集显里面比较牛的一款,iris 655。 当时买这个笔记本的主要原因是当小主机用的,平时接显示器,用来看网页,写代码,偶尔也能移动&#x…

如何分析伦敦金的价格走势预测?

伦敦金作为国际黄金市场的重要指标,其价格走势一直备受投资者关注。但是,黄金市场的价格变化受到多种因素的影响,因此要准确预测伦敦金的价格走势并非易事。在本文中,将介绍一些常用的方法和工具,帮助您分析伦敦金的价…

金融帝国实验室(Capitalism Lab)V10版本即将推出全新公司徽标(2023-11-13)

>〔在即将推出的V10版本中,我们将告别旧的公司徽标,采用全新光鲜亮丽、富有现代气息的设计,与金融帝国实验室(Capitalism Lab)的沉浸式体验完美互补!〕 ————————————— >〔《公司详细信…

ubuntu20源码编译搭建SRS流媒体服务器

第一、下载源码 下载源码,推荐用Ubuntu20: git clone -b develop https://gitee.com/ossrs/srs.git第二、编译 2.1、切换到srs/trunk目录: cd srs/trunk2.2、执行configure脚本 ./configure2.3、执行make命令 make2.4、修改conf/rtmp.c…

【打卡】牛客网:BM54 三数之和

资料&#xff1a; 1. 排序&#xff1a;Sort函数 升序&#xff1a;默认。 降序&#xff1a;加入第三个参数&#xff0c;可以greater<type>()&#xff0c;也可以自己定义 本题中发现&#xff0c;sort居然也可以对vector<vector<int>>排序。 C Sort函数详解_…

Axure9 基本操作(二)

1. 文本框、文本域 文本框&#xff1a;快速实现提示文字与不同类型文字显示的效果。 2. 下拉列表、列表框 下拉列表&#xff1a;快速实现下拉框及默认显示项的效果。 3. 复选框、单选按钮 4.

Mysql JSON 类型 索引查询 操作

JSON 类型操作 String 类型的 JSON 数组建立索引&查询语句 --索引添加 ALTER TABLE table_name ADD INDEX idx_json_cloumn ((cast(json_cloumn->"$[*]" AS CHAR(255) ARRAY))); --查询 explain select * from table_name tcai where JSON_CONTAINS(json_cl…

Linux 本地zabbix结合内网穿透工具实现安全远程访问浏览器

前言 Zabbix是一个基于WEB界面的提供分布式系统监视以及网络监视功能的企业级的开源解决方案。能监视各种网络参数&#xff0c;保证服务器系统的安全运营&#xff1b;并提供灵活的通知机制以让系统管理员快速定位/解决存在的各种问题。 本地zabbix web管理界面限制在只能局域…

关于400G光模块的常见问题解答

最近在后台收到了很多用户咨询关于400G光模块的信息&#xff0c;那400G光模块作为当下主流的光模块类型&#xff0c;有哪些问题是备受关注的呢&#xff1f;下面来看看小易的详细解答&#xff01; 1、什么是400G QSFP-DD光模块&#xff1f; 答&#xff1a;400G光模块是指传输速…

linux下安装python3.8(有坑)

1安装包下载 ###直接官网下载linux版本&#xff0c;找到对应的包 https://www.python.org/downloads/source/2安装包解压 tar -zxvf Python-3.8.0.tgz 3编译安装 1&#xff09;设置安装目录&#xff0c;比如在此创建在 /usr/local/python3 &#xff1a; mkdir -p /usr/loca…