retinaNet FocalLoss源码详解

  • targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1
## 把正样本所对应的锚框所对应的类别的列置为1
# aim = torch.randint(0, 1, (1, 80))
# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          0, 0, 0, 0, 0, 0, 0, 0]])
# aim[0, 12] = 1
# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          0, 0, 0, 0, 0, 0, 0, 0]])
  • result = torch.where(condition, x, y)
import torch 
condition = torch.tensor([True, False, True, False]) 
x = torch.tensor([1, 2, 3, 4]) 
y = torch.tensor([10, 20, 30, 40])
result = torch.where(condition, x, y) 
print(result)
tensor([ 1, 20,  3, 40])
  • torch.eq(targets, 1.)
targets = torch.tensor([1, 0, 1, 0, 1]) 
torch.eq(targets, 1.)
Out[20]: tensor([ True, False,  True, False,  True])

在这里插入图片描述

retinaNet FocalLoss源码详解

class FocalLoss(nn.Module):#def __init__(self):def forward(self, classifications, regressions, anchors, annotations):alpha = 0.25gamma = 2.0batch_size = classifications.shape[0]classification_losses = []regression_losses = []anchor = anchors[0, :, :]anchor_widths  = anchor[:, 2] - anchor[:, 0]anchor_heights = anchor[:, 3] - anchor[:, 1]anchor_ctr_x   = anchor[:, 0] + 0.5 * anchor_widthsanchor_ctr_y   = anchor[:, 1] + 0.5 * anchor_heightsfor j in range(batch_size):classification = classifications[j, :, :]regression = regressions[j, :, :]bbox_annotation = annotations[j, :, :]bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1]classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)if bbox_annotation.shape[0] == 0:if torch.cuda.is_available():alpha_factor = torch.ones(classification.shape).cuda() * alphaalpha_factor = 1. - alpha_factorfocal_weight = classificationfocal_weight = alpha_factor * torch.pow(focal_weight, gamma)bce = -(torch.log(1.0 - classification))# cls_loss = focal_weight * torch.pow(bce, gamma)cls_loss = focal_weight * bceclassification_losses.append(cls_loss.sum())regression_losses.append(torch.tensor(0).float().cuda())else:alpha_factor = torch.ones(classification.shape) * alphaalpha_factor = 1. - alpha_factorfocal_weight = classificationfocal_weight = alpha_factor * torch.pow(focal_weight, gamma)bce = -(torch.log(1.0 - classification))# cls_loss = focal_weight * torch.pow(bce, gamma)cls_loss = focal_weight * bceclassification_losses.append(cls_loss.sum())regression_losses.append(torch.tensor(0).float())continueIoU = calc_iou(anchors[0, :, :], bbox_annotation[:, :4]) # num_anchors x num_annotationsIoU_max, IoU_argmax = torch.max(IoU, dim=1) # num_anchors x 1#Iou_max 每一行的最大值,即锚框与标注框iou的最大值,iou_argmax代表是第几个标注框#import pdb#pdb.set_trace()# compute the loss for classificationtargets = torch.ones(classification.shape) * -1if torch.cuda.is_available():targets = targets.cuda()targets[torch.lt(IoU_max, 0.4), :] = 0#把锚框与目标框iou低于0.4的targets值置为0positive_indices = torch.ge(IoU_max, 0.5)#iou_max 大于0.5的置为Truenum_positive_anchors = positive_indices.sum()assigned_annotations = bbox_annotation[IoU_argmax, :]##assigned_annotations 存放与锚框 iou最大的标注框targets[positive_indices, :] = 0targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1## 把正样本所对应的锚框所对应的类别的列置为1# aim = torch.randint(0, 1, (1, 80))# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,#          0, 0, 0, 0, 0, 0, 0, 0]])# aim[0, 12] = 1# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,#          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,#          0, 0, 0, 0, 0, 0, 0, 0]])if torch.cuda.is_available():alpha_factor = torch.ones(targets.shape).cuda() * alphaelse:alpha_factor = torch.ones(targets.shape) * alphaalpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor)#对应为1的位置设置为alpha_factorfocal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification)focal_weight = alpha_factor * torch.pow(focal_weight, gamma)bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification))#注意这里的log是以e为底#二元交叉熵损失(Binary Cross Entropy, BCE)# cls_loss = focal_weight * torch.pow(bce, gamma)# 80中为一位都使用该公式进行计算cls_loss = focal_weight * bceif torch.cuda.is_available():cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())else:cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, torch.zeros(cls_loss.shape))temp = cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0)classification_losses.append(cls_loss.sum()/torch.clamp(num_positive_anchors.float(), min=1.0))# compute the loss for regressionif positive_indices.sum() > 0:assigned_annotations = assigned_annotations[positive_indices, :]anchor_widths_pi = anchor_widths[positive_indices]#把正样本的anchor_widths拿出来#注意理解anchor_widths[positive_indices]anchor_heights_pi = anchor_heights[positive_indices]anchor_ctr_x_pi = anchor_ctr_x[positive_indices]anchor_ctr_y_pi = anchor_ctr_y[positive_indices]gt_widths  = assigned_annotations[:, 2] - assigned_annotations[:, 0]gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1]gt_ctr_x   = assigned_annotations[:, 0] + 0.5 * gt_widthsgt_ctr_y   = assigned_annotations[:, 1] + 0.5 * gt_heights# clip widths to 1gt_widths  = torch.clamp(gt_widths, min=1)gt_heights = torch.clamp(gt_heights, min=1)targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi# 整个表达式# targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi# 的意义是计算目标框和锚框在x方向上的相对位移,并将其归一化到锚框的宽度上。## 具体地说:# (gt_ctr_x - anchor_ctr_x_pi)# 计算了真实目标框和锚框在x方向上的中心点的差值。# 除以anchor_widths_pi 将这个差值归一化到锚框的宽度上。targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pitargets_dw = torch.log(gt_widths / anchor_widths_pi)targets_dh = torch.log(gt_heights / anchor_heights_pi)# 表达式gt_widths / anchor_widths_pi# 计算了真实目标框宽度和锚框宽度之间的比例。然后,对这个比例取自然对数(torch.log),得到的结果# targets_dw是对数空间中的相对宽度差。# 这样的计算通常在目标检测任务中用于计算宽度方向上的损失。对数变换有助于处理不同尺度的宽度,因为当宽度差异很大时,# 对数尺度上的差异会变得更加均匀。此外,对数变换还有助于模型更好地学习如何调整锚框的宽度以匹配真实的目标框。# 例如,如果gt_widths是100,而anchor_widths_pi是50,那么gt_widths / anchor_widths_pi将是2,而# torch.log(2)大约是0.693。这意味着目标框的宽度是锚框宽度的两倍,而在对数尺度上,这个差异被表示为大约0.693。targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh))targets = targets.t()if torch.cuda.is_available():targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]]).cuda()else:targets = targets/torch.Tensor([[0.1, 0.1, 0.2, 0.2]])negative_indices = 1 + (~positive_indices)regression_diff = torch.abs(targets - regression[positive_indices, :])regression_loss = torch.where(torch.le(regression_diff, 1.0 / 9.0),0.5 * 9.0 * torch.pow(regression_diff, 2),regression_diff - 0.5 / 9.0)regression_losses.append(regression_loss.mean())else:if torch.cuda.is_available():regression_losses.append(torch.tensor(0).float().cuda())else:regression_losses.append(torch.tensor(0).float())return torch.stack(classification_losses).mean(dim=0, keepdim=True), torch.stack(regression_losses).mean(dim=0, keepdim=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/729016.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

