paddle2.3-基于联邦学习实现FedAVg算法-CNN

目录

1. 联邦学习介绍

2. 实验流程

3. 数据加载

4. 模型构建

5. 数据采样函数

6. 模型训练


1. 联邦学习介绍

联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。

联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。

本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区

2. 实验流程

联邦学习的基本流程是:

1. server初始化模型参数,所有的clients将这个初始模型下载到本地;

2. clients利用本地产生的数据进行SGD训练;

3. 选取K个clients将训练得到的模型参数上传到server;

4. server对得到的模型参数整合,所有的clients下载新的模型。

5. 重复执行2-5,直至收敛或达到预期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 数据加载

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型构建

class CNN(nn.Layer):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2D(1,32,5)self.relu = nn.ReLU()self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)self.conv2=nn.Conv2D(32,64,5)self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)self.fc1=nn.Linear(1024,512)self.fc2=nn.Linear(512,10)# self.softmax = nn.Softmax()def forward(self,inputs):x = self.conv1(inputs)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x=paddle.reshape(x,[-1,1024])x = self.relu(self.fc1(x))y = self.fc2(x)return y

5. 数据采样函数

# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):num_items_per_client = int(len(dataset)/clients)client_dict = {}image_idxs = [i for i in range(len(dataset))]for i in range(clients):client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除client_dict[i] = list(client_dict[i])return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):shard_idxs = [i for i in range(total_shards)]client_dict = {i: np.array([], dtype='int64') for i in range(clients)}idxs = np.arange(len(dataset))data_labels = Labellabel_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠label_idxs = label_idxs[:, label_idxs[1,:].argsort()]idxs = label_idxs[0,:]for i in range(clients):rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) shard_idxs = list(set(shard_idxs) - rand_set)for rand in rand_set:client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接return client_dict

class MNISTDataset(Dataset):def __init__(self, data,label):self.data = dataself.label = labeldef __getitem__(self, idx):image=np.array(self.data[idx]).astype('float32')image=np.reshape(image,[1,28,28])label=np.array(self.label[idx]).astype('int64')return image, labeldef __len__(self):return len(self.label)

6. 模型训练

class ClientUpdate(object):def __init__(self, data, label, batch_size, learning_rate, epochs):dataset = MNISTDataset(data,label)self.train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)self.learning_rate = learning_rateself.epochs = epochsdef train(self, model):optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())criterion = nn.CrossEntropyLoss(reduction='mean')model.train()e_loss = []for epoch in range(1,self.epochs+1):train_loss = []for image,label in self.train_loader:# image=paddle.to_tensor(image)# label=paddle.to_tensor(label.reshape([label.shape[0],1]))output=model(image)loss= criterion(output,label)# print(loss)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.numpy()[0])t_loss=sum(train_loss)/len(train_loss)e_loss.append(t_loss)total_loss=sum(e_loss)/len(e_loss)return model.state_dict(), total_loss

train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):global_weights = model.state_dict()train_loss = []start = time.time()# clients与server之间通信for curr_round in range(1, rounds+1):w, local_loss = [], []m = max(int(C*K), 1) # 随机选取参与更新的clientsS_t = np.random.choice(range(K), m, replace=False)for k in S_t:# print(data_dict[k])sub_data = ds[data_dict[k]]sub_y = L[data_dict[k]]local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)weights, loss = local_update.train(model)w.append(weights)local_loss.append(loss)# 更新global weightsweights_avg = w[0]for k in weights_avg.keys():for i in range(1, len(w)):# weights_avg[k] += (num[i]/sum(num))*w[i][k]weights_avg[k]=weights_avg[k]+w[i][k]   weights_avg[k]=weights_avg[k]/len(w)global_weights[k].set_value(weights_avg[k])# global_weights = weights_avg# print(global_weights)#模型加载最新的参数model.load_dict(global_weights)loss_avg = sum(local_loss) / len(local_loss)if curr_round % 10 == 0:print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))train_loss.append(loss_avg)end = time.time()fig, ax = plt.subplots()x_axis = np.arange(1, rounds+1)y_axis = np.array(train_loss)ax.plot(x_axis, y_axis, 'tab:'+plt_color)ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)ax.grid()fig.savefig(plt_title+'.jpg', format='jpg')print("Training Done!")print("Total time taken to Train: {}".format(end-start))return model.state_dict()#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

