【从零开始实现联邦学习】

1. 环境配置如下

  • python3.7
  • pip install torch
  • pip install torchvision

2. 代码如下

原书的代码存在一点bug,现已被作者修复
Client端代码如下

import torch.utils.dataclass Client(object):def __init__(self,conf,model,train_dataset,id=1):self.conf = conf                        # 配置文件self.local_model = model                # 客户端本地模型self.client_id = id                     # 客户端IDself.train_dataset = train_dataset      #客户端本地数据集all_range = list(range(len(self.train_dataset)))data_len = int(len(self.train_dataset)/self.conf['no_models'])indices = all_range[id*data_len:(id+1)*data_len]self.train_loader = torch.utils.data.DataLoader(self.train_dataset,batch_size=conf["batch_size"],sampler=torch.utils.data.sampler.SubsetRandomSampler(indices))def local_train(self, model):for name,param in model.state_dict().items():# 客户端首先用服务器端下发的全局模型覆盖本地模型self.local_model.state_dict()[name].copy_(param.clone())# 定义最优化函数器,用于本地模型训练optimizer = torch.optim.SGD(self.local_model.parameters(),lr=self.conf['lr'],momentum=self.conf['momentum'])# 本地模型训练self.local_model.train()for e in range(self.conf['local_epochs']):for batch_id,batch in enumerate(self.train_loader):data, target = batchif torch.cuda.is_available():self.local_model.cuda()data = data.cuda()target = target.cuda()optimizer.zero_grad()output = self.local_model(data)loss = torch.nn.functional.cross_entropy(output, target)loss.backward()optimizer.step()print("Epoch %d done." % e)diff = dict()for name,data in  self.local_model.state_dict().items():diff[name] = (data - model.state_dict()[name])return diff

Server端代码如下

import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transforms# 服务端
class Server(object):def __init__(self, conf, eval_dataset):self.conf = conf# 服务器端的模型self.global_model = models.get_model(self.conf["model_name"])self.eval_loader = torch.utils.data.DataLoader(eval_dataset,batch_size=self.conf["batch_size"],shuffle=True)self.accuracy_history = []  # 保存accuracy的数组self.loss_history = []  # 保存loss的数组# 聚合各个服务器上传的信息def model_aggregate(self, weight_accumulator):# weight_accumulator存储了每一个客户端的上传参数变化值for name,data in self.global_model.state_dict().items():update_per_layer = weight_accumulator[name] * self.conf['lambda']if data.type() != update_per_layer.type():data.add_(update_per_layer.to(torch.int64))else:data.add_(update_per_layer)# 定义模型评估函数def model_eval(self):self.global_model.eval()total_loss = 0.0correct = 0dataset_size = 0for batch_id,batch in  enumerate(self.eval_loader):data,target = batchdataset_size += data.size()[0]if torch.cuda.is_available():self.global_model.cuda()data = data.cuda()target = target.cuda()output = self.global_model(data)# 把损失值聚合起来total_loss += torch.nn.functional.cross_entropy(output,target,reduction='sum').item()# 获取最大的对数概率的索引值pred = output.data.max(1)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()# 计算准确率acc = 100.0 * (float(correct) / float(dataset_size))# 计算损失值total_l = total_loss / dataset_size# 将accuracy和loss保存到数组中self.accuracy_history.append(acc)self.loss_history.append(total_l)return acc,total_ldef save_results_to_file(self):# 将accuracy和loss保存到文件中with open("fed_accuracy_history.txt", "w") as f:for acc in self.accuracy_history:f.write("{:.2f}\n".format(acc))with open("fed_loss_history.txt", "w") as f:for loss in self.loss_history:f.write("{:.4f}\n".format(loss))

聚合代码如下

import json
import randomimport torchfrom MyDataLoader import get_dataset
from chapter3.Client import Client
from chapter3.Server import Serverwith open("conf.json",'r') as f:conf = json.load(f)# 接下来分别定义一个服务端对象和多个客户端对象,用来模拟横向联邦训练场景train_datasets,eval_datasets = get_dataset("./data/",conf["type"])
server = Server(conf,eval_datasets)
clients = []
# 创建多个客户端
for c in range(conf["no_models"]):clients.append(Client(conf,server.global_model,train_datasets,c))
# 每一轮迭代,服务端会从当前的客户端集合中随机挑选一部分参与本轮迭代训练,被选中的客户端调用本地训练接口local_train进行本地训练,
# 最后服务器调用模型聚合函数model——aggregate来更新全局模型,代码如下所示:
for e in range(conf["global_epochs"]):# 采样k个客户端参与本轮联邦训练candidates = random.sample(clients,conf['k'])# 初始化weight_accumulator并在GPU上(如果可用)weight_accumulator = {}if torch.cuda.is_available():device = torch.device("cuda:0")else:device = torch.device("cpu")for name,params in server.global_model.state_dict().items():# 在指定设备上创建并初始化weight_accumulator中的张量weight_accumulator[name] = torch.zeros_like(params).to(device)for c in candidates:# 确保本地训练后的模型差异在正确设备上diff = c.local_train(server.global_model)for name,params in server.global_model.state_dict().items():weight_accumulator[name].add_(diff[name])server.model_aggregate(weight_accumulator)acc,loss = server.model_eval()print("Epoch %d ,acc:%f,loss: %f\n" % (e,acc,loss))server.save_results_to_file()

