pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。
在这里插入图片描述

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)

增加加权的CE_loss代码实现

# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as npclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')def forward(self, preds, targets):return self.nll_loss(preds, targets)

语义分割类别计算

class CE_w_loss(nn.Module):def __init__(self,ignore_index=255):super(CE_w_loss, self).__init__()self.ignore_index = ignore_index# self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)def forward(self, outputs, targets):class_num = outputs.shape[1]# print("class_num :",class_num )# # 计算每个类别在整个 batch 中的像素数占比class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)# # 根据类别占比计算权重class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重# # print("class_weights :",class_weights)## 定义交叉熵损失函数,并使用动态计算的类别权重criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)# 计算损失loss = criterion(outputs, targets)print(loss.item())  # 打印损失值return lossnp.random.seed(666)pred = np.ones((2, 5, 256,256))seg = np.ones((2, 5, 256, 256)) # 灰度label = np.ones((2, 256, 256))  # 灰度pred = torch.from_numpy(pred)seg = torch.from_numpy(seg).int()  # 灰度label = torch.from_numpy(label).long()ce = CE_w_loss()loss = ce(pred, label)print("loss:",loss.item())

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/25057.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

图像特征提取 python

1. 边缘检测 (Edge Detection) 1.1 Sobel 算子 Sobel 算子是一种边缘检测算子,通过计算图像梯度来检测边缘。 import cv2 import numpy as np# 读取图像 image cv2.imread(image.jpg, 0)# 应用 Sobel 算子 sobel_x cv2.Sobel(image, cv2.CV_64F, 1, 0, ksize5)…

解决Windows窗口聚焦问题

情景引入: 在使用副屏显示器写代码,主屏显示器看教程的时候,突然有个知识点卡住了,这个时候你想要按下空格让视频暂停,但是按下后你会发现:视频没有暂停,倒是代码界面多了个空格。。。这就不好玩…

3. 变量的运算

文章目录 3.1 基本语法3.1.1 test条件测试3.1.2 中括号条件测试3.1.3 双中括号条件测试3.1.4 双圆括号 3.2 算术运算3.3 字符串运算符3.4 文件测试运算符3.5 关系运算符3.6 布尔运算符3.7 逻辑运算符 在 Shell 中包含如下的运算: 算术运算字符串运算符文件测试运算符…

用HTML实现拓扑面,动态4D圆环面,可手动调节,富有创新性的案例。(有源代码)

文章目录 前言一、示例二、目录结构三、index.html(主页面)四、main.js五、Tour4D.js六、swissgl.js七、dat.gui.min.js八、style.css 前言 如果你觉得对代码进行复制粘贴很麻烦的话,你可以直接将资源下载到本地。无需部署,直接可…

如何对stm32查看IO功能。

有些同学对于别人的开发板的资源,或者IO口,或者串口等资源不知道怎么分配。 方法1、看硬石、野火、正点原子的开发板,看下他们的例子,那个资源用什么。自己多看几个原理图,多看几个视频,做一下笔记。以后依…

【面试干货】MySQL 三种锁的级别(表级锁、行级锁和页面锁)

【面试干货】MySQL 三种锁的级别(表级锁、行级锁和页面锁) 1、表级锁2、行级锁3、页面锁4、总结 💖The Begin💖点点关注,收藏不迷路💖 在 MySQL 数据库中,锁是控制并发访问的重要机制&#xff0…

Stable Diffusion之最全详解图解

Stable Diffusion之最全详解图解 引言 Stable Diffusion,作为2022年发布的深度学习领域的重大突破,革新了文本到图像生成的边界。这一模型不仅能够根据文本描述精确生成视觉图像,还展示了在图像内补、外补、以及在提示词引导下实现图像转换的…

GQA,MLA之外的另一种KV Cache压缩方式:动态内存压缩(DMC)

0x0. 前言 在openreview上看到最近NV的一个KV Cache压缩工作:https://openreview.net/pdf?idtDRYrAkOB7 ,感觉思路还是有一些意思的,所以这里就分享一下。 简单来说就是paper提出通过一种特殊的方式continue train一下原始的大模型&#x…

DS:树与二叉树的相关概念