Round: 10... 	Average Loss: [0.024]
Round: 20... 	Average Loss: [0.015]
Round: 30... 	Average Loss: [0.008]
Round: 40... 	Average Loss: [0.003]
Round: 50... 	Average Loss: [0.004]
Round: 60... 	Average Loss: [0.002]
Round: 70... 	Average Loss: [0.002]
Round: 80... 	Average Loss: [0.002]
Round: 90... 	Average Loss: [0.001]
Round: 100... 	Average Loss: [0.]
Training Done!
Total time taken to Train: 759.6239657402039

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/92615.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自己动手写编译器:实现命令行模块

在前面一系列章节中,我们完成了词法解析的各种算法。包括解析正则表达式字符串,构建 NFA 状态就,从 NFA 转换为 DFA 状态机,最后实现状态机最小化,接下来我们注重词法解析模块的工程化实现,也就是我们将所有…

【信创】麒麟v10(arm)-mysql8-mongo-redis-oceanbase

Win10/Win11 借助qume模拟器安装arm64麒麟v10 前言 近两年的国产化进程一直在推进,基于arm架构的国产系统也在积极发展,这里记录一下基于麒麟v10arm版安装常见数据库的方案。 麒麟软件介绍: 银河麒麟高级服务器操作系统V10 - 国产操作系统、银河麒麟、中…

树概念及结构

.1树的概念 树是一种非线性的数据结构,它是由n(n>0)个有限结点组成一个具有层次关系的集合。把它叫做树是因 为它看起来像一棵倒挂的树,也就是说它是根朝上,而叶朝下的。 有一个特殊的结点,称为根结点&a…

springcloud:四、nacos介绍+启动+服务分级存储模型/集群+NacosRule负载均衡

nacos介绍 nacos是阿里巴巴提供的SpringCloud的一个组件,算是eureka的替代品。 nacos启动 安装过程这里不再赘述,相关安装或启动的问题可以见我的另一篇博客: http://t.csdn.cn/tcQ76 单价模式启动命令:进入bin目录&#xff0…

14:00面试测试岗,14:06就出来了,问的问题有点变态。。。

从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到9月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%,…

Kotlin异常处理runCatching,getOrNull,onFailure,onSuccess(1)

Kotlin异常处理runCatching&#xff0c;getOrNull&#xff0c;onFailure&#xff0c;onSuccess&#xff08;1&#xff09; fun main(args: Array<String>) {var s1 runCatching {1 / 1}.getOrNull()println(s1) //s11&#xff0c;打印1println("-")var s2 ru…

Academic accumulation|英文文献速读

一、英文文献速读法 &#xff08;一&#xff09;明确目的 建议大家阅读一篇论文之前先问一下自己是出于怎样的目的来阅读这篇文章&#xff0c;是为了找选题方向、学某个问题的研究设计、学某种研究方法、学文章写作还是别的。不同的阅读目的会导致不同的关注重点&#xff0c;例…

嵌入式学习笔记(41)SD卡启动详解

内存和外存的区别&#xff1a;一般是把这种RAM(random access memory,随机访问存储器&#xff0c;特点是任意字节读写&#xff0c;掉电丢失)叫内存&#xff0c;把ROM&#xff08;read only memory&#xff0c;只读存储器&#xff0c;类似于Flash SD卡之类的&#xff0c;用来存储…

建站软件WordPress和phpcms体验

一、网站程序 什么是网站程序? 网站程序是由程序员编写的一个网站安装包,程序是网站内容的载体。 常见的网站程序有: dedecms , phpcms ,帝国cms ,米拓cms , WordPress , discuz , ECShop ,shopex , z-blog等,根据不同类型的网站我们来选择不同的网站程序。 比如说搭建一个…

