keras训练完以后怎么预测_还在使用“龟速”的单显卡训练模型?动动手,让TPU节省你的时间...

点击上方关注,All in AI中国

本文将介绍如何使用Keras和Google CoLaboratory与TPU一起训练LSTM模型,与本地计算机上的GPU相比,这样训练能大大缩短训练时间。

f4dda0965b13cd3fdc03ae60f4d8b13b.png

很长一段时间以来,我都在单张GTX 1070显卡上训练我的模型,它的单精度大约为8.18 TFlops。后来Google的Colab开放了免费的Tesla K80显卡,配备12GB RAM,8.73TFlops。直到最近,Colab的运行时类型选择器中还会弹出带有180 TFlops的Cloud TPU选项。这篇教程将简要介绍如何将现有的Keras模型转换为TPU模型,然后在Colab上训练。与在GTX1070上训练相比,TPU能够加速20倍。

我们将构建一个易于理解,但训练起来非常复杂的Keras模型,这样我们就可以稍微"预热"一下Cloud TPU。在IMDB情感分类任务上训练LSTM模型可能是一个很好的例子,因为相比密集层和卷积层来说,训练LSTM对算力要求更高。

工作流程概述:

  • 使用静态输入batch_size构建用于功能API训练的Keras模型
  • 将Keras模型转换为TPU模型
  • 使用静态batch_size * 8训练TPU模型,并将权重保存到文件
  • 创建一个结构相同,但输入批大小可变的Keras模型,用于推理
  • 加载模型权重
  • 基于推理模型进行预测

