mask2former利用不确定性采样点选择提高模型性能

在机器学习和深度学习的训练过程中,不确定性高的点通常代表模型在这些点上的预测不够可靠或有较高的误差。因此,关注这些不确定性高的点,通过计算这些点的损失并进行梯度更新,可以有效地提高模型的整体性能。确定性高的点预测结果已经比较准确,相应地对模型的训练贡献较小,所以可以减少对这些点的关注或完全忽略它们的损失计算。

代码复现参考仓库:https://github.com/NielsRogge/Transformers-Tutorials

在这篇博客中,我们将详细解释 mask2former 中的一段代码,该代码通过不确定性采样点来选择重要点,并探讨其在模型训练中的重要性。mask2former原文描述比较简单,如下:
在这里插入图片描述

代码源自transformers库中的modeling_mask2former.py,主要讲解如下代码:

    def sample_points_using_uncertainty(self,logits: torch.Tensor,uncertainty_function,num_points: int,oversample_ratio: int,importance_sample_ratio: float,) -> torch.Tensor:"""This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. Theuncertainty is calculated for each point using the passed `uncertainty function` that takes points logitprediction as input.Args:logits (`float`):Logit predictions for P points.uncertainty_function:A function that takes logit predictions for P points and returns their uncertainties.num_points (`int`):The number of points P to sample.oversample_ratio (`int`):Oversampling parameter.importance_sample_ratio (`float`):Ratio of points that are sampled via importance sampling.Returns:point_coordinates (`torch.Tensor`):Coordinates for P sampled points."""num_boxes = logits.shape[0]num_points_sampled = int(num_points * oversample_ratio)# Get random point coordinatespoint_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)# Get sampled prediction value for the point coordinatespoint_logits = sample_point(logits, point_coordinates, align_corners=False)# Calculate the uncertainties based on the sampled prediction values of the pointspoint_uncertainties = uncertainty_function(point_logits)#[n1+n2, 1, 37632],理解为,值越大,不确定性越高num_uncertain_points = int(importance_sample_ratio * num_points)#9408num_random_points = num_points - num_uncertain_points#3136idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]#[n1+n2, 9408]这行代码的作用是从每个 num_boxes 的不确定性值中选择 num_uncertain_points 个最大值的索引。这些索引将用于从原始的点坐标张量 point_coordinates 中选择相应的点,这些点将被认为是基于不确定性的重要性采样点。shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)#这两行代码的主要目的是确保在从 point_coordinates 中选择点时,能够正确地访问全局索引,使得每个 box 的采样点能够准确地映射到整个张量中的位置。idx += shift[:, None]point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)#[n1+n2, 9408, 2]if num_random_points > 0:point_coordinates = torch.cat([point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],dim=1,)return point_coordinates

以下是 sample_points_using_uncertainty 函数的参数解释:

  • logits (torch.Tensor): P 个点的 logit 预测值。
  • uncertainty_function: 一个函数,接受 P 个点的 logit 预测值并返回它们的不确定性。
  • num_points (int): 需要采样的点的数量 P。
  • oversample_ratio (int): 过采样参数,用于增加采样点的数量,以确保能在不确定性采样中选到合适的点。
  • importance_sample_ratio (float): 使用重要性采样选出的点的比例。

函数步骤解释

  1. 计算总采样点数

    num_boxes = logits.shape[0]
    num_points_sampled = int(num_points * oversample_ratio)
    

    num_boxes 是指预测的盒子数量,num_points_sampled 是经过过采样之后的总采样点数。

  2. 生成随机点的坐标

    point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
    

    在 [0, 1] * [0, 1] 空间内生成随机点的坐标。

  3. 获取这些随机点的预测值

    point_logits = sample_point(logits, point_coordinates, align_corners=False)
    

    对随机点的坐标进行采样,获取它们的预测 logit 值。

  4. 计算这些点的不确定性

    point_uncertainties = uncertainty_function(point_logits)
    

    使用 uncertainty_function 计算这些点的不确定性。

  5. 确定不确定性采样和随机采样的点数

    num_uncertain_points = int(importance_sample_ratio * num_points)
    num_random_points = num_points - num_uncertain_points
    

    根据 importance_sample_ratio 确定通过不确定性采样的点数 num_uncertain_points,以及剩余的随机采样点数 num_random_points

  6. 选择不确定性最高的点

    idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
    shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
    idx += shift[:, None]
    point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
    

    使用 torch.topk 函数选择每个盒子中不确定性最高的 num_uncertain_points 个点,并获取它们的坐标。

  7. 添加随机点

    if num_random_points > 0:point_coordinates = torch.cat([point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],dim=1,)
    

    如果需要添加随机采样点,将它们与不确定性采样点合并。

  8. 返回采样点的坐标

    return point_coordinates
    

    最终返回所有采样点的坐标。

