pytorch resnet50_PyTorch终于能用上谷歌云TPU,推理性能提升4倍,我们该如何薅羊毛?...

晓查 发自 凹非寺
量子位 报道 | 公众号 QbitAI

Facebook在PyTorch开发者大会上正式推出了PyTorch 1.3,并宣布了对谷歌云TPU的全面支持,而且还可以在Colab中调用云TPU。

之前机器学习开发者虽然也能在Colab中使用PyTorch,但是支持云TPU还是第一次,这也意味着你不需要购买昂贵的GPU,可以在云端训练自己的模型。

而且如果你是谷歌云平台(Google Cloud Platform)的新注册用户,还能获得300美元的免费额度。

c1242243c9ec98059515d92bb6becae5.png

现在PyTorch官方已经在Github上给出示例代码,教你如何免费使用谷歌云TPU训练模型,然后在Colab中进行推理。

训练ResNet-50

PyTorch先介绍了在云TPU设备上训练ResNet-50模型的案例。如果你要用云TPU训练其他的图像分类模型,操作方式也是类似的。

在训练之前,我们先要转到控制台创建一个新的虚拟机实例,指定虚拟机的名称和区域。

8ad7372137278bf369658a6949cc8ec4.png

如果要对Resnet50在真实数据上进行训练,需要选择具有最多CPU数量的机器类型。为了获得最佳效果,请选择n1-highmem-96机器类型。

然后选择Debian GNU/Linux 9 Stretch + PyTorch/XLA启动盘。如果打算用ImageNet真实数据训练,需要至少300GB的磁盘大小。如果使用假数据训练,默认磁盘大小只要20GB。

创建TPU

  1. 转到控制台中创建TPU。

  2. 在“Name”中指定TPU Pod的名称。

  3. 在“Zone”中指定云TPU的区域,确保它与之前创建的虚拟机在同一区域中。

  4. 在“ TPU Type”下,选择TPU类型,为了获得最佳效果,请选择v3-8TPU(8个v3)。

  5. 在“ TPU software version”下,选择最新的稳定版本。

  6. 使用默认网络。

  7. 设置IP地址范围,例如10.240.0.0。

官方建议初次运行时使用假数据进行训练,因为fake_data会自动安装在虚拟机中,并且只需更少的时间和资源。你可以使用conda或Docker进行训练。

在fake_data上测试成功后,可以开始尝试用在ImageNet的这样实际数据上进行训练。

用conda训练:

# Fill in your the name of your VM and the zone.
$ gcloud beta compute  ssh "your-VM-name" --zone "your-zone".
(vm)$ export TPU_IP_ADDRESS=your-ip-address
(vm)$ export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
(vm)$ ulimit -n 10240
(vm)$ conda activate torch-xla-0.5
(torch-xla-0.5)$ python /usr/share/torch-xla-0.5/pytorch/xla/test/test_train_imagenet.py --datadir=~/imagenet --model=resnet50 --num_epochs=90 --num_workers=64 --batch_size=128 --log_steps=200

用Docker训练:

# Fill in your the name of your VM and the zone.
$ gcloud beta compute ssh "your-VM-name" --zone "your-zone".
(vm)$ export TPU_IP_ADDRESS=your-ip-address
(vm)$ docker run --shm-size 128G -v ~/imagenet:/tmp/imagenet -e XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470" gcr.io/tpu-pytorch/xla:r0.5 python3 pytorch/xla/test/test_train_imagenet.py --model=resnet50 --num_epochs=90 --num_workers=64 --log_steps=200 --datadir=/tmp/imagenet

在n1-highmem-96的虚拟机上选用完整v3-8 TPU进行训练,第一个epoch通常需要约20分钟,而随后的epoch通常需要约11分钟。该模型在90个epoch后达到约76%的top-1准确率。

为了避免谷歌云后续进行计费,在训练完成后请记得删除虚拟机和TPU。

性能比GPU提升4倍

训练完成后,我们就可以在Colab中导入自己的模型了。

打开notebook文件,在菜单栏的Runtime中选择Change runtime type,将硬件加速器的类型改成TPU。

2359fe1fec9dc15c8d1c35d2778f1e03.png

先运行下面的代码单元格,确保可以访问Colab上的TPU:

import os
assert os.environ[‘COLAB_TPU_ADDR’], ‘Make sure to select TPU from Edit > Notebook settings > Hardware accelerator’

然后在Colab中安装兼容PyTorch/TPU组件:

DIST_BUCKET="gs://tpu-pytorch/wheels"
TORCH_WHEEL="torch-1.15-cp36-cp36m-linux_x86_64.whl"
TORCH_XLA_WHEEL="torch_xla-1.15-cp36-cp36m-linux_x86_64.whl"
TORCHVISION_WHEEL="torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl"

