【读代码】高斯掩模

目录

问题:

主要功能:


问题:

看不懂实现的功能

主要功能:

从输出张量中提取与边界框对应的区域,并计算该区域与高斯核之间的均方误差(MSE)损失

例子

假设我们有以下输入:

  • boxes 是一个包含边界框坐标的张量,形状为 (2, 5),表示有两个边界框,每个边界框有 5 个坐标值
  • output 是一个 4D 张量,形状为 (1, 3, 100, 100),表示一个批次中有 1 个样本,3 个通道,大小为 100x100 的图像
  • sigma 是高斯核的标准差。
  • use_gpu 是一个布尔值,表示是否使用 GPU。
  • Loss 是一个初始为 0 的损失值。
import torch
import torch.nn.functional as F
import numpy as np# 示例输入
boxes = torch.tensor([[0, 10, 20, 30, 40], [0, 50, 60, 70, 80]])  # 两个边界框
output = torch.randn(1, 3, 100, 100)  # 随机生成的输出张量
sigma = 1.0
use_gpu = False
Loss = 0.0# 定义高斯核生成函数
def matlab_style_gauss2D(shape=(3, 3), sigma=0.5):m, n = [(ss - 1.) / 2. for ss in shape]y, x = np.ogrid[-m:m+1, -n:n+1]h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))h[h < np.finfo(h.dtype).eps * h.max()] = 0return h# 处理边界框
if boxes.shape[1] > 1:boxes = boxes.squeeze()  # 移除单维度for tempBoxes in boxes.squeeze():y1 = int(tempBoxes[1])  # 边界框的起始 y 坐标y2 = int(tempBoxes[3])  # 边界框的结束 y 坐标x1 = int(tempBoxes[2])  # 边界框的起始 x 坐标x2 = int(tempBoxes[4])  # 边界框的结束 x 坐标# 从输出中提取与边界框对应的区域out = output[:, :, y1:y2, x1:x2]# 创建高斯核GaussKernel = matlab_style_gauss2D(shape=(out.shape[2], out.shape[3]), sigma=sigma)# 将高斯核转换为 PyTorch 张量GaussKernel = torch.from_numpy(GaussKernel).float()# 如果使用 GPU,将高斯核移动到 GPUif use_gpu:GaussKernel = GaussKernel.cuda()# 计算提取区域和高斯核的 MSE 损失,并累加到总损失Loss += F.mse_loss(out.squeeze(), GaussKernel)print(f"总损失: {Loss}")
  1. 边界框处理

    • boxes 是一个形状为 (2, 5) 的张量,表示有两个边界框,每个边界框有 5 个坐标值。
    • boxes.squeeze() 移除单维度,得到形状为 (2, 5) 的张量。
    • 遍历每个边界框,提取起始和结束的 x 和 y 坐标。
  2. 提取区域

    • 使用提取的坐标从 output 张量中提取对应的区域。
  3. 创建高斯核

    • 使用 matlab_style_gauss2D 函数创建一个与提取区域形状相同的高斯核。
    • 将高斯核转换为 PyTorch 张量。
  4. 计算损失

    • 如果使用 GPU,将高斯核移动到 GPU。
    • 计算提取区域和高斯核之间的均方误差(MSE)损失,并累加到总损失 Loss 中

问题1 为什么 边界框有 5 个坐标值

import torch# 示例边界框张量,形状为 (2, 5)
boxes = torch.tensor([[0, 10, 20, 30, 40],  # 第一个边界框[1, 50, 60, 70, 80]   # 第二个边界框
])# 遍历每个边界框
for tempBoxes in boxes:class_id = tempBoxes[0]  # 类别标签或置信度分数y1 = int(tempBoxes[1])   # 左上角的 y 坐标x1 = int(tempBoxes[2])   # 左上角的 x 坐标y2 = int(tempBoxes[3])   # 右下角的 y 坐标x2 = int(tempBoxes[4])   # 右下角的 x 坐标print(f"类别标签: {class_id}, 左上角: ({x1}, {y1}), 右下角: ({x2}, {y2})")

输出:

类别标签: 0, 左上角: (20, 10), 右下角: (40, 30)
类别标签: 1, 左上角: (60, 50), 右下角: (80, 70)

问题2: 什么叫squeeze()移除单维度 为什么需要squeeze() 操作

squeeze() 是 PyTorch 中的一个方法,用于移除张量中大小为 1 的维度。这个操作在处理数据时非常有用,特别是在某些情况下,数据可能包含不必要的单维度

例子:假设 boxes 的形状为 (1, N, 5),其中 1 是一个单维度。使用 squeeze() 后,形状会变为 (N, 5),移除了大小为 1 的维度

