详解FedAvg:联邦学习的开山之作

FedAvg:2017年 开山之作

论文地址:https://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf
源码地址:https://github.com/shaoxiongji/federated-learning
针对的问题:移动设备中有大量的数据,但显然我们不能收集这些数据到云端以进行集中训练,所以引入了一种分布式的机器学习方法,即联邦学习Federal Learning。在FL中,server将全局模型下放给各client,client利用本地的数据去训练模型,并将训练后的权重上传到server,从而实现全局模型的更新。
论文贡献

  1. 提出了联邦学习这个研究方向,简单来说就是从分散的存储于各设备的数据中训练模型;
  2. 提出了FedAvg算法;
  3. 通过实验验证了FedAvg的可靠性;

总结一下就是,本文提出了FedAvg算法,这种算法融合了client上的局部随机梯度下降和server上的模型平均。作者用该算法做了不少实验,结果表明FedAvg对于unbalanced且non-iid的数据有很好的鲁棒性,并且使得在非数据中心存储的数据上进行深度网络训练所需的通信轮次减少了好几个数量级。
算法介绍

  1. 联邦随机梯度下降算法FedSGD

设定固定的学习率η,对K个client的数据计算损失梯度:
g k = ▽ F k ( w t ) g_k=\bigtriangledown F_k(w_t) gk=Fk(wt)
server将聚合每个服务器计算的梯度,以此来更新模型参数:
w t + 1 ← w t − η ∑ k = 1 K n k n g k = w t − η ▽ f ( w t ) w_{t+1}\leftarrow w_t-\eta\sum\limits_{k=1}^K\frac{n_k}{n}g_k=w_t-\eta\bigtriangledown f(w_t) wt+1wtηk=1Knnkgk=wtηf(wt)

  1. 联邦平均算法FedAvg:

在client进行局部模型的更新:
w t + 1 k ← w t − η g k w_{t+1}^k\leftarrow w_t-\eta g_k wt+1kwtηgk
server对每个client更新后的权重进行加权平均:
w t + 1 ← ∑ k = 1 K n k n w t + 1 k w_{t+1}\leftarrow \sum_{k=1}^K \frac{n_k}{n}w_{t+1}^k wt+1k=1Knnkwt+1k
注意,在这里每个client可以在本地独立地多次更新本地权重,然后将更好的权重参数发给server进行加权平均。这样做的好处是不用每更新一次就去聚合,这大大减少了通信量。
FedAvg的计算量与3个参数有关:

  • C:每轮训练选择client的比例,每一轮通信时只选择C*K个client;(K为client总数)
  • E:每个client更新本地权重时,在本地数据集上训练E轮;
  • B:client更新权重时,每次梯度下降所使用的数据量,即本地数据集的batch size;

对于一个拥有 n k n_k nk个数据样本的client,每轮通信本地参数的更新次数为:
u k = E × n k B u_k=E\times\frac{n_k}{B} uk=E×Bnk
所以我们可知,FedSGD只是FedAvg的一个特例,即当参数 E = 1 , B = ∞ E=1,B=\infty E=1B=时,FedAvg等价于FedSGD。注: B = ∞ B=\infty B=意味着batch size大小就是本地数据集大小。
下面为FedAvg的算法流程图:
FedAvg算法流程图
实验设计与实现
Q1:在训练伊始,需不需要对模型进行统一初始化?
image.png
可见,采用不同的初始化参数进行模型平均,模型性能比两个父模型都差(左图);而统一初始化后,对模型的平均可以显著减少整个训练集的loss,模型性能优于两个父模型(右图)。
该结论是实现FL的重要支持,在每一轮通信时,server有必要发布全局模型,使各client采用相同的参数在本地数据集上进行训练,可以有效减少loss。
Q2:数据集怎么设置?
原文中主要研究了MNIST数据集和一个莎士比亚作品集构建的数据集,但我们在这里主要关注MNIST数据集和Cifar-10数据集,这两个数据集也是以后FL领域工作最常用的。
在模型选择方面,作者选择了多层感知机MLP和卷积神经网络CNN。
在数据集划分方面,作者假设有100个client,对于MNIST数据集,进行了iid和non-iid两种划分:

  • MNIST-iid:数据随机打乱分给100个client,每个client得到600个样例;
  • MNIST-non-iid:按数字label将数据集划分为200个大小为300的碎片,每个client两个碎片,意味着每个client至多只能获得两种label的样例;

