TVM:使用 Auto-scheduling 来优化算子

TVM:使用 Auto-scheduling 来优化算子

在本教程中,我们将展示 TVM 的 Auto-scheduling 功能如何在无需编写自定义模板的情况下找到最佳 schedule。

与基于模板的 AutoTVM 依赖手动模板定义搜索空间不同,auto-scheduler 不需要任何模板。 用户只需编写计算声明,无需任何调度命令或模板。 auto-scheduler 可以自动生成一个大的搜索空间,并在该空间中找到一个好的 schedule。

我们在本教程中同样使用矩阵乘法作为示例。

import osimport numpy as np
import tvm
from tvm import te, auto_scheduler

定义矩阵乘法

首先,我们定义一个带有偏置的矩阵乘法。 请注意,这使用了 TVM 张量表达式语言中可用的标准操作。 主要区别在于在函数定义的开始使用了 auto_sceduler 装饰器。 该函数应返回输入/输出张量列表。 从这些张量中,自动调度器可以获得整个计算图。

@auto_scheduler.register_workload  # Note the auto_scheduler decorator
def matmul_add(N, L, M, dtype):A = te.placeholder((N, L), name="A", dtype=dtype)B = te.placeholder((L, M), name="B", dtype=dtype)C = te.placeholder((N, M), name="C", dtype=dtype)k = te.reduce_axis((0, L), name="k")matmul = te.compute((N, M),lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),name="matmul",attrs={"layout_free_placeholders": [B]},  # enable automatic layout transform for tensor B)out = te.compute((N, M), lambda i, j: matmul[i, j] + C[i, j], name="out")return [A, B, C, out]

创建搜索任务

定义函数后,我们现在可以创建供 auto_scheduler 搜索的任务。 我们指定此矩阵乘法的特定参数,在本例中为 1024x1024 大小的方阵的乘法。 然后我们创建一个搜索任务,其中 N=L=M=1024 ,数据类型为 ”float32”。

target = tvm.target.Target("llvm")
N = L = M = 1024
task = tvm.auto_scheduler.SearchTask(func=matmul_add, args=(N, L, M, "float32"), target=target)# Inspect the computational graph
print("Computational DAG:")
print(task.compute_dag)

注意:自定义 target 可以提高性能

为了让 TVM 充分利用特定硬件平台,您需要手动指定 CPU 功能。 例如: - 将下面的“llvm”替换为“llvm -mcpu=core-avx2”以启用 AVX2 - 将下面的“llvm”替换为“llvm -mcpu=skylake-avx512”以启用 AVX-512

此处输出:

Computational DAG:
A = PLACEHOLDER [1024, 1024]
B = PLACEHOLDER [1024, 1024]
matmul(i, j) += (A[i, k]*B[k, j])
C = PLACEHOLDER [1024, 1024]
out(i, j) = (matmul[i, j] + C[i, j])

为 Auto-Scheduler 设置参数

接下来,我们为自动调度程序设置参数。

  • num_measure_trials 是我们在搜索过程中可以使用的测量试验次数。 为了快速演示,我们在本教程中仅进行了 10 次试验。 在实践中,1000 是一个很好的搜索收敛值。 您可以根据您的时间预算进行更多试验。

  • 此外,我们使用 RecordToFile 将测量记录记录到文件 matmul.json 中。 测量记录可用于最佳查询历史记录、恢复搜索以及稍后进行更多分析。

  • 有关更多参数,请参阅 auto_scheduler.TuningOptions

log_file = "matmul.json"
tune_option = auto_scheduler.TuningOptions(num_measure_trials=10,measure_callbacks=[auto_scheduler.RecordToFile(log_file)],verbose=2,
)

运行搜索

现在我们准备好所有输入。 很简单,不是吗? 我们可以开始搜索并让自动调度程序发挥它的魔力。 经过一些测量试验后,我们可以从日志文件中加载最佳计划并应用它。

# Run auto-tuning (search)
task.tune(tune_option)
# Apply the best schedule
sch, args = task.apply_best(log_file)

检查优化过的 Schedule