为什么这里采用squeeze操作:

boxes 可能包含一个单维度,这会导致遍历和处理数据时出现问题。通过使用 squeeze(),可以确保 boxes 的形状符合预期,从而简化后续的处理

示例代码:

import torch# 示例张量,形状为 (1, 2, 5)
boxes = torch.tensor([[[0, 10, 20, 30, 40],[1, 50, 60, 70, 80]
]])print("原始形状:", boxes.shape)  # 输出: torch.Size([1, 2, 5])# 移除单维度
boxes = boxes.squeeze()print("移除单维度后的形状:", boxes.shape)  # 输出: torch.Size([2, 5])

所以论文里 squeeze的原因

boxes.squeeze() 移除了单维度,使得 boxes 的形状变为 (N, 5),这样可以方便地遍历每个边界框并进行处理

tensor的话 可能至少是三维的(我猜 是这个原因)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/51635.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

我的创作纪念日(一)——Giser?Noder?不如“Computer”

目录 Giser&#xff1f;Noder&#xff1f;不如“Computer” 一、根源&#xff1a;保持学习习惯的刚需 二、机缘&#xff1a;processOn的另类替代 三、日常&#xff1a;对技术栈丰富的思考 四、成就&#xff1a;保持心态健康的活着 五、憧憬&#xff1a;能一直心态健康的活…

前端实现【 批量任务调度管理器 】demo优化

一、前提介绍 我在前文实现过一个【批量任务调度管理器】的 demo&#xff0c;能实现简单的任务批量并发分组&#xff0c;过滤等操作。但是还有很多优化空间&#xff0c;所以查找一些优化的库&#xff0c; 主要想优化两个方面&#xff0c; 上篇提到的&#xff1a; 针对 3&…

CSS技巧专栏:一日一例 14-纯CSS实现模拟水波波动填充按钮特效

CSS技巧专栏:一日一例 14-纯CSS实现模拟水波波动填充按钮特效 大家好,今天介绍一个在网上很常见的模拟水波波动要灌满按钮的动画效果,效果下面图所示。 本例图片 案例分析 我们沿着Z轴从上到下数一下一共有几个层: 文字层:白色文字阴影的黑色文字,当鼠标移动上来时候…

黑马点评--给店铺类型查询添加缓存

controller/ShopTypeController.java /*** 店铺分类查询&#xff0c;用于展示首页头部店铺分类* return*/GetMapping("list")public Result queryTypeList() {return typeService.queryList();} service/IShopTypeService.java Result queryList(); service/impl/S…

fatal: Could not read from remote repository. 解决方法

问题描述&#xff1a; Git : fatal: Could not read from remote repository. Please make sure you have the correct access rights and the repository exists。 解决方法&#xff1a; 当在网上尝试大量方法仍然失败的时候&#xff0c;不妨试试这个方法。 在 github 上&…

探索 Redis 不同集群架构的性能与应用

1. 引言 Redis的集群配置成为了提高数据可靠性和服务可用性的关键。本文将带领大家了解Redis的四种主要集群架构&#xff0c;并重点分析哨兵模式和Redis Cluster架构和优势。 2. Redis的四种集群架构 2.1 单实例Redis 使用单个 Redis 实例提供服务。适用于小规模应用&#…

论文阅读:Deformable DETR: Deformable Transformers for End-to-End Object Detection

论文阅读&#xff1a;Deformable DETR: Deformable Transformers for End-to-End Object Detection Deformable DETR: 基于稀疏空间采样的注意力机制&#xff0c;让DCN与Transformer一起玩&#xff01; - 知乎 (zhihu.com) 【Deformable DETR 论文源码解读】Deformable Trans…

The Llama 3 Herd of Models.Llama 3 模型第1,2,3部分全文

现代人工智能(AI)系统是由基础模型驱动的。本文提出了一套新的基础模型,称为Llama 3。它是一组语言模型,支持多语言、编码、推理和工具使用。我们最大的模型是一个密集的Transformer,具有405B个参数和多达128K个tokens的上下文窗口。本文对Llama 3进行了广泛的实证评价。我们…

【error】AttributeError: module ‘cv2.dnn‘ has no attribute ‘DictValue‘(库冲突)

conda list conda remove opencv pip uninstall opencv-python conda list pip 同时卸载两个库 pip uninstall opencv-contrib-python opencv-python 没有and 直接写库名 module ‘cv2.dnn‘ has no attribute ‘DictValue‘解决办法_module cv2.dnn has no attribute d…