对于Cifar-10数据集,做了iid划分。
Q3:实验咋做的?
作者指出,相比于传统模式下训练模型时计算开销为主通信开销较小的情况,在FL中,通信开销才是大头,因此减少通信开销才是我们需要关注的,作者提出可以通过加大计算以减少训练模型所需的通信轮数。作者提出主要有两种方法:提高并行度、增加每个client的计算量
而FedAvg的计算量在前面我们也给出过,再来看一下:
u k = E × n k B u_k=E\times\frac{n_k}{B} uk=E×Bnk
提高并行度:固定参数E,对C和B进行讨论。注:此处C=0时,算法也会选择一个client参与,详见上面的算法流程图。
2NN测试集acc 97%,CNN测试集acc 99%所需的通信轮数

  • B = ∞ B=\infty B=时,增加client的比例C,效果提升的优势较小;
  • B = 10 B=10 B=10时,效果显著改善了,特别是在non-iid情况下;
  • B = 10 , C ≥ 10 B=10,C\geq10 B=10,C10时,收敛速度明显改进,当client到一定数量后,收敛速度增加也不明显了;

增加每个client的计算量:根据公式,可以通过增加E或者减小B实现。
对测试集到达期望acc所需的通信轮数

  • 每个通信轮次内增加更多的本地SGD可以显著降低通信成本;
  • 对于unbalanced-non-iid的莎士比亚数据集减少的通信轮数更多,推测可能某些client有相对较大的本地数据集,这种情况下增加了本地训练的价值;

Q4:FedAvg VS FedSGD?
image.png
蓝色实现即为FedSGD。由图可知,FedAvg相比FedSGD不仅降低通信轮数,还具有更高的测试精度。推测是平均模型产生了类似Dropout的正则化效益。
Q5:加大每个client的计算量会不会导致过拟合?
image.png
加大每个client的计算量(主要体现在加大E),确实可能导致训练损失停滞或发散。所以在实际应用时,在训练后期减少各client的E,或者在loss有震荡的苗头时即刻停止,这样做有助于收敛。
Q6:在Cifar-10数据集上的表现如何?
如下图所示:
image.png
image.png
针对第一张图的一点吐槽,你去拿分布式深度学习去pk单机上的深度学习,去比通信轮数,这不是太不公平了。。。
总结展望
作者证明了FL在实践中是可行的,能够用相对较少的通信轮数训练出高质量的模型。并且提出未来的一个方向就是通过差分隐私、安全多方技术等隐私保护技术去组合FL以提供隐私保护。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/25289.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

minio的一个基础使用案例:用户头像上传

文章目录 一、minio下载安装(Windows)二、案例需求分析三、后端接口开发 一、minio下载安装(Windows) 1. 下载minio服务端和客户端 minio下载地址 2. 手动搭建目录 /minio/binmc.exeminio.exe/data/logs手动创建minio应用程序目…

vivado HW_DEVICE

硬件设备 描述 在Vivado Design Suite的硬件管理器功能中,每个硬件目标都可以 具有一个或多个Xilinx FPGA设备进行编程或用于调试目的。这个 hw_device对象是通过hw_server打开的hw_target上的物理部分。这个 current_hw_device命令指定或返回当前设备。 相关对象 硬…

Linux—小小内核升级

本篇主要是讲述下关于内核的一些基本常识,并记录下内核升级和编译的过程,若有遗漏/有误之处,望各位大佬们指出。 Ⅰ 基本内核常识 常见内核安装包 内核(kernel):这是Linux操作系统的核心部分,它负责管理系统的硬件和…

Vue3学习第二天记录

Vue3学习第二天记录 背景说明截图记录一个简单的JS文件Vue3的watch()函数Vue3的toRef()/toRefs()函数前端数据类型的分类前端写一个对外暴露的函数前端的...语法Vue3中watch()函数的总结Vue3中watchEffect()函数Vue3中watch()函数的坑Vue3中computed()函数 背景 最近在学习尚硅…

Vue2入门(安装Vue、devtools,创建Vue)以及MVVM分层思想

文章目录 1.下载并安装Vue2.使用Vue2.1 创建Vue以及挂载Vue2.2 模板语句的数据来源:data2.3 template配置项详解2.4 Vue实例和容器的关系 3.安装devtools4.MVVM分层思想5.通过vm可以访问哪些属性 1.下载并安装Vue (1)Vue是一个基于JavaScrip…

搭建高可用k8s

