基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图

基于torch_dispatch机制生成Megatron-DeepSpeed调用关系图

  • 一.局部效果图
  • 二.运行训练过程,拦截算子,生成调用关系信息
  • 三.可视化,生成SVG图像

想知道Megatron-DeepSpeed训练过程中各模块之间的调用关系。torch_dispatch机制可以拦截算子,inspect又能获取到调用栈(文件,类名,函数,行号).基于这些信息可以生成调用关系,最后用graphviz生成SVG图像。该思路也可以用来画其它pytorch工程的调用关系图

1.为了减少图像宽度,一行显示一级文件路径
2.没有显示具体的ATen算子。因为边太乱

一.局部效果图

在这里插入图片描述

二.运行训练过程,拦截算子,生成调用关系信息

# 前面构建模型的代码省略
from torch.utils._python_dispatch import TorchDispatchMode
import inspect
from dataclasses import dataclass
from typing import Any
import pickle@dataclass
class _ProfilerState:cls: Anyobject: Any = Noneclass TorchDumpDispatchMode(TorchDispatchMode):def __init__(self,parent):super().__init__()self.parent=parent        self.global_index=0        self.nodes=set()self.edges=set()def __del__(self):self.rank = torch.distributed.get_rank()graph={"nodes":self.nodes,"edges":self.edges}with open(f"call_graph_{self.rank}.pkl","wb") as f:pickle.dump(graph,f)def is_keep(self,node):# if node.function.find("wrapper")>=0:#     return False# if node.function.find("_call_impl")>=0:#     return Falsereturn Truedef __torch_dispatch__(self, func, types, args=(), kwargs=None):self.global_index+=1self.rank = torch.distributed.get_rank() func_packet = func._overloadpacket       if kwargs is None:kwargs = {}if self.rank==0:stacks=[i for i in inspect.stack() if self.is_keep(i)]stacks_sz=len(stacks)for idx in range(stacks_sz-1,1,-1):if "self" in stacks[idx].frame.f_locals:class_name = stacks[idx].frame.f_locals["self"].__class__.__name__else:class_name=""this_node=f"{stacks[idx].filename}:[{class_name}]:{stacks[idx].function}"if "self" in stacks[idx-1].frame.f_locals:class_name = stacks[idx-1].frame.f_locals["self"].__class__.__name__else:class_name=""                                    next_node=f"{stacks[idx-1].filename}:[{class_name}]:{stacks[idx-1].function}"self.nodes.add(this_node)self.nodes.add(next_node)self.edges.add(f"{this_node}->{next_node}")# if stacks_sz>1:#     if "self" in stacks[1].frame.f_locals:#         class_name = stacks[1].frame.f_locals["self"].__class__.__name__#     else:#         class_name=""                #     this_node=f"{stacks[1].filename}:[{class_name}]:{stacks[1].function}"#     next_node=f"{func_packet.__name__}"#     self.nodes.add(this_node)   #     self.nodes.add(next_node)            #     self.edges.add(f"{this_node}->{next_node}")ret= func(*args, **kwargs)return retclass TorchDumper:_CURRENT_Dumper = Nonedef __init__(self,schedule: Any):self.p= _ProfilerState(schedule)def __enter__(self):assert TorchDumper._CURRENT_Dumper is NoneTorchDumper._CURRENT_Dumper = selfif self.p.object is None:o = self.p.cls(self)o.__enter__()self.p.object = oelse:self.p.object.step()return selfdef __exit__(self, exc_type, exc_val, exc_tb):TorchDumper._CURRENT_Dumper = Noneif self.p.object is not None:self.p.object.__exit__(exc_type, exc_val, exc_tb)del self.p.object  #序列化保存def main():with TorchDumper(TorchDumpDispatchMode):#训练入口pretrain(train_valid_test_datasets_provider,model_provider,forward_step,extra_args_provider=llama_argument_handler,args_defaults={"tokenizer_type": "GPT2BPETokenizer"},)if __name__ == "__main__":main()

三.可视化,生成SVG图像

