大模型自定义算子优化方案学习笔记:CUDA算子定义、算子编译、正反向梯度实现

01算子优化的意义

随着大模型应用的普及以及算力紧缺,下一步对于计算性能的追求一定是技术的核心方向。因为目前大模型的计算逻辑是由一个个独立的算子或者说OP正反向求导实现的,底层往往调用的是GPU提供的CUDA的驱动程序。如果不能对于整个计算过程学习并了解,对于性能优化领域无非是隔靴搔痒,今天也是抽一点时间读了下网上的一些文档和CUDA的文档,整理了学习材料。

首先说下为什么要自定义算子,无非是两个原因,

(1)TF、PyTorch等提供的原生算子不满足需求,需要通过底层接口,比如CUDA层更灵活的实现个性化算子

(2)目前提供的算子性能不足,没有很好的利用到GPU的并行计算优势,有优化空间

接着性能优化的问题说,因为GPU与CPU最大的区别是计算单元占据了绝大部分的体积(图中绿色部分),而控制单元较少。

自定义手写算子可以更好地利用绿色的计算单元的并行化优势。大的思路是Grid包含Block,Block包含Thread。于是首先算子需要把计算逻辑拆分成Thread,让程序可以并行化的运行起来,然后有机的管理各个Block的执行节奏,解决好异步和同步问题,就可以让芯片的计算效率最大化。

Grid of Thread Blocks

02实现流程

整个自定义算子优化其实可以分为CUDA算子定义、算子编译、正方向梯度实现几个步骤。

1、CUDA算子定义

需要定义以下几个文件:

(1)function.cpp:python层和CUDA层的衔接,实现手写的C++ CUDA算子可以被python代码调用

(2)function.h:CUDA函数声明文件

(3)function.cu:CUDA函数的逻辑实现

这里比较核心的文件就是.cu文件,构建的时候主要做两个事:一个是建设Kernel函数,因为只有Kernel函数是在GPU端执行,执行完之后要将控制权给到控制函数,这里要控制好异步、同步的问题。二是在kernel函数中需要通过循环函数定义每个Thread以及每个Block的工作,真正让计算并行化

.cpp文件可以通过pybind函数实现,这个函数主要解决的是C++代码和Python绑定的问题,项目地址:GitHub - pybind/pybind11: Seamless operability between C++11 and Python

2.编译和执行

import torch
from torch.utils.cpp_extension import load
cuda_module = load(name="function",extra_include_paths=["include"],sources=["function.cpp", "function.cu"],verbose=True)

接着就是执行过程中的编译,通过load函数底层会执行JIT(Just in time)的动态编译模式调用.cpp和.cu文件,底层运行的是Ninjia编译器,通过调用nvcc编译.so文件

[1/2] nvcc -c function.cu -o function.cuda.o
[2/3] c++ -c function.cpp -o function.o
[3/3] c++ function.o function.cuda.o -shared -o function.so

3.正反向梯度实现

以上两步实现了自定义算子的逻辑,可以通过手写CUDA算子并在python框架层调用,如果要满足建模需求,还需要实现正方向梯度。具体的做法是在建模的函数中实现forward和backward函数,定义导数作为输出。

以上大概就是手写算子优化的简单流程,仅当学习笔记。

参考:

【1】熬了几个通宵,我写了份CUDA新手入门代码 - 知乎

【2】CUDA C++ Programming Guide

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/228454.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

货物数据处理pandas版

1求和 from openpyxl import load_workbook import pandas as pddef print_hi(name):# Use a breakpoint in the code line below to debug your script.print(fHi, {name}) # Press CtrlF8 to toggle the breakpoint.# Press the green button in the gutter to run the scr…

【C++学习————引用】

【C学习——————引用】 欢迎阅读新一期的c模块————引用 ✒️个人主页:-Joker- 🏷️专栏:C 📜代码仓库:c_code 🌹🌹欢迎大佬们的阅读和三连关注,顺着评论回访🌹&a…

对可恢复的情况使用受检异常

在Java中,受检异常(Checked Exception)通常用于表示程序能够预期并且可能进行恢复的异常情况。这类异常是在编译时由编译器强制进行处理的,使得程序员必须显式处理这些异常,或者在方法签名中使用 throws 关键字声明。 …

react之项目打包,本地预览,路由懒加载,打包体积分析以及如何配置CDN

react之项目打包,本地预览,路由懒加载,打包体积分析以及如何配置CDN 一、项目打包二、项目本地预览三、路由懒加载四、打包体积分析五、配置CDN 一、项目打包 执行命令 npm run build根目录下生成的build文件夹 及时打包后的文件 二、项目本地预览 1.全局安装本地服务包 npm…

【Linux】介绍:进程退出、进程等待、进程程序替换

目录 一、进程退出 _exit函数 exit函数 _exit()与exit比较 return退出 二、进程等待 wait方法 waitpid方法 三、进程程序替换 替换函数 函数解释 命名理解 使用举例 一、进程退出 正常终止(可以通过 echo $? 查看进程退出码):1.…

Ubuntu22.04切换用户

