【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR

【Python实现机器遗忘算法】复现2023年TNNLS期刊算法UNSIR

在这里插入图片描述

1 算法原理

Tarun A K, Chundawat V S, Mandal M, et al. Fast yet effective machine unlearning[J]. IEEE Transactions on Neural Networks and Learning Systems, 2023.

本文提出了一种名为 UNSIR(Unlearning with Single Pass Impair and Repair) 的机器遗忘框架,用于从深度神经网络中高效地卸载(遗忘)特定类别数据,同时保留模型对其他数据的性能。以下是算法的主要步骤:

1. 零隐私设置(Zero-Glance Privacy Setting)

  • 假设:用户请求从已训练的模型中删除其数据(例如人脸图像),并且模型无法再访问这些数据,即使是为了权重调整。
  • 目标:在不重新训练模型的情况下,使模型忘记特定类别的数据,同时保留对其他数据的性能。

2. 学习误差最大化噪声矩阵(Error-Maximizing Noise Matrix)

  • 初始化:随机初始化噪声矩阵 N,其大小与模型输入相同。

  • 优化目标:通过最大化模型对目标类别的损失函数来优化噪声矩阵 N。具体优化问题为:
    a r g N m i n E ( θ ) = − L ( f , y ) + λ ∥ w n o i s e ∥ argNminE(θ)=−L(f,y)+λ∥wnoise∥ argNminE(θ)=L(f,y)+λwnoise

    其中:

    • L(f,y) 是针对要卸载的类别的分类损失函数。
    • λ∥wnoise∥ 是正则化项,防止噪声值过大。
    • 使用交叉熵损失函数 L 和 L2 归一化。
  • 噪声矩阵的作用:生成的噪声矩阵 N 与要卸载的类别标签相关联,用于在后续步骤中破坏模型对这些类别的记忆。

3. 单次损伤与修复(Single Pass Impair and Repair)

  • 损伤步骤(Impair Step)
    • 操作:将噪声矩阵 N 与保留数据子集Dr结合,训练模型一个周期(epoch)。
    • 目的:通过高学习率(例如 0.02)快速破坏模型对要卸载类别的权重。
    • 结果:模型对要卸载类别的性能显著下降,同时对保留类别的性能也会受到一定影响。
  • 修复步骤(Repair Step)
    • 操作:仅使用保留数据子集 Dr再次训练模型一个周期(epoch),学习率较低(例如 0.01)。
    • 目的:恢复模型对保留类别的性能,同时保持对要卸载类别的遗忘效果。
    • 结果:最终模型在保留数据上保持较高的准确率,而在卸载数据上准确率接近于零。

2 Python代码实现

