【读代码】高斯掩模

目录

问题:

主要功能:


问题:

看不懂实现的功能

主要功能:

从输出张量中提取与边界框对应的区域,并计算该区域与高斯核之间的均方误差(MSE)损失

例子

假设我们有以下输入:

  • boxes 是一个包含边界框坐标的张量,形状为 (2, 5),表示有两个边界框,每个边界框有 5 个坐标值
  • output 是一个 4D 张量,形状为 (1, 3, 100, 100),表示一个批次中有 1 个样本,3 个通道,大小为 100x100 的图像
  • sigma 是高斯核的标准差。
  • use_gpu 是一个布尔值,表示是否使用 GPU。
  • Loss 是一个初始为 0 的损失值。
import torch
import torch.nn.functional as F
import numpy as np# 示例输入
boxes = torch.tensor([[0, 10, 20, 30, 40], [0, 50, 60, 70, 80]])  # 两个边界框
output = torch.randn(1, 3, 100, 100)  # 随机生成的输出张量
sigma = 1.0
use_gpu = False
Loss = 0.0# 定义高斯核生成函数
def matlab_style_gauss2D(shape=(3, 3), sigma=0.5):m, n = [(ss - 1.) / 2. for ss in shape]y, x = np.ogrid[-m:m+1, -n:n+1]h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))h[h < np.finfo(h.dtype).eps * h.max()] = 0return h# 处理边界框
if boxes.shape[1] > 1:boxes = boxes.squeeze()  # 移除单维度for tempBoxes in boxes.squeeze():y1 = int(tempBoxes[1])  # 边界框的起始 y 坐标y2 = int(tempBoxes[3])  # 边界框的结束 y 坐标x1 = int(tempBoxes[2])  # 边界框的起始 x 坐标x2 = int(tempBoxes[4])  # 边界框的结束 x 坐标# 从输出中提取与边界框对应的区域out = output[:, :, y1:y2, x1:x2]# 创建高斯核GaussKernel = matlab_style_gauss2D(shape=(out.shape[2], out.shape[3]), sigma=sigma)# 将高斯核转换为 PyTorch 张量GaussKernel = torch.from_numpy(GaussKernel).float()# 如果使用 GPU,将高斯核移动到 GPUif use_gpu:GaussKernel = GaussKernel.cuda()# 计算提取区域和高斯核的 MSE 损失,并累加到总损失Loss += F.mse_loss(out.squeeze(), GaussKernel)print(f"总损失: {Loss}")
  1. 边界框处理

    • boxes 是一个形状为 (2, 5) 的张量,表示有两个边界框,每个边界框有 5 个坐标值。
    • boxes.squeeze() 移除单维度,得到形状为 (2, 5) 的张量。
    • 遍历每个边界框,提取起始和结束的 x 和 y 坐标。
  2. 提取区域

    • 使用提取的坐标从 output 张量中提取对应的区域。
  3. 创建高斯核

    • 使用 matlab_style_gauss2D 函数创建一个与提取区域形状相同的高斯核。
    • 将高斯核转换为 PyTorch 张量。
  4. 计算损失

    • 如果使用 GPU,将高斯核移动到 GPU。
    • 计算提取区域和高斯核之间的均方误差(MSE)损失,并累加到总损失 Loss 中

问题1 为什么 边界框有 5 个坐标值

import torch# 示例边界框张量,形状为 (2, 5)
boxes = torch.tensor([[0, 10, 20, 30, 40],  # 第一个边界框[1, 50, 60, 70, 80]   # 第二个边界框
])# 遍历每个边界框
for tempBoxes in boxes:class_id = tempBoxes[0]  # 类别标签或置信度分数y1 = int(tempBoxes[1])   # 左上角的 y 坐标x1 = int(tempBoxes[2])   # 左上角的 x 坐标y2 = int(tempBoxes[3])   # 右下角的 y 坐标x2 = int(tempBoxes[4])   # 右下角的 x 坐标print(f"类别标签: {class_id}, 左上角: ({x1}, {y1}), 右下角: ({x2}, {y2})")

输出:

类别标签: 0, 左上角: (20, 10), 右下角: (40, 30)
类别标签: 1, 左上角: (60, 50), 右下角: (80, 70)

问题2: 什么叫squeeze()移除单维度 为什么需要squeeze() 操作

squeeze() 是 PyTorch 中的一个方法,用于移除张量中大小为 1 的维度。这个操作在处理数据时非常有用,特别是在某些情况下,数据可能包含不必要的单维度

