Tensorflow源码解析3 -- TensorFlow核心对象 - Graph

1 Graph概述

计算图Graph是TensorFlow的核心对象,TensorFlow的运行流程基本都是围绕它进行的。包括图的构建、传递、剪枝、按worker分裂、按设备二次分裂、执行、注销等。因此理解计算图Graph对掌握TensorFlow运行尤为关键。
 

2 默认Graph

默认图替换

之前讲解Session的时候就说过,一个Session只能run一个Graph,但一个Graph可以运行在多个Session中。常见情况是,session会运行全局唯一的隐式的默认的Graph,operation也是注册到这个Graph中。

也可以显示创建Graph,并调用as_default()使他替换默认Graph。在该上下文管理器中创建的op都会注册到这个graph中。退出上下文管理器后,则恢复原来的默认graph。一般情况下,我们不用显式创建Graph,使用系统创建的那个默认Graph即可。

print tf.get_default_graph()with tf.Graph().as_default() as g:print tf.get_default_graph() is gprint tf.get_default_graph()print tf.get_default_graph()

输出如下

<tensorflow.python.framework.ops.Graph object at 0x106329fd0>
True
<tensorflow.python.framework.ops.Graph object at 0x18205cc0d0>
<tensorflow.python.framework.ops.Graph object at 0x10d025fd0>

由此可见,在上下文管理器中,当前线程的默认图被替换了,而退出上下文管理后,则恢复为了原来的默认图。

默认图管理

默认graph和默认session一样,也是线程作用域的。当前线程中,永远都有且仅有一个graph为默认图。TensorFlow同样通过栈来管理线程的默认graph。

@tf_export("Graph")
class Graph(object):# 替换线程默认图def as_default(self):return _default_graph_stack.get_controller(self)# 栈式管理,push pop@tf_contextlib.contextmanagerdef get_controller(self, default):try:context.context_stack.push(default.building_function, default.as_default)finally:context.context_stack.pop()

替换默认图采用了堆栈的管理方式,通过push pop操作进行管理。获取默认图的操作如下,通过默认graph栈_default_graph_stack来获取。

@tf_export("get_default_graph")
def get_default_graph():return _default_graph_stack.get_default()

下面来看_default_graph_stack的创建

_default_graph_stack = _DefaultGraphStack()
class _DefaultGraphStack(_DefaultStack):  def __init__(self):# 调用父类来创建super(_DefaultGraphStack, self).__init__()self._global_default_graph = Noneclass _DefaultStack(threading.local):def __init__(self):super(_DefaultStack, self).__init__()self._enforce_nesting = True# 和默认session栈一样,本质上也是一个listself.stack = []

_default_graph_stack的创建如上所示,最终和默认session栈一样,本质上也是一个list。
 

3 前端Graph数据结构

Graph数据结构

理解一个对象,先从它的数据结构开始。我们先来看Python前端中,Graph的数据结构。Graph主要的成员变量是Operation和Tensor。Operation是Graph的节点,它代表了运算算子。Tensor是Graph的边,它代表了运算数据。

@tf_export("Graph")
class Graph(object):def __init__(self):# 加线程锁,使得注册op时,不会有其他线程注册op到graph中,从而保证共享graph是线程安全的self._lock = threading.Lock()# op相关数据。# 为graph的每个op分配一个id,通过id可以快速索引到相关op。故创建了_nodes_by_id字典self._nodes_by_id = dict()  # GUARDED_BY(self._lock)self._next_id_counter = 0  # GUARDED_BY(self._lock)# 同时也可以通过name来快速索引op,故创建了_nodes_by_name字典self._nodes_by_name = dict()  # GUARDED_BY(self._lock)self._version = 0  # GUARDED_BY(self._lock)# tensor相关数据。# 处理tensor的placeholderself._handle_feeders = {}# 处理tensor的read操作self._handle_readers = {}# 处理tensor的move操作self._handle_movers = {}# 处理tensor的delete操作self._handle_deleters = {}

下面看graph如何添加op的,以及保证线程安全的。

  def _add_op(self, op):# graph被设置为final后,就是只读的了,不能添加op了。self._check_not_finalized()# 保证共享graph的线程安全with self._lock:# 将op以id和name分别构建字典,添加到_nodes_by_id和_nodes_by_name字典中,方便后续快速索引self._nodes_by_id[op._id] = opself._nodes_by_name[op.name] = opself._version = max(self._version, op._id)

GraphKeys 图分组

