一、引言
拆分学习是 2018 年由 MIT 最先提出的分布式算法。本文结合该领域的相关英文文献,介绍水平拆分学习的基本方法,同时还将对比拆分模型与中心化模型、联邦模型在不同条件下模型效率和准确性。拆分学习作为主流的隐私计算学习范式之一,也被普遍应用于构建隐私保护机器学习算法。
二、基本方法
1.1 核心思想
拆分学习将 NN 模型拆分成两部分,client 利用本地数据计算底层模型,得到隐层并传输给 server,server 继续计算上层模型,如图 1 所示【1】。
图1 拆分学习示意图
针对 client 数据水平切分场景下的拆分学习方法,主要分为三种:Centralized 拆分学习、P2P 拆分学习和 U-shape 拆分学习。
1.2 Centralized 拆分学习
图2 Centralized 拆分学习模型
(1)算法
如图 2 所示【2】,Alice 为 client, Bob 为 server。client 和 server 模型首先进行初始化。
-
client i 从 server 获取 client 端的密态模型参数,解密,更新 client 模型。
-
client i 进行前向传播,计算得到隐层,并将隐层 h 和真实标签 y 传给 server
-
server 得到 client i 的隐层 h 和 y,继续前向传播,得到 label 预测值 y_pred,进而得到 Loss(y, y_pred)。
-
server 进行模型的后向传播,更新 server 端的模型,进而得到 Loss 关于隐层的梯度 G,并将 G 传给 client。
-
client 利用梯度 G 继续后向传播,更新 client 端本地模型,client 将本地模型加密传给 server。
-
剩余参与训练的 client 依次进行步上述步骤。
(2)特点
-
client 异步更新,无法同步更新;
-
client 每次训练前需要从 server 获取密态的 client 模型;
-
server 得到样本的 label 和密态 client 模型(有隐私泄漏的风险)。
1.3 P2P 拆分学习
图3 Peer to peer 拆分学习
(1)算法
如图3所示【2】,client i 进行前向传播,计算得到隐层,并将隐层 h 和真实标签 y 传给server。
-
server 得到 client i的隐层 h 和 y,继续前向传播,得到 label 预测值 y_pred,进而得到 Loss(y, y_pred);
-
server 进行模型的后向传播,更新 server 端的模型,进而得到 Loss 关于隐层的梯度G,并将 G 传给 client;
-
client 利用梯度 G 继续后向传播,更新 client 端本地模型,client 将本地模型传给下一个 client;
-
下一个 client 依次进行上述步骤。
(2)特点
-
client 依次进行训练更新。
-
server 得到样本的 label。
-
lient 每次训练前需要从上一个 client 获取最新的 client 模型(client 掉线问题)。
1.4 U-shape 拆分学习
图4 U-shape 拆分学习
(1)算法
如图 4 所示【1】,模型依次拆成三部分:submodel-1,submodel-2(大部分计算),submodel-3(loss computing),其中 submodel-1 和 submodel-3 在 client 端进行,submodel-2 在 server 端进行。以 U-shape centralized 拆分学习为例:
-
client i 从 server 获取 client 端的密态 submodel-1 和 submodel-3 的模型参数,解密,更新 client 本地模型。
-
client i 进行前向传播,计算得到隐层,并将隐层 h1 传给 server。
-
server 得到 client i 的隐层 h,继续 submodel-2 前向传播,得到隐层 h2,传给 client。
-
client 得到 h1,继续 submodel-3 的前向传播,得到 y_pred,结合 client 的真实 label y 计算得到 loss。
-
client 和 server 进行模型的后向传播,更新模型。
-
client 将本地模型 submodel-1 和 submodel-3 加密传给 server。
-
剩余参与训练的 client 依次进行上述步骤。
(2)特点
- 相比于前两个模型,server 无法得到样本的 label。
三、实验结果
3.1 拆分学习 VS 单机模型
论文【2】中对比了拆分学习和单机模型的 Accuracy,其中在拆分学习中共有 10 个 clients,得到如下表所示的实验结果。
实验结论:拆分学习可以对齐单机模型的 Accuracy【2】。
3.2 拆分学习VS联邦学习
论文【2】中对比了相同 client-side flops 和 communication cost 下拆分学习和联邦学习的 performance。
论文【3】中对比了多 clients 条件和 Non-IID 数据分布下的拆分学习和联邦学习的 performance。
(1)Performance with the same client-side flops
结论:相同计算量的情况下,拆分学习的收敛速度及 Accuracy 优于联邦学习和 Large scale SGD。
(2)Performance with the same communication cost
结论:相同通信量的情况下,拆分学习收敛速度及 Accuracy 优于联邦学习和 Large scale SGD。
(3)Performance with the different clients’ number
结论:当 clients 数量变多时,模型性能有明显的震荡。
(4)Performance in the Non-IID setting
结论:拆分学习在 Non-IID 下性能比联邦学习差,甚至不收敛。
四、Reference
【1】Thapa C, Chamikara M A P, Camtepe S A. Advancements of federated learning towards privacy preservation: from federated learning to split learning[M]//Federated Learning Systems. Springer, Cham, 2021: 79-109.
【2】Gupta O, Raskar R. Distributed learning of deep neural network over multiple agents[J]. Journal of Network and Computer Applications, 2018, 116: 1-8.
【3】Gao Y, Kim M, Abuadbba S, et al. End-to-end evaluation of federated learning and split learning for Internet of Things[J]. arXiv preprint arXiv:2003.13376, 2020.