pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。
在这里插入图片描述

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)

增加加权的CE_loss代码实现

# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as npclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')def forward(self, preds, targets):return self.nll_loss(preds, targets)

语义分割类别计算

class CE_w_loss(nn.Module):def __init__(self,ignore_index=255):super(CE_w_loss, self).__init__()self.ignore_index = ignore_index# self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)def forward(self, outputs, targets):class_num = outputs.shape[1]# print("class_num :",class_num )# # 计算每个类别在整个 batch 中的像素数占比class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)# # 根据类别占比计算权重class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重# # print("class_weights :",class_weights)## 定义交叉熵损失函数,并使用动态计算的类别权重criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)# 计算损失loss = criterion(outputs, targets)print(loss.item())  # 打印损失值return lossnp.random.seed(666)pred = np.ones((2, 5, 256,256))seg = np.ones((2, 5, 256, 256)) # 灰度label = np.ones((2, 256, 256))  # 灰度pred = torch.from_numpy(pred)seg = torch.from_numpy(seg).int()  # 灰度label = torch.from_numpy(label).long()ce = CE_w_loss()loss = ce(pred, label)print("loss:",loss.item())

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/25057.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决Windows窗口聚焦问题

情景引入: 在使用副屏显示器写代码,主屏显示器看教程的时候,突然有个知识点卡住了,这个时候你想要按下空格让视频暂停,但是按下后你会发现:视频没有暂停,倒是代码界面多了个空格。。。这就不好玩…

用HTML实现拓扑面,动态4D圆环面,可手动调节,富有创新性的案例。(有源代码)

文章目录 前言一、示例二、目录结构三、index.html(主页面)四、main.js五、Tour4D.js六、swissgl.js七、dat.gui.min.js八、style.css 前言 如果你觉得对代码进行复制粘贴很麻烦的话,你可以直接将资源下载到本地。无需部署,直接可…

如何对stm32查看IO功能。

有些同学对于别人的开发板的资源,或者IO口,或者串口等资源不知道怎么分配。 方法1、看硬石、野火、正点原子的开发板,看下他们的例子,那个资源用什么。自己多看几个原理图,多看几个视频,做一下笔记。以后依…

【面试干货】MySQL 三种锁的级别(表级锁、行级锁和页面锁)

【面试干货】MySQL 三种锁的级别(表级锁、行级锁和页面锁) 1、表级锁2、行级锁3、页面锁4、总结 💖The Begin💖点点关注,收藏不迷路💖 在 MySQL 数据库中,锁是控制并发访问的重要机制&#xff0…

GQA,MLA之外的另一种KV Cache压缩方式:动态内存压缩(DMC)

0x0. 前言 在openreview上看到最近NV的一个KV Cache压缩工作:https://openreview.net/pdf?idtDRYrAkOB7 ,感觉思路还是有一些意思的,所以这里就分享一下。 简单来说就是paper提出通过一种特殊的方式continue train一下原始的大模型&#x…

DS:树与二叉树的相关概念

欢迎来到Harper.Lee的学习世界!博主主页传送门:Harper.Lee的博客主页想要一起进步的uu可以来后台找我哦! 一、树的概念及其结构 1.1 树的概念亲缘关系 树是一种非线性的数据结构,它是由n(n>0)个有限节点…

汇编:数组-寻址取数据

比例因子寻址: 比例因子寻址(也称为比例缩放索引寻址或基址加变址加比例因子寻址)是一种复杂的内存寻址方式,常用于数组和指针操作。它允许通过一个基址寄存器、一个变址寄存器和一个比例因子来计算内存地址。 语法 比例因子寻…

经典文献阅读之--Online Monocular Lane Mapping(使用Catmull-Rom样条曲线完成在线单目车道建图)

0. 简介 对于单目摄像头完成SLAM建图这类操作,对于自动驾驶行业非常重要,《Online Monocular Lane Mapping Using Catmull-Rom Spline》介绍了一种仅依靠单个摄像头和里程计生成基于样条的在线单目车道建图方法。我们提出的技术将车道关联过程建模为一个…

