问题清除指南|关于num_classes与 BCELoss、BCEWithLogitsLoss 和 CrossEntropyLoss 的关系

前言:关于「 num_classes = 1 」引发的探究。

2024年尾声,学弟问到一个问题:在研究工作 CNNDetection 的github开源代码 networks/trainer.py 文件的 line 27 self.model = resnet50(num_classes=1) 中,变量 num_classes 的值为1,而该工作完成的是一个二分类任务,于是就有了这篇博文的诞生。

首先求助 GPT-4o,获取初步解答如下:
在这里插入图片描述

进一步发现:trainer.py文件中 line 27 ,resnet50使用到的损失函数是BCEWithLogitsLoss(·)


在二分类任务中,num_classes 通常设置为1或2,具体取决于使用的损失函数和模型的输出形式。下面是关于这两种情况的解释:

num_classes = 1

  • 模型输出:模型的最后一层输出一个值(logit),表示样本属于正类的概率。
  • 损失函数:使用带有逻辑回归的二分类交叉熵损失函数 (BCEWithLogitsLoss) 或二元交叉熵损失函数 (BCELoss)。
  • 代码示例:
import torch.nn as nn
import torchvision.models as modelsclass BinaryClassificationResNet50(nn.Module):def __init__(self, pretrained=True):super(BinaryClassificationResNet50, self).__init__()self.model = models.resnet50(pretrained=pretrained)self.model.fc = nn.Linear(self.model.fc.in_features, 1)  # 输出一个值def forward(self, x):return self.model(x)# 初始化模型
model = BinaryClassificationResNet50(pretrained=True)# 打印模型结构
print(model)
# 输出结果
BinaryClassificationResNet50((model): ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer2): Sequential((0): Bottleneck((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3): Sequential((0): Bottleneck((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(4): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(5): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4): Sequential((0): Bottleneck((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=1, bias=True))
)

num_classes = 2

  • 模型输出:模型的最后一层输出两个值,分别表示样本属于两个类别的概率(通过softmax函数转换为概率)。
  • 损失函数:使用交叉熵损失函数 (CrossEntropyLoss)。
  • 代码示例:
import torch
import torch.nn as nn
import torchvision.models as modelsclass BinaryClassificationResNet50(nn.Module):def __init__(self, pretrained=True):super(BinaryClassificationResNet50, self).__init__()self.model = models.resnet50(pretrained=pretrained)self.model.fc = nn.Linear(self.model.fc.in_features, 2)  # 输出两个值def forward(self, x):return self.model(x)# 初始化模型
model = BinaryClassificationResNet50(pretrained=True)# 打印模型结构
print(model)
# 输出结果
BinaryClassificationResNet50((model): ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))...(fc): Linear(in_features=2048, out_features=2, bias=True))
)

写在最后:感谢 https://github.com/copilot !!!免费试用 GPT-4o ~


参考资料

  • 【PyTorch】进阶学习:探索BCEWithLogitsLoss的正确使用—二元分类问题中的logits与标签形状问题-CSDN博客
  • Pytorch中BCELoss,BCEWithLogitsLoss和CrossEntropyLoss的区别_bcewithlogitsloss 与交叉熵在 多标签任务中得区别-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/66031.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

grouped.get_group((‘B‘, ‘A‘))选择分组

1. df.groupby([team, df.name.str[0]]) df.groupby([team, df.name.str[0]]) 这一部分代码表示对 DataFrame df 按照 两个条件 进行分组: 按照 team 列(即团队)。按照 name 列的 首字母(df.name.str[0])。 df.name.s…

SQL字符串截取函数——Left()、Right()、Substring()用法详解

SQL字符串截取函数——Left()、Right()、Substring()用法详解 1. LEFT() 函数:从字符串的左侧提取指定长度的子字符串。 LEFT(string, length)string:要操作的字符串。length&#x…

C# 服务调用RFC函数获取物料信息,并输出生成Excel文件

这个例子是C#服务调用RFC函数,获取物料的信息,并生成Excel文件 上接文章:C#服务 文章目录 创建函数创建结构编写源代码创建批处理文件运行结果-成功部署服务器C#代码配置文件注意!! 创建函数 创建结构 编写源代码 创建…

打开idea开发软件停留在加载弹出框页面进不去

问题 idea软件点击打开,软件卡在加载弹框进不去。 解决方法 先进入“任务管理器”停止IDEA的任务进程 2.找到IDEA软件保存的本地数据文件夹 路径都是在C盘下面:路径:C:\Users\你的用户名\AppData\Local\JetBrains 删除目录下的文件夹&…

sqlserver sql转HTMM邮件发送

通过sql的形式,把表内数据通过邮件的形式发送出去 declare title varchar(100) DECLARE stat_date CHAR(10),create_time datetime SET stat_dateCONVERT(char(10),GETDATE(),120) SET create_timeDATEADD(MINUTE,-20,GETDATE()) DECLARE xml NVARCHAR (max) DECLAR…

Linux:各发行版及其包管理工具

相关阅读 Linuxhttps://blog.csdn.net/weixin_45791458/category_12234591.html?spm1001.2014.3001.5482 Debian 包管理工具:dpkg(低级包管理器)、apt(高级包管理器,建立在dpkg基础上)包格式:…

Java项目实战II基于小程序的驾校管理系统(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、核心代码 五、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 随着汽车保有量的不断增长,驾驶培训市场日…

小程序租赁系统开发的优势与应用探索

内容概要 在如今这个数码科技飞速发展的时代,小程序租赁系统开发仿佛是一张神奇的魔法卡,能让租赁体验变得顺畅如丝。想象一下,无论你需要租用什么,从单车到房屋,甚至是派对用品,只需动动手指,…

AAAI2025:这也能融合?巧用多坐标系融合策略,PC-BEV实现点云分割170倍加速,精度显著提升!

引言:本文提出了一种基于鸟瞰图(BEV)空间的激光雷达点云分割方法,该方法通过融合极坐标和笛卡尔分区策略,实现了快速且高效的特征融合。该方法利用固定网格对应关系,避免了传统点云交互中的计算瓶颈&#x…

职场常用Excel基础04-二维表转换

大家好,今天和大家一起分享一下excel的二维表转换相关内容~ 在Excel中,二维表(也称为矩阵或表格)是一种组织数据的方式,其中数据按照行和列的格式进行排列。然而,在实际的数据分析过程中,我们常…

python-redis访问指南

Redis(Remote Dictionary Server)是一种开源的内存数据结构存储,可用作数据库、缓存和消息代理。它功能强大且灵活,可根据需求调整架构和配置,以高性能、简单易用、支持多种数据结构而闻名,广泛应用于各种场…

Px4 V2.4.8飞控Mavlink命令控制说明

首先,可以使用两种方法连接飞控,使用虚拟机(LINUX)或使用地面站(QGC)连接。 在px4的代码文件位置打开命令终端,输入连接命令: ./Tools/mavlink_shell.py 在控制台使用help来获取所有…

MySQL8安装与卸载

1.下载mysql MySQL :: Download MySQL Community Serverhttps://dev.mysql.com/downloads/mysql/ 2.解压mysql安装包 解压到自己定义的目录,这里解压就是安装,解压后的路径不要有空格和中文。 3.配置环境变量 配置环境变量可以方便电脑在任何的路径…

简洁安装配置在Windows环境下使用vscode开发pytorch

简洁安装配置在Windows环境下使用vscode开发pytorch 使用anaconda安装pytorch,通过vscode集成环境开发pytorch 下载 anaconda 下载网址,选择对应系统的版本 https://repo.anaconda.com/archive/ windows可以选择Anaconda3-2024.10-1-Windows-x86_64.e…

使用 Jupyter Notebook:安装与应用指南

文章目录 安装 Jupyter Notebook1. 准备环境2. 安装 Jupyter Notebook3. 启动 Jupyter Notebook4. 选择安装方式(可选) 二、Jupyter Notebook 的基本功能1. 单元格的类型与运行2. 可视化支持3. 内置魔法命令 三、Jupyter Notebook 的实际应用场景1. 数据…

unity学习3:如何从github下载开源的unity项目

目录 1 网上别人提供的一些github的unity项目 2 如何下载github上的开源项目呢? 2.1.0 下载工具 2.1.1 下载方法1 2.1.2 下载方法2(适合内部项目) 2.1.3 第1个项目 和第4项目 的比较 第1个项目 第2个项目 第3个项目 2.1.4 下载方法…

npm install --global windows-build-tools --save 失败

注意以下点 为啥下载windows-build-tools,是因为node-sass4.14.1 一直下载不成功,提示python2 没有安装,最终要安装这个,但是安装这个又失败,主要有以下几个要注意的 1、node 版本 14.21.3 不能太高 2、管理员运行 …

十二、Vue 路由

文章目录 一、简介二、安装与基本配置安装 Vue Router创建路由实例在应用中使用路由实例三、路由组件与视图路由组件的定义与使用四、动态路由动态路由参数的定义与获取动态路由的应用场景五、嵌套路由嵌套路由的概念与配置嵌套路由的应用场景六、路由导航<router - link>…

NLP 中文拼写检测纠正论文-08-Combining ResNet and Transformer

拼写纠正系列 NLP 中文拼写检测实现思路 NLP 中文拼写检测纠正算法整理 NLP 英文拼写算法&#xff0c;如果提升 100W 倍的性能&#xff1f; NLP 中文拼写检测纠正 Paper java 实现中英文拼写检查和错误纠正&#xff1f;可我只会写 CRUD 啊&#xff01; 一个提升英文单词拼…

儿童坐姿矫正器是如何实现语音提示功能?

儿童坐姿不正确&#xff0c;不仅影响他们的体态美观&#xff0c;更关乎其身体健康与成长发育。长期以往&#xff0c;可能会导致脊柱侧弯、近视加深等一系列健康问题。家长应当对此给予足够重视&#xff0c;及时纠正孩子们的坐姿习惯。 为了改善这一状况&#xff0c;可以从这方…