点击上方关注,All in AI中国
本文将介绍如何使用Keras和Google CoLaboratory与TPU一起训练LSTM模型,与本地计算机上的GPU相比,这样训练能大大缩短训练时间。
很长一段时间以来,我都在单张GTX 1070显卡上训练我的模型,它的单精度大约为8.18 TFlops。后来Google的Colab开放了免费的Tesla K80显卡,配备12GB RAM,8.73TFlops。直到最近,Colab的运行时类型选择器中还会弹出带有180 TFlops的Cloud TPU选项。这篇教程将简要介绍如何将现有的Keras模型转换为TPU模型,然后在Colab上训练。与在GTX1070上训练相比,TPU能够加速20倍。
我们将构建一个易于理解,但训练起来非常复杂的Keras模型,这样我们就可以稍微"预热"一下Cloud TPU。在IMDB情感分类任务上训练LSTM模型可能是一个很好的例子,因为相比密集层和卷积层来说,训练LSTM对算力要求更高。
工作流程概述:
- 使用静态输入batch_size构建用于功能API训练的Keras模型
- 将Keras模型转换为TPU模型
- 使用静态batch_size * 8训练TPU模型,并将权重保存到文件
- 创建一个结构相同,但输入批大小可变的Keras模型,用于推理
- 加载模型权重
- 基于推理模型进行预测
在阅读本文的同时,你可以上手试验相应的Colab Jupyter notebook:Keras_LSTM_TPU.ipynb。(https://colab.research.google.com/drive/1QZf1WeX3EQqBLeFeT4utFKBqq-ogG1FN)
首先,按照下图中的说明来激活在Colab运行中的TPU。
固定输入批尺寸
大多数情况下,CPU和GPU上对输入形状没有限制,但XLA/TPU环境下会强制使用固定的形状和批尺寸。
Can TPU包含8个TPU核心,作为独立的处理单元运行。如果没有使用所有八个核心,那TPU就不会得到充分利用。为了充分提高训练的矢量化速度,相比在单一GPU上训练的同样的模型,我们可以选择较大的批尺寸。总批尺寸大小为1024(每个核心128个)通常是一个很好的起点。
如果你要训练批尺寸较大的型号,请尝试慢慢减小批尺寸,以保证TPU内存放得下,只需确保总批尺寸为64的倍数(每核心批尺寸应该是8的倍数)。
值得一提,在批尺寸较大时,通常可以提高优化器的学习速率,以实现更快的收敛。你可以在本文中找到参考——"Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"。(https://arxiv.org/pdf/1706.02677.pdf)
在Keras中,要定义静态批处理尺寸,我们使用函数API,然后为输入层指定batch_size参数。请注意,模型构建在一个带有batch_size参数的函数中,因此我们之后可以很方便地创建在CPU或GPU上运行的模型,这些模型接受可变批尺寸的输入。
此外,我们在这里使用了tf.train.Optimizer而不是标准的Keras优化器,因为TPU对Keras优化器的支持还处于实验阶段。
将Keras模型转换为TPU模型
tf.contrib.tpu.keras_to_tpu_model函数将tf.keras模型转换为等价的TPU版本。
然后,我们使用标准的Keras方法来训练,保存权重并评估模型。请注意,batch_size设置为模型输入batch_size的八倍,因为输入样本在8个TPU核心上均匀分布。
我做了一个实验,用来比较在Windows PC上运行单个GTX1070和在Colab上运行的TPU之间的训练速度,结果如下:
- GPU和TPU都将输入批尺寸设为128。
- GPU:每个历元179秒。20个历元后的验证准确率达到了76.9%,总计3600秒。
- TPU:每个历元5秒(第一个历元需要49秒)。20个历元后的验证准确率达到了95.2%,总计150秒。
- 在20个历元之后TPU的验证准确度高于在GPU上的表现,那是因为TPU上同时训练8个批的样本(每个批的大小为128)。
在CPU上进行推理
一旦我们获得了模型权重,我们就可以像往常一样加载它,然后在CPU或GPU等其他设备上进行预测。我们想要推理模型接受可变的输入批大小,这可以使用之前的make_model()函数来实现。
你可以看到推理模型现在可以接受可变输入样本数目,
然后,你可以使用标准的fit()、evaluate()函数与推理模型。
结论以及进一步阅读
这篇快速教程向你简要介绍了如何利用Google Colab上的免费Cloud TPU资源更快地训练Keras模型。
云TPU文档:https://cloud.google.com/tpu/docs/
云TPU性能指南:https://cloud.google.com/tpu/docs/performance-guide
云TPU故障排除指南:https://cloud.google.com/tpu/docs/troubleshooting
XLA概述:https://www.tensorflow.org/performance/xla/