【联邦学习——手动搭建简易联邦学习】

1. 目的

用于记录自己在手写联邦学习相关实验时碰到的一些问题,方便自己进行回顾。

2. 代码

2.1 本地模型计算梯度更新

# 比较训练前后的参数变化
def compare_weights(new_model, old_model):weight_updates = {}for layer_name, params in new_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)return weight_updates

测试代码如下:
有意思的点在于我获得了update = model2-model1
但是我去计算model1+update==model2的时候发现不相等
最后思考了一下可能是在这个计算的过程中存在精度的丢失

import torch
from torch import nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoaderdef weight_init(m):if isinstance(m, nn.Linear):nn.init.xavier_normal_(m.weight)nn.init.constant_(m.bias, 0)# 也可以判断是否为conv2d,使用相应的初始化方式elif isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')# 是否为批归一化层elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if __name__ == '__main__':# 建立model1和model2作为训练前后的模型model1 = models.get_model("resnet18")model1.apply(weight_init)model2 = models.get_model("resnet18")model2.apply(weight_init)weight_updates = compare_weights(model2,model1)# 创建一个临时的模型状态字典,用于存储更新后的参数updated_params = model1.state_dict().copy()for layer_name, update in weight_updates.items():# 确保该层存在于model1中且形状匹配,避免错误if layer_name in updated_params and update.shape == updated_params[layer_name].shape:# 直接相加更新参数updated_params[layer_name] += updateelse:print(f"Warning: Layer {layer_name} not found or shape mismatch, skipping update.")# 将更新后的参数加载回model1model1.load_state_dict(updated_params)for layer_name,params in model1.state_dict().items():ts = params-model2.state_dict().get(layer_name)# 很重要,这里应该是我们在做的时候会有一点模型精度上的损失,所以不能够计算这里等于0if torch.sum(ts).item()>1e-6:print(f"{layer_name}更新后与原有的不匹配,差距为")else:print(f"{layer_name}更新后与原有的匹配")

2.2 客户端代码

import numpy as np
import torch.utils.data
from tqdm import tqdm'''
conf 配置文件
model 模型
train_dataset 数据集
class_ratios 从数据集中筛选出一部分 class_ratios = {0: 0.5, 1: 0.5,..., 8: 0.5, 9: 0.5}
id 客户端的标识
'''class Client(object):def __init__(self,conf,model,device,train_loader,id=1):self.client_id = id                     # 客户端IDself.conf = conf                        # 配置文件self.local_model = model                # 客户端本地模型self.train_loader = train_loader        # 训练数据的迭代器,需要训练的数据已经在里面了self.grad_update = dict()               # 本地训练完之后的梯度更新self.weight = conf['weight']            # 全局模型梯度更新时的权重self.device = device                    # 训练的设备self.local_model.to(self.device)        # 将模型放入训练设备def train(self, model):self._before_train(model)self._local_train()self._after_train(model)def _before_train(self, model):self._load_global_model(model)# 用服务器模型来覆盖本地模型def _load_global_model(self,model):for name,param in model.state_dict().items():# 客户端首先用服务器端下发的全局模型覆盖本地模型self.local_model.state_dict()[name].copy_(param.clone())def _local_train(self):# 定义最优化函数器,用于本地模型训练optimizer = torch.optim.SGD(self.local_model.parameters(),lr=self.conf['lr'],momentum=self.conf['momentum'])# 本地模型训练self.local_model.train()loss = 0for epoch in range(self.conf['local_epochs']):for batch in tqdm(self.train_loader, desc=f"Epoch {epoch + 1}/{self.conf['local_epochs']}"):data, target = batch# 放入相应的设备data = data.to(self.device)target = target.to(self.device)# 梯度清零optimizer.zero_grad()output = self.local_model(data)loss = torch.nn.functional.cross_entropy(output, target)# 反向传播loss.backward()optimizer.step()print(f"Client{self.client_id}----Epoch {epoch} done.Loss {loss}")def _after_train(self,model):self._cal_update_weights(model)def _cal_update_weights(self, old_model):weight_updates = dict()for layer_name, params in self.local_model.state_dict().items():weight_updates[layer_name] = params - old_model.state_dict().get(layer_name)# 更新梯度模型的权重self.grad_update = weight_updates

2.3 服务器代码

