graphviz官方参考链接:
http://www.graphviz.org/documentation/
https://graphviz.readthedocs.io/en/stable/index.html
文章目录
- 需求描述
- 环境配置
- 实现思路
- 代码实现
需求描述
根据各模块之间的传参关系绘制出数据流,如下图所示:
并且生成对应的graphviz代码:
digraph my_graph {Input [fillcolor=gray70 shape=box style=filled]Output [fillcolor=gray70 shape=box style=filled]NodeANodeBNodeCInput -> NodeA [label=0]Input -> NodeA [label=1]NodeA -> NodeB [label=0]NodeA -> NodeC [label=1]NodeB -> Output [label=0]NodeC -> Output [label=0]
}
环境配置
- 安装Python中需要使用的
graphviz
包:
pip install graphviz
- 安装
graphviz
工具(可选,如果不安装无法直接使用Python的graphviz
包导出图片),例如ubuntu系统安装指令如下,其他系统可参考官方文档https://www.graphviz.org/download/:
sudo apt install graphviz
- VSCODE安装
Graphviz Interactive Preview
插件(可选,如果使用vscode开发建议安装此插件,通过此插件可以直接可视化graphviz代码,并保存图片)
实现思路
实现一个Node基类,所有的模块实现都继承自该基类。再实现一个Message基类,模块之间传递的数据都继承自该基类。然后在数据传递过程中记录流经的每个模块的名称以及数据的传递方向即可绘制出想要的数据流。
代码实现
下面给出了一个简易的实现方式:
import os
from graphviz import Digraph__graph_dict__ = {}class Message:def __init__(self, node_name: str, idx: int):self.node_name = node_nameself.idx = idxclass EdgeInfo:def __init__(self, start_node_name: str, end_node_name: str, label: str) -> None:self.start_node_name = start_node_nameself.end_node_name = end_node_nameself.label = labeldef __str__(self):return f'{self.start_node_name} -> {self.end_node_name} [label="{self.label}"];'class Node:input_num: intoutput_num: intnode_name: strdef __call__(self, *args):global __graph_dict__assert len(args) == self.input_numif self.node_name not in __graph_dict__:__graph_dict__[self.node_name] = []for input_ in args:__graph_dict__[input_.node_name].append(EdgeInfo(input_.node_name,self.node_name,str(input_.idx)))res = tuple(Message(self.node_name, i) for i in range(self.output_num))if self.output_num == 1:return res[0]return resdef export_graphviz(graph, num_input: int, save_path: str):base_name = os.path.basename(save_path)name, _ = base_name.split(".")global __graph_dict____graph_dict__.clear()__graph_dict__.update({"Input": [], "Output": []})# infer and collect flow infoinput_args = tuple(Message("Input", i) for i in range(num_input))outputs = graph(*input_args)for ouput_ in outputs:if ouput_.node_name not in __graph_dict__:__graph_dict__[ouput_.node_name] = []__graph_dict__[ouput_.node_name].append(EdgeInfo(ouput_.node_name,"Output",str(ouput_.idx)))# create graph codedigraph = Digraph(name=name, format="jpg")# add nodeskeys = list(__graph_dict__.keys())for k in keys:if k in ["Input", "Output"]:digraph.node(k, **{"shape": "box", "style": "filled", "fillcolor": "gray70"})else:digraph.node(k)# add edgesfor k in keys:for edge_info in __graph_dict__[k]:digraph.edge(edge_info.start_node_name,edge_info.end_node_name,edge_info.label)# print digraph codeprint(digraph.source)# export gv and jpg filetry:digraph.render(directory=os.path.dirname(save_path))except Exception as e:print(f"export digraph failed, {e}")class NodeA(Node):def __init__(self):self.input_num = 2self.output_num = 2self.node_name = "NodeA"class NodeB(Node):def __init__(self):self.input_num = 1self.output_num = 1self.node_name = "NodeB"class NodeC(Node):def __init__(self):self.input_num = 1self.output_num = 1self.node_name = "NodeC"class Graph:def __init__(self):self.node_a = NodeA()self.node_b = NodeB()self.node_c = NodeC()def __call__(self, x0, x1):y0, y1 = self.node_a(x0, x1)z0 = self.node_b(y0)z1 = self.node_c(y1)return z0, z1if __name__ == "__main__":graph = Graph()export_graphviz(graph, num_input=2, save_path="./my_graph.gv")
执行上述代码后会生成my_graph.gv
以及my_graph.gv.jpg
两个文件(如果没有安装graphviz工具是不会生成的),其中my_graph.gv
是graphviz的代码形式,my_graph.gv.jpg
是可视化后的结果。