相关函数

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset,TensorDataset
from torch.amp import autocast, GradScaler  
import numpy as np
import matplotlib.pyplot as plt
import os
import warnings
import random
from copy import deepcopy
random.seed(2024)
torch.manual_seed(2024)
np.random.seed(2024)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Falsewarnings.filterwarnings("ignore")
MODEL_NAMES = "MLP"
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 定义三层全连接网络
class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.fc1 = nn.Linear(28 * 28, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 加载MNIST数据集
def load_MNIST_data(batch_size,forgotten_classes,ratio):transform = transforms.Compose([transforms.ToTensor()])train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)forgotten_train_data,_ = generate_subset_by_ratio(train_data, forgotten_classes,ratio)retain_train_data,_ = generate_subset_by_ratio(train_data, [i for i in range(10) if i not in forgotten_classes])forgotten_train_loader= DataLoader(forgotten_train_data, batch_size=batch_size, shuffle=True)retain_train_loader= DataLoader(retain_train_data, batch_size=batch_size, shuffle=True)return train_loader, test_loader, retain_train_loader, forgotten_train_loader# worker_init_fn 用于初始化每个 worker 的随机种子
def worker_init_fn(worker_id):random.seed(2024 + worker_id)np.random.seed(2024 + worker_id)
def get_transforms():train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 标准化为[-1, 1]])test_transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # 标准化为[-1, 1]])return train_transform, test_transform
# 模型训练函数
def train_model(model, train_loader, criterion, optimizer, scheduler=None,use_fp16 = False):use_fp16 = True# 使用新的初始化方式:torch.amp.GradScaler("cuda")scaler = GradScaler("cuda")  # 用于混合精度训练model.train()running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)# 前向传播with autocast(enabled=use_fp16, device_type="cuda"):  # 更新为使用 "cuda"outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化optimizer.zero_grad()if use_fp16:scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()else:loss.backward()optimizer.step()running_loss += loss.item()if scheduler is not None:# 更新学习率scheduler.step()print(f"Loss: {running_loss/len(train_loader):.4f}")
# 模型评估(计算保留和遗忘类别的准确率)
def test_model(model, test_loader, forgotten_classes=[0]):"""测试模型的性能,计算总准确率、遗忘类别准确率和保留类别准确率。:param model: 要测试的模型:param test_loader: 测试数据加载器:param forgotten_classes: 需要遗忘的类别列表:return: overall_accuracy, forgotten_accuracy, retained_accuracy"""model.eval()correct = 0total = 0forgotten_correct = 0forgotten_total = 0retained_correct = 0retained_total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 计算总的准确率total += labels.size(0)correct += (predicted == labels).sum().item()# 计算遗忘类别的准确率mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=device))forgotten_total += mask_forgotten.sum().item()forgotten_correct += (predicted[mask_forgotten] == labels[mask_forgotten]).sum().item()# 计算保留类别的准确率(除遗忘类别的其他类别)mask_retained = ~mask_forgottenretained_total += mask_retained.sum().item()retained_correct += (predicted[mask_retained] == labels[mask_retained]).sum().item()overall_accuracy = correct / totalforgotten_accuracy = forgotten_correct / forgotten_total if forgotten_total > 0 else 0retained_accuracy = retained_correct / retained_total if retained_total > 0 else 0# return overall_accuracy, forgotten_accuracy, retained_accuracyreturn  round(overall_accuracy, 4), round(forgotten_accuracy, 4), round(retained_accuracy, 4)