Java 习题集

💖 单选题 💖 填空题 💖 判断题 💖 程序阅读题 1. 读代码写结果 class A {int m 5;void zengA(int x){m m x;}int jianA(int y){return m - y;} }class B extends A {int m 3;int jianA(int z){return super.jianA(z) m;} …

Java Web学习笔记20——Ajax-Axios

Axios: 介绍:Axios对原生的Ajax进行封装,简化书写,快速开发。 官网:https://www.axios-http.cn Axios 入门: {}是Js的对象。 get的请求参数是在URL后面?和相关参数值。 post的请求参数是在请…

Soildworks学习笔记(二)

放样凸台基体: 自动生成连接两个物体两个面的基体: 2.旋转切除: 3.剪切实体: 4.转换实体引用: 将实体的轮廓线转换至当前草图使其成为当前草图的图元,主要用于在同一平面或另一个坐标中制作草图实体或其尺寸的副本。 …

【深度学习】Transformer分类器,CICIDS2017,入侵检测

文章目录 1 前言2 什么是入侵检测系统?3 为什么选择Transformer?4 数据预处理5 模型架构5.1. 输入嵌入层(Input Embedding Layer)5.2. 位置编码层(Positional Encoding Layer)5.3. Transformer编码器层&…

MySQL—多表查询—子查询(介绍)

一、引言 上一篇博客学习完联合查询。 这篇开始,就来到多表查询的最后一种形式语法块——子查询。 (1)概念 SQL 语句中嵌套 SELECT 语句,那么内部的 select 称为嵌套查询,又称子查询。 表现形式 注意: …

复数的概念

1. 虚数单位:i 引入一个新数 ‘i’,i又叫做虚数单位,并规定: 它的平方等于 -1,即 i -1。实数可以与它进行四则运算,并且原有的加,乘运算律依然成立。 2.定义 复数的定义:形如 a…

CTFHUB-SQL注入-字符型注入

目录 查询数据库名 查询数据库中的表名 查询表中数据 总结 此题目和上一题相似,一个是整数型注入,一个是字符型注入。字符型注入就是注入字符串参数,判断回显是否存在注入漏洞。因为上一题使用手工注入查看题目 flag ,这里就不…

GIS数据快捷共享发布工具及操作视频

有网友反映还是不会操作GIS数据快捷共享发布工具(建立自己的地图网站),要我录个视频。 好,那就录一个: GIS数据快捷共享发布工具及操作视频 虽然默认例子是二维的,但这个服务器可以为二维、三维系统发布时间服务。都是…

NASA数据集——SARAL 近实时增值业务地球物理数据记录海面高度异常

SARAL Near-Real-Time Value-added Operational Geophysical Data Record Sea Surface Height Anomaly SARAL 近实时增值业务地球物理数据记录海面高度异常 简介 2020 年 3 月 18 日至今 ALTIKA_SARAL_L2_OST_XOGDR 这些数据是近实时(NRT)&#xff…

SpringCloudAlibaba基础二 Nacos注册中心

一 什么是 Nacos 官方:一个更易于构建云原生应用的动态服务发现(Nacos Discovery )、服务配置(Nacos Config)和服务管理平台。 集 注册中心配置中心服务管理 平台。 Nacos 的关键特性包括: 服务发现和服务健康监测动态配置服务动态 DNS 服务服务及其元数据管理 …

达梦8 探寻达梦排序原理:新排序机制(SORT_FLAG=1)

测试版本:--03134283938-20221019-172201-20018 达梦的排序机制由四个dm.ini参数控制: #maximum sort buffer size in Megabytes ,有效值范围(1~2048) SORT_BUF_SIZE 100 #ma…

mysql数据库打开失败的问题

打不开mysql 1. my.ini未配置路径:找到basedir和datadir,改成如下路径。改完记得去掉井号 !! 复制路径粘贴,记得去掉井号 !! 启动!方式1 发生系统错误5:没有用管理员身…