【Pytorch】cumsum的实现逻辑

本文只记录cumsum的实现逻辑的CUDA部分,也即底层调用了CUDA的什么实现算子。

void launch_cumsum_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(ScalarType::Half, ScalarType::BFloat16,self.scalar_type(), "cumsum_cuda",[&]() {scalar_t init = 0;scan_dim<scalar_t>(self,result,dim,init,std::plus<scalar_t>());});
}

通过定位源码,找到了执行kernel的关键代码,可以看到,此代码内部调用了Pytorch定义的宏,核心调用是pytorch定义的名为scan_dim的模板函数。
该模板函数的定义位于:aten/src/ATen/native/cuda/ScanUtils.cuh
代码如下:

template<typename scalar_t, typename BinaryFunction>
void scan_dim(const TensorBase& self, const TensorBase& result,int64_t dim, scalar_t init, BinaryFunction binary_op) {int ndim = self.dim();auto self_ = self.expect_contiguous();TORCH_INTERNAL_ASSERT(result.is_contiguous());if (self.numel() == self.size(dim)) {cuda::cub::inclusive_scan(self_->const_data_ptr<scalar_t>(), result.mutable_data_ptr<scalar_t>(), binary_op, self.numel());} else if (dim == ndim - 1) {scan_innermost_dim<scalar_t>(*self_, result, init, binary_op);} else {scan_outer_dim<scalar_t>(*self_, result, dim, init, binary_op);}
}

该函数内部最重要的是后面的条件结构,首先如果元素的总数和当前维度的元素个数相同,也即tensor是一维的,直接利用cub的前缀扫描方法,如果元素的总数和当前维度的元素个数不同,又分为最内层的维度,也即最后一维,以及其他情况。

template<typename scalar_t, class BinaryFunction>
__host__ void scan_outer_dim(const TensorBase& self, const TensorBase& result,int dim, scalar_t init, BinaryFunction binary_op) {const int64_t row_size = self.size(dim);auto sizes = self.sizes();// Treat all outer dimensions (i.e. dim_ < dim) as one.const int64_t num_orows = c10::multiply_integers(sizes.begin(), sizes.begin() + dim);// Treat all inner dimensions (i.e. dim > dimension) as one.const int64_t num_irows = c10::multiply_integers(sizes.begin() + dim + 1, sizes.end());dim3 threads(std::min(512, int(num_irows)));int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[1];dim3 grid(std::min(maxGridDim, num_orows), std::min(maxGridDim, ceil_div(num_irows, int64_t{threads.x})));check_fits_in_unsigned(num_irows, "num_irows");check_fits_in_unsigned(num_orows, "num_orows");check_fits_in_unsigned(row_size, "row_size");tensor_kernel_scan_outer_dim<scalar_t><<<grid, threads, 0, at::cuda::getCurrentCUDAStream()>>>(result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),num_orows, num_irows, row_size, init, binary_op);C10_CUDA_KERNEL_LAUNCH_CHECK();
}template <typename scalar_t, class BinaryFunction>
void scan_innermost_dim(const TensorBase& self, const TensorBase& result,scalar_t init, BinaryFunction binary_op) {int64_t ndim = self.dim();// Treat all outer dimensions as a single dimension.int64_t row_size = self.size(ndim - 1);int64_t num_rows = self.numel() / row_size;// assuming max_num_threads per block is 512const uint32_t num_threads = 512;const uint32_t log_num_threads_x = get_log_num_threads_x_inner_scan<uint32_t>(num_rows, row_size);const uint32_t num_threads_x = (1 << log_num_threads_x);const uint32_t num_threads_y = num_threads / num_threads_x;dim3 threads(num_threads_x, num_threads_y);int64_t maxGridDim = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];dim3 grid(std::min(maxGridDim, ceil_div(num_rows, int64_t{threads.y})));check_fits_in_unsigned(num_rows, "Number of rows (self.numel()/self.size(self.dim()-1))");check_fits_in_unsigned(row_size, "row_size");tensor_kernel_scan_innermost_dim<scalar_t><<<grid, threads, num_threads * 2 * sizeof(scalar_t),at::cuda::getCurrentCUDAStream()>>>(result.mutable_data_ptr<scalar_t>(), self.const_data_ptr<scalar_t>(),num_rows, row_size, log_num_threads_x, init, binary_op);C10_CUDA_KERNEL_LAUNCH_CHECK();
}

可以看到Pytorch针对上述两种情况进行了自定义,因为cub的inclusive_scan针对的是一维张量而非多维张量。

在调用核函数前,首先要定义调用核函数的网络结构和线程块结构,pytorch默认的线程块大小是512的,那么如何将512个线程块进行二维切分以满足合适的比例呢,pytorch中的做法是像下面这样:

template <typename integer>
constexpr inline integer get_log_num_threads_x_inner_scan(integer num_rows, integer row_size) {integer log_num_threads_x = 0;integer log_num_threads_y = 0;while (((integer)1 << log_num_threads_x) < row_size) {++log_num_threads_x;}while (((integer)1 << log_num_threads_y) < num_rows) {++log_num_threads_y;}// we want to keep the ratio between the x-threads and y-threads about the same as// the ratio between the row_size and num_rows, but the total number of threads in// a block should be about 512integer diff = log_num_threads_x - log_num_threads_y;// 9 is from log2(512)log_num_threads_x = ((integer)9 + diff) / (integer)2;// I found that in having larger log_num_threads_x can give significant speed up in some cases,// but detrimental in another case, so just keep the lower bound to be log2(16) == 4 to make it// similar to the previous implementation// Keeping the upper bound to be log2(512) == 9 as the maximum number of threads in a block.log_num_threads_x = std::min(std::max((integer)4, log_num_threads_x), (integer)9);return log_num_threads_x;
}

使用对数进行计算是便于计算出的x的结果可以整除,关键点在于最后平衡二者的比例的那行代码。可以预见,在某些情况下由于待处理数据的大小超过512造成线程块不能够完全分配的情况,此时就需要顾及线程块的比例,那么如果两个维度上线程块的对数值分别为x和y,对应的线程数分别为X,Y,也即 X = 2 x X=2^x X=2x。此时X与Y的比例 X / Y X / Y X/Y 的结果也即 2 x − y 2^{x - y} 2xy ,其实也就是 2 d i f f 2 ^ {diff} 2diff。那么如果将x变为(diff+9) / 2, y也就是 (9 - diff) / 2,二者相减也就是diff,因此保证了变换前后的比例。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/49544.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Adobe Character Animator (CH) 安装包软件下载

目录 一、软件简介 二、下载与安装 1. 下载 2. 安装 三、注意事项 1. 硬件要求 2. 兼容性 四、功能介绍 1. 实时面部捕捉 2. 实时语音同步 3. 动作捕捉 五、快捷键操作 CH 提供了一系列快捷键以方便用户快速操作。以下是一些常用的快捷键&#xff1a; 一、软件简介…

django电商用户消费数据分析系统-计算机毕业设计源码20891

摘 要 随着电子商务的快速发展&#xff0c;电商平台积累了大量的用户消费数据。为了更好地理解用户行为、优化商品结构和提升用户体验&#xff0c;本文设计并实现了一个基于Django框架的电商用户消费数据分析系统。 该系统包含后台首页、系统用户&#xff08;管理员&#xf…

Ubuntu Grub引导优化

配置文件 sudo vim /etc/default/grub修改参数 引导菜单等待时间 GRUB_TIMEOUT3自动引导上次选择的系统 如果安装了双系统或多系统&#xff0c;可以考虑配置此参数。 # 此参数默认值为0&#xff0c;引导第一个引导项 GRUB_DEFAULTsaved# 此参数默认没有&#xff0c;需要手…

Hive分布式SQL计算平台

Hive分布式SQL计算平台 一、Hive 概述二、Hive架构三、Hive客户端 1、Hive有哪些客户端可以使用2、Hive第三方客户端 四、Hive使用语法 1、数据库操作2、内部表&#xff0c;外部表3、数据的导入与导出4、分区表5、分桶表6、复杂类型操作7、数据抽样8、Virtual Columns 虚拟列9…

构建数字堡垒:面对微软蓝屏事件的反思与前瞻

目录 引言事件回顾问题解析1、技术缺陷2、安全意识不足3、应急响应机制不健全 预防措施1、加强软件测试2、提升安全意识3、建立应急响应机制4、跨行业合作 未来展望1、人工智能与机器学习2、区块链技术3、法规与政策 结语 引言 2024年的微软蓝屏事件&#xff0c;无疑是对全球I…

Samtec技术科普小课堂 | 一文入门射频连接器~

【摘要/前言】 在本文中&#xff0c;我们将回到基础知识&#xff0c;了解一下什么是射频连接器。如果您是信号完整性专家&#xff0c;请点击阅读原文访问我们的网站视频&#xff0c;通过我们的网络研讨会视频了解教科书上可能找不到的知识。 如果您是电气工程领域的新手&#…

PHP 多线程和异步编程的常见陷阱

本文由 ChatMoney团队出品 在PHP开发中&#xff0c;多线程和异步编程是提高应用性能和响应速度的重要手段。然而&#xff0c;这些技术也带来了许多挑战和陷阱&#xff0c;如共享状态冲突、死锁、超时、资源泄漏以及调试困难等。本文将详细探讨这些陷阱&#xff0c;并提供相应的…

pycharm git 新建备忘

git 提交时出现如下错误&#xff1a; Committer identity unknown *** Please tell me who you are. Run git config --global user.email "youexample.com" git config --global user.name "Your Name" to set your accounts default identity. Omit…

