onnxruntime 中的 Gather 算子

上一篇文章中介绍了 Division by Invariant Integers using Multiplication 的原理,很多框架均才用该算法优化除法运算。onnxruntime 是已知实现中最为简洁的,因此本文结合 onnxruntime 的 Gather 实现进行介绍。 Gather 算子是一个索引类算子,kernel 中每个线程计算偏移时使用 fast_divmod 避免除法运算。

注意:ONNX 中的 Gather 功能与 numpy.take 相同,torch.index_select 是其简化版。而 ONNX 中的 GatherElements 与 torch.gather 和 paddle. take_along_axis 相对应。

Gather

Gather
CudaKernel
OpKernel
GatherBase

会话运行时,ExecuteKernel 函数会调用 OpKernel。
CudaKernel 是 CUDA kernel 的基类,提供了 CudaKernel::Compute 函数。
OpKernelInfo 是一个非常轻量级的类,它作为构建 Kernel 实例所需的所有数据的聚合视图。 注意:它不拥有/持有任何对象。

class Gather : public CudaKernel, public GatherBase {public:Gather(const OpKernelInfo& info) : CudaKernel(info), GatherBase(info) {}Status ComputeInternal(OpKernelContext* context) const override;
};

Gather::ComputeInternal

Gather::ComputeInternal
GatherBase::PrepareForCompute
GatherImpl

创建一个 GatherBase::Prepare 结构体,包含了两个输入和一个输出张量的指针。
GatherBase::PrepareForCompute 准备输入输出。输出张量的秩为input_rank - 1 + indices_rank,即将axis参数指定的轴替换为indices张量的形状。
ORT_RETURN_IF_ERROR 在表达式失败时返回错误。
TensorShape::SizeFromDimension 计算从指定维度开始的乘积大小。
axis参数会将输入张量划分为3部分:batch 维度、索引维度、分块维度。
block_size为每个索引对应的分块大小。
N为索引数量。
input_block_size为在输入上的分块大小。
indices_max即索引上限。

Status Gather::ComputeInternal(OpKernelContext* context) const {Prepare p;ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));const TensorShape& input_shape = p.input_tensor->Shape();const int64_t block_size = input_shape.SizeFromDimension(p.axis + 1);size_t N = p.indices_tensor->Shape().Size();const int64_t input_block_size = input_shape.SizeFromDimension(p.axis);const int64_t output_block_size = N * block_size;const int64_t indices_max = input_shape[p.axis];
  const void* input_data = p.input_tensor->DataRaw();const void* indices_data = p.indices_tensor->DataRaw();void* output_data = p.output_tensor->MutableDataRaw();if (p.output_tensor->Shape().Size() == 0) {return Status::OK();}

gsl::narrow可确保无损失转换,并在无法转换时引发gsl::narrowing_error
fast_divmod 即 DivMod,用于快速计算除法。

  const fast_divmod divmod_output_block_size(gsl::narrow_cast<int>(output_block_size));const fast_divmod divmod_block_size(gsl::narrow_cast<int>(block_size));const size_t element_size = p.input_tensor->DataType()->Size();const size_t index_element_size = p.indices_tensor->DataType()->Size();

GatherImpl 函数索仅支持int32_tint64_t引类型。
传入的p.output_tensor->Shape().Size()即输出元素总数。

  // CUDA Kernel implementation supports element sizes of:// int8_t, int16_t, int32_t and int64_t which covers all supported// types since there is no computations necessary just data movementif (p.indices_tensor->IsDataType<int32_t>() ||p.indices_tensor->IsDataType<int64_t>()) {GatherImpl(Stream(context),input_block_size,indices_max,divmod_output_block_size,divmod_block_size,indices_data,index_element_size,input_data,element_size,output_data,p.output_tensor->Shape().Size());return Status::OK();}

ORT_MAKE_STATUS 创建一个 Status 对象。

  return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Type for Tind not supported yet in Gather.");
}

GatherImpl

GatherImpl
_GatherKernel

GridDim 结构体中定义了美剧值。
N为输出元素数量。直接求出所需 threadblock 的数量,没有太多策略。