react tab选项卡吸顶实现

react tab选项卡吸顶实现,直接上代码(代码有注释) tsx代码 /* eslint-disable react-hooks/exhaustive-deps */ import React, { useEffect, useState } from "react"; import DocumentTitle from react-document-title import s…

智奇科技工业 Linux 屏更新开机logo

智奇科技工业 Linux 屏更新开机logo 简介制作logo.img文件1、转换格式得到logo.bmp2、使用Linux命令生成img文件 制作rootfs.img文件替换rootfs.img中的logo 生成update.img固件附件 简介 智奇科技的 Linux 屏刷开机logo必须刷img镜像文件,比较复杂。 制作logo.i…

Python教程,python从入门到精通 第1天 温习笔记

1.1 字面量 1.2 注释 1.3 变量 1.4 数据类型 1.5 数据类型转换 1.6 标识符 1.7 运算符 1.8 字符串的三种定义方式 1.9 字符串拼接 1.10 字符串格式化 1.11 掌握格式化字符串的过程中做数字的精度控制 1.12 掌握快速字符串格式化的方式 1.13 字符串格式化-表达式的格…

《MySQL实战45讲》课程大纲

1MySQL实战45讲-01基础架构:一条SQL查询语句是如何执行的?2MySQL实战45讲-02日志系统:一条SQL更新语句是如何执行的?3MySQL实战45讲-03事务隔离:为什么你改了我还看不见?4MySQL实战45讲-04深入浅出索引&…

【C++干货基地】六大默认成员函数: This指针 | 构造函数 | 析构函数

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 引入 哈喽各位铁汁们好啊,我是博主鸽芷咕《C干货基地》是由我的襄阳家乡零食基地有感而发,不知道各位的…

Redis冲冲冲——redis数据类型及对应的数据结构

目录 引出redis数据类型及对应的数据结构Redis入门1.Redis是什么?2.Redis里面存Java对象 Redis进阶1.雪崩/ 击穿 / 穿透2.Redis高可用-主从哨兵3.持久化RDB和AOF4.Redis未授权访问漏洞5.Redis里面安装BloomFilte Redis的应用1.验证码2.Redis高并发抢购3.缓存预热用户…

SpringCloud 服务的注册与发现

一、前言 接下来是开展一系列的 SpringCloud 的学习之旅,从传统的模块之间调用,一步步的升级为 SpringCloud 模块之间的调用,此篇文章为第二篇,即使用服务注册和发现的组件,此篇文章会介绍 Eureka、Zookeeper 和 Consu…

