联邦学习Federated Learning
- 序言
- FL流程细节
- FL代码实现(Pytorch)
- Reference
序言
手机的数据涉及到个人隐私和安全,如果将客户端的数据上传到服务端,终究是很容易泄漏出用户的信息,何况 用户也不愿意把自己的数据交给服务端进行训练。
所以2016 年,谷歌最早提出了联邦学习(Federated Learning),通过让用户数据不出本地,而在设备本地进行模型训练。 目的是保障大数据交换时的信息安全、保护终端数据和个人数据隐私以及合规的前提下,在多参与方或多计算结点之间开展高效率的机器学习。
所以, 联邦学习本质上是一种分布式机器学习技术,或机器学习框架。
目标:联邦学习的目标是在保证数据隐私安全及合法合规的基础上,实现共同建模,提升AI模型的效果。
哈哈哈哈哈哈,一个公式总结就是: 联邦学习=分布式+加密技术
FL流程细节
A节点为用户设备,数据保存其中,并且不上传到其他节点
B节点是来自不同用户设备的模型参数聚合
C节点可以认为是一个统一的中央服务器(或集群),用来统一更新模型
大体流程是:
首先C初始化一个空模型
C将当前的模型参数传输到A
A设备中根据数据,计算模型参数(或梯度等)
A设备将更新后的模型参数(或梯度)上传
B对来自不同设备的模型参数(或梯度)聚合,例如简单取平均值
C更新根据B聚合后的结果,更新模型
回到第2.步,循环往复
再具体的细节设计到两个重要部分
第一: Client Trainer
第二:Server Trainer
其中FederatedAveraging 用来聚合所有客户端上传的参数,FederatedAveraging算法总共有三个基本的参数
- 参数C(0~1) 控制有多少比例的的客户端参与优化
- 参数E控制每一轮多少轮SGD需要在客户端运行
- 参数B是每一轮的Mini-Batch的数目大小
那么假设总共有K个客户端:
- 每一轮都随机选择出max(CK,1)个的客户端
- 对于每个客户端进行Mini-Batch的大小为B,轮数为E的SGD更新
- 对于参数直接进行加权平均(这里的权重是每个客户端的数据相对大小)
之前有其他研究表明,如何直接对参数空间进行加权平均,特别是Non-Convex的问题,会得到任意坏的结果。
而在FL的论文里,作者们对于这样的问题的处理是,让每一轮的各个客户端的起始参数值相同(也就是前一轮的全局参数值)。这一步使得算法效果大幅度提高。👍
FL代码实现(Pytorch)
Check this Github
Reference
- https://arxiv.org/pdf/1602.05629.pdf
- https://augint.tech/index.php?title=Federated_Learning
- https://daiwk.github.io/posts/dl-federated-learning.html