主函数

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from models.Base import load_MNIST_data, test_model, load_CIFAR100_data, init_modelclass UNSIRForget:def __init__(self, model):self.model = modelself.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 学习误差最大化噪声矩阵def learn_error_maximizing_noise(self, train_loader, forgotten_classes, lambda_reg=0.01, learning_rate=0.01, num_epochs=5):self.model.eval()# 初始化噪声矩阵 N,大小与输入图像相同(例如28x28图像)noise_matrix = torch.randn(1, 1, 28, 28, device=self.device, requires_grad=True)  # 假设输入是28x28的图像# 优化器用于优化噪声矩阵optimizer = torch.optim.SGD([noise_matrix], lr=learning_rate)noise_data = []noise_labels = []# 生成噪声数据集for epoch in range(num_epochs):total_loss = 0.0for images, labels in train_loader:images, labels = images.to(self.device), labels.to(self.device)# 只对属于遗忘类别的数据进行优化mask_forgotten = torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))noisy_images = images.clone()# 对遗忘类别的图像添加噪声noisy_images[mask_forgotten] += noise_matrix# 保存噪声数据noise_data.append(noisy_images)noise_labels.append(labels)# 前向传播outputs = self.model(noisy_images.view(-1, 28 * 28))  # 假设模型的输入是28x28的图像loss = F.cross_entropy(outputs, labels)# L2 正则化项(噪声矩阵的L2范数)l2_reg = lambda_reg * torch.norm(noise_matrix)# 总损失(包含交叉熵损失和L2正则化)total_loss = loss + l2_reg# 反向传播并更新噪声矩阵optimizer.zero_grad()total_loss.backward()optimizer.step()print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss.item():.4f}")# 返回包含噪声数据和标签的噪声数据集return torch.cat(noise_data), torch.cat(noise_labels), noise_matrix.detach()# 实现机器遗忘(针对特定类别,使用噪声矩阵进行干扰)def unlearn(self, train_loader, forgotten_classes, noise_data, noise_labels, noise_matrix, alpha_impair, alpha_repair, num_epochs=1):# 损伤步骤self.model.train()print("执行损伤中...")for epoch in range(num_epochs):for images, labels in train_loader:images, labels = images.to(self.device), labels.to(self.device)# 仅选择保留类别的数据mask_retained = ~torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))retained_images = images[mask_retained]retained_labels = labels[mask_retained]# 生成新的数据集,将噪声数据添加到保留数据中augmented_images = torch.cat([retained_images, noise_data], dim=0)augmented_labels = torch.cat([retained_labels, noise_labels], dim=0)# 前向传播outputs = self.model(augmented_images.view(-1, 28 * 28))  # 假设模型的输入是28x28的图像loss = F.cross_entropy(outputs, augmented_labels)# 更新模型权重self.model.zero_grad()loss.backward()with torch.no_grad():for param in self.model.parameters():param.data -= alpha_impair * param.grad.data# 修复步骤print("执行修复中...")for epoch in range(num_epochs):for images, labels in train_loader:images, labels = images.to(self.device), labels.to(self.device)# 仅使用保留类别的数据进行修复mask_retained = ~torch.isin(labels, torch.tensor(forgotten_classes, device=self.device))retained_images = images[mask_retained]retained_labels = labels[mask_retained]if retained_images.size(0) == 0:continue# 前向传播和损失计算outputs = self.model(retained_images.view(-1, 28 * 28))loss = F.cross_entropy(outputs, retained_labels)# 更新模型权重self.model.zero_grad()loss.backward()with torch.no_grad():for param in self.model.parameters():param.data -= alpha_repair * param.grad.datareturn self.model# UNSIR算法的主要流程
def unsir_unlearning(model_before, retrain_data, forget_data, all_data, forgotten_classes, lambda_reg=0.01, learning_rate=0.01, alpha_impair=0.5, alpha_repair=0.001, num_epochs=5):"""执行 UNSIR 算法的主要流程,包括学习误差最大化噪声矩阵、损伤、修复步骤,最终返回遗忘后的模型。"""unsir_forgetter = UNSIRForget(model_before)# 计算学习误差最大化噪声矩阵noise_data, noise_labels, noise_matrix = unsir_forgetter.learn_error_maximizing_noise(all_data, forgotten_classes, lambda_reg, learning_rate, num_epochs)# 执行 unlearn(损伤与修复步骤)unlearned_model = unsir_forgetter.unlearn(all_data, forgotten_classes, noise_data, noise_labels, noise_matrix, alpha_impair, alpha_repair, num_epochs)return unlearned_modeldef main():# 超参数设置batch_size = 256forgotten_classes = [0]ratio = 1model_name = "MLP"# 加载数据train_loader, test_loader, retain_loader, forget_loader = load_MNIST_data(batch_size, forgotten_classes, ratio)model_before = init_model(model_name, train_loader)# 在训练之前测试初始模型准确率overall_acc_before, forgotten_acc_before, retained_acc_before = test_model(model_before, test_loader)print("执行 UNSIR 遗忘...")model_after = unsir_unlearning(model_before,retain_loader,forget_loader,train_loader,forgotten_classes,lambda_reg=0.01,learning_rate=0.01,alpha_impair=0.5,alpha_repair=0.001,num_epochs=5,)# 测试遗忘后的模型overall_acc_after, forgotten_acc_after, retained_acc_after = test_model(model_after, test_loader)# 输出遗忘前后的准确率变化print(f"Unlearning前遗忘准确率: {100 * forgotten_acc_before:.2f}%")print(f"Unlearning后遗忘准确率: {100 * forgotten_acc_after:.2f}%")print(f"Unlearning前保留准确率: {100 * retained_acc_before:.2f}%")print(f"Unlearning后保留准确率: {100 * retained_acc_after:.2f}%")if __name__ == "__main__":main()