# Install Colab TPU compat PyTorch/TPU wheels and dependencies
!pip uninstall -y torch torchvision
!gsutil cp "$DIST_BUCKET/$TORCH_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCH_XLA_WHEEL" .
!gsutil cp "$DIST_BUCKET/$TORCHVISION_WHEEL" .
!pip install "$TORCH_WHEEL"
!pip install "$TORCH_XLA_WHEEL"
!pip install "$TORCHVISION_WHEEL"
!sudo apt-get install libomp5

接下来就可以导入你要训练好的模型和需要进行推理的图片了。

在PyTorch上使用TPU对性能的提升到底有多明显呢?官方选用了v2-8的一个核心,即1/8 TPU的情形,与使用英伟达Tesla K80 GPU进行对比,实测显示推理时间大大缩短,性能约有4倍左右的提升。

1f6a6c16d5fd8f72f6eb6fbf343bd823.png

GitHub地址:
https://github.com/pytorch/xla/tree/master/contrib/colab

作者系网易新闻·网易号“各有态度”签约作者

大会启幕!预见智能科技新未来

量子位MEET 2020智能未来大会启幕,将携手优秀AI企业、杰出科研人员呈现一场高质量行业盛会!详情可点击图片:

d5b5de3ec67071747bf368de9be673db.png

榜单征集!三大奖项,锁定AI Top玩家

2019中国人工智能年度评选启幕,将评选领航企业、商业突破人物、最具创新力产品3大奖项,并于MEET 2020大会揭榜,欢迎优秀的AI公司扫码报名!

4fa41794687f0c53f0c9a408eef94be6.png01a0a8eae01261532e4b5a127d87adf5.png

量子位 QbitAI · 头条号签约作者

վ'ᴗ' ի 追踪AI技术和产品新动态

喜欢就点「好看」吧 ! 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/455710.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

x264里的2pass指的是什么意思? x264源代码分析2.encode()

A:x264里的2pass指的是什么意思?另外stat是什么意思, 比如有个参数--stats <string> Filename for 2 pass stats [/"%s/"]/n", defaults->rc.psz_stat_out );stats在这是什么意思? 2pass是2次编码的意思&#xff0c;stats是统计文档的名称&a…

项目启动居然如此重要!

项目的启动阶段比较短&#xff0c;项目经理往往容易忽视这个阶段&#xff0c;但是&#xff0c;项目的启动却具有着重要的意义。 定基调&#xff1a; 基调包括工作的节奏、团队氛围和沟通风格等。 一首歌的第一句决定了这首歌的基调&#xff0c;如何唱好这第一句就是项目启动所要…

mysql数据库导入导出文件sql文件

window下 1.导出整个数据库 mysqldump -u 用户名 -p 数据库名 > 导出的文件名 mysqldump -u dbuser -p dbname > dbname.sql 2.导出一个表 mysqldump -u 用户名 -p 数据库名 表名> 导出的文件名 mysqldump -u dbuser -p dbname users> dbname_users.sql 3.导出…

Android Studio主题设置、颜色背景配置

2019独角兽企业重金招聘Python工程师标准>>> color-themes 效果展示 打开http://color-themes.com/有很多样式可供选择 1. Monokai Sublime Text 3(color theme) 2. Solarized Light (color theme) 3. Visual Studio 2015 Dark(color theme) 导入方式 下载主…

JavaScript中的函数

js函数 *第一种是使用function语句定义函数 function abc(){alert(abc); }*第二种是在表达式中定义函数 var 函数名 function\(参数1&#xff0c;参数2&#xff0c;…\){函数体};//例如&#xff1a;//定义var add function\(a,b\){return ab;}//调用函数document.write\(a…

x264源代码分析1。fread()

相关说明:1. 使用版本: x264-cvs-2004-05-11 2. 这次的分析基本上已经将代码中最难理解的部分做了阐释,对代码的主线也做了剖析,如果这个主线理解了,就容易设置几个区间,进行分工阅读,将各个区间击破了. 3. 需要学习的知识:a) 编码器的工作流程.b) H.264的码流结构,像x264_sp…

在centos下安装pycrypto报错 RuntimeError: autoconf error

解决&#xff1a;yum -y install gcc File "/usr/lib64/python3.6/distutils/dist.py", line 974, in run_command cmd_obj.run() File "/usr/lib64/python3.6/distutils/command/build.py", line 135, in run self.run_command(cm…

Java多线程实现异步调用

在Java平台,实现异步调用的角色有如下三个角色&#xff1a;调用者、 提货单 、真实数据&#xff0c;一个调用者在调用耗时操作,不能立即返回数据时,先返回一个提货单 .然后在过一断时间后凭提货单来获取真正的数据.去蛋糕店买蛋糕&#xff0c;不需要等蛋糕做出来(假设现做要很长…

sql server 2008 r2卸载重装_免费下载:Intouch软件、Windows操作系统、SQL数据库,VB6.0、C#...