关键代码解读

1. 偏移量的生成
 shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)

这行代码的目的是为每个 box 生成一个偏移量(shift),用于转换局部索引为全局索引。

  • torch.arange(num_boxes, dtype=torch.long, device=logits.device) 生成一个从 0 到 num_boxes-1 的张量。
  • num_points_sampled 是每个 box 中采样的点的数量。
  • 乘法操作 num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) 为每个 box 生成一个偏移量。例如,假设 num_points_sampled 为 100,那么生成的偏移量张量为 [0, 100, 200, 300, ...]

这些偏移量将用于将局部索引(即每个 box 内的索引)转换为全局索引(即在整个 point_coordinates 中的索引)。

2. 局部索引转换为全局索引
  idx += shift[:, None]

这行代码将局部索引转换为全局索引。

  • idxtorch.topk 返回的不确定性最高的点的局部索引,形状为 [num_boxes, num_uncertain_points]
  • shift[:, None] 的形状是 [num_boxes, 1],通过这种方式将每个 box 的偏移量广播到与 idx 的形状匹配。

通过将 shift 加到 idx 上,每个 box 的局部索引将变成全局索引。例如,如果第一个 box 的偏移量为 100,那么第一个 box 内的局部索引 [0, 1, 2, ...] 将变为 [100, 101, 102, ...]

总结

通过 sample_points_using_uncertainty 函数,我们可以有效地选择不确定性高的点进行训练,提高模型在这些关键点上的表现,同时减少确定性高的点的计算开销。这种不确定性采样方法结合了重要性采样和随机采样,确保了模型训练的高效性和鲁棒性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/26262.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【python】tkinter GUI开发: 多行文本Text,单选框Radiobutton,复选框Checkbutton,画布canvas的应用实战详解

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

Vue3使用datav3报错问题解决

报错:Failed to resolve entry for package "dataview/datav-vue3". The package may have incorrect main/module/exports specified in its package.json. 修改package.json 修改为 "module": "./es/index.mjs",

细说MCU串口函数及使用printf函数实现串口发送数据的方法

目录 1、硬件及工程 2、串口相关的库函数 (1)串口中断服务函数: (2)串口接收回调函数: (3)串口接收中断配置函数: (4)非中断发送&#xff…

飞速(FS)InfiniBand解决方案构建HPC网络

面对HPC领域的不断发展,未来HPC业务核心在于HPC网络和基础设施。随着高性能计算应用的复杂性和数据量的增长,对弹性、可扩展和高效网络的需求变得日益迫切。HPC网络架构作为HPC系统运行的基础,在数据处理、管理和大规模存储方面至关重要。本文…

什么是 URL 过滤?是如何保障浏览体验的?

互联网是一个无边无际的空间,几乎包含了你能想象到的一切。不幸的是,这意味着也存在着从不合适到非常危险的网站。这就是 URL 过滤可以发挥作用的地方。 一、URL 过滤的含义 我们希望您已经熟悉 URL(统一资源定位器),…

Linux命令详解(2)

文本处理是Linux命令行的重要应用之一。通过一系列强大的命令,用户可以轻松地对文本文件进行编辑、查询和转换。 cat: 这个命令用于查看文件内容。它可以一次性显示整个文件,或者分页显示。此外,cat 还可以用于合并多个文件的内容…

使用winscp 通过中转机器(跳板机、堡垒机)密钥远程连接服务器,保姆级别教程

1.winscp下载地址 winscp下载 2.安装自己选择位置 3.连接服务器 到这里,基本就是没有壁垒机的就可直接连接,传递文件 4.配置中转服务器(壁垒机、跳板机) 选择高级选项 配置utf-8的编码格式 配置中转服务器(壁垒机、跳板机) 设置中专机的密码或者私钥 配置私钥

Day 16:3040. 相同分数的最大操作数目II

Leetcode 相同分数的最大操作数目II 给你一个整数数组 nums ,如果 nums 至少 包含 2 个元素,你可以执行以下操作中的 任意 一个: 选择 nums 中最前面两个元素并且删除它们。选择 nums 中最后两个元素并且删除它们。选择 nums 中第一个和最后一…

Adobe Photoshop cc快速抠图与精致抠图方法