void GatherImpl(cudaStream_t stream,const int64_t input_block_size,const int64_t indices_max,const fast_divmod& output_block_size,const fast_divmod& block_size,const void* indices_data,size_t index_element_size,const void* input_data,size_t element_size,void* output_data,const size_t N) {int blocksPerGrid = (int)(ceil(static_cast<float>(N) / GridDim::maxThreadsPerBlock));

ToCudaType 模板类将类型枚举转换为数据类型。
根据元素大小调用 _GatherKernel 模板函数,这样减少了实例化类型。

  switch (element_size) {case sizeof(int8_t): {using CudaType = typename ToCudaType<int8_t>::MappedType;_GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);} break;case sizeof(int16_t): {using CudaType = typename ToCudaType<int16_t>::MappedType;_GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);} break;case sizeof(int32_t): {using CudaType = typename ToCudaType<int32_t>::MappedType;_GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);} break;case sizeof(int64_t): {using CudaType = typename ToCudaType<int64_t>::MappedType;_GatherKernel<<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, stream>>>(input_block_size, indices_max, output_block_size, block_size, indices_data, index_element_size,reinterpret_cast<const CudaType*>(input_data), reinterpret_cast<CudaType*>(output_data), (CUDA_LONG)N);} break;default:ORT_THROW("Unsupported element size by the Gather CUDA kernel");}
}

_GatherKernel

_GatherKernel
GetIndexValue

CALCULATE_ELEMENTWISE_INDEX_OR_EXIT 计算元素索引,并在超出范围时返回。

template <typename T>
__global__ void _GatherKernel(const int64_t input_block_size,const int64_t indices_max,const fast_divmod output_block_size,const fast_divmod block_size,const void* indices_data,const size_t index_element_size,const T* input_data,T* output_data,const CUDA_LONG N) {CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N);CUDA_LONG input_index = 0;

线程号id除以output_block_size得到输出元素所对应的输入分块索引input_block_index和输入分块内的偏移block_offset
根据block_offset计算对应的indices张量的索引以及分块内元素偏移。
GetIndexValue 取出indices张量的值。相比 TensorFlow 的 gather_functor_gpu.cu.h 没有进行向量化访存优化。
idx支持负数。索引值超出范围时赋零。

  int input_block_index, block_offset;output_block_size.divmod(id, input_block_index, block_offset);int indices_index, offset;block_size.divmod(block_offset, indices_index, offset);int64_t idx = GetIndexValue(indices_data, index_element_size, indices_index);idx = idx < 0 ? idx + indices_max : idx;if (idx < 0 || idx >= indices_max) {output_data[id] = 0;return;}

三部分相加得到输入张量索引。

  input_index = input_block_index * input_block_size + idx * block_size.d_ + offset;output_data[id] = input_data[input_index];
}

GetIndexValue

index_data指针转为相应类型,然后返回偏移位置的值。

__host__ __device__ inline int64_t GetIndexValue(const void* index_data, size_t index_element_size, size_t offset) {switch (index_element_size) {case sizeof(int32_t):return *(reinterpret_cast<const int32_t*>(index_data) + offset);break;case sizeof(int64_t):return *(reinterpret_cast<const int64_t*>(index_data) + offset);break;default:break;}// What is a sensible thing to do here?assert(false);return std::numeric_limits<int64_t>::max();
}

DivMod

除法取余实现基于 Division by Invariant Integers using Multiplication 中的 Figure 4.1。

在这里插入图片描述

// The code below is based on section 4 Unsigned division of paper https://gmplib.org/~tege/divcnst-pldi94.pdf
// In current ORT, fast_divmod is used for calculating the position of a element in tensor,
// so unsigned integer division from the paper is good enough for ORT. The advantage is that div is very simple,
// then GPU compiler can do loop unroll easilly when divmod is called in a loop.
template <>
struct DivMod<int> {DivMod(int d = 1) {d_ = d == 0 ? 1 : d;ORT_ENFORCE(d_ >= 1 && d_ <= static_cast<uint32_t>(std::numeric_limits<int>::max()));

l_ ℓ = ⌈ log ⁡ 2 x ⌉ \ell = \lceil \log_2 x \rceil =log2x

    for (l_ = 0; l_ < 32; l_++)if ((1U << l_) >= d_) break;

m m ′ = ⌊ 2 N ∗ ( 2 ℓ − d ) / d ⌋ + 1 m' = \lfloor 2^N ∗ (2^\ell − d)/d\rfloor + 1 m=2N(2d)/d+1