import torch.utils.data
import torchvision.datasets as datasets
from torchvision import models
from torchvision.transforms import transformsfrom utils.CommonUtils import copy_model_params# 服务端
class Server(object):def __init__(self, conf, eval_dataset, device):self.conf = conf# 全局老模型self.old_model = models.get_model(self.conf["model_name"])# 全局的新模型self.global_model = models.get_model(self.conf["model_name"])# 创建时保持新老模型的参数是一致的copy_model_params(self.old_model,self.global_model)# 根据客户端上传的梯度进行排列组合,用于测量贡献度的模型self.sub_model = models.get_model(self.conf["model_name"])self.eval_loader = torch.utils.data.DataLoader(eval_dataset,batch_size=self.conf["batch_size"],shuffle=True)self.accuracy_history = []  # 保存accuracy的数组self.loss_history = []  # 保存loss的数组self.device = deviceself.old_model.to(device)self.global_model.to(device)self.sub_model.to(device)# 模型重构def model_aggregate(self, clients, target_model):if target_model == self.global_model:print("++++++++全局模型更新++++++++")# 更新一下老模型参数copy_model_params(self.old_model,self.global_model)else:print("========子模型重构========")sum_weight = 0# 计算总的权重for client in clients:sum_weight += client.weight# 将old_model的模型参数赋值给sub_modelcopy_model_params(self.sub_model, self.old_model)# 初始化一个空字典来累积客户端的模型更新aggregated_updates = {}# 遍历每个客户端for client in clients:# 根据客户端的权重比例聚合更新for name, update in client.grad_update.items():if name not in aggregated_updates:aggregated_updates[name] = update * client.weight / sum_weightelse:aggregated_updates[name] += update * client.weight / sum_weight# 应用聚合后的更新到sub_modelfor name, param in target_model.state_dict().items():if name in aggregated_updates:param.copy_(param + aggregated_updates[name])  # 累加更新到当前层参数上# 定义模型评估函数def model_eval(self,target_model):target_model.eval()total_loss = 0.0correct = 0dataset_size = 0for batch_id,batch in enumerate(self.eval_loader):data,target = batchdataset_size += data.size()[0]# 放入和模型对应的设备data = data.to(self.device)target = target.to(self.device)# 模型预测output = target_model(data)# 把损失值聚合起来total_loss += torch.nn.functional.cross_entropy(output,target,reduction='sum').item()# 获取最大的对数概率的索引值pred = output.data.max(1)[1]correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()# 计算准确率acc = 100.0 * (float(correct) / float(dataset_size))# 计算损失值total_l = total_loss / dataset_size# 将accuracy和loss保存到数组中self.accuracy_history.append(acc)self.loss_history.append(total_l)if target_model == self.global_model:print(f"++++++++全局模型评估++++++++acc:{acc}  loss:{total_l}")else:print(f"========子模型评估========acc:{acc}  loss:{total_l}")return acc,total_ldef save_results_to_file(self):# 将accuracy和loss保存到文件中with open("fed_accuracy_history.txt", "w") as f:for acc in self.accuracy_history:f.write("{:.2f}\n".format(acc))with open("fed_loss_history.txt", "w") as f:for loss in self.loss_history:f.write("{:.4f}\n".format(loss))

2.4 Utils

def copy_model_params(target_model, source_model):for name, param in source_model.state_dict().items():target_model.state_dict()[name].copy_(param.clone())

3. 运行测试代码