Linux - 环境变量、程序地址空间、进程地址空间及Linux2.6内核进程调度队列

目录 环境变量 基本概念 常见环境变量 查看环境变量的方法 测试PATH 测试HOME 测试SHELL 和环境变量相关的命令 环境变量的组织方式 通过代码获取环境变量 通过系统调用获取环境变量 程序地址空间 进程地址空间 Linux2.6内核进程调度队列 一个CPU拥有一个runqueue 优先级 活…

谈一谈爬虫开发工程师

爬虫就只是抓数据的吗&#xff1f;并不是&#xff0c;爬虫工程师的工作不再仅仅是抓取数据&#xff0c;还需要处理其他各种复杂问题&#xff0c;今天我们就来聊聊爬虫开发工程师。 一、 爬虫开发工程师工作内容 爬虫开发工程师是负责编写和维护网络爬虫程序的专业人员。他们的…

【多模态大模型】 ALBEF in NeurIPS 2021

一、引言 论文&#xff1a; Align before Fuse: Vision and Language Representation Learning with Momentum Distillation 作者&#xff1a; Salesforce Research 代码&#xff1a; ALBEF 特点&#xff1a; 该方法使用ViT进行图像特征提取&#xff0c;提出将BERT分两部分&am…

Cocos Creator2D游戏开发(3)-飞机大战(1)-背景动起来

资源见: https://pan.baidu.com/s/1cryYNdBOry5A4YEEcLwhDQ?pwdzual 步骤 1, 让背景动起来 2, 玩家飞机显现,能操控,能发射子弹 3.敌机出现 4. 碰撞效果(子弹和敌机,敌机和玩家) 5. 积分和游戏结束 6. 游戏存档,对接微信小游戏,保存历史最高分 7. cocos发布到微信小游戏 资源…

探索Python的进度条神器:tqdm

文章目录 探索Python的进度条神器&#xff1a;tqdm一、背二、tqdm简介三、安装tqdm四、tqdm的五个简单使用示例五、tqdm在不同场景下的应用六、常见问题及解决方案七、总结 探索Python的进度条神器&#xff1a;tqdm 一、背 景&#xff1a;为什么选择tqdm&#xff1f; 在Python…

苦学Opencv的第十四天:人脸检测和人脸识别

Python OpenCV入门到精通学习日记&#xff1a;人脸检测和人脸识别 前言 经过了十三天的不懈努力&#xff0c;我们终于也是来到了人脸检测和人脸识别啦&#xff01;相信大家也很激动吧。接下来我们开始吧&#xff01; 人脸识别是基于人的脸部特征信息进行身份识别的一种生物识…

Spring 常用的三种拦截器详解

前言 在开发过程中&#xff0c;我们常常使用到拦截器来处理一些逻辑。最常用的三种拦截器分别是 AOP、 Interceptor 、 Filter&#xff0c;但其实很多人并不知道什么时候用AOP&#xff0c;什么时候用Interceptor&#xff0c;什么时候用Filter&#xff0c;也不知道其拦截顺序&am…

spring —— 事务管理器

事务管理主要针对数据源进行操作&#xff1a;在数据库方面&#xff0c;通过 TransactionManager 事务管理器进行管理&#xff0c;表明一旦出现错误&#xff0c;该数据源的所有数据全部复原。那么数据库如何判断是否发生了错误呢&#xff1f;这就需要在代码方面&#xff0c;通过…

抖音直播弹幕数据逆向:websocket和JS注入

&#x1f50d; 思路与步骤详解 &#x1f575;️‍♂️ 思路介绍 首先&#xff0c;我们通过抓包工具进入的直播间&#xff0c;捕获其网络通信数据&#xff0c;重点关注WebSocket连接。发现直播弹幕数据通过WebSocket传输&#xff0c;这种方式比传统的HTTP更适合实时数据的传输。…

前端基于 axios 实现批量任务调度管理器 demo

一、背景介绍 这是一个基于 axios 实现的批量任务调度管理器的 demo。它使用了axios、promise 等多种技术和原理来实现批量处理多个异步请求&#xff0c;并确保所有请求都能正确处理并报告其状态。 假设有一个场景&#xff1a;有一个任务列表&#xff0c;有单个任务的处理功能…

【Qt】QLCDNumberQProgressBarQCalendarWidget

目录 QLCDNumber 倒计时小程序 相关属性 QProgressBar 进度条小程序 相关设置 QLCDNumber QLCDNumber是Qt框架中用于显示数字或计数值的小部件。通常用于显示整数值&#xff0c;例如时钟、计时器、计数器等 常用属性 属性说明intValueQLCDNumber显示的初始值(int类型)va…