联邦学习FedAvg-基于去中心化数据的深度网络高效通信学习

        随着计算机算力的提升,机器学习作为海量数据的分析处理技术,已经广泛服务于人类社会。 然而,机器学习技术的发展过程中面临两大挑战:一是数据安全难以得到保障,隐私泄露问题亟待解决;二是网络安全隔离和行业隐私,不同行业部门之间存在数据壁垒,导致数据形成“孤岛”无法安全共享,而仅凭各部门独立数据训练的机器学习模型性能无法达到全局最优化。为解决上述问题,谷歌提出了联邦学习(FL,federated learning)技术。

        本文主要对联邦学习的开山之作《Communication-Efficient Learning of Deep Networks from Decentralized Data》 进行重点内容的解读与整理总结。

论文链接:Communication-Efficient Learning of Deep Networks from Decentralized Data

源码实现:https://gitcode.net/mirrors/WHDY/fedavg?utm_source=csdn_github_accelerator 

目录

摘要

1. 介绍

1.1 问题来源

1.2 本文贡献

1.3 联邦学习特性

1.4 联邦优化

1.5 相关工作

1.6 联邦学习框架图

2. 算法介绍

2.1 联邦随机梯度下降(FedSGD)

2.2 联邦平均算法(FedAvg)

3. 实验设计与实现

3.1 模型初始化

3.2 数据集的设置

3.2.1 MNIST数据集

3.2.2 莎士比亚作品集

3.3 实验优化

3.3.1 增加并行性

3.3.2 增加客户端计算量

 3.4 探究客户端数据集的过度优化

3.5 CIFAR实验

3.6 大规模LSTM实验

4. 总结展望

 摘要

现代移动设备拥有大量的适合模型学习的数据,基于这些数据训练得到的模型可以极大地提升用户体验。例如,语言模型能提升语音设别的准确率和文本输入的效率,图像模型能自动筛选好的照片。然而,移动设备拥有的丰富的数据经常具有关于用户的敏感的隐私信息且多个移动设备所存储的数据总量很大,这样一来,不适合将各个移动设备的数据上传到数据中心,然后使用传统的方法进行模型训练。作者提出了一个替代方法,这种方法可以基于分布在各个设备上的数据(无需上传到数据中心),然后通过局部计算的更新值进行聚合来学习到一个共享模型。作者定义这种非中心化方法为“联邦学习”。作者针对深度网络的联邦学习任务提出了一种实用方法,这种方法在学习过程中多次对模型进行平均。同时,作者使用了五种不同的模型和四个数据集对这种方法进行了实验验证。实验结果表明,这种方法面对不平衡以及非独立同分布的数据,具有较好的鲁棒性。在这种方法中,通信所产生的资源开销是主要的瓶颈,实验结果表明,与同步随机梯度下降相比,该方法的通信轮次减少了10-100倍。

1 介绍

1.1 问题来源

        移动设备中有大量数据适合机器学习任务,利用这些数据反过来可以改善用户体验。例如图像识别模型可以帮助用户挑选好的照片。但是这些数据具有高度私密性,并且数据量大,所以我们不可能把这些数据拿到云端服务器进行集中训练。论文提出了一种分布式机器学习方法称为联邦学习(Federal Learning),在该框架中,服务器将全局模型下发给客户,客户端利用本地数据集进行训练,并将训练后的权重上传到服务器,从而实现全局模型的更新。

1.2 本文贡献

  • 提出了从分散的存储于各个移动设备的数据中训练模型是一个重要的研究方向
  • 提出了一个简单实用的算法来解决这种在非中心化设置下的学习问题
  • 做了大量实验来评估所提算法

        具体来说,本文介绍了“联邦平均”算法,这种算法融合了客户端上的局部随机梯度下降计算与服务器上的模型平均。作者使用该算法进行了大量实验,结果表明了这种算法对于不平衡且非独立同分布的数据具有很好的鲁棒性,并且使得在非中心存储的数据上进行深度网络训练所需的通信轮次减少了几个数量级。

