如何实现TensorFlow自定义算子?

在上一篇文章中 Embedding压缩之基于二进制码的Hash Embedding,提供了二进制码的tensorflow算子源码,那就顺便来讲下tensorflow自定义算子的完整实现过程。

前言

制作过程基于tensorflow官方的custom-op仓库以及官网教程,并且在Ubuntu和MacOS系统通过了测试。

官方提供的案例虽然也涵盖了整个流程,但是它过于简单,自己遇到其他需求的实现可能还得去翻阅资料。而基于上一篇文章的二进制码Hash编码的算子实现,是能够满足大部分自定义需求的,并且经过测试是支持tensorflow1.x和2.x的

文章中的代码只是展示了核心部分,并不是完整代码,全部放出来的话会显示得十分冗长。完整代码可前往下面的任一git仓库:

仅包含tensorflow自定义算子的独立仓库

自定义算子(含其他文章的代码)

目录结构

整个项目的目录结构如下,下面会对每一个文件进行讲述其作用:

├── Makefile
└── tensorflow_binary_code_hash├── BUILD├── __init__.py├── cc│   ├── kernels│   │   ├── binary_code_hash.h│   │   ├── binary_code_hash_kernels.cc│   │   ├── binary_code_hash_kernels.cu.cc│   │   └── binary_code_hash_only_cpu_kernels.cc│   └── ops│       └── binary_code_hash_ops.cc└── python├── __init__.py└── ops├── __init__.py├── binary_code_hash_ops.py└── binary_code_hash_test.py

前置依赖

make

make

g++

g++

cuda

cuda

nvcc

tensorflow

无需源码安装,pip安装的情况下已通过测试。

  1. cuda与tensorflow之间版本已兼容,直接pip安装

  2. cuda与tensorflow之间版本不兼容

    a. 新建Python环境:

    conda create -n <your_env_name> python=<x.x.x> cudatoolkit=<x.x> cudnn -c conda-forge

    b. 现有Python环境:

    conda install cudatoolkit=<x.x> cudnn -c conda-forge -n <your_env_name>

    执行以上步骤后,再进行pip安装

  3. 当然,你仍然可以选择源码编译安装: https://www.tensorflow.org/install/source

Step1. 定义运算接口

对应文件:tensorflow_binary_code_hash/cc/ops/binary_code_hash_ops.cc

这里需要将接口注册到 TensorFlow 系统,通过对 REGISTER_OP 宏的调用来可以定义运算的接口。

你可以在这里定义算子所需要的输入,和设置输出的格式。接口内容如下,主要包括两个部分:

  1. 定义输入。Input部分为输入张量,Attr部分是其他非张量的参数,Output则是输出张量。规定了输入张量hash_id和输出张量bh_id的类型是T,T为32位和64位的整型。strategy参数则是枚举,只能是succession或者skip;
  2. 在Lmabdas函数体里面可以定义输出的shape。
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"using namespace tensorflow;REGISTER_OP("BinaryCodeHash").Attr("T: {int64, int32}").Input("hash_id: T").Attr("length: int").Attr("t: int").Attr("strategy: {'succession', 'skip'}").Output("bh_id: T").SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {// 这里进行输入的校验和指定输出的shapereturn Status::OK();});

比如,输出的shape需要由输入的shape和其他参数决定,而不是官方样例里的输出跟输入的shape一样。

下面的代码则是如何获取参数的值:

int length;
c->GetAttr("length", &length);

再有获取输入的信息和输入的校验,最后指定输出的shape,在这里,可以定义动态shape,即有些维度可以是未知的size,用-1表示

// 获取输入张量的形状,并检验输入的维度数>=1
shape_inference::ShapeHandle input_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
// 获取输入张量的维度数
int input_rank = c->Rank(input_shape);
// 创建新的形状列表
std::vector<shape_inference::DimensionHandle> output_shape;
for (int i = 0; i < input_rank; ++i) {output_shape.push_back(c->Dim(input_shape, i));
}
// 添加一个额外的维度
output_shape.push_back(c->MakeDim(block_num));
// 将output_shape指定为输出张量的形状,则输出比输入多一维,类似于embedding_lookup
c->set_output(0, c->MakeShape(output_shape));

Step2. 实现运算内核

Step2.1 定义计算头文件

对应文件:tensorflow_binary_code_hash/cc/kernels/binary_code_hash.h

这里是C++的头文件,只包括计算逻辑的仿函数(函数对象)BinaryCodeHashFunctor的声明,没有具体实现

