控制流
只要对tensorflow有一点了解,都应该知道graph是tensorflow最基本的一个结构。Tensorflow的所有计算都是以图作为依据的。图的nodes表示一些基本的数学运算,比如加法,卷积,pool等。Node使用protoBuf来进行描述,包括node的名字,op,input等,详细可以参看tensorflow中的node_def.proto文件。Node对应的op使用C++来进行实现。图中的边表示了数据流动的方向以及节点之间的依赖关系。比如A->B就表示B必须在A执行完之后才能够执行。以下是inception网络的图结构。
当了解了tensorflow的一些基本op之后,我们会存在这样的疑问。对于需要分支跳转,循环的部分,tensorflow是如何实现的。比如tf.cond,tf.while_loop这些语句在底层是如何表示的呢?tensorflow定义一些基本的控制原语,通过一定的组合可以完成高层次控制语言的实现,比如a=op?C:D这样的语句。
tensorflow控制流的设计原则是通过引入最少的控制模块,利用这些控制模块可以表示很多复杂应用广泛的控制过程。这些控制模块还能够适应并发,分布式运算,同时能够实现自动微分。在tensorflow,一个计算节点在执行帧(execution frame,类比进程的栈帧)里执行。控制流原语负责创建和管理执行。直观地理解,TF运行时建立一个个执行帧,在执行帧里执行所有属于这个执行帧的计算节点。执行帧可以嵌套(父子关系)。来自不同执行帧且没有依赖关系的计算节点可以并行计算。这里介绍5种最基本的控制原语。
1 switch
依据控制条件p,选择性将输入数据d传播到两个输出端。
2 merge
Merge算子将一个可用输入传给输出,只要有任意一个输入可用,switch就可以执行。
3 enter
Enter算子依据执行帧唯一标识名称将输入传递到相应执行帧。Enter算子用于将一个tensor从一个执行帧传递到子执行帧。
4 exit
Exit算子用于将子执行帧的数据传递父执行帧。
5 nextIteration
netIteration算子可以将其输入传递到当前执行帧的下一个iteration。Tensorflow的runtime可以随时跟踪执行帧中的iteration。任何一个op都有一个唯一的iteration ID进行标识。
现在我们来看这几种原子指令是如何实现条件判断和循环的。
Tensorflow中条件判断cond(pre, fn1, fn2)实现的伪代码如下:
首先创建一个条件控制context,这个context会调用两个不同的计算图。使用哪个计算图由条件pre来决定。最后将调用两个计算图的结果通过merge节点输出到下一个计算图。使用merge节点是为了保证只要有一个图有了结果就可以马上输送到下一个节点进行后续计算。用图描述如下:
对于循环语句,tensorflow中使用一下伪代码来完成:
首先创建一个循环控制context。然后创建一个enter和merge节点来导入循环体变量。使用enter节点是通过帧名识别这个循环体从而去执行。Merge是将循环变量传递给判断条件图,进行循环判定。加入的switch节点用于对循环条件判断的结果进行计算图选择。循环体内部计算结果需要进行多次循环,所以进入了nextIteration节点。Switch的false输出用于终止循环,所以进入exit节点将最终结果输出。
有了这些控制节点,tensorflow就可以将一个图分割成多个子图,并部署到多个硬件执行设备上。在两个子图分割处,添加send和receive节点用于不同设备之间数据通信。Tensorflow对节点如何分配没有限制,只要这个节点可以在这个设备上执行,就可以分配。如果没有这些控制节点,那么一幅图中的一个节点就只能执行一次,有了这些控制节点,计算图就能够有更多计算方式。一个节点可以循环执行多次,还可以被分配到不同设备执行。
Tensorflow可以支持自动微分。当用户建立了计算图和定义了loss函数后,tensorflow会根据计算图的结构建立反向传播图。给定一个计算节点,可以通过映射到计算公式方式进而求取微分。从而能够找出其反向传播的节点的表示。对于控制节点来说,enter的反向传播节点是exit,switch的反向传播节点是merge(对于cond来说),或者nextIteration+merge(对于while_loop来说)。Merge的反向传播节点是switch。nextIteration的反向传播节点是identity。Enter的反向传播节点是exit。有了这些对应关系,就可以自动来推断反向传播图。从而求取梯度了。而且可以在多个设备上进行计算分配。
比如对于cond条件判断,如果其不是loop中的条件判断,那么其正向传播图和反向传播图的映射关系为:
优化器
优化器是在原始计算图基础上进行优化,提高计算在硬件上的效率。优化主要有几个目标:简化图结构,降低最大的硬件存储使用率,进行硬件友好的图转化。图优化方法有很多,有些和硬件无关,有些和硬件的具体实现细节相关。高层次优化是对图进行一定简化,它对硬件是透明的。通过简化可以去除一些冗余计算。比如常数折叠,多余控制节点去除等。还有一些利用结合律,分配律等对公式进行简化,比如:
1) 图的简化可以删除一些冗余计算,将图用最终等效结果替换。比如一个建立tensor的过程:
将tensor的shape创建和数据创建合并,直接用常数替换。这样就去除了shape创建过程。
2) 常数折叠可以将两个以上常数用一个常数替代,需要优化器进行一些计算。比如:
3) 代数优化利用算术的性质进行一定转化。比如:
addN相当于硬件上可支持的一个并行计算单元,可以一次计算多个输入。所以可以将连续的三个加法用一个并行加法替换。
第二个利用了算术的分配律和结合律将三个具有相同乘数提取出来。最后一个对逻辑进行了等效转化,从而减少了计算节点。
这个matrix+scalar的时候需要对scalar先进行广播,然后再加。转化后减少了广播次数。
这两个消除了冗余计算。
4) op融合将多个计算节点融合为一个节点来计算。这个是和硬件有关的,比如一个硬件计算单元可以完成conv+batch_norm,那么就可以实现这样的计算融合,就不需要单独多出来一个计算单元。常用的op融合有:
5) 存储优化的目的是为了降低对片外的访问频率,这样能够提高数据运算效率,减少等待数据加载时间。
往期文章
1 GPU,多核CPU和AI芯片的存储结构
2 让推荐系统更快:Nvidia Merlin架构