【生物信息学】基因差异分析Deg(数据读取、数据处理、差异分析、结果可视化)

目录 一、实验介绍 二、实验环境 1. 配置虚拟环境 2. 库版本介绍 3. IDE 三、实验内容 0. 导入必要的工具包 1. 定义一些阈值和参数 2. 读取数据 normal_data.csv部分展示 tumor_data.csv部分展示 3. 绘制箱型图 4. 删除表达量低于阈值的基因 5. 计算差异显著的基…

成都瀚网科技:抖音上线地方方言自动翻译功能

为了让很多方言的地域历史、文化、习俗能够以短视频的形式生产、传播和保存&#xff0c;解决方言难以被更多用户阅读和理解的问题&#xff0c;平台正式上线推出当地方言自动翻译功能。创作者可以利用该功能&#xff0c;将多个方言视频“一键”转换为普通话字幕供大众观看。 具体…

【视频去噪】基于全变异正则化最小二乘反卷积是最标准的图像处理、视频去噪研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

Redis是否要分库的实践

Redis的分库其实没有带来任何效率上的提升&#xff0c;只是提供了一个命名空间&#xff0c;而这个命名空间可以完全通过key的设计来避开这个问题。 一个优雅的Redis的key的设计如下

Windows历史版本下载

1、微PE工具箱&#xff08;非广告本人常用&#xff09; 常用安装Windows系统的微PE工具 地址&#xff1a;https://www.wepe.com.cn/download.html 2、Windows系统下载地址&#xff08;非微软官方&#xff09; 地址&#xff1a;MSDN, 我告诉你 - 做一个安静的工具站 下载&…

【嵌入式】使用MultiButton开源库驱动按键并控制多级界面切换

目录 一 背景说明 二 参考资料 三 MultiButton开源库移植 四 设计实现--驱动按键 五 设计实现--界面处理 一 背景说明 需要做一个通过不同按键控制多级界面切换以及界面动作的程序。 查阅相关资料&#xff0c;发现网上大多数的应用都比较繁琐&#xff0c;且对于多级界面的…

并查集LRUCache

文章目录 并查集1.概念2. 实现 LRUCache1. 概念2. 实现使用标准库实现自主实现 并查集 1.概念 并查集是一个类似于森林的数据结构&#xff0c;并、查、集指的是多个不相干的集合直接的合并和查找&#xff0c;并查集使用于N个集合。适用于将多个元素分成多个集合&#xff0c;在…

[FineReport]安装与使用(连接Hive3.1.2)

一、安装(对应hive3.1.2) 注&#xff1a;服务器的和本地的要同时安装。本地是测试环境&#xff0c;服务器的是生产环境 1、服务器安装 1、下载 免费下载FineReport - FineReport报表官网 向下滑找到 2、解压 [rootck1 /home/data_warehouse/software]# tar -zxvf tomcat…

数据挖掘(1)概述

一、数据仓库和数据挖掘概述 1.1 数据仓库的产生 数据仓库与数据挖掘&#xff1a; 数据仓库和联机分析处理技术(存储)。数据挖掘&#xff1a;在大量的数据中心挖掘感兴趣的知识、规则、规律、模式、约束(分析)。数据仓库用于决策分析&#xff1a; 数据仓库&#xff1a;是在数…

机器学习算法基础--K-means应用实战--图像分割

目录 1.项目内容介绍 2.项目关键代码 3.项目效果展示 1.项目内容介绍 本项目是将一张图片进行k-means分类&#xff0c;根据色彩k进行分类&#xff0c;最后比较和原图的效果。 题目还是比较简单的&#xff0c;我们只要通过k-means聚类&#xff0c;一类就是一种色彩得出聚类之…

快速上手kettle(三)壶中可以放些啥?

序言 快速上手kettle开篇中,我们将kettle比作壶,并对这个壶做了简单介绍。 而上一期中我们实现了①将csv文件通过kettle转换成excel文件; ②将excel文件通过kettle写入到MySQL数据库表中 这两个案例。 相信大家跟我一样,对kettle已经有了初步认识,并且对这强大的工具产…