一、背景 Photoshop cc绝对是最好用的抠图and修图软件,但是即使最简单的抠图,每次用时都忘记怎么做,然后再去B站搜,非常费时,下面记录一下抠图过程,方便查阅。 一、Adobe Photoshop快速抠图 选择——主体…

大模型基础——从零实现一个Transformer(3)

大模型基础——从零实现一个Transformer(1)-CSDN博客 大模型基础——从零实现一个Transformer(2)-CSDN博客 一、前言 之前两篇文章已经讲了Transformer的Embedding,Tokenizer,Attention,Position Encoding, 本文我们继续了解Transformer中剩下的其他组件. 二、归一化 2.1 L…

C++--DAY7

vector容器 #include <iostream> #include <vector>using namespace std; void printVector(vector<int> &v) {//定义一个迭代器 指针vector<int>::iterator iter;//v.end&#xff08;&#xff09;是最后一个元素的下一个元素地址for(iterv.begin…

申请郑州水污染防治乙级资质,这些材料你需要提前准备

申请郑州水污染防治乙级资质时&#xff0c;你需要提前准备以下材料&#xff0c;以确保申请流程的顺利进行&#xff1a; 一、企业基本材料 企业法人营业执照副本复印件&#xff1a;需加盖企业公章&#xff0c;确保复印件清晰、完整。企业章程文本&#xff1a;提供企业章程的完整…

网络安全在2024好入行吗?

前言 024年的今天&#xff0c;慎重进入网安行业吧&#xff0c;目前来说信息安全方向的就业对于学历的容忍度比软件开发要大得多&#xff0c;还有很多高中被挖过来的大佬。 理由很简单&#xff0c;目前来说&#xff0c;信息安全的圈子人少&#xff0c;985、211院校很多都才建立…

Java课程设计:基于ssm的旅游管理系统系统(内附源码)

文章目录 一、项目介绍二、项目展示三、源码展示四、源码获取 一、项目介绍 2023年处于信息科技高速发展的大背景之下。在今天&#xff0c;缺少手机和电脑几乎已经成为不可能的事情&#xff0c;人们生活中已经难以离开手机和电脑。针对增加的成本管理和操作,各大旅行社非常必要…

鸿蒙开发文件管理:【@ohos.securityLabel (数据标签)】

数据标签 该模块提供文件数据安全等级的相关功能&#xff1a;向应用程序提供查询、设置文件数据安全等级的JS接口。 说明&#xff1a; 本模块首批接口从API version 9开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。 导入模块 import security…

iOS ReactiveCocoa MVVM

学习了在MVVM中如何使用RactiveCocoa&#xff0c;简单的写上一个demo。重点在于如何在MVVM各层之间使用RAC的信号来更方便的在各个层之间进行响应式数据交互。 demo需求&#xff1a;一个登录界面(登录界面只有账号和密码都有输入&#xff0c;登录按钮才可以点击操作)&#xff0…

白酒:茅台镇白酒的口感特点与品质保障

当我们谈及中国的白酒&#xff0c;茅台镇无疑是其中璀璨的一颗明珠。而在这片神奇的土地上&#xff0c;云仓酒庄凭借其与众不同的酿造工艺和卓着的品质保障&#xff0c;成为了茅台镇白酒的新一代佼佼者。今天&#xff0c;让我们一起领略云仓酒庄豪迈白酒的口感特点和品质魅力。…

win11 修改hosts提示无权限

win11下hosts的文件路径 C:\Windows\System32\drivers\etc>hosts修改文件后提示无权限。 我做了好几个尝试&#xff0c;都没个啥用~比如&#xff1a;右键 管理员身份运行&#xff0c;在其他版本的windows上可行&#xff0c;但是win11不行&#xff0c;我用的是微软账号登录的…

【Ardiuno】实验使用ESP32单片机实现高级web服务器暂时动态图表功能(图文)

接下来&#xff0c;我们继续实验示例代码中的Wifi“高级web服务器”&#xff0c;配置相关的无线密码后&#xff0c;开始实验 #include <WiFi.h> #include <WiFiClient.h> #include <WebServer.h> #include <ESPmDNS.h>const char *ssid "XIAOFE…

高通Android 12 右边导航栏改成底部显示

最近同事说需要修改右边导航栏到底部&#xff0c;问怎么搞&#xff1f;然后看下源码尝试下。 1、Android 12修改代码路径 frameworks/base/services/core/java/com/android/server/wm/DisplayPolicy.java a/frameworks/base/services/core/java/com/android/server/wm/Display…