数据集的加载

import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transformsdef get_dataset(dir, name):if name == 'mnist':train_dataset = datasets.MINST(dir, train=True, download=True,transform=transforms.ToTensor())eval_dataset = datasets.MINST(dir, train=False, transform=transforms.ToTensor())elif name=='cifar':transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914,0.4822,0.4465),(0.2023,0.1994,0.2010)),])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),])train_dataset = datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)return train_dataset, eval_dataset

配置文件如下

{"model_name" : "resnet18","// comment1": "客户端的个数","no_models" : 10,"type" : "cifar","global_epochs" : 10,"local_epochs" : 3,"// comment2": "每一轮中挑选的机器数","k" : 6,"batch_size" : 32,"lr" : 0.001,"momentum" : 0.0001,"lambda" : 0.1
}

3.结果如下

可以看到联邦学习的效果还是不如集中式学习,也有可能是因为我迭代的轮次不够。
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/862443.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为什么不再推荐使用 VRTK 4?

引言 VRTK (Virtual Reality Toolkit) 发布于2016年,初期受到了广大开发者的欢迎并被广泛采用。但是随着 VR 开发生态的发展,这款工具逐渐失去了最初的光芒。本文试图通过几个维度的分析,解释为什么目前不推荐使用 VRTK 进行开发的理由&…

Eigen中关于四元数的常用操作

四元数(Quaternion)是一种数学工具,广泛用于计算机图形学、机器人学和物理模拟中,特别适合处理三维旋转。Eigen库是一个高性能的C数学库,提供了丰富的线性代数功能,其中就包括对四元数的支持。 1. 为什么选…

mklink

文章目录 mklink概述笔记备注END mklink 概述 看一个开源工程中,有一个.bat脚本,用来建立符号链接。 用的是mklink, 试试,比快捷方式好用。 笔记 测试环境 win10x64-22H2 准备测试用的文件 D:\my_tmp\dir1\readme.txt mklink的命令行帮助…

Windows平台使用S3Browser连接兼容的对象存储

本文记录了在Windows平台使用S3Browser连接兼容的对象存储的过程 一、安装S3Browser 1、下载 S3Browser官网:https://s3browser.com/ 直接下载:https://s3browser.com/download/s3browser-11-6-7.exe 2、安装 3、同意授权后确定安装目录 4、勾选立即…

第7章 Redis的噩梦:阻塞

文章目录 前言1 发现阻塞2.内在原因2.1API或数据结构使用不合理2.1.1如何发现慢查询2.1.2.如何发现大对象 2.2 CPU饱和2.3 持久化阻塞2.3.1fork阻塞2.3.2.AOF刷盘阻塞2.3.3.HugePage写操作阻塞 3 外在原因3.1CPU竞争3.2 内存交换 前言 Redis是典型的单线程架构,所有…

Studying-代码随想录训练营day23| 39.组合总和、40.组合总和II、131.分割回文串

第23天,回溯part02,回溯两个题型组合,切割(ง •_•)ง💪 目录 39.组合总和 40.组合总和II 131.分割回文串 总结 39.组合总和 文档讲解:代码随想录组合总和 视频讲解:手撕组合总和 题目:…

【Qt】信号和槽机制

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; &#x1f525;c系列专栏&#xff1a;C/C零基础到精通 &#x1f525; 给大…

WINDOWS+PHP+Mysql+Apache环境中部署SQLi-Labs、XSS-Labs、UPload-Labs、DVWA、pikachu等靶场环境

web渗透测试学习&#xff0c;需要自己搭建一些靶场&#xff0c;本人主要介绍在WINDOWSPHPMysqlApache环境中部署SQLi-Labs、XSS-Labs、UPload-Labs、DVWA、pikachu等靶场环境。以下是靶场代码下载的链接&#xff1a; pikachu靶场代码 链接&#xff1a;https://pan.baidu.com/s…

废品回收小程序开发:提高废品回收效率

当下&#xff0c;废品回收已经成为了热门行业&#xff0c;家家户户几乎都会进行废品回收&#xff0c;无论是废纸盒还是塑料瓶等&#xff0c;都会送到废品回收站。不过&#xff0c;随着互联网的快速发展&#xff0c;传统的回收模式出现了大量的局限性&#xff0c;已经不能满足大…