欢迎来到Harper.Lee的学习世界!博主主页传送门:Harper.Lee的博客主页想要一起进步的uu可以来后台找我哦! 一、树的概念及其结构 1.1 树的概念亲缘关系 树是一种非线性的数据结构,它是由n(n>0)个有限节点…

汇编:数组-寻址取数据

比例因子寻址: 比例因子寻址(也称为比例缩放索引寻址或基址加变址加比例因子寻址)是一种复杂的内存寻址方式,常用于数组和指针操作。它允许通过一个基址寄存器、一个变址寄存器和一个比例因子来计算内存地址。 语法 比例因子寻…

LeetCode //C - 168. Excel Sheet Column Title

168. Excel Sheet Column Title Given an integer columnNumber, return its corresponding column title as it appears in an Excel sheet. For example: A -> 1 B -> 2 C -> 3 … Z -> 26 AA -> 27 AB -> 28 … Example 1: Input: columnNumber 1 Outp…

经典文献阅读之--Online Monocular Lane Mapping(使用Catmull-Rom样条曲线完成在线单目车道建图)

0. 简介 对于单目摄像头完成SLAM建图这类操作,对于自动驾驶行业非常重要,《Online Monocular Lane Mapping Using Catmull-Rom Spline》介绍了一种仅依靠单个摄像头和里程计生成基于样条的在线单目车道建图方法。我们提出的技术将车道关联过程建模为一个…

Java 习题集

💖 单选题 💖 填空题 💖 判断题 💖 程序阅读题 1. 读代码写结果 class A {int m 5;void zengA(int x){m m x;}int jianA(int y){return m - y;} }class B extends A {int m 3;int jianA(int z){return super.jianA(z) m;} …

Java Web学习笔记20——Ajax-Axios

Axios: 介绍:Axios对原生的Ajax进行封装,简化书写,快速开发。 官网:https://www.axios-http.cn Axios 入门: {}是Js的对象。 get的请求参数是在URL后面?和相关参数值。 post的请求参数是在请…

优化Elasticsearch搜索性能:查询调优与索引设计

在构建一个基于Elasticsearch的搜索解决方案时,性能优化是一个至关重要的环节。无论是处理海量数据,还是满足快速响应的搜索需求,优化Elasticsearch的性能都能显著提高用户体验和系统效率。本文将重点介绍如何通过查询调优和索引设计来优化El…

Soildworks学习笔记(二)

放样凸台基体: 自动生成连接两个物体两个面的基体: 2.旋转切除: 3.剪切实体: 4.转换实体引用: 将实体的轮廓线转换至当前草图使其成为当前草图的图元,主要用于在同一平面或另一个坐标中制作草图实体或其尺寸的副本。 …

【深度学习】Transformer分类器,CICIDS2017,入侵检测

文章目录 1 前言2 什么是入侵检测系统?3 为什么选择Transformer?4 数据预处理5 模型架构5.1. 输入嵌入层(Input Embedding Layer)5.2. 位置编码层(Positional Encoding Layer)5.3. Transformer编码器层&…

通过SSH远程登录华为设备

01 进入系统编辑视图 system-view Enter system view, return user view with return command. 02 创建本地RSA密钥对 [HUAWEI]rsa local-key-pair creat The key name will be:HUAWEI_Host The range of public key size is (2048 ~ 2048). NOTE: Key pair generation will ta…

17、关于加强数据资产管理的指导意见

数据资产,作为经济社会数字化转型进程中的新兴资产类型,正日益成为推动数字中国建设和加快数字经济发展的重要战略资源。为深入贯彻落实党中央决策部署,现就加强数据资产管理提出如下意见。 一、总体要求 (一)指导思想。 以新时代中国特色社会主义思想为指导,全面深入…

CF817F MEX Queries 题解

题目描述 解题思路 考虑分块。发现 l l l 和 r r r 的范围都很大,但是我们只需要知道第一个没有出现的正整数是在哪个位置,然后就能得到答案,所以我们把 l l l、 r r r、 r + 1 r+1 r+1 离散化,然后用一个数组映射回去,我们求得位置后就可以通过这个映射数组求出具体…