3 总结

当前方法不支持随机样本或类别子集的卸载,这可能违反零隐私假设。

仍属于重新优化的算法,即还需要训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/893852.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

知识管理平台在企业信息化建设中的应用价值与未来展望

内容概要 在当今信息化时代,企业面临着海量信息的挑战,知识管理平台因此应运而生,成为企业提升管理效率和决策能力的关键工具。知识管理平台不仅仅是一个信息存储的工具,它集成了信息共享、协作与创新、决策支持等多项功能。通过…

MiniHack:为强化学习研究提供丰富而复杂的环境

人工智能咨询培训老师叶梓 转载标明出处 想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具(限时免费)。 1小时实战课程,您将学习到如何轻松上手并有效利用 Llama Facto…

从AD的原理图自动提取引脚网络的小工具

这里跟大家分享一个我自己写的小软件,实现从AD的原理图里自动找出网络名称和引脚的对应。存成文本方便后续做表格或是使用简单行列编辑生成引脚约束文件(如.XDC .UCF .TCL等)。 我们在FPGA设计中需要引脚锁定文件,就是指示TOP层…

MySQL分表自动化创建的实现方案(存储过程、事件调度器)

《MySQL 新年度自动分表创建项目方案》 一、项目目的 在数据库应用场景中,随着数据量的不断增长,单表存储数据可能会面临性能瓶颈,例如查询、插入、更新等操作的效率会逐渐降低。分表是一种有效的优化策略,它将数据分散存储在多…

HTML5使用favicon.ico图标

目录 1. 使用favicon.ico图标 1. 使用favicon.ico图标 favicon.ico一般用于作为网站标志,它显示在浏览器的地址栏或者标签上 制作favicon图标 选择一个png转ico的在线网站,这里以https://www.bitbug.net/为例。上传图片,目标尺寸选择48x48&a…

【C++动态规划 网格】2328. 网格图中递增路径的数目|2001

本文涉及知识点 C动态规划 LeetCode2328. 网格图中递增路径的数目 给你一个 m x n 的整数网格图 grid ,你可以从一个格子移动到 4 个方向相邻的任意一个格子。 请你返回在网格图中从 任意 格子出发,达到 任意 格子,且路径中的数字是 严格递…

fatal error C1083: ޷[特殊字符]ļ: openssl/opensslv.h: No such file or directory

一、环境 1. Visual Studio 2017 2. edk2:202305 3. Python:3.11.4 二、 fatal error C1083: ޷򿪰ļ: openssl/opensslv.h: No such file or directory 上图出现这个警告,不用管。 出现Done,说明编译成功。 执行上…

组件框架漏洞

一.基础概念 1.组件 定义:组件是软件开发中具有特定功能或特性的可重用部件或模块,能独立使用或集成到更大系统。 类型 前端 UI 组件:像按钮、下拉菜单、导航栏等,负责构建用户界面,提升用户交互体验。例如在电商 AP…

隐藏字符造成的linux命令执行失败(非常难绷)

隐藏字符问题发生情景 事情是这样的,为了方便主机和虚拟机之间数据的传输,我打算建一个共享文件夹。由于我选择的是手动挂载,在VirtualBox 中创建好共享文件夹后,我着手打开Ubuntu,想将这个共享文件夹挂载到我的家目录…

C/C++ 虚函数

虚函数的定义 虚函数是指在基类内部声明的成员函数前面添加关键字 virtual 指明的函数虚函数存在的意义是为了实现多态,让派生类能够重写(override)其基类的成员函数派生类重写基类的虚函数时,可以添加 virtual 关键字,但不是必须这么做虚函…

