MLX vs MPS vs CUDA:苹果新机器学习框架的基准测试

如果你是一个Mac用户和一个深度学习爱好者,你可能希望在某些时候Mac可以处理一些重型模型。苹果刚刚发布了MLX,一个在苹果芯片上高效运行机器学习模型的框架。

最近在PyTorch 1.12中引入MPS后端已经是一个大胆的步骤,但随着MLX的宣布,苹果还想在开源深度学习方面有更大的发展。

在本文中,我们将对这些新方法进行测试,在三种不同的Apple Silicon芯片和两个支持cuda的gpu上和传统CPU后端进行基准测试。

这里把基准测试集中在图卷积网络(GCN)模型上。这个模型主要由线性层组成,所以对于其他的模型也应该得到类似的结果。

创造环境

要为MLX构建环境,我们必须指定是使用i386还是arm架构。使用conda,可以使用:

 CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forgeconda activate mlx

如果检查你的env是否实际使用了arm,下面命令的输出应该是arm,而不是i386(因为我们用的Apple Silicon):

 python -c "import platform; print(platform.processor())"

然后就是使用pip安装MLX:

 pip install mlx

GCN模型

GCN模型是图神经网络(GNN)的一种,它使用邻接矩阵(表示图结构)和节点特征。它通过收集邻近节点的信息来计算节点嵌入。每个节点获得其邻居特征的平均值。这种平均是通过将节点特征与标准化邻接矩阵相乘来完成的,并根据节点度进行调整。为了学习这个过程,特征首先通过线性层投射到嵌入空间中。

我们将使用MLX实现一个GCN层和一个GCN模型:

 import mlx.nn as nnclass GCNLayer(nn.Module):def __init__(self, in_features, out_features, bias=True):super(GCNLayer, self).__init__()self.linear = nn.Linear(in_features, out_features, bias)def __call__(self, x, adj):x = self.linear(x)return adj @ xclass GCN(nn.Module):def __init__(self, x_dim, h_dim, out_dim, nb_layers=2, dropout=0.5, bias=True):super(GCN, self).__init__()layer_sizes = [x_dim] + [h_dim] * nb_layers + [out_dim]self.gcn_layers = [GCNLayer(in_dim, out_dim, bias)for in_dim, out_dim in zip(layer_sizes[:-1], layer_sizes[1:])]self.dropout = nn.Dropout(p=dropout)def __call__(self, x, adj):for layer in self.gcn_layers[:-1]:x = nn.relu(layer(x, adj))x = self.dropout(x)x = self.gcn_layers[-1](x, adj)return x

可以看到,mlx的模型开发方式与tf2基本一样,都是调用

__call__

进行前向传播,其实torch也一样,只不过它自定义了一个forward函数。

下面就是训练

 gcn = GCN(x_dim=x.shape[-1],h_dim=args.hidden_dim,out_dim=args.nb_classes,nb_layers=args.nb_layers,dropout=args.dropout,bias=args.bias,)mx.eval(gcn.parameters())optimizer = optim.Adam(learning_rate=args.lr)loss_and_grad_fn = nn.value_and_grad(gcn, forward_fn)# Training loopfor epoch in range(args.epochs):# Loss(loss, y_hat), grads = loss_and_grad_fn(gcn, x, adj, y, train_mask, args.weight_decay)optimizer.update(gcn, grads)mx.eval(gcn.parameters(), optimizer.state)# Validationval_loss = loss_fn(y_hat[val_mask], y[val_mask])val_acc = eval_fn(y_hat[val_mask], y[val_mask])

在MLX中,计算是惰性的,这意味着eval()通常用于在更新后实际计算新的模型参数。而另一个关键函数是nn.value_and_grad(),它生成一个计算参数损失的函数。第一个参数是保存当前参数的模型,第二个参数是用于前向传递和损失计算的可调用函数。它返回的函数接受与forward函数相同的参数(在本例中为forward_fn)。我们可以这样定义这个函数:

 def forward_fn(gcn, x, adj, y, train_mask, weight_decay):y_hat = gcn(x, adj)loss = loss_fn(y_hat[train_mask], y[train_mask], weight_decay, gcn.parameters())return loss, y_hat

它仅仅包括计算前向传递和计算损失。Loss_fn()和eval_fn()定义如下:

 def loss_fn(y_hat, y, weight_decay=0.0, parameters=None):l = mx.mean(nn.losses.cross_entropy(y_hat, y))if weight_decay != 0.0:assert parameters != None, "Model parameters missing for L2 reg."l2_reg = sum(mx.sum(p[1] ** 2) for p in tree_flatten(parameters)).sqrt()return l + weight_decay * l2_regreturn ldef eval_fn(x, y):return mx.mean(mx.argmax(x, axis=1) == y)

损失函数是计算预测和标签之间的交叉熵,并包括L2正则化。由于L2正则化还不是内置特性,需要手动实现。

本文的完整代码:https://github.com/TristanBilot/mlx-GCN