例子:假设 boxes 的形状为 (1, N, 5),其中 1 是一个单维度。使用 squeeze() 后,形状会变为 (N, 5),移除了大小为 1 的维度

为什么这里采用squeeze操作:

boxes 可能包含一个单维度,这会导致遍历和处理数据时出现问题。通过使用 squeeze(),可以确保 boxes 的形状符合预期,从而简化后续的处理

示例代码:

import torch# 示例张量,形状为 (1, 2, 5)
boxes = torch.tensor([[[0, 10, 20, 30, 40],[1, 50, 60, 70, 80]
]])print("原始形状:", boxes.shape)  # 输出: torch.Size([1, 2, 5])# 移除单维度
boxes = boxes.squeeze()print("移除单维度后的形状:", boxes.shape)  # 输出: torch.Size([2, 5])

所以论文里 squeeze的原因

boxes.squeeze() 移除了单维度,使得 boxes 的形状变为 (N, 5),这样可以方便地遍历每个边界框并进行处理

tensor的话 可能至少是三维的(我猜 是这个原因)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/51635.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring容器启动时执行代码(数据预热)

文章目录 静态代码块PostConstructInitialzingBeanCommandLineRunner和ApplicationRunnerServletContextListener执行顺序 在Java项目中&#xff0c;有时我们需要在应用启动时执行一些初始化代码&#xff0c;比如加载配置、初始化数据库连接池、预热数据等。这些操作对于应用的…

我的创作纪念日(一)——Giser?Noder?不如“Computer”

目录 Giser&#xff1f;Noder&#xff1f;不如“Computer” 一、根源&#xff1a;保持学习习惯的刚需 二、机缘&#xff1a;processOn的另类替代 三、日常&#xff1a;对技术栈丰富的思考 四、成就&#xff1a;保持心态健康的活着 五、憧憬&#xff1a;能一直心态健康的活…

前端实现【 批量任务调度管理器 】demo优化

一、前提介绍 我在前文实现过一个【批量任务调度管理器】的 demo&#xff0c;能实现简单的任务批量并发分组&#xff0c;过滤等操作。但是还有很多优化空间&#xff0c;所以查找一些优化的库&#xff0c; 主要想优化两个方面&#xff0c; 上篇提到的&#xff1a; 针对 3&…

CSS技巧专栏:一日一例 14-纯CSS实现模拟水波波动填充按钮特效

CSS技巧专栏:一日一例 14-纯CSS实现模拟水波波动填充按钮特效 大家好,今天介绍一个在网上很常见的模拟水波波动要灌满按钮的动画效果,效果下面图所示。 本例图片 案例分析 我们沿着Z轴从上到下数一下一共有几个层: 文字层:白色文字阴影的黑色文字,当鼠标移动上来时候…

黑马点评--给店铺类型查询添加缓存

controller/ShopTypeController.java /*** 店铺分类查询&#xff0c;用于展示首页头部店铺分类* return*/GetMapping("list")public Result queryTypeList() {return typeService.queryList();} service/IShopTypeService.java Result queryList(); service/impl/S…

fatal: Could not read from remote repository. 解决方法

问题描述&#xff1a; Git : fatal: Could not read from remote repository. Please make sure you have the correct access rights and the repository exists。 解决方法&#xff1a; 当在网上尝试大量方法仍然失败的时候&#xff0c;不妨试试这个方法。 在 github 上&…

java手动编译和运行程序

java手动编译和运行程序 1、无package无依赖jar public class HelloWorld {public static void main(String[] args) {System.out.println("Hello World!");} }$ javac HelloWorld.java $ java -classpath . HelloWorld # 或者 $ java -cp . HelloWorld2、有packag…

探索 Redis 不同集群架构的性能与应用

1. 引言 Redis的集群配置成为了提高数据可靠性和服务可用性的关键。本文将带领大家了解Redis的四种主要集群架构&#xff0c;并重点分析哨兵模式和Redis Cluster架构和优势。 2. Redis的四种集群架构 2.1 单实例Redis 使用单个 Redis 实例提供服务。适用于小规模应用&#…

深度学习在智慧交通中的应用:行人车辆检测与计数系统详解

引言 在现代城市中&#xff0c;行人和车辆的检测与计数对交通管理和城市规划具有重要意义。通过使用深度学习技术&#xff0c;可以实现对行人和车辆的实时检测与计数&#xff0c;提高交通管理的效率。本文将详细介绍如何构建一个基于深度学习的行人车辆检测与计数系统&#xf…