import jsonimport torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision import models
from torchvision import transforms, datasets
from client.Client import Client
from server.Server import Serverwith open("../conf/client1.json",'r') as f:conf = json.load(f)with open("../conf/server1.json",'r') as f:serverConf = json.load(f)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.CIFAR10(root='../data/', train=True, download=True,transform=transform)
eval_dataset = datasets.CIFAR10(root='../data/', train=False, download=True,transform=transform)# train_loader = DataLoader(train_dataset, shuffle=True, batch_size=32, num_workers=2)# 计算数据集长度
total_samples = len(train_dataset)
# 确保可以平均分配,否则需要调整逻辑以处理余数
assert total_samples % 2 == 0, "数据集样本数需为偶数以便完全平分"# 分割点
split_point = total_samples // 2# 创建两个子集
train_dataset_first_half = Subset(train_dataset, range(0, split_point))
train_dataset_second_half = Subset(train_dataset, range(split_point, total_samples))# 然后为每个子集创建DataLoader
batch_size = 32train_loader_first_half = DataLoader(train_dataset_first_half, shuffle=True, batch_size=batch_size, num_workers=2)
train_loader_second_half = DataLoader(train_dataset_second_half, shuffle=True, batch_size=batch_size, num_workers=2)# 检查CUDA是否可用
if torch.cuda.is_available():device = torch.device("cuda")  # 如果CUDA可用,选择GPU
else:device = torch.device("cpu")   # 如果CUDA不可用,选择CPUlocal_model = models.get_model("resnet18")
local_model2 = models.get_model("resnet18")server = Server(serverConf,eval_dataset,device)
client1 = Client(conf, local_model, device,train_loader_first_half, 1)
client2 = Client(conf, local_model2, device,train_loader_second_half, 2)
for i in range(2):client1.train(server.global_model)client2.train(server.global_model)server.model_aggregate([client1,client2], server.global_model)server.model_eval(server.global_model)server.model_aggregate([client1], server.sub_model)server.model_eval(server.sub_model)server.model_aggregate([client2], server.sub_model)server.model_eval(server.sub_model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/11806.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python---Pandas万字总结(1)

Pandas基础-1 Pandas 是 一个强大的分析结构化数据的工具集。Pandas 以 NumPy 为基础(实现数据存储和运算),提供了专门用于数据分析的类型、方法和函数,对数据分析和数据挖掘提供了很好的支持;同时 pandas 还可以跟数…

kubeadm 在vubuntu22.04.4 server 上安装kubernetes 1.28.9

一、基础安装(所有节点执行)---------------------------------------- 时间同步 关闭防火墙 sudo ufw disable sudo ufw status关闭交换内存 临时关闭 sudo swapoff -a free -m永久关闭 sudo vim /etc/fstab 注释掉交换内存 转发 IPv4 并让 iptab…

2、MySQL总结

一、基础部分 (一)、概念 关系型数据库 (二)、SQL编写 CRUD(查询、插入、更新、删除) 左右连接、内连接、子查询 (三)、存储过程、存储函数 存储过程和函数(一&am…

云贝教育 |【好课上新】ITSS服务工程师与服务经理认证培训

课程前言 ITSS是中国电子技术标准化研究院推出的,包含“IT 服务工程师”和“IT 服务经理”的系列培训。有效满足GB/T 28827.1 的符合性评估要求和ITSS服务资质升级要求。 IT 服务工程师”结合 IT服务从业人员能力规范和要求,从服务技术、服务技巧和服务…

【Linux】调试器-gdb使用

大家好,我是苏貝,本篇博客带大家了解Linux的编译器-gcc/g,如果你觉得我写的还不错的话,可以给我一个赞👍吗,感谢❤️ 目录 1. 背景(A) 看大小(B) 查看ELF格式的文件 2.使用(A) 进入gdb(B) quit/q&#xff…

【码农日常】将mp4转换为逐帧图片

项目场景: 拍摄了一段视频记录设备工作的状态和测量仪器的实时数据。由于测量仪器岁数比较大,不够智能,遂打算将视频转换为逐帧图片进行分析。 网上没找到现成工具,借鉴网上大神的操作方式打算用python写一个工具。 问题描述 用…

Mybatis中collection和association的区别

在MyBatis中,如果我们想对一对一或者一对多的多表进行查询,该如何处理呢? MyBatis提供了下面两个标签来处理一对一、多对一、一对多的映射关系: association: 处理一对一、多对一 collection: 处理一对多 一对一 每个人都有身份…

一、VIsual Studio下的Qt环境配置(Visual Studio 2022 + Qt 5.12.10)

一、下载编译器Visual Studio2022和Qt 5.12.10 Visual Studio 2022 社区版就够学习使用了 Qt5.12.10 安装教程网上搜,一大堆 也很简单,配置直接选默认,路径留意一下即可 二、配置环境 Ⅰ,配置Qt环境变量 系统变量下的Path&a…

【DevOps】深入解析 Docker日志分析和服务故障排除技巧

在今天的云计算和微服务架构中,Docker凭借其轻量级和高效的容器化技术,已成为软件部署不可或缺的一部分。然而,随着应用复杂性的增加,有效的日志管理和故障排除能力成为了开发者和运维人员必须掌握的核心技能。本文将带你深入探索…

AI办公自动化-用kimi自动清理删除重复文件

在kimichat中输入提示词: 你是一个Python编程专家,要完成一个编写Python脚本的任务,具体步骤如下: 1、打开文件夹D:\downloads; 2、哈希值比较比较里面所有的文件,如果文件相同,那么移动多余…

鲲鹏服务器ARM硬件的高字节忽略(Top Byte Ignore,TBI)

HWASan1 和HWASanIO2 等借助ARM的高字节忽略(Top Byte Ignore,TBI)的硬件特性,使用内存标记检测内存错误。TBI是指64位ARM机器中,程序64位地址中最高的字节被硬件忽略,实际的地址空间只有56位。我们在鲲鹏服…

群晖NAS本地搭建Bitwarden密码管理服务并实现远程同步密码托管

文章目录 1. 拉取Bitwarden镜像2. 运行Bitwarden镜像3. 本地访问4. 群晖安装Cpolar5. 配置公网地址6. 公网访问Bitwarden7. 固定公网地址8. 浏览器密码托管设置 Bitwarden是一个密码管理器应用程序,适用于在多个设备和浏览器之间同步密码。自建密码管理软件bitwarde…

华为配置智能无损网络综合

配置智能无损网络综合示例 适用产品和版本 安装了P系列单板的CE16800、CE6866、CE6866K、CE8851-32CQ8DQ-P、CE8851K系列交换机V300R020C00或更高版本。 安装了SAN系列单板的CE16800、CE6860-SAN、CE8850-SAN系列交换机V300R020C10或更高版本。 CE6860-HAM、CE8850-HAM系列交换…

初识FlaskMySQL实现前后端通信 全栈开发之路——后端篇(1)

全栈开发一条龙——前端篇 第一篇:框架确定、ide设置与项目创建 第二篇:介绍项目文件意义、组件结构与导入以及setup的引入。 第三篇:setup语法,设置响应式数据。 第四篇:数据绑定、计算属性和watch监视 第五篇 : 组件…

02-WPF_基础(二)

3、控件学习 控件学习 布局控件: panel、Grid 内容空间:Context 之恶能容纳一个控件或布局控件 代表提内容控件:内容控件可以设置标题 Header 父类:HeaderContextControl。 条目控件:可以显示一列数据&#xf…

CircleCI的原理及应用详解(二)

本系列文章简介: 在当今快速发展的软件开发环境中,如何确保代码质量、提升开发效率以及快速响应市场需求成为了每个开发团队面临的重要挑战。为了解决这些问题,持续集成和持续部署(CI/CD)工具应运而生,它们…

前端面试题大合集4----框架篇(React)

一、React 合成事件 Dom事件流分三个阶段&#xff1a;事件捕获阶段&#xff0c;目标阶段&#xff0c;事件冒泡阶段 React在事件绑定时有一套自身的机制&#xff0c;就是合成事件。如下比较直观&#xff1a; react中事件绑定&#xff1a; <div className"dome" …

如何解决3D模型变黑或贴图不显示的问题---模大狮模型网

在进行3D建模和视觉渲染时&#xff0c;经常会遇到模型表面变黑或贴图不显示的问题&#xff0c;这可能严重影响最终视觉效果的质量。这些问题通常与材质设置、光照配置或文件路径错误有关。本文将探讨几种常见原因及其解决方法&#xff0c;帮助3D艺术家和开发者更有效地处理这些…

Java | Leetcode Java题解之第88题合并两个有序数组

题目&#xff1a; 题解&#xff1a; class Solution {public void merge(int[] nums1, int m, int[] nums2, int n) {int p1 m - 1, p2 n - 1;int tail m n - 1;int cur;while (p1 > 0 || p2 > 0) {if (p1 -1) {cur nums2[p2--];} else if (p2 -1) {cur nums1[p…

Hive表数据优化

Hive表数据优化 1.文件格式 为Hive表中的数据选择一个合适的文件格式&#xff0c;对提高查询性能的提高是十分有益的。 &#xff08;1&#xff09;Text File 文本文件是Hive默认使用的文件格式&#xff0c;文本文件中的一行内容&#xff0c;就对应Hive表中的一行记录。 可…