联邦学习在non-iid数据集上的划分和训练——从零开始实现

虽然网上已经有了很多关于Dirichlet分布进行数据划分的原理和方法介绍,但是整个完整的联邦学习过程还是少有人分享。今天就从零开始实现

加载FashionMNIST数据集

import torch
from torchvision import datasets, transforms# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 加载训练和测试数据集
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

定义Dirichlet分布的划分函数

这里的写法是其中一种,也可以参考其它大神的写法。
具体Dirichlet划分的原理也可以参考下面的博客:
联邦学习:按Dirichlet分布划分Non-IID样本 - orion-orion - 博客园 (cnblogs.com)

import numpy as npdef dirichlet_distribution_noniid(dataset, num_clients, alpha):# 获取每个类的索引class_indices = [[] for _ in range(10)]for idx, (image, label) in enumerate(dataset):class_indices[label].append(idx)# 使用Dirichlet分布进行数据划分client_indices = [[] for _ in range(num_clients)]for class_idx in class_indices:np.random.shuffle(class_idx)proportions = np.random.dirichlet([alpha] * num_clients)proportions = (np.cumsum(proportions) * len(class_idx)).astype(int)[:-1]client_split = np.split(class_idx, proportions)for client_idx, client_split_indices in enumerate(client_split):client_indices[client_idx].extend(client_split_indices)return client_indices

将数据集划分给各客户端

这里的代码操作核心在于,对数据加载器DataLoader中的Subset的理解,这个函数是根据索引将数据集划分为子数据集,以前我知道它是在做什么,但是一直不太明白用法,最终在ChatGPT的帮助下完成了:

num_clients = 10
alpha = 0.5 #non-iid程度的超参数,我喜欢用0.5和0.3
client_indices = dirichlet_distribution_noniid(train_dataset, num_clients, alpha)# 创建客户端数据加载器
from torch.utils.data import DataLoader, Subsetclient_loaders = [DataLoader(Subset(train_dataset, indices), batch_size=32, shuffle=True) for indices in client_indices]

定义模型、训练函数和测试函数

import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as pltclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.flatten(x)x = torch.relu(self.fc1(x))x = self.fc2(x)return xdef train(model, train_loader, criterion, optimizer, device, epochs=5):model.train()model.to(device)for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")def test(model, test_loader, device):model.eval()model.to(device)correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalreturn accuracy

进行训练并记录测试准确度

# 选择设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 创建模型和损失函数
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练和测试数据加载器
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)# 记录每轮测试准确度
test_accuracies = []# 在每个客户端上进行训练并测试
for i, client_loader in enumerate(client_loaders):print(f"Training on client {i+1}")train(model, client_loader, criterion, optimizer, device)accuracy = test(model, test_loader, device)test_accuracies.append(accuracy)print(f"Test Accuracy after client {i+1}: {accuracy:.4f}")# 绘制测试准确度变化图
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_clients + 1), test_accuracies, marker='o')
plt.title('Test Accuracy after Training on Each Client')
plt.xlabel('Client')
plt.ylabel('Test Accuracy')
plt.ylim(0, 1)
plt.grid(True)
plt.show()

一些踩过的坑

Expected more than 1 value per channel when training, got input size torch.Size

解决方案

这里可能是当UE数量让数据集没法整除的时候,出现了多余的batch。
设置 batch_size>1, 且 drop_last=True

 DataLoader(train_set, batch_size=args.train_batch_size,num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)

RuntimeError: output with shape [1, 28, 28] doesn’t match the broadcast shape [3, 28, 28]

错误是因为图片格式是灰度图只有一个channel,需要变成RGB图才可以,所以需要在对图片的处理transforms里面修改:

transform = transforms.Compose([transforms.ToTensor(),transforms.Lambda(lambda x: x.repeat(3,1,1)),# 增加这一行transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])   

运行结果

将以上的代码拼接起来,就能够正常跑起来,我也已经在自己的电脑上验证过了。
image.png
image.png

当然了,上面画的是一次epoch的各个client的准确度,进行多次epoch的训练可以自己再修改。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/23302.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nvidia/算能 +FPGA+AI大算力边缘计算盒子:医疗健康智能服务

北京天星医疗股份有限公司(简称“天星医疗”)作为国产运动医学的领导者,致力于提供运动医学的整体临床解决方案,公司坐落于北京经济技术开发区。应用于肩关节、膝关节、足/踝关节、髋关节、肘关节、手/腕关节的运动医学设备、植入物和手术器械共计300多个…

Python Flask 入门开发

Python基础学习: Pyhton 语法基础Python 变量Python控制流Python 函数与类Python Exception处理Python 文件操作Python 日期与时间Python Socket的使用Python 模块Python 魔法方法与属性 Flask基础学习: Python中如何选择Web开发框架?Pyth…

个人笔记-python生成gif

使用文件的修改时间戳进行排序 import os import re import imageio# 设置图片所在的文件夹路径 folder_path /home/czy/ACode/AMAW_20240219/9.3.x(Discrete_time_marching)/9.3.17.11.1(Disc_concessive_CH_ZJ)/current_figures # 文件夹路径;linux…

电商核心技术系列58:电商平台的智能数据分析与业务洞察

相关系列文章 电商技术揭秘相关系列文章合集(1) 电商技术揭秘相关系列文章合集(2) 电商技术揭秘相关系列文章合集(3) 电商核心技术揭秘56:客户关系管理与忠诚度提升 电商核心技术揭秘57:数…

【Python系列】Python 方法变量参数详解

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

白酒:茅台镇白酒的醇厚口感与细腻层次

茅台镇,中国白酒的璀璨明珠,以其与众不同的自然环境和杰出的酿造技艺,孕育出了无数上好白酒。云仓酒庄豪迈白酒作为茅台镇的杰出品牌,以其醇厚口感和细腻层次,赢得了无数消费者的喜爱。 茅台镇地处赤水河畔&#xff0c…

训练集和测试集的分布一致性分析

规律一致性分析的实际作用   在实际建模过程中,规律一致性分析是非常重要但又经常容易被忽视的一个环节。通过规律一致性分析,我们可以得出非常多的可用于后续指导后续建模的关键性意见。通常我们可以根据规律一致性分析得出以下基本结论: …

ai写作神器app有哪些?好用的智能写作APP推荐

ai写作神器app有哪些?AI写作神器app在现代写作领域正迅速崭露头角,它们不仅极大提升了创作效率,而且通过集成前沿的人工智能技术,为创作者们提供了前所未有的便利。这些app能够智能分析写作需求,快速生成高质量的内容&…

十五、【源码】动态Sql

源码地址:https://github.com/mybatis/mybatis-3/ 仓库地址:https://gitcode.net/qq_42665745/mybatis/-/tree/15-dynamic-sql 动态Sql 解析动态Sql分为两部分 1.解析XML中Sql的时候,要将其解析成不同的SqlNode节点,但是不进行…

Jenkins的jdk和maven配置

目录 传送门前言一、概念二、JDK的配置三、Maven配置四、环境变量配置五、坑 传送门 SpringMVC的源码解析(精品) Spring6的源码解析(精品) SpringBoot3框架(精品) MyBatis框架(精品&#xff09…

「51媒体」媒体发布会如何做媒体邀约

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 媒体发布会的媒体邀约是一个需要精心策划和准备的过程。 策划与准备阶段: 明确目标:明确发布会的目标、核心议题、举办日期、时间和地点。 准备资料&#xff1a…

体验式营销+旅行文化:品牌海外推广的创新之路

在全球化的时代背景下,体验式营销作为一种新兴的营销方式,以其独特的参与性、互动性和情感共鸣,成为品牌与消费者建立深层次联系的有效手段。而将体验式营销与旅行文化相结合,能够为海外消费者提供独特的品牌体验。本文Nox聚星将和…

GD32单片机开发--点亮第一盏灯

知不足而奋进 望远山而前行 目录 系列文章目录 文章目录 前言 目标 内容 开发流程 需求分析 项目新建 代码编写 GPIO初始化 完整代码 程序编译 程序烧录 烧录扩展(熟悉) 官方烧录器烧录(熟悉) 总结 前言 在本次项…

C#操作MySQL从入门到精通(11)——对查询数据使用正则表达式过滤

前言 对于之前提到的使用匹配、比较、通配符等过滤方式能解决大部分的项目问题,但是有时候也会遇到一些比较复杂的过滤需求,这时候就需要正则表达式来实现了,正则表达式使用regexp这个关键字来实现。 本次测试的数据库表的内容如下: 1、基本字符匹配(包含某些字符) 匹…

嵌入式之存储基本知识

系列文章目录 嵌入式之存储基本知识 嵌入式之存储基本知识 系列文章目录一、RAM与ROM二、DRAM和SRAM三、SDRAM(DRAM的一种)四、DDR 一、RAM与ROM RAM(随机存取存储器)和ROM(只读存储器)是两种不同类型的计…

揭秘VVIC API:开启高效数据交互的密钥,你的项目就差这一步

VVIC API接口概述 VVIC API提供了对VVIC服务的数据访问和操作功能。通过此API,开发者可以集成VVIC服务到他们的应用程序中,实现数据同步、用户认证、资源管理等功能。 点击获取key和secret API端点示例 用户认证 方法:POSTURL:/…

Nvidia Jetson/Orin +FPGA+AI大算力边缘计算盒子:无人机自主飞行软件平台

案例简介 北京泛化智能科技有限公司(gi)所主导开发的 Generalized Autonomy Aviation System (GAAS) 是为无人机以及城市空中交通 (UAM, Urban Air Mobility) 所设计的开源无人机自主飞行框架。通过 SLAM、路径规划和 Global Optimization Graph 等功能…

【Linux】(三)—— 文件管理和软件安装

文件管理 Linux的文件管理是系统管理中的核心部分,它涉及到如何组织、访问、修改和保护文件及目录结构。 目录 文件管理基本概念常用命令查看和切换目录创建文件和目录删除文件和目录文件拷贝移动和重命名文件文件查看cat文件查看more查找文件查找文本 数据流和管道…

redsystems教程的基本使用之重置密码(忘记密码解决方法)

前言: 相信很多人都有疑惑,要是我不记得密码怎么办?如果你登录了,点击更改密码后,还是要你填写登录密码才能修改。为了解决这问题,博主通过了钻研成功搞出来了!!!&#…

DS:数与二叉树的相关概念

欢迎来到Harper.Lee的学习世界!博主主页传送门:Harper.Lee的博客主页想要一起进步的uu可以来后台找我哦! 一、树的概念及其结构 1.1 树的概念亲缘关系 树是一种非线性的数据结构,它是由n(n>0)个有限节点…