1.3 联邦学习特性

  • 从多个移动设备中存储的真实数据中进行模型训练比从存储在数据中心的数据中进行模型训练更具优势
  • 由于数据具有隐私,且多个移动设备所存储的数据总量很大,因此不适合将其上传至数据中心再进行模型训练
  • 对于监督学习任务,数据中的标签信息可以从用户与应用程序的交互中推断出来

1.4 联邦优化

        传统分布式学习关注点在于如何将一个大型神经网络训练分布式进行,数据仍然可能是在几个大的训练中心存储。而联邦学习更关注数据本身,利用联邦学习保证了数据不出本地,并根据数据的特点,对学习模型进行改进。相比于典型的分布式优化问题,联邦优化具有几个关键特性:

  • Non-IID:数据的特征和分布在不同参与方间存在差异
  • Unbalanced:一些用户会更多地使用服务或应用程序,导致本地训练数据量存在差
  • Massively distributed:参与优化的用户数>>平均每个用户的数据量
  • Limited communication:无法保证客户端和服务器端的高效通信

 本文重点关注优化任务中非独立同分布和不平衡问题,以及通信受限的临界属性。

注:独立同分布假设(IID)

        非凸神经网络的目标函数:

对于一个机器学习的问题来说,有,即用模型参数w预测实例的损失。

        设有K个client,第k个client的数据点为P_{k},对应的数据集数量为n_{k}=\left | P_{k} \right |上式可写为:

P_{k}上的数据集是随机均匀采样的,称IID设置,此时有:

不成立则称Non-IID。 

1.5 相关工作

        相关工作中,2010年通过迭代平均本地训练的模型来对感知机进行分布式训练,2015年研究了语音识别深度神经网络的分布式训练,在2015论文里研究了使用“软”平均的异步训练方法。这些工作都考虑的是数据中心化背景下的分布式训练,没有考虑具有数据不平衡且非独立同分布特点的联邦学习任务。但是它们提供了一种思路,即通过迭代平均本地训练模型的算法来解决联邦学习的问题。与本文的研究动机相似在这篇论文中讨论了保护设备中的用户数据的隐私的优点。而在这篇论文中,作者关注于训练深度网络,强调隐私的重要性以及通过在每一轮通信中仅共享一部分参数,进而降低通信开销;但是,他们也没有考虑数据的不平衡以及非独立同分布性,并且他们的研究工作缺乏实验评估。

1.6 联邦学习框架图

2 算法介绍

2.1 联邦随机梯度下降(FedSGD)

设置固定的学习率η,对K个客户端的数据计算其损失梯度:

中心服务器聚合每个客户端计算的梯度,以此来更新模型参数:

其中,

2.2 联邦平均算法(FedAvg)

在客户端进行局部模型的更新:

中心服务器对每个客户端更新后的参数进行加权平均:

每个客户端可以独立地更新模型参数多次,然后再将更新好的参数发送给中心服务器进行加权平均:

FedAvg的计算量与三个参数有关:

  • C:每轮训练选择客户端的比例
  • E:每个客户端更新参数的循环次数所设计的一个因子
  • B:客户端更新参数时,每次梯度下降所使用的数据量

对于一个拥有n_{k}个数据样本的客户端,每轮本地参数更新的次数为:

注:FedSGD只是FedAvg的一个特例,即当参数E=1,B=∞时,FedAvg等价于FedSGD。
 
FedSGD和FedAvg的关系示意图:
地址:https://blog.csdn.net/biongbiongdou/article/details/104358321

3 实验设计与实现

3.1 模型初始化

实验设置
  • 数据集:MNIST中600个无重复的独立同分布样本
  • E=20; C=1; B=50; 中心服务器聚合一次
  • 不同模型使用不同/相同的初始化模型,并通过θ对两模型参数进行加权求和
       

研究模型平均对模型效果的影响:

        这里有两种情况,一种是不同模型使用不同的初始化模型;一种是不同模型使用相同的初始化模型。并且可以通过参数控制权重比进行模型的加权求和。

        可看到,采用不同的初始化参数进行模型平均后,平均模型的效果变差,模型性能比两个父模型都差;采用相同的初始化参数进行模型平均后,对模型的平均可以显著的减少整个训练集的损失,模型性能优于两个父模型。

        该结论是用于实现联邦学习的重要支撑,在每一轮训练时,server发布全局模型,使各个client采用相同的参数模型进行训练,可以有效的减少训练集的损失。