包括输入张量in和输出张量out,其他则是一些非张量参数。这里其他参数对于到时cuda运算内核就很重要,因为cuda显存的数据其实都是从内存拷贝过去的,即这些参数对应的实参,因此仿函数的参数要齐全。

#include <string>namespace tensorflow {
namespace functor {template <typename Device, typename T>
struct BinaryCodeHashFunctor {void operator()(const Device& d, int size, const T* in, T* out, int length, int t, bool succession);
};
}  // namespace functor
}  // namespace tensorflow

Step2.2 cpu运算内核

对应文件:tensorflow_binary_code_hash/cc/kernels/binary_code_hash_kernels.cc

这里主要包括三部分:

  1. 计算逻辑的仿函数具体实现
  2. 运算内核的实现类
  3. 内核注册

2.2.1 计算仿函数实现

在这里实现BinaryCodeHashFunctor具体的计算逻辑,输入张量的数据通过指针变量in来访问,然后将计算结果写入到输出张量对应的指针变量out。

这里需要注意的是输入张量和输出张量都是一维的形式,即压平的数据。

// CPU specialization of actual computation.
template <typename T>
struct BinaryCodeHashFunctor<CPUDevice, T> {void operator()(const CPUDevice& d, int size, const T* in, T* out, int length, int t, bool succession) {// 实现自己的计算逻辑}
};

2.2.2 内核实现类

在这里,运算内核实现类需要继承OpKernel,如下面的代码

  • 在构造函数里面,可以对非张量参数进行详细的检验;
  • 在Compute重载函数完成所有计算工作。
#include "binary_code_hash.h"
#include "tensorflow/core/framework/op_kernel.h"// OpKernel definition.
// template parameter <T> is the datatype of the tensors.
template <typename Device, typename T>
class BinaryCodeHashOp : public OpKernel {public:explicit BinaryCodeHashOp(OpKernelConstruction* context) : OpKernel(context) {// 参数校验}void Compute(OpKernelContext* context) override {// 实现自己的内核逻辑}private:int length_;
};

构造函数。下面的代码展示了非张量参数赋值给成员变量、参数的校验。

explicit BinaryCodeHashOp(OpKernelConstruction* context) : OpKernel(context) {OP_REQUIRES_OK(context, context->GetAttr("length", &length_));OP_REQUIRES(context, length_ > 0,errors::InvalidArgument("Need length > 0, got ", length_));
}

Compute函数

Compute函数中访问输入张量内容和输入张量检验。

const Tensor& input_tensor = context->input(0);// 检验输入张量是否为一维向量
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),errors::InvalidArgument("BinaryCodeHash expects a 1-D vector."));

Compute函数中为输出张量分配内存和定义输出的shape,在这里就不能使用动态shape,则所有维度的size都需要是明确的。

Tensor* output_tensor = NULL;
// 输出张量比输入张量多一个维度
tensorflow::TensorShape output_shape = input_tensor.shape();
output_shape.AddDim(block_num);  // Add New dimension
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output_tensor));

最后,Compute函数里面启动计算内核仿函数。这里留意下,这里喂给仿函数的实参,到时是会拷贝到显存的,即上面提到的,这里喂给cpu的数据跟后面喂给cuda的是一样的。

BinaryCodeHashFunctor<Device, T>()(context->eigen_device<Device>(),static_cast<int>(input_tensor.NumElements()),input_tensor.flat<T>().data(),output_tensor->flat<T>().data(),length_, t_, strategy_ == "succession");

2.2.3 内核注册

CPU和CPU内核都需要在这里进行注册。

这里还包括对上面运算接口定义(tensorflow_binary_code_hash/cc/ops/binary_code_hash_ops.cc)中的T进行约束,因为上面Attr中的T不属于算子函数的参数,因此需要在这里进行对应指定int32和int64。

// Register the CPU kernels.
#define REGISTER_CPU(T)                                          \REGISTER_KERNEL_BUILDER(                                       \Name("BinaryCodeHash").Device(DEVICE_CPU).TypeConstraint<T>("T"), \BinaryCodeHashOp<CPUDevice, T>);
REGISTER_CPU(int64);
REGISTER_CPU(int32);
// Register the GPU kernels.
#ifdef GOOGLE_CUDA
#define REGISTER_GPU(T)                                          \extern template struct BinaryCodeHashFunctor<GPUDevice, T>;           \REGISTER_KERNEL_BUILDER(                                       \Name("BinaryCodeHash").Device(DEVICE_GPU).TypeConstraint<T>("T"), \BinaryCodeHashOp<GPUDevice, T>);
REGISTER_GPU(int32);
REGISTER_GPU(int64);