我们可以在 auto-scheduling 后降低(lower)schedule 以查看 IR。 auto-schduling 程序正确执行优化,包括多级平铺、布局转换、并行化、矢量化、展开和算子融合。

print("Lowered TIR:")
print(tvm.lower(sch, args, simple_mode=True))

此处输出:

Lowered TIR:
primfn(A_1: handle, B_1: handle, C_1: handle, out_1: handle) -> ()attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}buffers = {out: Buffer(out_2: Pointer(float32), float32, [1024, 1024], []),A: Buffer(A_2: Pointer(float32), float32, [1024, 1024], []),C: Buffer(C_2: Pointer(float32), float32, [1024, 1024], []),B: Buffer(B_2: Pointer(float32), float32, [1024, 1024], [])}buffer_map = {A_1: A, B_1: B, C_1: C, out_1: out} {allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [1048576]), storage_scope = global {for (ax0.ax1.fused.ax2.fused: int32, 0, 128) "parallel" {for (ax4: int32, 0, 256) {for (ax6: int32, 0, 4) {for (ax7: int32, 0, 8) {auto_scheduler_layout_transform[((((ax0.ax1.fused.ax2.fused*8192) + (ax4*32)) + (ax6*8)) + ax7)] = (float32*)B_2[((((ax4*4096) + (ax6*1024)) + (ax0.ax1.fused.ax2.fused*8)) + ax7)]}}}}for (i.outer.outer.j.outer.outer.fused: int32, 0, 16384) "parallel" {allocate(matmul: Pointer(global float32x8), float32x8, [4]), storage_scope = global;for (i.outer.inner: int32, 0, 2) {matmul[ramp(0, 1, 8)] = broadcast(0f32, 8)matmul[ramp(8, 1, 8)] = broadcast(0f32, 8)matmul[ramp(16, 1, 8)] = broadcast(0f32, 8)matmul[ramp(24, 1, 8)] = broadcast(0f32, 8)for (k.outer: int32, 0, 256) {for (k.inner: int32, 0, 4) {matmul[ramp(0, 1, 8)] = ((float32x8*)matmul[ramp(0, 1, 8)] + (broadcast((float32*)A_2[((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))matmul[ramp(8, 1, 8)] = ((float32x8*)matmul[ramp(8, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 1024)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))matmul[ramp(16, 1, 8)] = ((float32x8*)matmul[ramp(16, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 2048)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))matmul[ramp(24, 1, 8)] = ((float32x8*)matmul[ramp(24, 1, 8)] + (broadcast((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (k.outer*4)) + k.inner) + 3072)], 8)*(float32x8*)auto_scheduler_layout_transform[ramp((((floormod(i.outer.outer.j.outer.outer.fused, 128)*8192) + (k.outer*32)) + (k.inner*8)), 1, 8)]))}}for (i.inner: int32, 0, 4) {out_2[ramp(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (i.inner*1024)) + (floormod(i.outer.outer.j.outer.outer.fused, 128)*8)), 1, 8)] = ((float32x8*)matmul[ramp((i.inner*8), 1, 8)] + (float32x8*)C_2[ramp(((((floordiv(i.outer.outer.j.outer.outer.fused, 128)*8192) + (i.outer.inner*4096)) + (i.inner*1024)) + (floormod(i.outer.outer.j.outer.outer.fused, 128)*8)), 1, 8)])}}}}
}

检查正确性并评估性能

我们构建二进制文件并检查其正确性和性能。

func = tvm.build(sch, args, target)
a_np = np.random.uniform(size=(N, L)).astype(np.float32)
b_np = np.random.uniform(size=(L, M)).astype(np.float32)
c_np = np.random.uniform(size=(N, M)).astype(np.float32)
out_np = a_np.dot(b_np) + c_npdev = tvm.cpu()
a_tvm = tvm.nd.array(a_np, device=dev)
b_tvm = tvm.nd.array(b_np, device=dev)
c_tvm = tvm.nd.array(c_np, device=dev)
out_tvm = tvm.nd.empty(out_np.shape, device=dev)
func(a_tvm, b_tvm, c_tvm, out_tvm)# Check results
np.testing.assert_allclose(out_np, out_tvm.numpy(), rtol=1e-3)# Evaluate execution time.
evaluator = func.time_evaluator(func.entry_name, dev, min_repeat_ms=500)
print("Execution time of this operator: %.3f ms"% (np.median(evaluator(a_tvm, b_tvm, c_tvm, out_tvm).results) * 1000)
)

