【对抗算法复现】CW

首先进行数据的预处理

transform = transforms.Compose([transforms.ToTensor(),  # 将图片转换为Tensor,自动将[0,255]映射到[0,1]transforms.Normalize((0.491,0.482 ,0.446), (0.247 ,0.243 ,0.261))  # 对张量进行标准化,使其范围为[-1,1]
])

CW实现

def cw_l2_attack(model,images,labels,targeted=True,c=0.1,kappa=0,max_iter=1000,learning_rate=0.01):# 计算损失函数,根据模型输出和目标标签计算一个分数,衡量模型输出的误导程度def f(x):# 论文中的 Z(X) 输出 batchsize, num_classesoutputs = model(x)#将标签转换为one-hot编码形式one_hot_labels = torch.eye(len(outputs[0]),device=device)[labels].to(device)# 水平方向最大的取值,忽略索引。意思是,除去真实标签,看看每个 batchsize 中哪个标签的概率最大,取出概率i, _ = torch.max((1 - one_hot_labels) * outputs, dim=1)# 选择真实标签的概率j = torch.masked_select(outputs, one_hot_labels.bool())# 如果有攻击目标,虚假概率减去真实概率,if targeted:#使模型对目标错误类别的置信度至少比真实类别高 kappa。return torch.clamp(i - j, min=-kappa)# 没有攻击目标,就让真实的概率小于虚假的概率,逐步降低,也就是最小化这个损失else:#降低模型对真实类别的置信度,使其至少低于虚假类别的概率 -kappa。return torch.clamp(j - i, min=-kappa)w = torch.zeros_like(images, requires_grad=True).to(device)#一个与输入图像相同大小的张量,用于存储对抗扰动,并设置True以便后续梯度下降optimizer = optim.Adam([w], lr=learning_rate)#定义优化器#prev = 1e10for step in range(max_iter):a = 1 / 2 * (nn.Tanh()(w) + 1)#扰动应用到图像上,对抗图像# 第一个目标,对抗样本与原始样本足够接近loss1 = nn.MSELoss(reduction='sum')(a, images)# 第二个目标,误导模型输出loss2 = torch.sum(c * f(a))cost = loss1 + loss2optimizer.zero_grad()cost.backward()optimizer.step()# 早停策略 如果连续迭代中损失没有改善,则提前停止攻击if step % (max_iter // 10) == 0:if cost > prev:print('Attack Stopped due to CONVERGENCE....')return aprev = costattack_images = 1 / 2 * (nn.Tanh()(w) + 1)#最终的对抗图像return attack_images

3.导入模型

print('load model')
model = ResNet50()
pth_file = '../checkpoint/resnet50_ckpt.pth'
d = torch.load(pth_file)['net']
d = OrderedDict([(k[7:], v) for (k, v) in d.items()])
model.load_state_dict(d)
model.to(device)
model.eval()

完整代码