Step2.3 cuda运算内核

对应文件:tensorflow_binary_code_hash/cc/kernels/binary_code_hash_kernels.cu.cc

这里需要包括两个东西:

  1. CUDA计算内核
  2. BinaryCodeHashFunctor仿函数的具体实现

2.3.1 CUDA计算内核

这是属于CUDA的核函数,带有声明符号__global__。与前面CPU内核中的计算仿函数类似,输入张量的数据通过指针变量in来访问,然后将计算结果写入到输出张量对应的指针变量out。但不同的是输入张量的访问涉及到CUDA中的grid、block和线程的关系,下面的代码则是简单地实现了所有数据的遍历。

// Define the CUDA kernel.
// Cann't use c++ std.
template <typename T>
__global__ void BinaryCodeHashCudaKernel(const int size, const T* in, T* out, int length, int t, bool succession) {for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;i += blockDim.x * gridDim.x) {// 实现自己的计算逻辑// out[i] = 2 * ldg(in + i);
}

Blocks, Grids, and Threads

2.3.2 CUDA内核仿函数

在这里定义了CUDA计算内核的启动,其实跟上述的CPU内核实现类,即tensorflow_binary_code_hash/cc/kernels/binary_code_hash_kernels.cc中的Compute重载函数。只是不同的是这里不需要获取输入和参数,因为CUDA是直接由CPU内存拷贝过去。