此处输出:

Execution time of this operator: 45.418 ms

使用记录文件

在搜索过程中,所有的测量记录都被记录到记录文件“matmul.json”中。 测量记录可用于重新应用搜索结果、恢复搜索和执行其他分析。

这是一个示例,我们从文件加载最佳 schedule,并打印等效的 Python schedule API。 这可用于调试和学习 auto-scheduling 程序的行为。

print("Equivalent python schedule:")
print(task.print_best(log_file))

此处输出:

Equivalent python schedule:
matmul_i, matmul_j, matmul_k = tuple(matmul.op.axis) + tuple(matmul.op.reduce_axis)
out_i, out_j = tuple(out.op.axis) + tuple(out.op.reduce_axis)
matmul_i_o_i, matmul_i_i = s[matmul].split(matmul_i, factor=4)
matmul_i_o_o_i, matmul_i_o_i = s[matmul].split(matmul_i_o_i, factor=1)
matmul_i_o_o_o, matmul_i_o_o_i = s[matmul].split(matmul_i_o_o_i, factor=2)
matmul_j_o_i, matmul_j_i = s[matmul].split(matmul_j, factor=8)
matmul_j_o_o_i, matmul_j_o_i = s[matmul].split(matmul_j_o_i, factor=1)
matmul_j_o_o_o, matmul_j_o_o_i = s[matmul].split(matmul_j_o_o_i, factor=1)
matmul_k_o, matmul_k_i = s[matmul].split(matmul_k, factor=4)
s[matmul].reorder(matmul_i_o_o_o, matmul_j_o_o_o, matmul_i_o_o_i, matmul_j_o_o_i, matmul_k_o, matmul_i_o_i, matmul_j_o_i, matmul_k_i, matmul_i_i, matmul_j_i)
out_i_o_i, out_i_i = s[out].split(out_i, factor=4)
out_i_o_o, out_i_o_i = s[out].split(out_i_o_i, factor=2)
out_j_o_i, out_j_i = s[out].split(out_j, factor=8)
out_j_o_o, out_j_o_i = s[out].split(out_j_o_i, factor=1)
s[out].reorder(out_i_o_o, out_j_o_o, out_i_o_i, out_j_o_i, out_i_i, out_j_i)
s[matmul].compute_at(s[out], out_j_o_i)
out_i_o_o_j_o_o_fused = s[out].fuse(out_i_o_o, out_j_o_o)
s[out].parallel(out_i_o_o_j_o_o_fused)
s[matmul].pragma(matmul_i_o_o_o, "auto_unroll_max_step", 8)
s[matmul].pragma(matmul_i_o_o_o, "unroll_explicit", True)
s[matmul].vectorize(matmul_j_i)
s[out].vectorize(out_j_i)

一个更复杂的例子是恢复搜索。 在这种情况下,我们需要自己创建搜索策略和成本模型,并通过日志文件恢复搜索策略和成本模型的状态。 在下面的示例中,我们恢复状态并再进行 5 次试验。

def resume_search(task, log_file):print("Resume search:")cost_model = auto_scheduler.XGBModel()cost_model.update_from_file(log_file)search_policy = auto_scheduler.SketchPolicy(task, cost_model, init_search_callbacks=[auto_scheduler.PreloadMeasuredStates(log_file)])tune_option = auto_scheduler.TuningOptions(num_measure_trials=5, measure_callbacks=[auto_scheduler.RecordToFile(log_file)])task.tune(tune_option, search_policy=search_policy)resume_search(task, log_file)

此处输出:

Resume search:
/usr/local/lib/python3.6/dist-packages/xgboost/training.py:17: UserWarning: Old style callback is deprecated.  See: https://xgboost.readthedocs.io/en/latest/python/callbacks.htmlwarnings.warn(f'Old style callback is deprecated.  See: {link}', UserWarning)