每个Operation节点都有一个特定的标签,从而实现节点的分类。相同标签的节点归为一类,放到同一个Collection中。标签是一个唯一的GraphKey,GraphKey被定义在类GraphKeys中,如下

@tf_export("GraphKeys")
class GraphKeys(object):GLOBAL_VARIABLES = "variables"QUEUE_RUNNERS = "queue_runners"SAVERS = "savers"WEIGHTS = "weights"BIASES = "biases"ACTIVATIONS = "activations"UPDATE_OPS = "update_ops"LOSSES = "losses"TRAIN_OP = "train_op"# 省略其他

name_scope 节点命名空间

使用name_scope对graph中的节点进行层次化管理,上下层之间通过斜杠分隔。

# graph节点命名空间
g = tf.get_default_graph()
with g.name_scope("scope1"):c = tf.constant("hello, world", name="c")print c.op.namewith g.name_scope("scope2"):c = tf.constant("hello, world", name="c")print c.op.name

输出如下

scope1/c
scope1/scope2/c  # 内层的scope会继承外层的,类似于栈,形成层次化管理

 

4 后端Graph数据结构

Graph

先来看graph.h文件中的Graph类的定义,只看关键代码

 class Graph {private:// 所有已知的op计算函数的注册表FunctionLibraryDefinition ops_;// GraphDef版本号const std::unique_ptr<VersionDef> versions_;// 节点node列表,通过id来访问std::vector<Node*> nodes_;// node个数int64 num_nodes_ = 0;// 边edge列表,通过id来访问std::vector<Edge*> edges_;// graph中非空edge的数目int num_edges_ = 0;// 已分配了内存,但还没使用的node和edgestd::vector<Node*> free_nodes_;std::vector<Edge*> free_edges_;}

后端中的Graph主要成员也是节点node和边edge。节点node为计算算子Operation,边为算子所需要的数据,或者代表节点间的依赖关系。这一点和Python中的定义相似。边Edge的持有它的源节点和目标节点的指针,从而将两个节点连接起来。下面看Edge类的定义。

Edge

class Edge {private:Edge() {}friend class EdgeSetTest;friend class Graph;// 源节点, 边的数据就来源于源节点的计算。源节点是边的生产者Node* src_;// 目标节点,边的数据提供给目标节点进行计算。目标节点是边的消费者Node* dst_;// 边id,也就是边的标识符int id_;// 表示当前边为源节点的第src_output_条边。源节点可能会有多条输出边int src_output_;// 表示当前边为目标节点的第dst_input_条边。目标节点可能会有多条输入边。int dst_input_;
};

Edge既可以承载tensor数据,提供给节点Operation进行运算,也可以用来表示节点之间有依赖关系。对于表示节点依赖的边,其src_output_, dst_input_均为-1,此时边不承载任何数据。

下面来看Node类的定义。

Node

class Node {public:// NodeDef,节点算子Operation的信息,比如op分配到哪个设备上了,op的名字等,运行时有可能变化。const NodeDef& def() const;// OpDef, 节点算子Operation的元数据,不会变的。比如Operation的入参列表,出参列表等const OpDef& op_def() const;private:// 输入边,传递数据给节点。可能有多条EdgeSet in_edges_;// 输出边,节点计算后得到的数据。可能有多条EdgeSet out_edges_;
}

节点Node中包含的主要数据有输入边和输出边的集合,从而能够由Node找到跟他关联的所有边。Node中还包含NodeDef和OpDef两个成员。NodeDef表示节点算子的信息,运行时可能会变,创建Node时会new一个NodeDef对象。OpDef表示节点算子的元信息,运行时不会变,创建Node时不需要new OpDef,只需要从OpDef仓库中取出即可。因为元信息是确定的,比如Operation的入参个数等。

由Node和Edge,即可以组成图Graph,通过任何节点和任何边,都可以遍历完整图。Graph执行计算时,按照拓扑结构,依次执行每个Node的op计算,最终即可得到输出结果。入度为0的节点,也就是依赖数据已经准备好的节点,可以并发执行,从而提高运行效率。

系统中存在默认的Graph,初始化Graph时,会添加一个Source节点和Sink节点。Source表示Graph的起始节点,Sink为终止节点。Source的id为0,Sink的id为1,其他节点id均大于1.
 

5 Graph运行时生命周期

Graph是TensorFlow的核心对象,TensorFlow的运行均是围绕Graph进行的。运行时Graph大致经过了以下阶段

  1. 图构建:client端用户将创建的节点注册到Graph中,一般不需要显示创建Graph,使用系统创建的默认的即可。
  2. 图发送:client通过session.run()执行运行时,将构建好的整图序列化为GraphDef后,传递给master
  3. 图剪枝:master先反序列化拿到Graph,然后根据session.run()传递的fetches和feeds列表,反向遍历全图full graph,实施剪枝,得到最小依赖子图。
  4. 图分裂:master将最小子图分裂为多个Graph Partition,并注册到多个worker上。一个worker对应一个Graph Partition。
  5. 图二次分裂:worker根据当前可用硬件资源,如CPU GPU,将Graph Partition按照op算子设备约束规范(例如tf.device(’/cpu:0’),二次分裂到不同设备上。每个计算设备对应一个Graph Partition。
  6. 图运行:对于每一个计算设备,worker依照op在kernel中的实现,完成op的运算。设备间数据通信可以使用send/recv节点,而worker间通信,则使用GRPC或RDMA协议。

这些阶段根据TensorFlow运行时的不同,会进行不同的处理。运行时有两种,本地运行时和分布式运行时。故Graph生命周期到后面分析本地运行时和分布式运行时的时候,再详细讲解。


原文链接
本文为云栖社区原创内容,未经允许不得转载。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/519730.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用AI说再见!“辣眼睛”的买家秀

阿里妹导读&#xff1a;提起买家秀和卖家秀&#xff0c;相信大家脑中会立刻浮现出诸多画面。同一件衣服在不同人、光线、角度下&#xff0c;会呈现完全不同的状态。运营小二需从大量的买家秀中挑选出高质量的图片。如果单纯靠人工来完成&#xff0c;工作量过于巨大。下面&#…

mysql相关

查询指定时间相近的记录SELECT * abs(UNIX_TIMESTAMP(t.create_time)-UNIX_TIMESTAMP(2020-06-04 10:10:39)) as min from t_video_history t WHERE t.ip 10.0.5.124 GROUP BY min asc limit 1查询一个月之前的数据 select * from t_video_history t where date_format(t.cr…

云+X案例展 | 电商零售类:WakeData助力叁拾加数字化变革

本案例由WakeData投递并参与评选&#xff0c;CSDN云计算独家全网首发&#xff1b;更多关于【云X 案例征集】的相关信息&#xff0c;点击了解详情丨挖掘展现更多优秀案例&#xff0c;为不同行业领域带来启迪&#xff0c;进而推动整个“云行业”的健康发展。在新零售时代下&#…

linux环境安装Kafka最新版本 jdk1.8

文章目录一、环境分布二、实战1. kafka下载2. 解压3. 配置4. 编写启动脚本5. 编写关闭脚本6. 赋予脚本可执行权限7. 脚本使用案例一、环境分布 软件版本jdk1.8kafkakafka_2.13-2.5.0 二、实战 kafka官网地址&#xff1a; http://kafka.apache.org/downloads 1. kafka下载 …

BeanUtils对象之间的复制

1、maven依赖<dependency><groupId>commons-beanutils</groupId><artifactId>commons-beanutils</artifactId><version>1.9.4</version> </dependency>2、常用API // 把orig对象copy到dest对象中. public void copyProperties…

基于泛型编程的序列化实现方法

写在前面 序列化是一个转储-恢复的操作过程&#xff0c;即支持将一个对象转储到临时缓冲或者永久文件中和恢复临时缓冲或者永久文件中的内容到一个对象中等操作&#xff0c;其目的是可以在不同的应用程序之间共享和传输数据&#xff0c;以达到跨应用程序、跨语言和跨平台的解耦…

两大硬件设计被OCP官方接受,腾讯成国内互联网公司第一家

刚刚获悉&#xff0c;腾讯在光网络设备和数据中心领域的两大硬件自研设计“OPC-4”和“TMDC”顺利通过OCP&#xff08;Open Compute Project&#xff09;审核并正式接受为官方开源贡献。这是腾讯在硬件领域的开源设计首次被OCP官方正式认可&#xff0c;同时&#xff0c;腾讯也成…

java 集成kafka单机版 适配jdk1.8

文章目录一、环境分布1. 版本声明2. 依赖2. case测试2. case2测试一、环境分布 1. 版本声明 linux服务器软件版本jdk1.8kafkakafka_2.13-2.4.0注&#xff1a;建议版本和应用依赖的客户端版本依赖保持一致&#xff0c;如果需要更高版本&#xff0c;可以尝试 但是有一点&#x…

微服务架构下,解决数据一致性问题的实践

随着业务的快速发展&#xff0c;应用单体架构暴露出代码可维护性差、容错率低、测试难度大和敏捷交付能力差等诸多问题&#xff0c;微服务应运而生。微服务的诞生一方面解决了上述问题&#xff0c;但是另一方面却引入新的问题&#xff0c;其中主要问题之一就是&#xff1a;如何…

2019阿里云开年Hi购季满返活动火热报名中!

2019阿里云云上采购季活动已经于2月25日正式开启&#xff0c;从已开放的活动页面来看&#xff0c;活动分为三个阶段&#xff1a; 2月25日-3月04日的活动报名阶段、3月04日-3月16日的新购满返5折抢购阶段、3月16日-3月31日的续费抽豪礼5折抢购阶段。 整个大促活动包含1个主会场…

mybatis中resultType取出数据顺序不一致解决方法

原来我的查询返回resultType “map” &#xff0c; 也就是这个map&#xff0c;打乱了顺序。因为map并不能保证存入取出数据一致。 解决方法&#xff1a;resultType "map" 改为 resultType"java.util.LinkedHashMap"

2019云计算高光时刻:乱云飞渡 传统IT大溃败

前言&#xff1a;2019年&#xff0c;物理机最后一张王牌也败给了云计算&#xff0c;无论从成本还是性能的角度&#xff0c;都没有不选云计算的理由&#xff0c;这是一个时代的终结。 2019的云计算市场格局&#xff0c;依旧是马太效应凸显、大者恒大的趋势继续&#xff0c;但在…

java 集成 kafka 0.8.2.1 适配jdk1.6

文章目录一、版本说明二、实战2.1. 依赖2.2. 生产者代码2.3. 消费端代码2.4. 测试三、小伙伴疑难解答3.1. 首先新建一个maven项目3.2. 把我的依赖和代码复制过去3.3. 把我写的case调试通3.4. 找到左边External Libraries3.5. jar处理3.6. 打开非maven项目&#xff0c;添加jar3.…

阿里云MWC 2019发布7款重磅产品,助力全球企业迈向智能化

当地时间2月25日&#xff0c;在巴塞罗那举行的MWC 2019上&#xff0c;阿里云面向全球发布了7款重磅产品&#xff0c;涵盖无服务器计算、高性能存储、全球网络、企业级数据库、大数据计算等主要云产品&#xff0c;可满足电子商务、物流、金融科技以及制造等各行业企业的数字化转…

Spring Cloud Alibaba迁移指南(一):一行代码从 Hystrix 迁移到 Sentinel

自 Spring Cloud 官方宣布 Spring Cloud Netflix 进入维护状态后&#xff0c;我们开始制作《Spring Cloud Alibaba迁移指南》系列文章&#xff0c;向开发者提供更多的技术选型方案&#xff0c;并降低迁移过程中的技术难度。 第一篇&#xff0c;我们对Hystrix、Resilience4j 和…

util中注入service

Autowiredprivate GovCustomerService service;private static GovCustomerService govCustomerService;PostConstruct //完成对service的注入public void init() {govCustomerService service;}

linux环境安装 kafka 0.8.2.1 jdk1.6

文章目录一、环境分布二、实战1. kafka下载2. 解压3. 配置4. 编写启动脚本5. 编写关闭脚本6. 赋予脚本可执行权限7. 脚本使用案例三、Config配置四、Consumer配置五、Producer配置很多小伙伴问我&#xff0c;为什么不用最新版本的kafka呢&#xff1f;关于这个问题&#xff0c;都…

元旦限时特惠,耳机、书籍等大降价

戳蓝字“CSDN云计算”关注我们哦&#xff01;今天是12月31日离2020年仅有不到一天的时间你们的2019年目标都实现了吗&#xff1f;在这一年你写了多少行代码改了多少个bug呢&#xff1f;2020年的愿望是否也是希望自己写的代码bug能少一些&#xff1f;小编的2020年希望能买到更多…

深入解读MySQL8.0 新特性 :Crash Safe DDL

前言 在MySQL8.0之前的版本中&#xff0c;由于架构的原因&#xff0c;mysql在server层使用统一的frm文件来存储表元数据信息&#xff0c;这个信息能够被不同的存储引擎识别。而实际上innodb本身也存储有元数据信息。这给ddl带来了一定的挑战&#xff0c;因为这种架构无法做到d…

mysql查询包含字符串(模糊查询)

mysql查询包含字符串更高效率的方法一、LOCATE语句SELECT column from table where locate(‘keyword’, condition)>0二、或是 locate 的別名 positionSELECT column from table where position(‘keyword’ IN condition)三、INSTR语句SELECT column from table where ins…