可以看到除了一些细节函数调用的差别,基本的训练流程与pytorch和tf都很类似,但是这里的一个很好的事情是消除了显式地将对象分配给特定设备的需要,就像我们在PyTorch中经常使用.cuda()和.to(device)那样。这是因为苹果硅芯片的统一内存架构,所有变量共存于同一空间,也就是说消除了CPU和GPU之间缓慢的数据传输,这样也可以保证不会再出现与设备不匹配相关的烦人的运行时错误。

基准测试

我们将使用MLX与MPS, CPU和GPU设备进行比较。我们的测试平台是一个2层GCN模型,应用于Cora数据集,其中包括2708个节点和5429条边。

对于MLX, MPS和CPU测试,我们对M1 Pro, M2 Ultra和M3 Max进行基准测试。在两款NVIDIA V100 PCIe和V100 NVLINK上进行测试

MPS:比M1 Pro的CPU快2倍以上,在其他两个芯片上,与CPU相比有30-50%的改进。

MLX:比M1 Pro上的MPS快2.34倍。与MPS相比,M2 Ultra的性能提高了24%。在M3 Pro上MPS和MLX之间没有真正的改进。

CUDA V100 PCIe & NVLINK:只有23%和34%的速度比M3 Max与MLX,这里的原因可能是因为我们的模型比较小,所以发挥不出V100和NVLINK的优势(NVLINK主要GPU之间的数据传输大的情况下会有提高)。这也说明了苹果的统一内存架构的确可以消除CPU和GPU之间缓慢的数据传输。

总结

与CPU和MPS相比,MLX可以说是非常大的金币,在小数据量的情况下它甚至接近特斯拉V100的性能。也就是说我们可以使用MLX跑一些不是那么大的模型,比如一些表格数据。

从上面的基准测试也可以看到,现在可以利用苹果芯片的全部力量在本地运行深度学习模型(我一直认为MPS还没发挥苹果的优势,这回MPS已经证明了这一点)。

MLX刚刚发布就已经取得了惊人的影响力,并展示了巨大的潜力。相信未来几年开源社区的进一步增强,可以期待在不久的将来更强大的苹果芯片,将MLX的性能提升到一个全新的水平。

另外也说明了MPS(虽然也发布不久)还是有巨大的发展空间的,毕竟切换框架是一件很麻烦的事情,如果MPS能达到MLX 80%或者90%的速度,我想不会有人去换框架的。

最后说到框架,现在已经有了Pytorch,TF,JAX,现在又多了一个MLX。各种设备、各种后端包括:TPU(pytorch使用的XLA),CUDA,ROCM,现在又多了一个MPS。

https://avoid.overfit.cn/post/eb87d12f29eb4665adb43ad59fd3d64f

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/241197.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在Excel中,如何简单快速地删除重复项,这里提供详细步骤

当你在Microsoft Excel中使用电子表格时,意外地复制了行,或者如果你正在制作其他几个电子表格的合成电子表格,你将遇到需要删除的重复行。这可能是一项非常无脑、重复、耗时的任务,但有几个技巧可以让它变得更简单。 删除重复项 …

Android Canvas画布saveLayer与对应restoreToCount,Kotlin