SQL中的函数

目录 前言 一、系统内置函数 1、数学函数 2、日期和时间函数 3、聚合函数 4、字符串函数 二、自定义函数 1、标量函数的创建与调用 2、内嵌表值函数的创建与调用 3、多语句表值函数的创建与调用 前言 函数是由一个或多个 T-SQL 语句组成的子程序&#xff0c;可用于封…

浅谈监听器之聚合报告

浅谈监听器之聚合报告 “聚合报告”&#xff08;Aggregate Report&#xff09;是JMeter中最常用且功能强大的监听器之一&#xff0c;它提供了一种简洁而全面的方式来汇总和分析测试结果。本文档旨在深入解析聚合报告的特性和使用方法&#xff0c;帮助用户更好地理解和应用这一…

【Linux】条件变量及生产者消费者模型

为什么要将这两者放在一起进行呢&#xff1f; 主要是因为生产消费与条件变量关系密切&#xff0c;正好相辅相成。 目录 条件变量&#xff1a;条件变量的引出&#xff1a;条件变量的解释与接口&#xff1a;测试代码&#xff1a; 生产者消费者模型&#xff1a;概念&#xff1a;代…

opengaussdb在oepnEuler上安装

安装前提&#xff1a; 软件环境&#xff1a;openEuler 20.03LTS 个人开发者最低配置2核4G&#xff0c;推荐配置4核8G 数据库版本&#xff1a;openGauss-5.0.2-openEuler-64bit-all.tar.gz 数据库下载地址&#xff1a; https://docs-opengauss.osinfra.cn/zh/docs/5.0.0/docs/In…

OpenSNN推文:百度沈抖:深度拥抱人工智能+,加速发展新质生产力,共创智能时代新未来

在中国联通合作伙伴大会上&#xff0c;百度集团执行副总裁、百度智能云事业群总裁沈抖发表了一场题为“深度拥抱人工智能&#xff0c;加快发展新质生产力”的精彩演讲&#xff0c;深入探讨了大模型技术在当前科技浪潮中的核心地位及其对企业生产力的深远影响。 沈抖指出&#…

【LeetCode】86.分割链表

1. 题目 2. 分析 这题没有太大难度&#xff0c;主要是熟悉代码。 3. 代码 # Definition for singly-linked list. # class ListNode: # def __init__(self, val0, nextNone): # self.val val # self.next next class Solution:def partition(self, he…

Linux系统编程-文件系统

目录 什么是Linux文件系统 文件系统的职责 存储介质抽象 inode&#xff1a;文件系统的核心 文件分配策略 目录结构 文件系统布局 日志和恢复机制 目录权限 粘滞位(t位)&#xff1a; 硬链接和符号链接 硬链接的特点&#xff1a; 创建硬链接&#xff1a; 符号链接的…

MySQL补充性文件

数据库专属单词 authentication #身份验证 delimiter #分隔符 character #字符集 collate #整理。 指定字符集的排序规则 unicode #统一码 flush #刷新 privileges #特权 string #串 set #设置 use #使用 zerofill #修饰符。0可以填补输出的值 unsigned #修饰符。无符…

德国云手机:企业移动办公解决方案

在现代商业环境中&#xff0c;移动办公已经成为一种趋势。德国云手机作为一种高效的解决方案&#xff0c;为企业提供了强大的支持。本文将探讨德国云手机如何优化企业的移动办公环境。 一、德国云手机的主要优势 高灵活性 德国云手机具有高度的灵活性&#xff0c;能够根据用户需…

Elasticsearch:Golang ECS 日志记录 - Logrus

ECS 记录器是你最喜欢的日志库的格式化程序/编码器插件。它们可让你轻松地将日志格式化为与 ECS 兼容的 JSON。 编码器以 JSON 格式记录&#xff0c;内部依赖于默认的 logrus.JSONFormatter。它还处理 ECS 错误格式的错误字段记录。 默认情况下&#xff0c;会添加以下字段&am…

【学习笔记】无人机系统(UAS)的连接、识别和跟踪(三)-架构模型和概念

引言 3GPP TS 23.256 技术规范&#xff0c;主要定义了3GPP系统对无人机&#xff08;UAV&#xff09;的连接性、身份识别、跟踪及A2X&#xff08;Aircraft-to-Everything&#xff09;服务的支持。 3GPP TS 23.256 技术规范&#xff1a; 【免费】3GPPTS23.256技术报告-无人机系…

js-toLocaleString()方法的使用(根据本地规则将对象转换为字符串-千分位或百分比等)

1.应用场景 toLocaleString() 方法是 JavaScript 中的一个内置方法&#xff0c;它可以根据本地规则将对象转换为字符串。但主要被用于 Date、Number 和 Array 对象上。 2.具体应用 2.1 date对象中使用 Date 对象上调用 toLocaleString() 方法时&#xff0c;它会根据运行代码的…