高可用只针对于api-server,需要用到nginx keepalived,nginx提供4层负载,keepalived提供vip(虚拟IP) 系统采用openEuler 22.03 LTS 1. 前期准备 因为机器内存只有16G,所有我采用3master 1node 1.1 修改主机配置(所有节…

扩散模型会成为深度学习的下一个前沿领域吗?

文章目录 一、说明二、 第 1 部分:了解扩散模型2.1 什么是扩散模型2.2 正向扩散2.3 反向扩散 三、他们的高成本四、扩散模型的用处五、为什么扩散模型如此出色六、第 2 部分:使用扩散模型生成6.1 用于自然语言处理和 LLM 的文本扩散6.2 音频视频生成6.3 …

下载安装Thonny并烧录MicroPython固件至ESP32

Thonny介绍 一、Thonny的基本特点 面向初学者:Thonny的设计初衷是为了帮助Python初学者更轻松、更快速地入门编程。它提供了直观易懂的用户界面和丰富的功能,降低了编程的门槛。轻量级:作为一款轻量级的IDE,Thonny不会占用过多的…

RDK X3(aarch64) 测试激光雷达思岚A1

0. 环境 - 亚博智能的ROSMASTER-X3 - RDK X3 1.0 0.1 资料 文档资料 https://www.slamtec.com/cn/Support#rplidar-a-series SDK https://github.com/slamtec/rplidar_sdk ROS https://github.com/slamtec/rplidar_ros https://github.com/Slamtec/sllidar_ros2 1. robostu…

1310. 子数组异或查询 异或 前缀和 python

有一个正整数数组 arr,现给你一个对应的查询数组 queries,其中 queries[i] [Li, Ri]。 对于每个查询 i,请你计算从 Li 到 Ri 的 XOR 值(即 arr[Li] xor arr[Li1] xor ... xor arr[Ri])作为本次查询的结果。 并返回一…

《精通ChatGPT:从入门到大师的Prompt指南》附录B:推荐阅读资源

作者:斯图尔特拉塞尔 (Stuart Russell) 和 彼得诺维格 (Peter Norvig) 简介:这本书被誉为人工智能领域的经典教材,内容涵盖了AI的基本原理、算法及其应用。无论是入门者还是专业研究者,都能从中获得启发。 2. 《深度学习》 作者…

【目标跟踪网络训练 Market-1501 数据集】DeepSort 训练自己的跟踪网络模型

前言 Deepsort之所以可以大量避免IDSwitch,是因为Deepsort算法中特征提取网络可以将目标检测框中的特征提取出来并保存,在目标被遮挡后又从新出现后,利用前后的特征对比可以将遮挡的后又出现的目标和遮挡之前的追踪的目标重新找到&#xff0…

企业网页制作

随着互联网的普及,企业网站已成为企业展示自己形象、吸引潜在客户、开拓新市场的重要方式。而企业网页制作则是构建企业网站的基础工作,它的质量和效率对于企业网站的成败至关重要。 首先,企业网页制作需要根据企业的特点和需求进行规划。在网…

前端 移动端 手机调试 (超简单,超有效 !)

背景:webpack工具构建下的vue项目 1. 找出电脑的ipv4地址 2. 替换 host 3. 手机连接电脑热点或者同一个wifi 。浏览器打开链接即可。

Spring运维之业务层测试数据回滚以及设置测试的随机用例

业务层测试数据回滚 我们之前在写dao层 测试的时候 如果执行到这边的代码 会在数据库 里面留下数据 运行一次留一次数据 开发有开发数据库,运行有运行数据库 我们先连数据库 在pom文件里引入mysql的驱动和mybatis-plus的依赖 在数据层写接口 用mybatis-plus进…

openh264 场景变化检测算法源码分析

文件位置 openh264/codec/processing/scenechangedetection/SceneChangeDetection.cppopenh264/codec/processing/scenechangedetection/SceneChangeDetection.h 代码流程 说明: 通过代码流程分析,当METHOD_SCENE_CHANGE_DETECTION_SCREEN场景类型为时…

Linux -- 了解 vim

目录 vim Linux 怎么编写代码? 了解 vim 的模式 什么是命令模式? 命令模式下 vim 的快捷键: 光标定位: 复制粘贴: 删除及撤销: 注释代码: 什么是底行模式? ​编辑 ​编辑…

Java:111-SpringMVC的底层原理(中篇)

这里续写上一章博客(110章博客): 现在我们来学习一下高级的技术,前面的mvc知识,我们基本可以在67章博客及其后面相关的博客可以学习到,现在开始学习精髓: Spring MVC 高级技术: …

Large-Scale LiDAR Consistent Mapping usingHierarchical LiDAR Bundle Adjustment

1. 代码地址 GitHub - hku-mars/HBA: [RAL 2023] A globally consistent LiDAR map optimization module 2. 摘要 重建精确一致的大规模激光雷达点云地图对于机器人应用至关重要。现有的基于位姿图优化的解决方案,尽管它在时间方面是有效的,但不能直接…

ubuntu使用docker安装openwrt

系统:ubuntu24.04 架构:x86 1. 安装docker 1.1 离线安装 docker下载地址 根据系统版本,依次下载最新的三个关于docker的软件包 container.io(注意后缀版本顺序)docker-ce-clidocker-ce 然后再ubuntu系统中依次按顺…