import pickle
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import os
from tqdm import tqdm
from collections import OrderedDict
from resnet import ResNet50os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
device = torch.device("cuda")
#数据预处理# 图像预处理操作定义
transform = transforms.Compose([transforms.ToTensor(),  # 将图片转换为Tensor,自动将[0,255]映射到[0,1]transforms.Normalize((0.491,0.482 ,0.446), (0.247 ,0.243 ,0.261))  # 对张量进行标准化,使其范围为[-1,1]
])
class CIFAR10Dataset(Dataset):"""CIFAR-10数据集加载类,支持图像转换操作"""def __init__(self, data, labels, transform=None):self.data = dataself.labels = labelsself.transform = transformdef __len__(self):"""返回数据集中的图像总数"""return len(self.data)def __getitem__(self, idx):"""获取单个图像及其标签,并应用预定义的转换"""image = self.data[idx]label = self.labels[idx]if self.transform:image = self.transform(image)return image, label#CW aTTACK
targeted=3
def cw_l2_attack(model,images,labels,targeted=True,c=0.1,kappa=0,max_iter=1000,learning_rate=0.01):# 计算损失函数,根据模型输出和目标标签计算一个分数,衡量模型输出的误导程度def f(x):# 论文中的 Z(X) 输出 batchsize, num_classesoutputs = model(x)# batchszie,根据labels 的取值,确定每一行哪一个为 1# >>> a = torch.eye(10)[[2, 3]]# >>> a# tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],# [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]])#将标签转换为one-hot编码形式one_hot_labels = torch.eye(len(outputs[0]),device=device)[labels].to(device)# 水平方向最大的取值,忽略索引。意思是,除去真实标签,看看每个 batchsize 中哪个标签的概率最大,取出概率i, _ = torch.max((1 - one_hot_labels) * outputs, dim=1)# 选择真实标签的概率j = torch.masked_select(outputs, one_hot_labels.bool())# 如果有攻击目标,虚假概率减去真实概率,if targeted:#使模型对目标错误类别的置信度至少比真实类别高 kappa。return torch.clamp(i - j, min=-kappa)# 没有攻击目标,就让真实的概率小于虚假的概率,逐步降低,也就是最小化这个损失else:#降低模型对真实类别的置信度,使其至少低于虚假类别的概率 -kappa。return torch.clamp(j - i, min=-kappa)w = torch.zeros_like(images, requires_grad=True).to(device)#一个与输入图像相同大小的张量,用于存储对抗扰动,并设置True以便后续梯度下降optimizer = optim.Adam([w], lr=learning_rate)#定义优化器#prev = 1e10for step in range(max_iter):a = 1 / 2 * (nn.Tanh()(w) + 1)#扰动应用到图像上,对抗图像# 第一个目标,对抗样本与原始样本足够接近loss1 = nn.MSELoss(reduction='sum')(a, images)# 第二个目标,误导模型输出loss2 = torch.sum(c * f(a))cost = loss1 + loss2optimizer.zero_grad()cost.backward()optimizer.step()# 早停策略 如果连续迭代中损失没有改善,则提前停止攻击if step % (max_iter // 10) == 0:if cost > prev:print('Attack Stopped due to CONVERGENCE....')return aprev = costattack_images = 1 / 2 * (nn.Tanh()(w) + 1)#最终的对抗图像return attack_imagesprint('load model')
model = ResNet50()
pth_file = '../checkpoint/resnet50_ckpt.pth'
d = torch.load(pth_file)['net']
d = OrderedDict([(k[7:], v) for (k, v) in d.items()])
model.load_state_dict(d)
model.to(device)
model.eval()# 加载处理好的数据
test_data = np.load('../data/test_data.npy')
test_label = np.load('../data/test_labels.npy')# 实例化数据集
testset = CIFAR10Dataset(data=test_data, labels=test_label, transform=transform)# 创建数据加载器
testloader = DataLoader(testset, batch_size=200, shuffle=False)# CIFAR-10的类别标签
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')save_data=None
print("cw attack...")
for data, label in tqdm(testloader):data, label = data.to(device), label.to(device)adv_data = cw_l2_attack(model=model, images=data, labels=label)if save_data is None:save_data = adv_data.detach_().cpu().numpy()else:save_data = np.concatenate((save_data,adv_data.detach_().cpu().numpy()), axis=0)# 定义保存文件的路径
file_path = 'result/cw2_cifar10_2.npy'# 确保目录存在
if not os.path.exists(os.path.dirname(file_path)):os.makedirs(os.path.dirname(file_path))# 现在可以安全地保存文件
np.save(file_path, save_data)
print('cw2_cifar10_2 has been saved')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/44490.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

“十四五”新型基础设施建设

一、基础设施 基础设施包括交通设施、邮电通讯设施、能源动力设施、供水排水设施、环保设施、防卫防灾安全设施等传统基础设施。这些设施共同构成了一个国家或地区正常运转的支柱,确保社会经济活动的正常进行。 交通设施:交通设施是基础设施的重要组成部…

AbyssFish单连通周期边界多孔结构2D软件

软件介绍 AbyssFish单连通周期边界多孔结构2D软件(以下简称软件)可用于生成具备周期性边界条件的单连通域多孔结构PNG图片,软件可设置生成模型的尺寸、孔隙率、孔隙尺寸、孔喉尺寸等参数,并且具备孔隙形态控制功能。 软件生成的…

视频号热门视频数据分析工具,快速查看同行数据创作者必看!

每天排行榜是帮助创作者查看同行数据为自己提供创作灵感,此外每天排行榜热门的视频收集了用户喜欢看的类型。 灵感分类了解当前社会关注的热点内容。该工具通过监测和分析视频号全网舆情,选取热门话题进行排序,形成一个每日热点排行榜。 这…

yolov8 分类太阳能板

原文:yolov8 分类太阳能板 - 知乎 (zhihu.com) 1、数据集 https://github.com/zae-bayern/elpv-dataset​github.com/zae-bayern/elpv-dataset 2、数据分析 import matplotlib.pyplot as plt import ostrain_dir = "./images" valid_extensions=(.jpg, .png, .j…

华为防火墙上的配置(1)

实验拓扑图 实验要求: 1、DMZ区内的服务器,生产区仅能在办公时间内(9:00-18:00)可以访问,办公区的设备全天可以访问 2、生产区不允许访问互联网,办公区和游客区允许访问互联网 3、办公区设备10.0.2.10不…

00:HAL库的认识

一:HAL库 开发现状: 1:下载 网站: https://www.st.com/zh/embedded-software/stm32cube-mcu-mpu-packages.html 去选择我们的系列 我们使用的是STM32F103C8t6的这个 继续一直向下拉点击这个;之后傻瓜步骤直接可以…

最新2023年行政区划、路网、土壤质地矢量数据

行政区划矢量数据是指用矢量格式表示的地理信息系统(GIS)数据,其中包含了行政区域的边界信息,如国家、省份、城市、区县、乡镇甚至村级的界限。这些数据通常以点、线、面的几何图形来表示具体的地理实体,并且每个实体都…

亚马逊erp跟卖采集之关键词采集