总结

在本教程中,我们展示了如何使用 TVM Auto-Scheduler 自动优化矩阵乘法,而无需指定搜索模板。 它结束了一系列从张量表达式 (TE) 语言开始的示例,这些示例演示了 TVM 如何优化计算操作。

Ref:

https://tvm.apache.org/docs/tutorial/auto_scheduler_matmul_x86.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532663.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言—sort函数比较大小的快捷使用--algorithm头文件下

sort函数 一般情况下要将一组数从的大到小排序或从小到大排序&#xff0c;要定义一个新的函数排序。 而我们也可以直接使用在函数下的sort函数&#xff0c;只需加上头文件&#xff1a; #include<algorithm> using namespace std;sort格式&#xff1a;sort(首元素地址&…

散列的使用

散列 散列简单来说&#xff1a;给N个正整数和M个负整数&#xff0c;问这M个数中的每个数是否在N中出现过。 比如&#xff1a;N&#xff1a;{1,2,3,4}&#xff0c;M{2,5,7}&#xff0c;其中M的2在N中出现过 对这个问题最直观的思路是&#xff1a;对M中每个欲查的值x&#xff0…

关于C++中的unordered_map和unordered_set不能直接以pair作为键名的问题

关于C中的unordered_map和unordered_set不能直接以pair作为键名的问题 在 C STL 中&#xff0c;不同于有序的 std::map 和 std::set 是基于红黑树实现的&#xff0c;std::unordered_map 和 std::unordered_set 是基于哈希实现的&#xff0c;在不要求容器内的键有序&#xff0c…

AI编译器与传统编译器的联系与区别

AI编译器与传统编译器的区别与联系 总结整理自知乎问题 针对神经网络的编译器和传统编译器的区别和联系是什么&#xff1f;。 文中提到的答主的知乎主页&#xff1a;金雪锋、杨军、蓝色、SunnyCase、贝壳与知了、工藤福尔摩 笔者本人理解 为了不用直接手写机器码&#xff0…

python学习1:注释\变量类型\转换函数\转义字符\运算符

python基础学习 与大多数语言不同&#xff0c;python最具特色的就是使用缩进来表示代码块&#xff0c;不需要使用大括号 {} 。缩进的空格数是可变的&#xff0c;但是同一个代码块的语句必须包含相同的缩进空格数。 &#xff08;一个tab4个空格&#xff09; Python语言中常见的…

Python、C++ lambda 表达式

Python、C lambda 表达式 lambda函数简介 匿名函数lambda&#xff1a;是指一类无需定义标识符&#xff08;函数名&#xff09;的函数或子程序。所谓匿名函数&#xff0c;通俗地说就是没有名字的函数&#xff0c;lambda函数没有名字&#xff0c;是一种简单的、在同一行中定义函…

python 学习2 /输入/ 输出 /列表 /字典

python基础学习第二天 输入输出 xinput("输入内容") print(x)input输出&#xff1a; eval :去掉字符串外围的引号&#xff0c;按照python的语法执行内容 aeval(12) print(a)eval输出样式&#xff1a; 列表 建立&#xff0c;添加&#xff0c;插入&#xff0c;删去…

Linux、Mac 命令行快捷键

Linux、Mac 命令行快捷键 Linux 命令行编辑快捷键&#xff0c;参考了好多个&#xff0c;应该算是比较全的了&#xff0c;Linux 和 Mac 的都有&#xff0c;笔者本人比较常用的也已经红色标出来了&#xff0c;如有错误或遗漏&#xff0c;欢迎留言指出。 光标移动及编辑&#xff…

Python 命令行传参

Python 命令行传参 说到 python 命令行传参&#xff0c;可能大部分人的第一反应就是用 argparse。的确&#xff0c;argparse 在我们需要指定多个预设的参数&#xff08;如深度学习中指定模型的超参数等&#xff09;时&#xff0c;是非常有用的。但是如果有时我们只需要一个参数…

快速排序 C++