3.2 数据集的设置

        初步研究包括两个数据集三个模型族,前两个模型用于识别MNIST数据集,后一个用于实现莎士比亚作品集单词预测。

3.2.1 MNIST数据集

2NN:拥有两个隐藏层,每层200个神经元的多层感知机模型,ReLu激活;

CNN:两个卷积核大小为5X5的卷积层(分别是32通道和64通道,每层后都有一个2X2的最大池化层);

IDD:数据随机打乱分给100个客户端,每个客户端600个样例;

Non-IDD:按数字标签将数据集划分为200个大小为300的碎片,每个客户端两个碎片;

  • 3.2.2 莎士比亚作品集

LSTM:将输入字符嵌入到一个已学习的8维空间中,然后通过两个LSTM层处理嵌入的字符,每层256个节点,最后,第二个LSTM层的输出被发送到每一个字符有一个节点的softmax输出层,使用unroll的80个字符长度进行训练;

Unbalanced-Non-IID:每个角色形成一个客户端,共1146个客户端;

Balanced-IID:直接将数据集划分给1146个客户端;

3.3 实验优化

        在数据中心存储的优化中,通信开销相对较小,计算开销占主导地位。而在联邦优化中,任何一个单一设备所具有的数据量较少,且现代移动设备有相对快的处理器所以这里更关注通信开销因此,我们想要使用额外的计算来减少训练模型所需通信的轮次主要有两个方法,分别是提高并行度以及增加每个客户端的计算量。

3.3.1 增加并行性

固定参数E,对C和B进行讨论。

  •  当B=∞时,增加客户端比例,效果提升的优势较小;
  • 当B=10时,有显著改善,特别是在Non-IID情况下;
  • 在B=10,当C≥0.1时,收敛速度有明显改进,当用户达到一定数量时,收敛增加的速度不再明显。

3.3.2 增加客户端计算量

对于增加每个客户端的计算量,可以通过减小B或者增加E来实现。

  • 每轮增加更多的本地SGD更新可以显著降低通信成本;
  • 对于Unbalanced-Non-IDD的莎士比亚数据减少通信轮数倍数更多,推测可能某些客户端有相对较大的本地数据集,使得增加本地训练更有价值;

 将上述实验结果用折线图的形式展示,这里蓝色线表示的是联邦随机梯度下降的结果:

  • FedAvg相比FedSGD不仅降低通信轮数,还具有更高的测试精度。推测是平均模型产生了类似Dropout的正则化效益; 

 3.4 探究客户端数据集的过度优化

        在E=5以及E=25的设置下,对于大的本地更新次数而言,联邦平均的训练损失会停滞或发散;因此在实际应用时,对于一些模型,在训练后期减少本地训练周期将有助于收敛。 

3.5 CIFAR实验

在CTFAR数据集上进行实验,模型是TensorFlow教程中的模型包括两个卷积层,两个全连接层和一个线性传输层,大约10^6个参数。下表给出了baselineSGD、FedSGD和FedAvg达到三种不同精度目标的通信轮数。

不同学习率下FedSGD和FedAvg的曲线:

3.6 大规模LSTM实验

 为了证明我们的方法对于解决实际问题的有效性,我们进行了一项大规模单词预测任务。

训练集包含来自大型社交网络的100万个公共帖子。我们根据作者对帖子进行分组,总共有超过50个客户端。我们将每个客户的数据集限制为最多5000个单词。模型是一个256节点的LSTM,其词汇量为10000个单词。每个单词的输入和输出嵌入为192维,并与模型共同训练;总共有4950544个参数,使用10个字符的unroll。

对于联邦平均和联邦随机梯度下降的最佳学习率曲线:

  • 相同准确率的情况下,FedAvg的通信轮数更少;测试精度方差更小;
  • E=1比E=5的表现效果更好; 

4 总结展望

         我们的实验表明,联邦学习可以在实践中实现,因为它可以使用相对较少的几轮通信来训练高质量的模型,这一点在各种模型体系结构上得到了证明:一个多层感知器、两个不同的卷积NNs、一个两层LSTM和一个大规模LSTM。虽然联邦学习提供了许多实用的隐私保护,但是通过差分隐私、安全多方计算提供了可以提供更有力的保障,或者他们的组合是未来工作的一个有趣方向。请注意,这两类技术最自然地应用于像FedAvg这样的同步算法。