爬虫基础之爬取某基金网站+数据分析

声明: 本案例仅供学习参考使用,任何不法的活动均与本作者无关 网站:天天基金网(1234567.com.cn) --首批独立基金销售机构-- 东方财富网旗下基金平台! 本案例所需要的模块: 1.requests 2.re(内置) 3.pandas 4.pyecharts 其他均需要 pip install 模块名 爬取步骤: …

RKNN_C++版本-YOLOV5

1.背景 为了实现低延时,所以开始看看C版本的rknn的使用,确实有不足的地方,请指正(代码借鉴了rk官方的仓库文件)。 2.基本的操作流程 1.读取模型初始化 // 设置基本信息 // 在postprocess.h文件中定义,详见…

Learning Vue 读书笔记 Chapter 2

2. Vue 基本工作原理 2.1 Virtual DOM 概念: DOM: DOM以内存中树状数据结构的形式,代表了网页上的HTML(或XML)文档内容。它充当了一个编程接口,将网页与实际的编程代码(如JavaScript)连接起来…

【C++高并发服务器WebServer】-7:共享内存

本文目录 一、共享内存1.1 shmget函数1.2 shmat1.3 shmdt1.4 shmctl1.5 ftok1.6 共享内存和内存映射的关联1.7 小demo 二、共享内存操作命令 一、共享内存 共享内存允许两个或者多个进程共享物理内存的同一块区域(通常被称为段)。由于一个共享内存段会称…

CrypTen——基于pytorch的隐私保护机器学习框架

目录 一、CrypTen概述 二、应用场景 三、CrypTen优势 四、CrypTen技术解析 1.基于pytorch的构建基础 2.核心密码学原语 3.加密模型训练流程 五、传统隐私保护技术与CrypTen的对比 1.传统隐私保护技术介绍 2.CrypTen与传统隐私保护技术的区别 六、CrypTen的环境配置…

ES6 简单练习笔记--变量申明

一、ES5 变量定义 1.在全局作用域中 this 其实就是window对象 <script>console.log(window this) </script>输出结果: true 2.在全局作用域中用var定义一个变量其实就相当于在window上定义了一个属性 例如: var name "孙悟空" 其实就相当于执行了 win…

Arduino大师练成手册 -- 控制 PN532 NFC 模块

要在 Arduino 上控制 PN532 NFC 模块&#xff0c;你可以按照以下步骤进行&#xff1a; 硬件连接 VCC&#xff1a;连接到 Arduino 的 3.3V 引脚。 GND&#xff1a;连接到 Arduino 的 GND 引脚。 SDA&#xff1a;连接到 Arduino 的 SDA 引脚&#xff08;通常是 A4&#xff09…

python——Django 框架

Django 框架 1、简介 Django 是用python语言写的开源web开发框架&#xff0c;并遵循MVC设计。 Django的**主要目的是简便、快速的开发数据库驱动的网站。**它强调代码复用&#xff0c;多个组件可以很方便的以"插件"形式服务于整个框架&#xff0c;Django有许多功能…

大模型正确调用方式

1、ollama 安装 curl -fsSL https://ollama.com/install.sh | sh 如果是AutoDl服务器&#xff0c;可以开启学术加速。 source /etc/network_turbo 本次使用腾讯云Cloud Studio&#xff0c;所以已经安装好了 Ollama 2、启动 ollama run 模型的名字 ollama serve # 开启服务 olla…

CE-PBFT:大规模联盟区块链的高可用一致性算法

摘要 区块链已广泛应用于农产品溯源、供应链管理、物流运输等各个领域。作为联盟区块链不可缺少的组成部分&#xff0c;共识算法保证了网络中每个节点的一致性和可信度。然而&#xff0c;由于通信过程的复杂性&#xff0c;现有的大规模联盟区块链场景中的共识算法存在低系统吞…