Android Canvas画布saveLayer与对应restoreToCount,Kotlin private fun mydraw() {val originBmp BitmapFactory.decodeResource(resources, R.mipmap.pic).copy(Bitmap.Config.ARGB_8888, true)val newBmp Bitmap.createBitmap(originBmp.width, originBmp.heigh…

【Win10安装Qt6.3】安装教程_保姆级

前言 Windows系统安装Qt4及Qt5.12之前版本和安装Qt.12之后及Qt6方法是不同的 ;因为之前的版本提供的有安装包,直接一路点击Next就Ok了。但Qt5.12版本之后,Qt公司就不再提供安装包了,不论是社区版,专业版等&#xff0c…

并发控制工具类CountDownLatch、CyclicBarrier、Semaphore

并发控制工具类CountDownLatch、CyclicBarrier、Semaphore 1.CountDownLatch 可以使一个或多个线程等待其他线程各自执行完毕后再执行。 CountDownLatch 是多线程控制的一种工具,它被称为 门阀、 计数器或者闭锁。这个工具经常用来用来协调多个线程之间的同步&…

项目从vue2 升级vue3,项目大迁移 ,UI组件库更换

目录 背景描述 开发准备 第一步:升级环境 第二步:划分功能迁移顺序 第三步:详细了解需要迁移的业务页面 第四步:项目的一些配置的准备 详细开发流程 总结/分析: 背景描述 之前的版本:vue 2.6.8 i…

【PHY6222】绑定详解

1.函数详解 bStatus_t GAPBondMgr_SetParameter( uint16 param, uint8 len, void* pValue ) 设置绑定参数。 bStatus_t GAPBondMgr_GetParameter( uint16 param, void* pValue ) 获取绑定参数。 param: GAPBOND_PAIRING_MODE,配对模式,…

【postgres】8、Range 类型

文章目录 8.17 Range 类型8.17.1 内置类型8.17.2 示例8.17.3 开闭区间8.17.4 无穷区间 https://www.postgresql.org/docs/current/rangetypes.html 8.17 Range 类型 Range 类型,可以描述一个数据区间,有明确的子类型,而且子类型应该能被排序…

计算机网络——数据链路层(三)

前言: 前面我们已经对计算机网络的物理层有了一个大概的了解,今天我们学习的是物理层服务的上一层数据链路层,位于物理层和网络层之间。数据链路层在物理层提供的服务的基础上向网络层提供服务,其最基本的服务是将源自物理层来的数据可靠地传…

Mac使用Vmware Fusion虚拟机配置静态ip地址

一、设置虚拟机的网络为NAT 二、修改虚拟机的网络适配器网络 1、查看虚拟机的网卡 cd /etc/sysconfig/network-scripts#有些系统显示的是ens33,ens160等等 #不同的系统和版本,都会有所不同 #Centos8中默认是ens160,在RedHat/Centos7则为ens33 2、查看网…

Java语法---使用sort进行排序

目录 一、升序 二、降序 (1)类实现接口 (2)匿名内部类 三、自定义排序规则 四、集合中的sort排序 (1)升序 (2)降序 (3)自定义排序 一、升序 升序排…

C++内存管理和模板初阶

C/C内存分布 请看代码: int globalVar 1; static int staticGlobalVar 1; void Test() {static int staticVar 1;int localVar 1;int num1[10] { 1, 2, 3, 4 };char char2[] "abcd";const char* pChar3 "abcd";int* ptr1 (int*)mallo…

OpenHarmony之内核层解析~

OpenHarmony简介 技术架构 OpenHarmony整体遵从分层设计,从下向上依次为:内核层、系统服务层、框架层和应用层。系统功能按照“系统 > 子系统 > 组件”逐级展开,在多设备部署场景下,支持根据实际需求裁剪某些非必要的组件…

一周工作问题总结(2023.12.18-2023.12.22)

一周工作问题总结 1. 接口调用频率2. 汉字在数据库中占用字节问题3. Map在循环中修改自己的key与value4. Group BY5.递归同步数据6.代码移动Idea飘红 1. 接口调用频率 供应商给的接口可以每秒调用5-10次,那么我为了保险每秒调用5次,为了防止过度调用接口…

BUG记录——drawio出现“非绘图文件 (error on line 7355 at column 83: AttValue: ‘ expected)”

BUG现象 drawio出现“非绘图文件 (error on line 7355 at column 83: AttValue: ’ expected)”,如下图: 解决办法 这只是我自己摸索到的解决办法并不一定适用于所以人,对我是适用的。 首先用记事本打开损坏的drawio文件,如下 …

mathtype公式章节编号

1. word每章标题后插入章节符 如果插入后显示章节符,需要进行隐藏 开始->样式->MTEquationSection->修改样式->字体,勾选隐藏 2. 设置mathtype公式编号格式 插入编号->格式化->设置格式

什么是动态代理?

目录 一、为什么需要代理? 二、代理长什么样? 三、Java通过什么来保证代理的样子? 四、动态代理实现案例 五、动态代理在SpringBoot中的应用 导入依赖 数据库表设计 OperateLogEntity实体类 OperateLog枚举 RecordLog注解 上下文相…

SpringMVC基础知识(持续更新中~)

笔记: https://gitee.com/zhengguangqq/ssm-md/blob/master/ssm%20md%E6%A0%BC%E5%BC%8F%E7%AC%94%E8%AE%B0/%E4%B8%89%E3%80%81SpringMVC.md 细节补充:

深度学习 | 梯度下降算法及其变体

一、最优化与深度学习 1.1、训练误差与泛化误差 1.2、经验风险 1.3、优化中的挑战 1.3.1、局部最小值 1.3.2、 鞍点 经常是由于模型复杂度过高或者训练样本数据过少造成的 —— Overfitting 1.3.3、悬崖 1.3.4、长期依赖问题 二、损失函数 2.1、损失函数的起源 损失函数(loss…

041_小驰私房菜_MTK平台添加支持通过原生Camera API接口调用UsbCamera

平台:MTK 问题:通过调用Android Camera API去调用UsbCamera,需要做哪些修改? Google官方文档,关于usbcamera的支持: 外接 USB 摄像头 | Android 开源项目 | Android Open Source Project 相关修改内容如下: 一、MTK平台支持通过标准接口打开USB Camera 1)device相…

每日一题——轮转数组

1. 题目描述 给定一个整数数组nums,将数组中的元素向右轮转k个位置,其中k是非负数。 示例1: 输入:nums [1,2,3,4,5,6,7],k 3 输出:[5,6,7,1,2,3,4] 解释: 向右轮转 1步:[7,1,2,3,4,5,6] 向右…