参考文章:

https://blog.csdn.net/qq_41605740/article/details/124584939?spm=1001.2014.3001.5506

https://blog.csdn.net/weixin_45662974/article/details/119464191?spm=1001.2014.3001.5506 

https://zhuanlan.zhihu.com/p/515756280 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/65109.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp 配置网络请求并使用请求轮播图

由于平台的限制,小程序项目中不支持 axios,而且原生的 wx.request() API 功能较为简单,不支持拦截器等全局定制的功能。因此,建议在 uni-app 项目中使用 escook/request-miniprogram 第三方包发起网络数据请求。 官方文档&#xf…

【C++入门】命名空间、缺省参数、函数重载、引用、内联函数

​👻内容专栏: C/C编程 🐨本文概括: C入门学习必备语法 🐼本文作者: 阿四啊 🐸发布时间:2023.9.3 前言 C是在C的基础之上,容纳进去了面向对象编程思想,并增加…

大数据-玩转数据-Flink窗口函数

一、Flink窗口函数 前面指定了窗口的分配器, 接着我们需要来指定如何计算, 这事由window function来负责. 一旦窗口关闭, window function 去计算处理窗口中的每个元素. window function 可以是ReduceFunction,AggregateFunction,or ProcessWindowFunction中的任意一种. Reduc…

打包个七夕exe玩玩

前段时间七夕 当别的哥们都在酒店不要不要的时候 身为程序员的我 还在单位群收到收到 正好后来看到大佬些的这个 https://www.52pojie.cn/thread-1823963-1-1.html 这个贱 我必须要犯,可是我也不能直接给他装个python吧 多麻烦 就这几个弹窗 好low 加上bgm 再打包成…

Nexus仓库介绍以及maven deploy配置

一 、Nexus仓库介绍 首先介绍一下Nexus的四个仓库的结构: maven-central 代理仓库,代理了maven的中央仓库:https://repo1.maven.org/maven2/; maven-public 仓库组,另外三个仓库都归属于这个组,所以我们的…

贝叶斯神经网络 - 捕捉现实世界的不确定性

贝叶斯神经网络 - 捕捉现实世界的不确定性 Bayesian Neural Networks 生活本质上是不确定性和概率性的,贝叶斯神经网络 (BNN) 旨在捕获和量化这种不确定性 在许多现实世界的应用中,仅仅做出预测是不够的;您还想知道您对该预测的信心有多大。例…

第2章 Linux多进程开发 2.18 内存映射

内存映射:可以进行进程间的通信 1.如果对mmap的返回值(ptr)做操作(ptr), munmap是否能够成功? void * ptr mmap(…); ptr; 可以对其进行操作 munmap(ptr, len); // 错误,要保存地址 2.如果open时O_RDONLY, mmap时prot参数指定PROT_READ | PROT_WRITE会怎样? 错…

二进制安全虚拟机Protostar靶场 安装,基础知识讲解,破解STACK ZERO

简介 pwn是ctf比赛的方向之一,也是门槛最高的,学pwn前需要很多知识,这里建议先去在某宝上买一本汇编语言第四版,看完之后学一下python和c语言,python推荐看油管FreeCodeCamp的教程,c语言也是 pwn题目大部…

SpringBoot 使用MyBatis分页插件实现分页功能

SpringBoot 使用MyBatis分页插件实现分页功能 1、集成pagehelper2、配置pagehelper3、编写代码4、分页效果 案例地址&#xff1a; https://gitee.com/vinci99/paging-pagehelper-demo/tree/master 1、集成pagehelper <!-- 集成pagehelper --> <dependency><gr…

“亚马逊云科技创业加速器”首期聚焦AI,促进入营企业业务发展

生成式AI技术飞速发展&#xff0c;颠覆着人们的生活&#xff0c;正在掀起新一轮的科技革命。在生成式AI的浪潮中&#xff0c;亚马逊云科技旨在为中国的优秀初创企业提供全方位支持&#xff0c;助其抢占先机。 在6月底举办的亚马逊云科技中国峰会上&#xff0c;亚马逊云科技联合…