快速排序 C 本文图示借鉴自清华大学邓俊辉老师数据结构课程。 快速排序的思想 快速排序是分治思想的典型应用。该排序算法可以原地实现&#xff0c;即空间复杂度为 O(1)O(1)O(1)&#xff0c;而时间复杂度为 O(nlogn)O(nlogn)O(nlogn) 。 算法将待排序的序列 SSS 分为两个子…

Linux命令行下感叹号的几个用法

Linux命令行下 " ! " 的几个用法 ! 在大多数编程语言中表示取反的意思&#xff0c;但是在命令行中&#xff0c;他还有一些其他的神奇用法。熟练掌握这些用法&#xff0c;可以大大提高我们日常命令行操作的效率。 1 执行历史命令 !! ! 在命令行中可以用来执行历史…

三地址码简介

三地址码简介 三地址码&#xff08;Three Address Code&#xff09;是一种最常用的中间语言&#xff0c;编译器可以通过它来改进代码转换效率。每个三地址码指令&#xff0c;都可以被分解为一个四元组&#xff08;4-tuple&#xff09;的形式&#xff1a;&#xff08;运算符&am…

llvm与gcc

llvm与gcc llvm 是一个编译器&#xff0c;也是一个编译器架构&#xff0c;是一系列编译工具&#xff0c;也是一个编译器工具链&#xff0c;开源 C11 实现。 gcc 相对于 clang 的优势&#xff1a; gcc 支持更过语言前端&#xff0c;如 Java, Ada, FORTRAN, Go等gcc 支持更多地 …

攻防世界web新手区解题 view_source / robots / backup

1**. view_source** 题目描述&#xff1a;X老师让小宁同学查看一个网页的源代码&#xff0c;但小宁同学发现鼠标右键好像不管用了。 f12查看源码即可发现flag 2. robots 题目描述&#xff1a;X老师上课讲了Robots协议&#xff0c;小宁同学却上课打了瞌睡&#xff0c;赶紧来教教…

python参数传递*args和**kwargs

python参数传递*args和**kwargs 和* 实际上真正的Python参数传递语法是 * 和 ** 。*args 和 **kwargs 只是一种约定俗成的编程实践。我们也可以写成 *vars 和 **kvars 。就如同其他常规变量的命名一样&#xff0c; args 和 kwargs 只是一种习惯的名称。 *args 和 **kwargs 一…

听GPT 讲Rust源代码--src/tools(25)

File: rust/src/tools/clippy/clippy_lints/src/methods/suspicious_command_arg_space.rs 在Rust源代码中&#xff0c;suspicious_command_arg_space.rs文件位于clippy_lints工具包的methods目录下&#xff0c;用于实现Clippy lint SUSPICIOUS_COMMAND_ARG_SPACE。 Clippy是Ru…

Java一次编译,到处运行是如何实现的

Java一次编译&#xff0c;到处运行是如何实现的 转自&#xff1a;https://cloud.tencent.com/developer/article/1415194 &#xff08;排版微调&#xff09; JAVA编译运行总览 Java是一种高级语言&#xff0c;要让计算机执行你撰写的Java程序&#xff0c;也得通过编译程序的…

JIT(动态编译)和AOT(静态编译)编译技术比较

JIT&#xff08;动态编译&#xff09;和AOT&#xff08;静态编译&#xff09;编译技术比较 转自&#xff1a;https://www.cnblogs.com/tinytiny/p/3200448.html Java 应用程序的性能经常成为开发社区中的讨论热点。因为该语言的设计初衷是使用解释的方式支持应用程序的可移植…

python解释器

python解释器 计算机编程语言 本部分参考自&#xff1a;https://zhuanlan.zhihu.com/p/141212114 从计算机编程语言说起&#xff0c;它主要分为三类&#xff1a;机器语言、汇编语言、高级语言。 机器语言是一种计算机可以直接识别并执行的二进制指令集。由于其可以直接交给…

编译型语言与解释型语言

编译型语言与解释型语言 首先要说明&#xff0c;编译型语言与解释型语言这种分类方法是不科学的&#xff0c;或者说已经过时了&#xff0c;但是这种称呼大抵还是能够让人明白我们将要讨论的是什么东西。 文中所列参考是笔者认为比较有帮助的一些扩展阅读内容。 首先贴一个很形…