在阅读本文的同时,你可以上手试验相应的Colab Jupyter notebook:Keras_LSTM_TPU.ipynb。(https://colab.research.google.com/drive/1QZf1WeX3EQqBLeFeT4utFKBqq-ogG1FN)

首先,按照下图中的说明来激活在Colab运行中的TPU。

80c751cfce9118d7a0f3aea2a65ce646.png

激活TPU

固定输入批尺寸

大多数情况下,CPU和GPU上对输入形状没有限制,但XLA/TPU环境下会强制使用固定的形状和批尺寸。

Can TPU包含8个TPU核心,作为独立的处理单元运行。如果没有使用所有八个核心,那TPU就不会得到充分利用。为了充分提高训练的矢量化速度,相比在单一GPU上训练的同样的模型,我们可以选择较大的批尺寸。总批尺寸大小为1024(每个核心128个)通常是一个很好的起点。

如果你要训练批尺寸较大的型号,请尝试慢慢减小批尺寸,以保证TPU内存放得下,只需确保总批尺寸为64的倍数(每核心批尺寸应该是8的倍数)。

值得一提,在批尺寸较大时,通常可以提高优化器的学习速率,以实现更快的收敛。你可以在本文中找到参考——"Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"。(https://arxiv.org/pdf/1706.02677.pdf)

在Keras中,要定义静态批处理尺寸,我们使用函数API,然后为输入层指定batch_size参数。请注意,模型构建在一个带有batch_size参数的函数中,因此我们之后可以很方便地创建在CPU或GPU上运行的模型,这些模型接受可变批尺寸的输入。

74a8d2e19c8fde1fcafdb4acfac198b3.png

此外,我们在这里使用了tf.train.Optimizer而不是标准的Keras优化器,因为TPU对Keras优化器的支持还处于实验阶段。

将Keras模型转换为TPU模型

tf.contrib.tpu.keras_to_tpu_model函数将tf.keras模型转换为等价的TPU版本。

86c584e222b04e31ba37735e86b9ee24.png

然后,我们使用标准的Keras方法来训练,保存权重并评估模型。请注意,batch_size设置为模型输入batch_size的八倍,因为输入样本在8个TPU核心上均匀分布。

a7f7262a5f5a2903955bccbeb8828d24.png

我做了一个实验,用来比较在Windows PC上运行单个GTX1070和在Colab上运行的TPU之间的训练速度,结果如下:

  • GPU和TPU都将输入批尺寸设为128。
  • GPU:每个历元179秒。20个历元后的验证准确率达到了76.9%,总计3600秒。
  • TPU:每个历元5秒(第一个历元需要49秒)。20个历元后的验证准确率达到了95.2%,总计150秒。
  • 在20个历元之后TPU的验证准确度高于在GPU上的表现,那是因为TPU上同时训练8个批的样本(每个批的大小为128)。

在CPU上进行推理

一旦我们获得了模型权重,我们就可以像往常一样加载它,然后在CPU或GPU等其他设备上进行预测。我们想要推理模型接受可变的输入批大小,这可以使用之前的make_model()函数来实现。

bc4039b346df25e1dd24da3f71df66b2.png

你可以看到推理模型现在可以接受可变输入样本数目,

6bad7d1498641d83b1045fb652c683d4.png

然后,你可以使用标准的fit()、evaluate()函数与推理模型。

结论以及进一步阅读

这篇快速教程向你简要介绍了如何利用Google Colab上的免费Cloud TPU资源更快地训练Keras模型。

云TPU文档:https://cloud.google.com/tpu/docs/

云TPU性能指南:https://cloud.google.com/tpu/docs/performance-guide

云TPU故障排除指南:https://cloud.google.com/tpu/docs/troubleshooting

XLA概述:https://www.tensorflow.org/performance/xla/

5f6982a3dcefc2e09bd707df85b060f5.png

编译出品

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/276445.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

手把手教你写个小程序定时器管理库

背景凹凸曼是个小程序开发者,他要在小程序实现秒杀倒计时。于是他不假思索,写了以下代码:Page({init: function () {clearInterval(this.timer)this.timer setInterval(() > {// 倒计时计算逻辑console.log(setInterval)})}, })可是&…

[New Portal]Windows Azure Virtual Machine (14) 在本地制作数据文件VHD并上传至Azure(1)

《Windows Azure Platform 系列文章目录》 之前的内容里,我介绍了如何将本地的Server 2012中文版 VHD上传至Windows Azure,并创建基于该Server 2012 VHD的虚拟机。 我们知道,VHD不仅仅可以保存操作系统,而且可以保存数据文件。 如…

python 退出程序_Python:用Ctrl+C解决终止多线程程序的问题!(建议收藏)

前言:今天为大家带来的内容是Python:用CtrlC解决终止多线程程序的问题!文章中的代码具有不错的参考意义,希望在此能够帮助到各位!(多数代码用图片的方式呈现出来,方便各位观看与收藏)出发点:前段时间&#…

若川知乎高赞:有哪些必看的 JS 库?

欢迎星标我的公众号,回复加群,长期交流学习我的知乎回答目前2w阅读量,270赞,现在发到公众号声明原创。必看的js库?只有当前阶段值不值看。我从去年7月起看一些前端库的源码,历时一年才写了八篇《学习源码整…

基于EasyUI的Web应用程序及过去一年的总结

前言 一个多月之前已经提交了离职申请,好在领导都已经批准了,过几天就办理手续了,在此感谢领导的栽培与挽留,感谢各位同事在工作中的给我的帮助,离开这个团队确实有一些不舍,不为别的,只因为这个…

快速使用Vue3最新的15个常用API

之前我写了一篇博客介绍了Vue3的新特性,简单了解了一下Vue3都有哪些特色,并且在文末带大家稍微体验了一下Vue3中 Compsition API 的简单使用上一篇文章地址:紧跟尤大的脚步提前体验Vue3新特性,你不会还没了解过Vue3吧因为这个月的…

超级马里奥代码_任天堂的源码泄露,揭示超级马里奥的前世之生

被黑客盯上的任天堂任天堂遭到了史上最大规模的黑客攻击,Wii 完整源码、设计以及《宝可梦》多部作品的信息遭到泄露,而此次泄露事件的后续影响似乎也爆发了出来。《马里奥赛车》和《超级马里奥世界2》(耀西岛)的早期原型视频,以及《超级马里奥…

漫画 | 前端发展史的江湖恩怨情仇

时间总是过得很快, 似乎快得让人忘记了昨天,前端WEB领域的发展更是如此,转眼间已是近30年,时光荏苒,初心不变,在一代又一代前端人的努力下,前端已经是互联网不可或缺的一部分。然而很多前端打工…

10 个你可能还不知道 VS Code 使用技巧

经常帮一些同学 One-on-One 地解决问题,在看部分同学使用 VS Code 的时候,有些蹩脚,实际上一些有用的技巧能够提高我们的日常工作效率。NO.1一、重构代码VS Code 提供了一些快速重构代码的操作,例如:将一整段代码提取为…

构建安全的Xml Web Service系列之如何察看SoapMessage

上一篇文章地址:构建安全的Xml Web Service系列一之初探使用Soap头 (5-22 12:53) 要分析Xml Web Service的安全性,首先要解决的问题是我们能了解和清楚Soap消息的格式和内容,如果获得不了SoapMessage,分析如何能构建安全Xml w…

前端高效开发必备的 js 库梳理

之前有很多人问学好前端需要学习哪些 js 库, 主流框架应该学 vue 还是 react ? 针对这些问题, 笔者来说说自己的看法和学习总结.首先我觉得在学习任何知识之前必须要有一个明确的学习目标, 知道自己为什么要学它, 而不是看网上说的一股脑的给你灌输各种知识, 让你学习各种库, …

交叉报表crosstab隐藏列名显示_SAP软件 报表查询之 输出格式设置

SAP不仅是功能强大、逻辑严谨的ERP软件,还提供了强大的报表查询功能。SAP的ALV报表展示功能是SAP的一大特点,实现了类似于EXCEL的功能。使用好ALV报表功能可以方便用户从SAP中取到想要的数据,尤其是财务用户。大家在使用SAP报表时&#xff0c…

seo每日一贴_白杨SEO:我看ZAC的外贸SEO应该怎样做?(策略篇)

前言:这是白杨SEO公众号更新第64篇。本该写写头条SEO啥的,最近在师徒培训讲站内SEO时有旁听同学提到后面讲讲谷歌SEO怎么样,因为谷歌全世界搜索市场占有率,所以外贸SEO最主要还是做谷歌SEO。以白杨特意又去了前辈ZAC的SEO每日一贴…

[转]网页栅格系统研究(2):蛋糕的切法

[出自]http://lifesinger.org/blog/2008/10/grid-system-2/首先澄清一个应用场景问题。研究(1)中指出,对于结构复杂的网站,不少设计师们喜欢采用960固定宽度布局。但要注意的是,960并不是万能钥匙,大部分网…

Vue3响应式原理

关注若川视野,回复"pdf" 领取资料,回复"加群",可加群长期交流学习本文结构- 关于Vue3- Vue2响应式原理回顾- Vue3响应式方案- Vue3响应式原理- 手写mini版Vue3响应式本文共计:2349字2图预计阅读时间&#xff…

找准切入点,调试看源码,事半功倍

关注若川视野,回复"pdf" 领取资料,回复"加群",可加群长期交流学习最近写了很多源码分析相关的文章,React、Vue 都有,想把我阅读源码的一些心得分享给大家。React:React 架构的演变 - 从…

Android布局大全

Android的界面是有布局和组件协同完成的,布局好比是建筑里的框架,而组件则相当于建筑里的砖瓦。组件按照布局的要求依次排列,就组成了用户所看见的界面。 所有的布局方式都可以归类为ViewGroup的5个类别,即ViewGroup的5个直接子类…

java实现加减乘除运算符随机生成十道题并判断对错_2020年Java面试题(3年的工作总结),最全的知识点总结...

这份Java面试题整整花了三个月的时间来整理,都是自己再工作中总结出来,记住多少就写多少,希望这份资料可以帮助你们,文末有其余部分资料的领取方式.Redis12道面试题1.什么是Redis?答:Remote Dictionary Ser…

.NET 中的泛型 101

1.1.1 摘要 图1 C# 泛型介绍 在接触泛型之前,我们编程一般都是使用具体类型(char, int, string等)或自定义类型来定义我们变量,如果我们有一个功能很强的接口,而且我们想把它提取或重构成一个通用的接口,使…

年底了,给想进阶的的前端朋友一些福利

2020 年,很多朋友都经历了一段比较艰难的求职季。年末,“就业寒冬”迎来了一丝暖阳,很多中大型互联网公司扩大了未来一年的招聘需求。前不久,字节跳动放出了年末要招 1 万人的消息,腾讯校招规模也将扩张至 5000 人&…