6. series对象及DataFrame对象知识总结

【目录】 文章目录 6. series对象及DataFrame对象知识总结1. 导入pandas库2. pd.Series创建Series对象2.1 data 列表2.2 data 字典 3. s1.index获取索引4. s1.value获取值5. pd.DataFrame()-创建DataFrame 对象5.1 data 列表5.2 data 嵌套列表5.3 data 字典 6. df[列索引]…

java对象创建的过程

1、检查指令的参数是否能在常量池中定位到一个类的符号引用 2、检查此符号引用代表的类是否已被加载、解析和初始化过。如果没有&#xff0c;就先执行相应的类加载过程 3、类加载检查通过后&#xff0c;接下来虚拟机将为新生对象分配内存。 4、内存分配完成之后&#xff0c;…

一句话画出动漫效果

链接&#xff1a; AI Comic Factory - a Hugging Face Space by jbilcke-hfDiscover amazing ML apps made by the communityhttps://huggingface.co/spaces/jbilcke-hf/ai-comic-factory 选择类型&#xff1a; Japanese 输入提示词&#xff1a; beauty and school love st…

12、监测数据采集物联网应用开发步骤(9.1)

监测数据采集物联网应用开发步骤(8.2) TCP/IP Server开发 在com.zxy.common.Com_Para.py中添加如下内容 #锁机制 lock threading.Lock() #本机服务端端口已被连接客户端socket list dServThreadList {} #作为服务端接收数据拦截器 ServerREFLECT_IN_CLASS "com.plug…

设计模式-装饰模式

文章目录 一、简介二、基本概念三、装饰模式的结构和实现类图解析&#xff1a;装饰器的实现方式继承实现&#xff1a;组合实现&#xff1a;继承和组合对比 四、装饰模式的应用场景五、与其他模式的关系六、总结 一、简介 装饰模式是一种结构型设计模式&#xff0c;它允许动态地…

msvcp120.dll丢失的解决方法?全面解决方法推荐

msvcp120.dll是Windows操作系统中的一个关键组件&#xff0c;如果丢失或损坏&#xff0c;可能会导致系统崩溃或无法正常运行。本文将介绍三种解决msvcp120.dll丢失问题的方法。 随着计算机应用的广泛普及&#xff0c;越来越多的人开始遇到各种电脑问题。其中&#xff0c;msvcp…

FPGA原理与结构——FIFO IP核的使用与测试

一、前言 本文介绍FIFO Generator v13.2 IP核的具体使用与例化&#xff0c;在学习一个IP核的使用之前&#xff0c;首先需要对于IP核的具体参数和原理有一个基本的了解&#xff0c;具体可以参考&#xff1a; FPGA原理与结构——FIFO IP核原理学习https://blog.csdn.net/apple_5…

算法通关村第十二关——字符串反转问题解析

前言 字符串反转是关于字符串算法里的重要问题&#xff0c;虽然不是太难&#xff0c;但需要考虑到一些边界问题。本篇文章就对几道字符串反转题目进行分析。 1.反转字符串 力扣344题&#xff0c;编写一个函数&#xff0c;其作用是将输入的字符串反转过来。输入字符串以字符数…

亚马逊云科技生成式AI技术辅助教学领域,近实时智能应答2D数字人搭建

早在大语言模型如GPT-3.5等的兴起和被日渐广泛的采用之前&#xff0c;教育行业已经在AI辅助教学领域有过各种各样的尝试。在教育行业&#xff0c;人工智能技术的采用帮助教育行业更好地实现教学目标&#xff0c;提高教学质量、学习效率、学习体验、学习成果。例如&#xff0c;人…

应用案例 | 基于三维机器视觉的机器人麻袋拆垛应用解决方案

​Part.1 项目背景 在现代物流和制造行业中&#xff0c;麻袋的拆垛操作是一个重要且频繁的任务。传统的麻袋拆垛工作通常由人工完成&#xff0c;分拣效率较低&#xff0c;人力成本较高&#xff0c;现场麻袋堆叠、变形严重&#xff0c;垛型不规则、不固定&#xff0c;严重影响分…