# coding=utf-8import os
from graphviz import Digraph,Graph
import pickle
import random
from distinctipy import distinctipydef generate_colors(N):'''生成N种有区别度的颜色'''result=[]for red, green, blue in distinctipy.get_colors(N):result.append("#{:02X}{:02X}{:02X}".format(int(red*255), int(green*255), int(blue*255)))return resultdef replace_name(name):'''修改节点名字(缩短,添加换行)'''if name.find("__torch_dispatch__")>=0:return Nonename=name.replace("/home/user/Megatron-DeepSpeed/","")name=name.replace("/home/anaconda3/envs/dev/lib/python3.10/site-packages/","")name=name.replace("/home/user/deepspeed/","")name=name.replace("/home/anaconda3/envs/dev/","")name=name.replace("/",r"\n")name=name.replace(":",r"\n")return name# 1.加载HOOK生成的调用关系文件
rank=0
with open(f"call_graph_{rank}.pkl","rb") as f:data=pickle.load(f)# 2.构建图,设置属性
dot = Digraph()
dot.node_attr = {"shape": "plaintext"}
dot.attr('graph', layout='dot')
dot.graph_attr.update(sep='4.0', ratio='compress')node_desc_id_map={}  #节点名与描述的关系映射表
src_node_color={}    #节点颜色映射表(同一个节点输出的边颜色一样)colors = generate_colors(10)
colors_sz=len(colors)fontsize="16"        #节点字体大小
penwidth="2.0"       #边宽度# 3.添加节点
for idx,v in enumerate(data["nodes"]):v=replace_name(v)if v is None:continuenode_desc_id_map[v]=f"{idx}"if v.find("megatron")>=0:dot.node(f"{idx}",v,style='filled',color='#73FBFD',fontsize=fontsize)elif v.find("deepspeed")>=0:dot.node(f"{idx}",v,style='filled',color='#FA8D89',fontsize=fontsize)else:dot.node(f"{idx}",v,style='filled',color='#C0C0C0',fontsize=fontsize)src_node_color[v]=colors[idx%colors_sz]# 4.添加边
for edge in data["edges"]:from_node,to_node=edge.split("->")from_node=replace_name(from_node)to_node=replace_name(to_node)if all([from_node,to_node]):color=src_node_color[from_node]dot.edge(node_desc_id_map[from_node], node_desc_id_map[to_node],color=color,penwidth=penwidth)# 5.保存SVG
save_path='megatron_deepspeed_callgraph'
dot.render(save_path,format='svg', view=False)# 6.修改背景色为灰色
import xml.etree.ElementTree as ET
svg_tree = ET.parse(f'{save_path}.svg')
root = svg_tree.getroot()
element = root.find(".//{http://www.w3.org/2000/svg}polygon")
element.set('fill', 'gray')
svg_tree.write(f'{save_path}.svg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/9745.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

笔记本电脑怎么查看硬盘型号?无需额外软件,五招让你轻松掌握

随着科技的进步,笔记本电脑已经成为我们日常生活和工作中不可或缺的工具。而在选购或维护笔记本电脑时,了解硬盘的型号和性能是至关重要的。本文以windows10系统为例,将向您介绍几招,帮助您轻松掌握查看笔记本电脑硬盘型号的方法。…

适合年轻人的恋爱交友脱单软件有哪些?中国十大社交软件排行榜分享

交友始祖:Tinder 一直很受欢迎,可以向上扫给 super like (每日有一次免费机会)。如果双方互相 like,代表配对成功,就可以开始聊天。另外,每日有 10 个 top picks 供选择,你可以免费选一位 主力编外&#xf…

博士阶段应该搞什么:-人才引进要求

目录 专利,高水平论文(一作),技能证书,职称,高端竞赛,科研成果奖 济宁学院

Java医院绩效考核系统源码maven+Visual Studio Code一体化人力资源saas平台系统源码

Java医院绩效考核系统源码mavenVisual Studio Code一体化人力资源saas平台系统源码 医院绩效解决方案包括医院绩效管理(BSC)、综合奖金核算(RBRVS),涵盖从绩效方案的咨询与定制、数据采集、绩效考核及反馈、绩效奖金核…

67万英语单词学习词典ACCESS\EXCEL数据库

这似乎是最多记录的英语单词学习词典,包含复数、过去分词等形式的单词。是一个针对想考级的人员辅助背单词学英语必备的数据,具体请自行查阅以下的相关截图。 有了数据才能想方设法做好产品,结合权威的记忆理论,充分调动用户的眼…

4.Spring Security重要接口

当什么都没有配置的时候,账号和密码是由spring security自定义生成的。在实际项目中账号和密码都是从数据库中查询出来的。所以要通过自定义逻辑控制认证逻辑。 UserDetailService 接口 1.创建类继承UsernamePasswordAuthenticationFilter,重写三个方法&#xff1…

Cocos creator实现《战机长空》关卡本地存储功能

Cocos creator实现《战机长空》关卡本地存储功能 Cocos creator在开放小游戏过程中,经常会出现设置关卡,这里记录一下关卡数据本地存储功能。 一、关卡设置数据 假如我们有关卡数据如下, let settings [ { level: 1, // 第1关 score: 0,…

判断大模型微调是否产生灾难性遗忘的实战方案

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

画出入学管理系统的顶层图和1层图

(学校作业) 题目如下: 某培训机构入学管理系统有报名、交费和就读等多项功能,下面是对其各项功能的说明: 1、报名:由报名处负责,需要在学员登记表上进行报名登记,需要查询课…

微软必应bing国内广告开户费用?如何开户投放?

当下搜索引擎广告无疑是企业触达潜在客户、提升品牌曝光度的重要途径之一,微软必应(Bing)作为全球第二大搜索引擎,尽管在国内市场份额上可能不敌某些本土巨头,但其独特的用户群体和国际影响力使其成为众多企业拓展市场…

【数据结构与算法】常见的排序算法

文章目录 排序的概念冒泡排序(Bubble Sort)插入排序(Insert Sort)选择排序(Select Sort)希尔排序(Shell Sort)写法一写法二 快速排序(Quick Sort)hoare版本&a…

前端Vue uView 组件<u-search> 自定义右侧搜索按钮样式

前言 uView 文档的效果不是ui设计的样式 需要重新编辑 原效果 ui设计效果 解决方案 设置里说明的需要传一个样式对象 这个对象 需要写在 script 标签里面 这里需要遵循驼峰命名 比如font-size 改为 fontSize lineHeight和textAlign为水平锤子居中效果 searchStyle: {ba…

Box86源码解读记录

1. 背景说明 Github地址:https://github.com/ptitSeb/box86 官方推荐的视频教程:Box86/Box64视频教程网盘 2. 程序执行主体图 Box86版本: Box86 with Dynarec v0.3.4 主函数会执行一大堆的初始化工作,包括但不限于:BOX上下文 …

【ARMv8/v9 系统寄存器 4 -- ARMv8 通用寄存器详细介绍】

文章目录 ARMv8 通用寄存器通用寄存器X30 寄存器和链接寄存器(LR)程序计数器(PC)ARMv8 X30和PC之间的关系小结 ARMv8 通用寄存器 在ARMv9架构中(这也适用于ARMv8,因为ARMv9是其进化版本)&#…

腾讯云coding代码托管平台配置问题公钥拉取失败提示 Permission denied(publickey)

前言 最近在学校有个课设多人开发一个游戏,要团队协作,选用了腾讯云的coding作为代码管理仓库,但在配置的时候遇到了一些问题,相比于github,发现腾讯的coding更难用,,,这里记录一下…

如何设计与管理一个前端项目

目录 前端项目设计 前端项目搭建 洞察项目瓶颈 方案调研与选型对比 前端项目管理 合理的分工排期 风险把控 及时反馈与复盘 结束语 如果说基础知识的掌握是起跑线,那么使大家之间拉开差距的更多是前端项目开发经验和技能。对于一个项目来说,从框…

【Android Studio】【NCNN】YOLOV8安卓部署

目录 下载Android Studio 克隆安卓项目 关于自训练模型闪退问题 下载Android Studio 下载Android Studio,配置安卓开发环境,这个过程比较漫长。 安装cmake,注意安装的是cmake3.10版本。 根据手机安卓版本选择相应的安卓版本&#xff0c…

彻底解决python的pip install xxx报错(文末附所有依赖文件)

今天安装pip install django又报错了: C:\Users\Administrator>pip install django WARNING: Ignoring invalid distribution -ip (d:\soft\python\python38\lib\site-pac kages) Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting djan…

论文阅读:The Unreasonable Ineffectiveness of the Deeper Layers 层剪枝与模型嫁接的“双生花”

作者实证研究了针对流行的开放式预训练 LLM 系列的简单层修剪策略,发现在不同的 QA 基准上,直到去掉一大部分(最多一半)层(Transformer 架构)后,性能的下降才会降到最低。为了修剪这些模型&…

探索 IPv6 协议:互联网的新一代寻址

目录 一.概述 IPv4 的问题和 IPv6 的新特性 IPv6 协议体系 二.IPv6 寻址架构:巨大的地址空间与灵活的寻址模式 IPv6 寻址概述 地址表示方法 地址前缀与地址类型标识 单播地址 任播地址 多播地址 特殊的 IPv6 地址 IPv6 主机与路由器寻址 地址分配 三.I…