论文阅读:Deformable DETR: Deformable Transformers for End-to-End Object Detection

论文阅读&#xff1a;Deformable DETR: Deformable Transformers for End-to-End Object Detection Deformable DETR: 基于稀疏空间采样的注意力机制&#xff0c;让DCN与Transformer一起玩&#xff01; - 知乎 (zhihu.com) 【Deformable DETR 论文源码解读】Deformable Trans…

The Llama 3 Herd of Models.Llama 3 模型第1,2,3部分全文

现代人工智能(AI)系统是由基础模型驱动的。本文提出了一套新的基础模型,称为Llama 3。它是一组语言模型,支持多语言、编码、推理和工具使用。我们最大的模型是一个密集的Transformer,具有405B个参数和多达128K个tokens的上下文窗口。本文对Llama 3进行了广泛的实证评价。我们…

【error】AttributeError: module ‘cv2.dnn‘ has no attribute ‘DictValue‘(库冲突)

conda list conda remove opencv pip uninstall opencv-python conda list pip 同时卸载两个库 pip uninstall opencv-contrib-python opencv-python 没有and 直接写库名 module ‘cv2.dnn‘ has no attribute ‘DictValue‘解决办法_module cv2.dnn has no attribute d…

实分析与测度论问题的分类

实分析主要研究实数、实数序列、实数极限以及实值函数的分析&#xff0c;而度量空间则是一个具有距离函数的集合&#xff0c;其分类可以从多个角度进行。 实分析 实分析主要关注实数、实数序列、实数极限以及实值函数的分析。它涉及到多个重要的概念和理论&#xff0c;包括但…

Linux - 环境变量、程序地址空间、进程地址空间及Linux2.6内核进程调度队列

目录 环境变量 基本概念 常见环境变量 查看环境变量的方法 测试PATH 测试HOME 测试SHELL 和环境变量相关的命令 环境变量的组织方式 通过代码获取环境变量 通过系统调用获取环境变量 程序地址空间 进程地址空间 Linux2.6内核进程调度队列 一个CPU拥有一个runqueue 优先级 活…

谈一谈爬虫开发工程师

爬虫就只是抓数据的吗&#xff1f;并不是&#xff0c;爬虫工程师的工作不再仅仅是抓取数据&#xff0c;还需要处理其他各种复杂问题&#xff0c;今天我们就来聊聊爬虫开发工程师。 一、 爬虫开发工程师工作内容 爬虫开发工程师是负责编写和维护网络爬虫程序的专业人员。他们的…

Springboot与SpringSecurity使用(2):授权、自定义异常处理

一、用户授权 在SpringSecurity中&#xff0c;会使用默认的FilterSecurityInterceptor来进行权限校验。在FilterSecurityInterceptor中会从SecurityContextHolder获取其中的Authentication&#xff0c;然后获取其中的权限信息。判断当前用户是否拥有访问当前资源所需的权限。Sp…

2024-HW最新漏洞整理及相应解决方案(一)

前言&#xff1a; 漏洞是基于部分安全厂家、软件厂商的公众号或官方网站&#xff0c;以及一些非官方渠道等途径整理的HW安全漏洞情报&#xff0c;情报里附含漏洞详情和解决方案。护网期间我将持续更新分享&#xff0c;希望可以在护网期间帮助到大家 漏洞 用友U8CLOUDv3.6版本以…

c++初阶篇(七):类和对象(日期类)

1.头文件 定义了日期类&#xff0c;给出了类成员变量及成员函数的声明 #pragma once #include<iostream> #include<assert.h> using namespace std; class Date{public:friend ostream& operator<<(ostream& out, const Date& d);friend istre…

计算机网络中的 IPv6 部署与转换

背景介绍 随着互联网的迅速发展&#xff0c;IPv4 地址资源日益枯竭&#xff0c;无法满足未来互联网设备连接的需求。为了解决这一问题&#xff0c;IPv6 应运而生。IPv6&#xff08;互联网协议第六版&#xff09;提供了比 IPv4 更大的地址空间、更好的安全性和扩展性。然而&…

【多模态大模型】 ALBEF in NeurIPS 2021

一、引言 论文&#xff1a; Align before Fuse: Vision and Language Representation Learning with Momentum Distillation 作者&#xff1a; Salesforce Research 代码&#xff1a; ALBEF 特点&#xff1a; 该方法使用ViT进行图像特征提取&#xff0c;提出将BERT分两部分&am…