为大家整理了常用的Windows操作系统和安装软件&#xff0c;基本上都是经过我们项目测试OK的版本&#xff0c;以后项目调试就齐全了&#xff0c;不用再“东奔西走”&#xff0c;“小鹿乱撞”了。整理不易&#xff0c;若对您有帮助请关注并转发&#xff0c;以便帮助到更多的人。I…

Android ToolBar 使用完全解析

ToolBar简介 ToolBar是Android 5.0推出的一个新的导航控件用于取代之前的ActionBar&#xff0c;由于其高度的可定制性、灵活性、具有Material Design风格等优点&#xff0c;越来越多的应用也用上了ToolBar&#xff0c;比如常用的知乎软件其顶部导航栏正是使用ToolBar。官方考虑…

【零散积累】传输文件(sz/rz/scp命令)

来自wiki迁移页面路径&#xff1a;刘旺的主页 / 个人零散积累 / 01> 传输文件&#xff08;sz/rz/scp命令&#xff09; 工作中的传输文件会出现在linux之间&#xff0c;或者linux与windows之间。 一、怎么实现linux与windows之间的文件传输&#xff1f; 1.sz和rz是什么 s…

x264_macroblock_cache_load()

功能:完成将已编码数据参数和待编码数据装入到h->mb.cache中,下图是BUF中存储的数据在以MB为单位的时候的存储顺序 x264_macroblock_cache_load( h, i_mb_x, i_mb_y );//是把当前宏块的up宏块和left宏块的intra4x4_pred_mode&#xff0c;non_zero_count加载进来&#xff0c…

U(优)盘安装FreeBSD-9.0+GNOME_lite桌面

贴图在我的主页&#xff1a;http://hi.baidu.com/daodej/item/26313f4fc3db51ef1f19bcc6 修订于&#xff1a;2012/07/04 标题&#xff1a;U(优)盘安装FreeBSD-9.0GNOME_lite桌面&#xff0c;boot0启动XP(Windows)、FreeBSD、Ubuntu(Linux)三系统 【黑括号表示说明&#xff0c;中…

【零散积累】 vim常用操作

类型 操作 含义 删除 dd 删除游标所在的一整行(常用) ndd n为数字。删除光标所在的向下n行&#xff0c;例如20dd则是删除光标所在的向下20行 d1G 删除光标所在到第一行的所有数据 dG 删除光标所在到最后一行的所有数据 d$ 删除光标所在处&#xff0c;到该…

生活中常见物联网实例_物联网网关常见问题解答(一)

1.为什么物联网解决方案需要网关&#xff1f;物联网网关弥合了设备&#xff0c;传感器&#xff0c;设备&#xff0c;系统和云之间的通信鸿沟。通过系统地连接云&#xff0c;物联网网关提供了本地处理和存储&#xff0c;并具有基于传感器输入的数据自主控制现场设备的功能。物联…

predict_16x16[i_mode]( p_dst, i_stride )lowres

h->predict_16x16[i_mode]( p_dst, i_stride ); 计算对应预测模式时的预测采样值。输出放到dst指向的数组中。Pred0ct_16x16是7个元素指向的数组&#xff0c;数组的每个元素是一个指向函数的指针变量&#xff0c;在x264_predict_16x16_init函数初始这个指针数组。7个元素分…

【零散积累】shell脚本学习

来自wiki迁移页面路径&#xff1a;刘旺的主页 / 个人零散积累 / 03> shell脚本学习 case Shell case语句&#xff08;多分支条件判断&#xff09; $( ) Linux—shell中$(( ))、$( )、与${ }的区别 - chengd - 博客园 在bash中&#xff0c;$( )与 &#xff08;反引号&…

mysql 表锁-解锁

遇到问题“”用工具navicat打开一张表的时候&#xff0c;有的时候会发现这张表怎么打不开&#xff0c;关了navicat工具&#xff0c;再打开&#xff0c;也是同样的状态。查看表锁&#xff1a;show OPEN TABLES where In_use > 0;查看是否是表锁住了。-- 查看进程号 show proc…

alsa 测试 linux_Electron 构建步骤 (Linux)

遵循下面的引导&#xff0c;在 Linux 上构建 Electron .PrerequisitesPython 2.7.x. 一些发行版如 CentOS 仍然使用 Python 2.6.x &#xff0c;所以或许需要 check 你的 Python 版本&#xff0c;使用 python -V.Node.js v0.12.x. 有很多方法来安装 Node. 可以从 Node.js下载原文…

JavaScript中的数学对象Math

js数学对象Math //四舍五入 var res Math.round(5.921);//获取最大值 var res Math.max(10,23,523,43,65,46,32,32);//获取最小值 var res Math.min(12312,324,32,42,3,23,412,4332,21,3,-1);//获取绝对值 var res Math.abs(-100);//退一取整 var res Math.floor(1.9);//…