探索Android架构设计

Android 应用架构设计探索&#xff1a;MVC、MVP、MVVM和组件化 MVC、MVP和MVVM是常见的三种架构设计模式&#xff0c;当前MVP和MVVM的使用相对比较广泛&#xff0c;当然MVC也并没有过时之说。而所谓的组件化就是指将应用根据业务需求划分成各个模块来进行开发&#xff0c;每个…

领夹麦克风哪个品牌音质最好?主播一般用什么麦克风?麦克风推荐

在这个充满创意与表达的时代&#xff0c;无线领夹麦克风以其独特的魅力&#xff0c;成为了声音创作者们的得力助手。它小巧便携&#xff0c;功能强大&#xff0c;无论是日常拍摄、直播互动还是专业演出&#xff0c;都能轻松应对&#xff0c;让你的声音随时随地清晰传递。那么&a…

# Kafka_深入探秘者(10):kafka 监控

Kafka_深入探秘者&#xff08;10&#xff09;&#xff1a;kafka 监控 一、kafka JMX 1、JMX &#xff1a;全称 Java Managent Extension 在实现 Kafka 监控系统的过程中&#xff0c;首先我们要知道监控的数据从哪来&#xff0c;Kafka 自身提供的监控指标(包括 broker 和主题的…

管理的核心是管人,管人的核心就是这3条,看懂的是高手

管理的核心是管人&#xff0c;管人的核心就是这3条&#xff0c;看懂的是高手 一&#xff1a;管欲 每个人都有欲望&#xff0c;无可厚非。管理者的任务就是利用欲望&#xff0c;管理欲望&#xff0c;通过欲望来达到管人的目的。 最需要管理的就是以下两种&#xff1a; 1、金…

普乐蛙景区9d电影体验馆商场影院娱乐设备旋转飞行影院

今天与大家聊聊VR娱乐新潮流&#xff0c;我们普乐蛙的新品——旋转飞行影院&#xff01;裸眼7D环幕影院&#xff0c;话不多说上产品&#xff01;我们通过亲身体验来给大家讲讲这款高性价比新品的亮点。 想象一下走上电动伸缩梯&#xff0c;坐进动感舱&#xff0c;舱门缓缓合上&…

点击获取2024SIAL西雅国际食品展上海展后报告

随着2024年SIAL 西雅展&#xff08;上海&#xff09;的圆满落幕&#xff0c;我们不仅见证了一场食品与饮料行业的国际盛会&#xff0c;更是感受到了上海这座城市独有的魅力与活力。在这里&#xff0c;我们回顾了上海展的辉煌成就&#xff0c;同时&#xff0c;我们也满怀期待地展…

第4章 客户端-客户端管理

1. 客户端API 1.1client list client list命令能列出与Redis服务端相连的所有客户端连接信息。 127.0.0.1:6379> client list id254487 addr10.2.xx.234:60240 fd1311 name age8888581 idle8888581 flagsN db0 sub0 psub0 multi-1 qbuf0 qbuf-free0 obl0 oll0 omem0 events…

MySQL数据库基础练习系列——教务管理系统

项目名称与项目简介 教务管理系统是一个旨在帮助学校或教育机构管理教务活动的软件系统。它涵盖了学生信息管理、教师信息管理、课程管理、成绩管理以及相关的报表生成等功能。通过该系统&#xff0c;学校可以更加高效地处理教务数据&#xff0c;提升教学质量和管理水平。 1.…

5款提高工作效率的免费工具推荐

SimpleTex SimpleTex是一款用于创建和编辑LaTeX公式的简单工具。它能够识别图片中的复杂公式并将其转换为可编辑的数据格式。该软件提供了一个直观的界面&#xff0c;用户可以在编辑LaTeX代码的同时实时预览公式的效果&#xff0c;无需额外的编译步骤。此外&#xff0c;SimpleT…

生命在于学习——Python人工智能原理(2.5.1)

五、Python的类与继承 5.1 Python面向对象编程 在现实世界中存在各种不同形态的事物&#xff0c;这些事物之间存在各种各样的联系。在程序中使用对象来映射现实中的事物&#xff0c;使用对象之间的关系描述事物之间的联系&#xff0c;这种思想用在编程中就是面向对象编程。 …

C++笔记:实现一个字符串类(构造函数、拷贝构造函数、拷贝赋值函数)

实现一个字符串类String&#xff0c;为其提供可接受C风格字符串的构造函数、析构函数、拷贝构造函数和拷贝赋值函数。 声明依赖文件 其中ostream库用于打印标准输入输出&#xff0c;cstring库为C风格的字符串库 #include <iostream> #include <cstring> 声明命…