环境音效生成器Moodist

什么是 Moodist ? Moodist 是免费、开源的环境音效生成器。拥有 54 种精选的音效,轻松为专注或放松创建自定义混合音效。无需账户,无需繁琐操作,尽享纯净宁静。探索大自然的宁静和城市的韵律。在 Moodist 中提升你的氛围&#xff…

Node 旧淘宝源 HTTPS 过期处理

今天拉取老项目更新依赖,出现 urlshttps%3A%2F%2Fregistry.npm.taobao.org%2Fegg-logger%2Fdownload%2Fegg-logger-2.6.1.tgz: certificate has expired 类似报错。即使删除 node_modules 重新安装,问题依然无法解决。 一、问题演示 二、原因分析 1、淘…

平台工程指南:从架构构建到职责分工

平台工程只是 DevOps 专业化的另一个术语,还是另有所指?事实可能介于两者之间。DevOps 及其相关的 DevXOps 有着浓厚的文化色彩,以各个团队为中心。不幸的是,在许多地方,DevOps 引发了新的问题,如工具激增和…

【云原生】kubeadm快速搭建K8s集群Kubernetes1.19.0

目录 一、 Kubernetes 的概述 二、服务器配置 2.1 服务器部署规划 2.2服务器初始化配置 三、安装Docker/kubeadm/kubelet【所有节点】 3.1 安装Docker 3.2 添加阿里云YUM软件源 3.3 安装kubeadm,kubelet和kubectl 四、部署Kubernetes Master 五、部署Kube…

网络入侵检测系统之Suricata(十四)--匹配流程

其实规则的匹配流程和加载流程是强相关的,你如何组织规则那么就会采用该种数据结构去匹配,例如你用radix tree组织海量ip规则,那么匹配的时候也是采用bit test确定前缀节点,然后逐一左右子树查询,Suricata也是如此&…

基于Spring Boot的图书个性化推荐系统 ,计算机毕业设计(带源码+论文)

源码获取地址: 码呢-一个专注于技术分享的博客平台一个专注于技术分享的博客平台,大家以共同学习,乐于分享,拥抱开源的价值观进行学习交流http://www.xmbiao.cn/resource-details/1765769136268455938

Doris实战——特步集团零售数据仓库项目实践

目录 一、背景 二、总体架构 三、ETL实践 3.1 批量数据的导入 3.2 实时数据接入 3.3 数据加工 3.4 BI 查询 四、实时需求响应 五、其他经验 5.1 Doris BE内存溢出 5.2 SQL任务超时 5.3 删除语句不支持表达式 5.4 Drop 表闪回 六、未来展望 原文大佬的这篇Doris数…

离散数学——(3)联结词及对应的真值指派,最小全功能联结词集,对偶式,范式,范式存在定理,小项

目录 1.联结词及对应的真值指派 2.最小全功能联结词集 3.对偶式 4.范式 1.析取范式 5.范式存在定理 6.小项 1.联结词及对应的真值指派 2.最小全功能联结词集 3.对偶式 4.范式 1.析取范式 5.范式存在定理 6.小项

hfish蜜罐搭建与使用

本次是对自己在学习蓝队过程中的一次对安全设备 hfish蜜罐的搭建和使用考核记录,距离之前已 经过去很久了,对之前在考核过程中的操作进行回顾和总结. 蜜罐在这里我进行免费分享 hfish-3.1.4-windows-amd64.zip官方版下载丨最新版下载丨绿色版下载丨APP下载-123云…

Visual Studio如何进行类文件的管理(类文件的分离)

大家好: 衷心希望各位点赞。 您的问题请留在评论区,我会及时回答。 一、问题背景 实际开发中,类的声明放在头文件中,给程序员看类的成员和方法。比如:Dog.h(类的声明文件) 类的成员函数的具体…

继承,切片,隐藏

定义: 子类(派生类)继承了父类(基类)的成员函数和成员变量(类层次的复用) 赋值 子类可以赋值给父类。父类不可以直接复制给子类 (不像不同内置类型的赋值要转化为临时变量&#xf…

U盘上文件夹突然空了?掌握3个方法,轻松找回数据!

“我的u盘插上之后,不知道为什么就空了,里面所有的文件都没有了,有什么方法可以找回u盘里丢失的文件吗?” 在日常使用U盘的过程中,我们有时会遇到一个令人头疼的问题,U盘上的文件夹突然空了。这究竟是怎么回…

Windows系统安装MongoDB并结合内网穿透实现公网访问本地数据库

文章目录 前言1. 安装数据库2. 内网穿透2.1 安装cpolar内网穿透2.2 创建隧道映射2.3 测试随机公网地址远程连接 3. 配置固定TCP端口地址3.1 保留一个固定的公网TCP端口地址3.2 配置固定公网TCP端口地址3.3 测试固定地址公网远程访问 前言 MongoDB是一个基于分布式文件存储的数…