pytorch为自己的extension backend添加profiler功能
- 1.参考文档
- 2.your-extension-for-pytorch需要增加的代码
- 3.pytorch demo及如何调整chrome trace json文件
- 4.[可视化](https://ui.perfetto.dev/)
本文演示了pytorch如何为自己的extension backend添加profiler功能
背景介绍
- 1.没有CNLight、Profiling AscendCL API、ROC Trace之类Profing功能,无法trace runtime,drive,kernel,也无法获取设备的metrics
- 2.只有event功能,可以统计kernel耗时
- 3.本文只是一种尝试,并不合理.
- 4.torch原生的profiler框架,依赖kineto,kineto目前支持CUPTI和ROC Tracer,如果不修改torch源码,第三方设备不方便使用
- 5.华为、寒武纪、habana都是采用torch.profile的接口形式及at::addThreadLocalCallback功能,但不依赖torch.profiler框架
profing原始数据都是私有格式,并且修改TensorBoard的插件,可于可视化
实施步骤
- 1.调用torch::profiler::impl::registerPrivateUse1Methods注册
- 2.因为没有correlation ID去关联host api与kernel,因此export_chrome_trace出来的数据没有kernel信息
- 3.获取prof.profiler.function_events里的数据,通过{ev.name}{ev.id}{ev.thread}拼成uuid与上面chrome trace中的events关联
- 4.因为只有一个stream。可以根据Host lanuch时间、kernel耗时、launch latency(先验),推断出kernel的开始、结束时间,并用flow event进行关联(虽然并不准确)
- 5.最后把kernel event以及flow event追加到chrome trace中
1.参考文档
- ROC Tracer
- CUPTI
- 华为profiler_npu
- Profiling AscendCL API
- 寒武纪profile_mlu
- 寒武纪CNLight
- habana torch
- intel_extension_for_pytorch
- Make the kineto extendable for other runtime than CUD
- pytorch_open_registration_example
- rename_privateuse1_backend
- Trace Event Format
2.your-extension-for-pytorch需要增加的代码
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>
#include <c10/util/irange.h>
#include <torch/csrc/profiler/stubs/base.h>
#include <torch/csrc/profiler/util.h>using torch::profiler::impl::ProfilerStubs;
using torch::profiler::impl::ProfilerVoidEventStub;namespace torch {
namespace profiler {
namespace impl {struct NPUMethods : public ProfilerStubs {void record(int* device,ProfilerVoidEventStub* event,int64_t* cpu_ns) const override{if (device) {TORCH_CHECK(xpurtGetDevice((uint32_t*)device));}xpurtEvent_t xpurt_event;TORCH_CHECK(xpurtEventCreate(&xpurt_event));*event = std::shared_ptr<void>(xpurt_event, [](xpurtEvent_t ptr) {TORCH_CHECK(xpurtEventDestroy(ptr));});auto xpurt_stream = c10::xpu::getCurrentxpuStream(vastai::get_device());if (cpu_ns) {*cpu_ns = getTime();}TORCH_CHECK(xpurtEventRecord(xpurt_event, xpurt_stream)); } float elapsed(const ProfilerVoidEventStub* event1_,const ProfilerVoidEventStub* event2_) const override{auto event1 = static_cast<xpurtEvent_t>(event1_->get());TORCH_CHECK(xpurtEventSynchronize(event1));auto event2 = static_cast<xpurtEvent_t>(event2_->get());TORCH_CHECK(xpurtEventSynchronize(event2));int64_t time_ms = 0;TORCH_CHECK(xpurtEventElapsedTime(&time_ms, event1, event2));return time_ms*1.0;} void onEachDevice(std::function<void(int)> op) const override{uint32_t device = 0;TORCH_CHECK(xpurtGetDevice(&device));op(device);} void synchronize() const override { } bool enabled() const override {return true;} void mark(const char*name) const override { } void rangePush(const char*name) const override { } void rangePop() const override {}
};struct RegisterNPUMethods {RegisterNPUMethods(){static NPUMethods methods;torch::profiler::impl::registerPrivateUse1Methods(&methods);}
};
RegisterNPUMethods reg;
}}}
3.pytorch demo及如何调整chrome trace json文件
import time
import torchvision.models as models
from torch import nn
import torch.nn.functional as F
import copy
import math
import torch
from torch.profiler import profile
import json
import tqdmdef is_valid_kernel(name,duration,valid_kernel_threshold=100):'''通过算子的名字和耗时判断是否是Device Kernel'''invalid_kernels=["aten::view","aten::reshape","aten::t","aten::empty","aten::transpose","aten::as_strided","aten::item","aten::_local_scalar_dense","aten::result_type","aten::_unsafe_view","aten::expand"]for k in invalid_kernels:if name.find(k)>=0:return Falseif duration<valid_kernel_threshold:return False return Truedef filter_ev(ev):'''过滤Kernel'''if 'args' in ev and "External id" in ev['args']:return Truereturn Falsedef get_uuid(ev,tid_map):return f"{ev['name']}_{ev['args']['External id']}_{tid_map[ev['tid']]}"def get_valid_kernels(traceEvents,kernel_event,tid_map):valid_kernels=[]device_memory_usage=0for ev in traceEvents:if filter_ev(ev):uuid=get_uuid(ev,tid_map)if uuid not in kernel_event:continueduration=kernel_event[uuid]['kernel_time']kernel_name=ev['name']if kernel_event[uuid]['device_memory_usage']>0:device_memory_usage=kernel_event[uuid]['device_memory_usage']if is_valid_kernel(kernel_name,duration):launch_beg=ev['ts']launch_end=ev['ts']+ev['dur'] valid_kernels.append({"name":kernel_name,"launch_beg":launch_beg,"launch_end":launch_end,"kernel_duration":duration,"host_pid":ev['pid'],"host_tid":ev['tid'],"device_memory_usage":device_memory_usage,"is_leaf_kernel":False})return sorted(valid_kernels,key=lambda x:x['launch_beg'])def is_leaf_kernel(kernel,valid_kernels):'''判断是否是叶子Kernel'''ret=Truefor k in valid_kernels:if k['is_leaf_kernel']:continue#自己的时间跨度内还有别的Kernelif k['launch_beg']>kernel['launch_beg'] and k['launch_end']<kernel['launch_end']:ret=Falsebreakreturn retdef create_tid_map(traceEvents):tids=set()for ev in traceEvents:if filter_ev(ev):tid=ev['tid']tids.add(tid)tid_map={}tids=sorted(tids,reverse=False)for i,v in enumerate(tids):tid_map[v]=i+1return tid_mapdef merge_prof_timeline(prof_json,kernel_event_json,output_json):kernel_lanuch_latency=0with open(prof_json,'r',encoding='utf-8') as f:prof = json.load(f)with open(kernel_event_json,'r',encoding='utf-8') as f:kernel_event = json.load(f) traceEvents=prof['traceEvents']tid_map=create_tid_map(traceEvents)print(tid_map)#获取所有kernelvalid_kernels=get_valid_kernels(traceEvents,kernel_event,tid_map)print(len(valid_kernels))#筛出所有会在device上执行的kernelon_device_kernels=[]for kernel in tqdm.tqdm(valid_kernels):if is_leaf_kernel(kernel,valid_kernels):on_device_kernels.append(kernel)kernel_start_offset=0kernel_index=0for kernel in on_device_kernels:name=kernel['name']kernel_duration=kernel["kernel_duration"]lanuch_time=kernel["launch_beg"]host_pid=kernel['host_pid']host_tid=kernel['host_tid']device_memory_usage=kernel['device_memory_usage']if kernel_start_offset==0:kernel_start_offset=lanuch_time+kernel_start_offsetif lanuch_time>kernel_start_offset: #kernel 队列空闲kernel_start_offset=lanuch_time#增加kernel事件traceEvents.append({"ph": "X", "cat": "device_kernel", "name":name, "pid": 10, "tid": 10,"ts": kernel_start_offset, "dur": kernel_duration})#增加内存事件traceEvents.append({"ph": "C", "cat": "memory", "name":"memory", "pid": 11, "tid": 11,"ts": lanuch_time, "args": {"value":device_memory_usage}})#增加flow eventtraceEvents.append({"ph": "s", "id": kernel_index, "pid": host_pid, "tid": host_tid, "ts": lanuch_time,"cat": "ac2g", "name": "ac2g"})traceEvents.append({"ph": "f", "id": kernel_index, "pid": 10, "tid": 10,"ts": kernel_start_offset,"cat": "ac2g", "name": "ac2g", "bp": "e"})kernel_index+=1kernel_start_offset+=(kernel_duration+kernel_lanuch_latency)#保存最终的结果with open(output_json,'w',encoding='utf-8') as f:json.dump(prof, f,ensure_ascii=False,indent=4)def clones(module, N):return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])class ScaledDotProductAttention(nn.Module):def __init__(self):super(ScaledDotProductAttention, self).__init__()def forward(self,query, key, value, mask=None, dropout=None):d_k = query.size(-1)scores = query@key.transpose(-2,-1) / math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e20)p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)return p_attn@value, p_attnclass MultiHeadAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):super(MultiHeadAttention, self).__init__()assert d_model % h == 0self.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)self.attention = ScaledDotProductAttention()def forward(self, query, key, value, mask=None):if mask is not None:mask = mask.unsqueeze(1)nbatches = query.size(0)query=self.linears[0](query).view(nbatches, -1, self.h, self.d_k)query=query.transpose(1, 2)key=self.linears[1](key).view(nbatches, -1, self.h, self.d_k)key=key.transpose(1, 2)value=self.linears[2](value).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)x, self.attn = self.attention(query, key, value, mask=mask,dropout=self.dropout)x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)use_cuda=True
try:import torch_xpuimport torch_xpu.contrib.transfer_to_xputorch.xpu.set_device(0)torch.profiler.ProfilerActivity.PrivateUse1="xpu"use_cuda=False
except:passimport os
os.environ['LOCAL_RANK']="0"
os.environ['RANK']="0"
os.environ['WORLD_SIZE']="1"
os.environ['MASTER_ADDR']="localhost"
os.environ['MASTER_PORT']="6006"import torch.distributed as dist
dist.init_process_group(backend='vccl')
local_rank=int(os.environ['LOCAL_RANK'])
rank=torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
if not dist.is_available() or not dist.is_initialized():print("dist init error")cross_attn = MultiHeadAttention(h=8, d_model=64).half().cuda()
cross_attn.eval()
q1 = torch.ones((1, 50, 64),dtype=torch.float32).half().cuda()
k1 = q1.clone()
v1 = q1.clone()
out = cross_attn.forward(q1,k1,v1).sum()
torch.cuda.synchronize()activities=[torch.profiler.ProfilerActivity.CPU]
if use_cuda:activities.append(torch.profiler.ProfilerActivity.CUDA)with profile(activities=activities,schedule=torch.profiler.schedule(wait=1,warmup=1,active=3,repeat=1),record_shapes=True,with_stack=True,with_modules=True,with_flops=True,profile_memory=True,) as prof:for i in range(10):out = cross_attn.forward(q1,k1,v1).sum()prof.step()torch.cuda.synchronize()if not use_cuda:kernel_event={}for ev in prof.profiler.function_events:if ev.privateuse1_time>0:uuid=f"{ev.name}_{ev.id}_{ev.thread}"#print(uuid,ev.id,ev.name,ev.privateuse1_time,ev.time_range.start,ev.time_range.end-ev.time_range.start,ev.privateuse1_memory_usage)kernel_event[uuid]={"kernel_time":ev.privateuse1_time,"device_memory_usage":ev.privateuse1_memory_usage,"start_us":ev.time_range.start,"host_dur":ev.time_range.end-ev.time_range.start,"thread":ev.thread} import jsonwith open(f"kernel_event_{rank}.json",'w',encoding='utf-8') as f:json.dump(kernel_event, f,ensure_ascii=False,indent=4)prof.export_chrome_trace(f"prof_{rank}.json")merge_prof_timeline(f"prof_{rank}.json",f"kernel_event_{rank}.json",f"prof_{rank}.json")
else:#print(prof.key_averages().table(sort_by="self_cpu_time_total"))prof.export_chrome_trace(f"prof_{q1.device.type}.json")