大家好,今天讲这款erp的跟卖采集关键词采集。 打开erp跟卖功能采集任务,点新增任务站点美国,有5种采集方式:关键词、店铺链接、类目ASIN。 选择关键词采集,这里我选择女童装,选择女童板鞋复制粘贴。页数我…

新书速览|HTML5+CSS3 Web前端开发与实例教程:微课视频版

《HTML5CSS3 Web前端开发与实例教程:微课视频版》 本书内容 《HTML5CSS3 Web前端开发与实例教程:微课视频版》秉承“思政引领,立德树人”的教育理念,自然融入多维度、深层次的思政元素,全面对标企业和行业需求&#x…

Chameleon:动态UI框架使用详解

文章目录 引言Chameleon框架原理核心概念工作流程 基础使用安装与配置创建基础界面 高级使用自定义组件响应式布局数据流与状态管理 结论 引言 Chameleon,作为一种动态UI框架,旨在通过灵活、高效的方式帮助开发者构建跨平台、响应用户交互的图形用户界面…

ant-design-vue表格设置某列标题部分文字颜色

在ant-design-vue的表格组件中&#xff0c;可以通过使用slot自定义列头&#xff08;title&#xff09;的内容来实现部分文字的颜色设置。以下是一个简单的例子&#xff0c;展示如何设置某列标题部分文字颜色为红色&#xff1a; <template><a-table :columns"col…

iwconfig iwpriv学习之路

iwconfig和iwpriv是两个常用的wifi调试工具&#xff0c;最近需要使用这两个工具完成某款wifi芯片的定频测试&#xff0c;俗话说好记性不如烂笔头&#xff0c;于是再此记录下iwconfig和iwpriv的使用方式。 -----再牛逼的梦想&#xff0c;也抵不住傻逼般的坚持&#xff01; ----2…

单向链表队列

实现单向链表队列的&#xff0c;创建&#xff0c;入队&#xff0c;出队&#xff0c;遍历&#xff0c;长度&#xff0c;销毁。 queue.h #ifndef __QUEUE_H__ #define __QUEUE_H__#include <stdio.h> #include <stdlib.h> #include <string.h> #define max 30…

大语言模型里的微调vs RAG vs 模板提示词

文章目录 介绍微调&#xff08;Fine-tuning&#xff09;定义优点&#xff1a;缺点&#xff1a;应用场景&#xff1a;技术细节 检索增强生成&#xff08;RAG&#xff0c;Retrieval-Augmented Generation&#xff09;定义优点&#xff1a;缺点&#xff1a;应用场景&#xff1a;技…

鸿蒙开发:Universal Keystore Kit(密钥管理服务)【密钥派生(ArkTS)】

密钥派生(ArkTS) 以HKDF256密钥为例&#xff0c;完成密钥派生。具体的场景介绍及支持的算法规格。 开发步骤 生成密钥 指定密钥别名。 初始化密钥属性集&#xff0c;可指定参数HUKS_TAG_DERIVED_AGREED_KEY_STORAGE_FLAG&#xff08;可选&#xff09;&#xff0c;用于标识基…

jvm 06 补充 OOM 和具体工具使用

1.OOM 是什么 OOM&#xff0c;全称“Out Of Memory”&#xff0c;翻译成中文就是“内存用完了”&#xff0c;来源于java.lang.OutOfMemoryError。看下关于的官方说明&#xff1a; Thrown when the Java Virtual Machine cannot allocate an object because it is out of memor…

三角函数 积化和差、和差化积公式

积化和差公式 公式1 s i n A ⋅ s i n B − 1 2 [ c o s ( A B ) − c o s ( A − B ) ] \mathrm{sin}A\cdot\mathrm{sin}B-\dfrac{1}{2}[\mathrm{cos}(AB)-\mathrm{cos}(A-B)] sinA⋅sinB−21​[cos(AB)−cos(A−B)]. − 1 2 [ c o s ( A B ) − c o s ( A − B ) ] -\dfra…

电机学-绪论

绪论 电机&#xff1a;根据电磁感应定律和电磁力定律实现机电能量转换和信号传递与转换的电磁机械装置。 电磁感应定律&#xff1a; BiliBili: 法拉第电磁感应定律 BiliBili: 楞次定律 BiliBili: 左手定则、右手定则、右手螺旋定则

数据结构JAVA

1.数据结构之栈和队列 栈结构 先进后出 队列结构 先进先出 队列 2.数据结构之数组和链表 数组结构 查询快、增删慢 队列结构 查询慢、增删快 链表的每一个元素我们叫结点 每一个结点都是独立的对象

对于多个表多个字段进行查询、F12查看网页的返回数据帮助开发、数据库的各种查询方式(多对多、多表查询、子查询等)。

对于多个表多个字段进行查询、F12查看网页的返回数据帮助开发、数据库的各种查询方式&#xff08;多对多、多表查询、子查询等&#xff09;。 一、 前端界面需要展现多个表的其中几个数据的多表查询。1. 三个表查询其中字段返回&#xff1a;&#xff08;用一下sql语句&#xff…