一、只有一个用户时没有切换用户菜单项 1、用户信息 cat /etc/passwd 2、系统菜单 二、添加用户 添加新用户ym,全名yang mi 三、有两个及以上的用户时出现切换用户菜单项 1、用户信息 cat /etc/passwd 2、系统菜单 四、切换用户 1、点击上图中Switch User …

爬虫 scrapy ——scrapy shell调试及下载当当网数据(十一)

目录 一、scrapy shell 1.什么是scrapy shell? 2.安装 ipython 3.使用scrapy shell 二、当当网案例 1.在items.py中定义数据结构 2.在dang.py中解析数据 3.使用pipeline保存 4.多条管道的使用 5.多页下载 参考 一、scrapy shell 1.什么是scrapy shell&am…

【如何提取React项目中的公共模块,多个项目共用】

文章目录 目录 前言 一、创建公共模块 二、初始化公共模块 三、给公共模块添加内容 四、添加对公共模块的依赖 五、使用公共模块里的资源 后记 前言 在工作中经常会遇到这样的需求,有个React项目,代码分为客户端,管理端两份&#xff…

Vue3-21-组件-子组件给父组件发送事件

情景描述 【子组件】中有一个按钮,点击按钮,触发一个事件, 我们希望这个事件的处理逻辑是,给【父组件】发送一条消息过去, 从而实现 【子组件】给【父组件】通信的效果。这个问题的解决就是 “发送事件” 这个操作。 …

arthas获取spring bean

参考文章 arthas获取spring bean 写一个工具Util package com.example.lredisson.util;import org.springframework.beans.BeansException; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContextAware; import o…

HarmonyOS给应用添加消息通知

给您的应用添加通知 通知介绍 通知旨在让用户以合适的方式及时获得有用的新消息,帮助用户高效地处理任务。应用可以通过通知接口发送通知消息,用户可以通过通知栏查看通知内容,也可以点击通知来打开应用,通知主要有以下使用场景…

Cell Systems | 深度学习开启蛋白质设计新时代

今天为大家介绍的是来自Bruno Correia团队的一篇综述。深度学习领域的迅速进步对蛋白质设计产生了显著影响。最近,深度学习方法在蛋白质结构预测方面取得了重大突破,使我们能够得到数百万种蛋白质的高质量模型。结合用于生成建模和序列分析的新型架构&am…

相机倾斜棋盘格标定全记录 vs200+opencv安装

论文参考是这个 Geiger A, Moosmann F, Car , et al. Automatic camera and range sensor calibration using a single shot[C]//Robotics and Automation (ICRA), 2012 IEEE International Conference on. IEEE, 2012: 3936-3943. 代码是这个github 花了一上午配好了c环境。。…

Flink系列之:SQL提示

Flink系列之:SQL提示 一、动态表选项二、语法三、例子四、查询提示五、句法六、加入提示七、播送八、随机散列九、随机合并十、嵌套循环十一、LOOKUP十二、进一步说明十三、故障排除十四、连接提示中的冲突案例十五、什么是查询块 SQL 提示可以与 SQL 语句一起使用来…

Sci. Rep. | 一个对任意分子体系实现准确且高效几何深度学习的通用框架

这篇工作是来自纽约城市大学/康奈尔医学院谢磊团队的一篇论文。作者提出了一个通用框架,PAMNet,可以对任意分子体系实现准确且高效的几何深度学习。在小分子性质、RNA三维结构以及蛋白质-配体结合亲和力的预测任务上,PAMNet在准确性和效率方面…

【ESXi】ESXi 版本回退

目录 8. ESXi 版本回退8.1 版本回退条件与注意事项8.2 版本回退步骤8.3 示例演示(1)准备工作(2)进入DCUI界面(3)按 F11 重启系统引导(4)进入引导选项(5)进入 …

弧形导轨的精度等级

为符合工控自动化生产制造必须,弧形导轨在运输武器装备领域应时而生,并已在电子生产制造、手机上、半导体材料、动力锂电池等领域获得广泛运用。其中,弧形导轨的精度等级是评估其运动精度的重要指标,通常包括制造精度和运行精度两…

Flink系列之:大状态与 Checkpoint 调优

Flink系列之:大状态与 Checkpoint 调优 一、概述二、监控状态和 Checkpoints三、Checkpoint 调优四、RocksDB 调优五、增量 Checkpoint六、RocksDB 或 JVM 堆中的计时器七、RocksDB 内存调优八、容量规划九、压缩十、Task 本地恢复十一、主要(分布式存储…

spring-kakfa依赖管理之org/springframework/kafka/listener/CommonErrorHandler错误

问题: 整个项目使用spring-boot2.6.8版本,使用gradle构建,在common模块指定了implementation org.springframework.kafka:spring-kafka:2.6.8’这个工程也都能运行(这正常发送kafka消息和接收消息),但是执行…

java --- 集合进阶

目录 一、单列集合顶层接口 Collection 1.1 基本方法 1.2 Collection 的遍历方式 二、list集合 1.2 ArrayList Vector 底层结构 1.3 LinkedList ArrayList 和 LinkedList 比较 三、set接口 3.1、Set 接口和常用方法 3.2 HashSet HashSet 底层机制(HashMap…