    uint64_t one = 1;uint64_t m = ((one << 32) * ((one << l_) - d_)) / d_ + 1;M_ = static_cast<uint32_t>(m);// according to paper, the value of m' should fit in a unsigned integer.ORT_ENFORCE(M_ > 0 && M_ == m);}

DivMod::div

t t 1 = M U L U H ( m ′ , n ) t_1 = \mathrm{MULUH}(m', n) t1=MULUH(m,n),使用uint64_t计算避免溢出。
对于 q q q

  • 如果 d = 1 d = 1 d=1,那么 ℓ = 0 \ell = 0 =0,所以 m ′ = 1 m' = 1 m=1 s h 1 = s h 2 = 0 sh_1 = sh_2 = 0 sh1=sh2=0。代码计算 t 1 = ⌊ 1 ∗ n / 2 N ⌋ = 0 t_1 = \lfloor 1 ∗ n/2^N \rfloor = 0 t1=1n/2N=0 q = n q = n q=n
  • d > 1 d > 1 d>1,则 ℓ ≥ 1 \ell≥1 1,故 s h 1 = 1 sh_1 = 1 sh1=1 s h 2 = ℓ − 1 sh_2 =\ell −1 sh2=1
    q = S R L ( t 1 + S R L ( n − t 1 , s h 1 ) , s h 2 ) = S R L ( t 1 + S R L ( n − t 1 , 1 ) , ℓ − 1 ) = ⌊ t 1 + ⌊ ( n − t 1 ) 2 ⌋ 2 ℓ − 1 ⌋ = ⌊ ⌊ 2 ∗ t 1 2 + ( n − t 1 ) 2 ⌋ 2 ℓ − 1 ⌋ (4.5) = ⌊ ⌊ ( t 1 + n ) / 2 ⌋ 2 ℓ − 1 ⌋ = ⌊ t 1 + n 2 ℓ ⌋ \begin{aligned} q &= \mathrm{SRL}(t_1 + \mathrm{SRL}(n − t_1, sh_1), sh_2)\\ &= \mathrm{SRL}(t_1 + \mathrm{SRL}(n − t_1, 1), \ell− 1)\\ &=\lfloor \frac{t_1 + \lfloor \frac{(n − t_1)}{2} \rfloor}{2^{\ell− 1}}\rfloor\\ &=\lfloor \frac{\lfloor \frac{2*t_1}{2} + \frac{(n − t_1)}{2} \rfloor}{2^{\ell− 1}}\rfloor \qquad\text{(4.5)}\\ &=\lfloor \frac{\lfloor(t_1 + n)/2\rfloor}{2^{\ell− 1}} \rfloor\\ &=\lfloor \frac{t_1 + n}{2^{\ell}} \rfloor \end{aligned} q=SRL(t1+SRL(nt1,sh1),sh2)=SRL(t1+SRL(nt1,1),1)=21t1+2(nt1)=2122t1+2(nt1)(4.5)=21⌊(t1+n)/2=2t1+n

__umulhi 计算两个 32 位无符号整数的乘积的最高有效 32 位。

  __host__ __device__ inline int div(int n) const {
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)uint32_t t = __umulhi(M_, n);return (t + n) >> l_;
#else// Using uint64_t for t, then t + n won't overflow.uint64_t t = ((uint64_t)M_ * n) >> 32;return static_cast<int>((t + n) >> l_);
#endif}

DivMod::mod

n m o d d = n − d ∗ ⌊ n / d ⌋ n \enspace \mathrm{mod} \enspace d = n − d ∗ \lfloor n/d \rfloor nmodd=ndn/d

  __host__ __device__ inline int mod(int n) const {return n - div(n) * d_;}

DivMod::divmod

  __host__ __device__ inline void divmod(int n, int& q, int& r) const {q = div(n);r = n - q * d_;}
  uint32_t d_;  // divisoruint32_t M_;  // m' in the paper.uint32_t l_;  // l_ = ceil(log2(d_))
};
  • Gather
  • ONNXRuntime整体概览
  • ONNXRuntime源码之OpKernel注册
  • Ways to specify [[nodiscard]] before C++17
  • microsoft/GSL
  • How to use gsl narrow cast
  • 警告 C26472
  • GSL and C++ Core Guidelines
  • Gather
  • Gather
  • tf.gather
  • torch.gather
  • paddle.gather
  • [菁英计划] 索引取值及gather函数 #36815
  • paddle. take_along_axis
  • torch.gather in pytorch.onnx and onnxruntime #31464
  • Replace torch.gather by other operator?
  • Problem compiling onnx model using GLOW compiler: constant not found
  • pytorch导出onnx的原则-以SwinTransformer和DETR在trt8.0.3.4部署为例
  • GatherElements
  • tf2onnx Gather
  • OpenVINO Gather
  • Pytorch equivalent of numpy.take()
  • torch.index_select
  • wkentaro/pytorch-for-numpy-users
  • Similar operation like numpy.take
  • numpy.take
  • tensorflow/tensorflow/core/kernels/gather_functor.h
  • tensorflow/core/kernels/gather_functor_batched.h
  • abseil中的微操

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/774466.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python(django)之单一接口展示功能前端开发

1、代码 建立apis_manage.html 代码如下&#xff1a; <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><title>测试平台</title> </head> <body role"document"> <nav c…

S7-200 SMART 选型指南及详细技术参数

S7-200 SMART 选型指南 选型指南 硬件能力 功能 CPU外形结构与电源需求计算 直流安装 交流安装 电源需求与计算 S7-200 SMART CPU模块提供5VDC和24VDC电源&#xff1a; CPU有一个内部电源&#xff0c;用于为CPU、扩展模块、信号板提供电源和满足其他24 VDC用户电源需求。请使…

鸿蒙HarmonyOS应用开发之使用Node-API接口进行线程安全开发

场景介绍 napi_create_threadsafe_function是Node-API接口之一&#xff0c;用于创建一个线程安全的JavaScript函数。主要用于在多个线程之间共享和调用&#xff0c;而不会出现竞争条件或死锁。例如以下场景&#xff1a; 异步计算&#xff1a;如果需要进行耗时的计算或IO操作&a…

Scala介绍与环境搭建

Scala环境搭建与介绍 一、Scala环境搭建 1、环境准备与下载 2、验证Scala 3、IDEA新建项目&#xff0c;配置Scala&#xff0c;运行Hello world 二、Scala介绍 1、Scala 简介 2、Scala 概述 一、Scala环境搭建 1、环境准备与下载 JDK1.8 Java Downloads | Oracle 下载需求版本…

如何将python项目转变成deb安装包

先将python项目转变成可执行文件 1. 首先确保你的python项目可以正常执行 2.安装pyinstaller模块&#xff0c;pip install pyinstaller -i Simple Index 3.确定好你的项目的文件入口&#xff0c;也就是运行的文件.py 4. 开始打包成单文件&#xff0c;pyinstaller -F <第…

STM32学习笔记(6_7)- TIM定时器的编码器接口原理

无人问津也好&#xff0c;技不如人也罢&#xff0c;都应静下心来&#xff0c;去做该做的事。 最近在学STM32&#xff0c;所以也开贴记录一下主要内容&#xff0c;省的过目即忘。视频教程为江科大&#xff08;改名江协科技&#xff09;&#xff0c;网站jiangxiekeji.com 现在开…

【Java程序设计】【C00374】基于(JavaWeb)Springboot的社区疫情管理系统(有论文)

TOC 博主介绍&#xff1a;java高级开发&#xff0c;从事互联网行业六年&#xff0c;已经做了六年的毕业设计程序开发&#xff0c;开发过上千套毕业设计程序&#xff0c;博客中有上百套程序可供参考&#xff0c;欢迎共同交流学习。 项目简介 项目获取 &#x1f345;文末点击卡片…

教学软件哪个好?这个一站式智慧教学系统值得推荐!

过去培训机构老师授课的场景主要在线下&#xff0c;可以使用大屏幕 PPT 来完成培训的交付&#xff0c;而现在随着数字化基础设施的完善&#xff0c;同时为了尽可能覆盖更多的人&#xff0c;依赖线下的培训场景也逐步转移到线上来完成&#xff0c;因此也对在线教学工具产生了需…

东方博宜 1521. 计算分数加减表达式的值

东方博宜 1521. 计算分数加减表达式的值 #include<iostream> #include<iomanip> using namespace std; int main() {double n ;cin >> n ;double sum ;sum 0.0 ;double j ;j 1.0 ;for (int i 1 ; i < n ; i){sum 1.0 / i * j ; j * -1 ;}cout <…

计算机网络01-20

计算机网络01-20 以下是本文参考的资料 欢迎大家查收原版 本版本仅作个人笔记使用1、OSI 的七层模型分别是&#xff1f;各自的功能是什么&#xff1f;2、说一下一次完整的HTTP请求过程包括哪些内容&#xff1f;孤单小弟 —— HTTP真实地址查询 —— DNS指南好帮手 —— 协议栈可…

Docker进阶:Docker Swarm —弹性伸缩调整服务的副本数量

Docker进阶&#xff1a;Docker Swarm —弹性伸缩调整服务的副本数量 1、 创建一个Nginx服务&#xff08;Manager节点&#xff09;2、查看服务状态&#xff08;Manager节点&#xff09;3、测试访问&#xff08;Worker节点&#xff09;4、查看服务日志&#xff08;Manager节点&am…

详解智慧路灯杆网关的集中供电能力

智慧路灯杆网关是智慧杆物联网系统中不可或缺的设备。智慧杆网关不仅可以作为杆载设备与云平台、设备与设备之间的桥梁&#xff0c;促进数据的无缝传输&#xff0c;而且还能提供高效的能源管理和供电功能。 BMG8200系列交流型智慧路灯杆网关就集成了强大的供电能力&#xff0c;…

浅析扩散模型与图像生成【应用篇】(十三)——PITI

13. Pretraining is All You Need for Image-to-Image Translation 该文提出一种基于预训练扩散模型的图像转换方法&#xff0c;称为PITI。其思想并不复杂&#xff0c;就是借鉴现有视觉和NLP领域中常见的预训练方法&#xff0c;考虑预先在一个大规模的任务无关数据集上对扩散模…

nginx学习记录-反向代理

1. 反向代理 一个简单的反向代理示意图如下&#xff1a; 我们的PC需要访问内网资源时&#xff0c;网关路由不直接将请求转发给内网的应用服务器&#xff0c;而是通过nginx服务器进行代理转发&#xff0c;转发到应用服务器上&#xff0c;应用服务器响应请求后会将响应数据再通过…

AJAX~

概念:AJAX(Asynchronous JavaScript And XML):异步的JavaScript和XML AJAX作用&#xff1a; 1.与服务器进行数据交换:通过AJAX可以给服务器发送请求&#xff0c;并获取服务器响应的数据 使用了AJAX和服务器进行通信&#xff0c;就可以使用HTMLAJAX来替换JSP页面了 2&#xf…

【MATLAB源码-第170期】基于matlab的BP神经网络股票价格预测GUI界面附带详细文档说明。

操作环境&#xff1a; MATLAB 2022a 1、算法描述 基于BP神经网络的股票价格预测是一种利用人工神经网络中的反向传播&#xff08;Backpropagation&#xff0c;简称BP&#xff09;算法来预测股票市场价格变化的技术。这种方法通过模拟人脑的处理方式&#xff0c;尝试捕捉股票…

欧美用户真实反馈!他们为什么选择爱可声助听器?

在竞争激烈的助听器市场上&#xff0c;爱可声助听器在欧美地区赢得了广泛的认可和好评。为什么越来越多的欧美用户选择爱可声助听器呢&#xff1f; 约翰&#xff0c;纽约的退休音乐教师 约翰是一位热爱音乐的退休音乐教师&#xff0c;他的一生都与音乐相伴&#xff0c;从年轻…

常用的AD规则设置

目录 规则编辑器&#xff1a; 间距规则&#xff1a; 线宽规则&#xff1a; 过孔规则&#xff1a; 铺铜设置&#xff1a; 生成制造过孔&#xff1a; 过孔之间间距&#xff1a; 最小阻焊层间距&#xff1a; 丝印到阻焊的距离&#xff1a; 丝印到丝印距离&#xff1a; 走…

01使用调试工具

文章目录 前言一、用openocd打开单片机二、利用4444端口向单片机写入hex文件三、利用3333端口和gdb进行调试四、之前我出的问题总结 前言 之前写了一篇关于在linux下搭建stm32标准库的文章后&#xff0c;有一些小伙伴们还是出现了一些奇奇怪怪的错误&#xff0c;这一篇文章就是…

JDK21|借鉴了近十种语言,String终于变好用了

作者:鱼仔 博客首页: https://codeease.top 公众号:Java鱼仔 前言 要想看官方对于JDK21的更新说明&#xff0c;可以直接跳转到下面这个官方网站中 官网地址为&#xff1a;https://openjdk.org/projects/jdk/21/ JDK21是最新的LTS版本&#xff0c;里面添加了不少新的特性&…