// Define the GPU implementation that launches the CUDA kernel.
template <typename T>
struct BinaryCodeHashFunctor<GPUDevice, T> {void operator()(const GPUDevice& d, int size, const T* in, T* out, int length, int t, bool succession) {// std::cout << "@@@@@@ Runnin CUDA @@@@@@" << std::endl;// Launch the cuda kernel.//// See core/util/cuda_kernel_helper.h for example of computing// block count and thread_per_block count.int block_count = 1024;int thread_per_block = 20;BinaryCodeHashCudaKernel<T><<<block_count, thread_per_block, 0, d.stream()>>>(size, in, out, length, t, succession);}
};

Step3. 编译

对应文件:Makefile

CXX := g++# 待编译的算子源码文件
BINARY_CODE_HASH_SRCS = tensorflow_binary_code_hash/cc/kernels/binary_code_hash_kernels.cc $(wildcard tensorflow_binary_code_hash/cc/kernels/*.h) $(wildcard tensorflow_binary_code_hash/cc/ops/*.cc)# 获取tensorflow的c++源码位置
TF_CFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')
TF_LFLAGS := $(shell $(PYTHON_BIN_PATH) -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_link_flags()))')# 对于新版本的tensorflow, 需要使用新标准, 比如tensorflow2.10则需指定-std=c++17
CFLAGS = ${TF_CFLAGS} -fPIC -O2 -std=c++11
LDFLAGS = -shared ${TF_LFLAGS}# 编译目标so文件位置
BINARY_CODE_HASH_GPU_ONLY_TARGET_LIB = tensorflow_binary_code_hash/python/ops/_binary_code_hash_ops.cu.o
BINARY_CODE_HASH_TARGET_LIB = tensorflow_binary_code_hash/python/ops/_binary_code_hash_ops.so# 编译命令: binary_code_hash op
binary_code_hash_op: $(BINARY_CODE_HASH_TARGET_LIB)
$(BINARY_CODE_HASH_TARGET_LIB): $(BINARY_CODE_HASH_SRCS) $(BINARY_CODE_HASH_GPU_ONLY_TARGET_LIB)$(CXX) $(CFLAGS) -o $@ $^ ${LDFLAGS}  -D GOOGLE_CUDA=1  -I/usr/local/cuda/targets/x86_64-linux/include -L/usr/local/cuda/targets/x86_64-linux/lib -lcudart

执行 make binary_code_hash_op 对算子源文件进行编译,就可以得到相关的so文件, tensorflow_binary_code_hash/python/ops/_binary_code_hash_ops.sotensorflow_binary_code_hash/python/ops/_binary_code_hash_ops.cu.o

Python调用

对应文件:tensorflow_binary_code_hash/python/ops/binary_code_hash_ops.pytensorflow_binary_code_hash/python/ops/binary_code_hash_test.py

经过上一步编译生成了算子的so文件之后,我们就可以在Python中引入自定义的算子函数进行使用。

在这两个Python文件中,包括了算子的调用和算子执行的测试单元。其中最为关键的算子导入代码如下:

from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loaderbinary_code_hash_ops = load_library.load_op_library(resource_loader.get_path_to_datafile('_binary_code_hash_ops.so'))
binary_code_hash = binary_code_hash_ops.binary_code_hash

可以直接使用make执行测试脚本:make binary_code_hash_test。也可以选择进入目录,手动执行Python脚本。

CPU版本

对于没有GPU资源的小伙伴,也提供了纯CPU版本的算子实现。

  • 定义运算接口与GPU版本通用:tensorflow_binary_code_hash/cc/ops/binary_code_hash_ops.cc
  • 实现运算内核则对应文件:tensorflow_binary_code_hash/cc/kernels/binary_code_hash_only_cpu_kernels.cc
  • 其编译命令也包含在Makefile文件中,对应执行:make binary_code_hash_cpu_only
  • 最终生成的so文件则是:tensorflow_binary_code_hash/python/ops/_binary_code_hash_cpu_ops.so

完整代码

仅包含tensorflow自定义算子的独立仓库

自定义算子(含其他文章的代码)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/231697.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode—11.盛最多水的容器【中等】

2023每日刷题&#xff08;六十三&#xff09; Leetcode—11.盛最多水的容器 实现代码 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) int maxArea(int* height, int heightSize) {int left 0, right heightSize - 1;int m…

知识蒸馏:channel wise知识蒸馏CWD

论文:https://arxiv.org/pdf/2011.13256.pdf 1. 摘要 知识蒸馏用于训练紧凑型(轻量)模型被证明是一种简单、高效的方法, 轻量的学生网络通过教师网络的知识迁移来实现监督学习。大部分的KD方法都是通过algin学生网络和教师网络的归一化的feature map, 最小化feature map上的…

数据分析思维导图

参考&#xff1a; https://zhuanlan.zhihu.com/p/567761684?utm_id0 1、数据分析步骤地图 2、数据分析基础知识地图 3、数据分析技术知识地图 4、数据分析业务流程 5、数据分析师能力体系 6、数据分析思路体系 7、电商数据分析核心主题 8、数据科学技能书知识地图 9、数据挖掘…

MATLAB 点云SVD分解计算平面法向量 (41)

MATLAB 点云SVD分解计算平面法向量 (41) 一、算法介绍二、算法实现一、算法介绍 算法主要是采用SVD分解矩阵的方法,计算平面的法向量。 二、算法实现 % 加载点云数据 ptCloud = pcread(D:\shuju\近似平面点集合2.pcd);% 计算点云质心 centroid = mean(ptCloud.</

TensorFlow 2 和 Keras 之间的区别总结

1、什么是TensorFlow 2 TensorFlow 2是谷歌开源的一款深度学习框架&#xff0c;于2019年发布&#xff0c;并且在同年10月1日发布了TensorFlow 2.0.0正式稳定版。这款框架被很多企业与创业公司广泛用于自动化工作任务和开发新系统。 TensorFlow 2在分布式训练支持、可扩展的生…

python使用ctypes访问Windows原生API

在Windows系统中&#xff0c;C语言编写的动态链接库&#xff08;DLL&#xff09;是一种可由多个程序同时使用的代码和数据共享库。DLL文件包含了一些可以被其他程序调用的函数和数据。这些DLL文件通常与应用程序一起发布&#xff0c;并在需要时被加载到内存中&#xff0c;以便应…

【玩转 TableAgent 数据智能分析】股票交易数据分析+预测

文章目录 一、什么是TableAgent二、TableAgent 的特点三、实践前言四、实践准备4.1 打开官网4.2 注册账号4.3 界面介绍4.4 数据准备 五、确认分析需求六、TableAgent体验七、分析结果解读八、总结&展望 一、什么是TableAgent TableAgent是一款面向企业用户的智能数据分析工…

HTML中边框样式、内外边距、盒子模型尺寸计算(附代码图文示例)【详解】

Hi i,m JinXiang ⭐ 前言 ⭐ 本篇文章主要介绍HTML中边框样式、内外边距、盒子模型尺寸计算以及部分理论知识 &#x1f349;欢迎点赞 &#x1f44d; 收藏 ⭐留言评论 &#x1f4dd;私信必回哟&#x1f601; &#x1f349;博主收将持续更新学习记录获&#xff0c;友友们有任何问…

用23种设计模式打造一个cocos creator的游戏框架----(二十)解析器模式

1、模式标准 模式名称&#xff1a;解析器模式 模式分类&#xff1a;行为型 模式意图&#xff1a;给定一个语言&#xff0c;定义它的文法的一种表示&#xff0c;并定义一个解释器&#xff0c;这个解释器使用该表示来解释语言中的句子。 结构图&#xff1a; 适用于&#xff1…

K8S(十一)—Service详解

目录 Service发布服务&#xff08;服务类型&#xff09;type: ClusterIP选择自己的 IP 地址例子 type: NodePort选择你自己的端口为 type: NodePort 服务自定义 IP 地址配置例子 type: LoadBalancer混合协议类型的负载均衡器禁用负载均衡器节点端口分配设置负载均衡器实现的类别…

java之HikariCP连接池介绍和使用方法 简单易懂!!!

文章目录 一、HikariCP连接池介绍二、导入的jar包三、代码演示配置文件使用配置文件连接运行结果 一、HikariCP连接池介绍 在我们的工作中&#xff0c;免不了要和数据库打交道&#xff0c;而要想和数据库打好交道&#xff0c;选择一款合适的数据库连接池就至关重要&#xff0c…

软件试运行整体方案

一、 试运行目的 &#xff08;一&#xff09; 系统功能、性能与稳定性考核 &#xff08;二&#xff09; 系统在各种环境和工况条件下的工作稳定性和可靠性 &#xff08;三&#xff09; 检验系统实际应用效果和应用功能的完善 &#xff08;四&#xff09; 健全系统运行管理体…

网神防火墙后台用户敏感信息泄露漏洞复现

简介 网神防火墙是一款由中国知名网络安全公司启明星辰开发的防火墙产品。它提供了全面的网络安全防护功能,旨在保护企业网络免受各种网络威胁和攻击。 该产品存在用户账号信息泄露漏洞,通过构造特定数据包,获取防火墙管理员登录的账号密码。 漏洞复现 FOFA语法: body=&…

6TIM定时器

STM32的定时器功能众多&#xff0c;拥有基本定时功能&#xff0c;输出比较功能&#xff08;如产生PWM波等&#xff09;&#xff0c;输入捕获&#xff08;测量方波信号&#xff09;&#xff0c;读取正交编码器的波形。 1.中断原理 TIM定时器的基本功能是对输入的时钟进行计数&…

vue使用xlsx和xlsx-style导出xlsx文件并修改样式

1.下载依赖 npm install xlsx --save npm install file-saver --save npm install xlsx-style --save2.先修改xlsx-style的源码&#xff0c;一旦引入xlsx-style则会报错 在\node_modules\xlsx-style\dist\cpexcel.js 807行 的 var cpt require(’./cpt’ ‘able’); 改成 v…

Python如何画函数图像

1 问题 通过图像可以直观地学习函数变化&#xff0c;在学习函数等方面效果显著。下面我们尝试用Python的2D绘图库matplotlib来绘制函数图像。实现 yx*x 图象。 2 方法 用文字描述解题思路&#xff0c;可配合一些图形以便更好的阐述。解决问题的步骤采用如下方式&#xff1a; …

100GPTS计划-AI写诗PoetofAges

地址 https://chat.openai.com/g/g-Cd5daC0s5-poet-of-ages https://poe.com/PoetofAges 测试 创作一首春天诗歌 创作一首夏天诗歌 创作一首秋天诗歌 创作一首冬天诗歌 微调 诗歌风格 语气&#xff1a;古典 知识库

掌握 Babel:让你的 JavaScript 与时俱进(下)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

蓝桥杯嵌入式——LED

原理&#xff1a;PD2为使能端&#xff0c;高电平使能。使能的时候&#xff0c;给PC8-PC15高电平即可点亮LED。 CUBE里将这些端口设置为GPIO输出模式&#xff0c;将PC8-15初始电平设置为高电平(这样一上电就不会亮),PD2默认低电平&#xff0c;不用管&#xff0c;然后生成代码即可…

算法通关第十九关-青铜挑战理解动态规划

大家好我是苏麟 , 今天聊聊动态规划 . 动态规划是最热门、最重要的算法思想之一&#xff0c;在面试中大量出现&#xff0c;而且题目整体都偏难一些对于大部人来说&#xff0c;最大的问题是不知道动态规划到底是怎么回事。很多人看教程等&#xff0c;都被里面的状态子问题、状态…