demo_GAN

# 导入PyTorch库,这是一个用于深度学习的开源库
import torch
# 导入PyTorch的神经网络模块(nn),用于定义神经网络结构
import torch.nn as nn
# 导入PyTorch的函数式模块(functional),提供了一些常用的激活函数和损失函数等
import torch.nn.functional as F
# 导入PyTorch的优化器模块(optim),用于定义优化算法,如梯度下降等
import torch.optim as optim
# 从PyTorch的数据加载器模块中导入DataLoader和TensorDataset类,用于加载和处理数据集
from torch.utils.data import DataLoader, TensorDataset
# 从torchvision库的实用工具模块中导入save_image函数,用于保存生成的图像
from torchvision.utils import save_image
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
# 导入os模块,用于处理文件和目录操作
import os
import matplotlib.pyplot as plt# 自注意力机制模块定义
class SelfAttention(nn.Module):def __init__(self, in_dim):super(SelfAttention, self).__init__()self.query = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)self.key = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, C, width, height = x.size()query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)key = self.key(x).view(batch_size, -1, width * height)energy = torch.bmm(query, key)attention = F.softmax(energy, dim=-1)value = self.value(x).view(batch_size, -1, width * height)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, C, width, height)out = self.gamma * out + xreturn out# Generator Model定义了一个名为Generator的神经网络模型,它继承自PyTorch框架中的nn.Module类
class Generator(nn.Module):def __init__(self, noise_dim, label_dim):super(Generator, self).__init__()self.label_dim = label_dim# 定义了一个名为self.fc的神经网络层序列,含三个层,输入层:随机噪声和标签,批量归一化层,漏洞型relu层self.fc = nn.Sequential(nn.Linear(noise_dim + label_dim, 1024 * 2 * 2),nn.BatchNorm1d(1024 * 2 * 2),# 这是一个Leaky ReLU激活函数层,它的作用是将负数的输入值乘以一个小的常数(这里是0.2),然后将结果作为输出,在原始数据上进行操作nn.ReLU(inplace=True))# Hidden Layers: Deconv + BN + Leaky ReLUself.deconv_layers = nn.Sequential(nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),  # 2x2 -> 4x4nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 4x4 -> 8x8nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 8x8 -> 16x16nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 16x16 -> 32x32nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 32x32 -> 64x64nn.BatchNorm2d(32),nn.ReLU(inplace=True),# 增加一层,扩展到128x128nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # 64x64 -> 128x128nn.BatchNorm2d(16),nn.ReLU(inplace=True),nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1),  # 128x128 -> 128x128 (RGB)nn.Tanh()  # 输出层,范围[-1, 1])def forward(self, noise, labels):# 拼接噪声和标签x = torch.cat((noise, labels), dim=1)x = self.fc(x).view(-1, 1024, 2, 2)return self.deconv_layers(x)class Discriminator(nn.Module):def __init__(self, input_channels):super(Discriminator, self).__init__()# 第一层:并行卷积层(3×3和5×5卷积核),后续拼接self.conv1_3x3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1)),nn.ReLU(inplace=True))self.conv1_5x5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=5, stride=2, padding=2)),nn.ReLU(inplace=True),)# (2) conv + BN + leaky Relu (dilation rate 1)self.conv2 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1)),nn.ReLU(inplace=True),)# (3) conv + BN + leaky Relu + self-attention mechanismself.conv3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)),nn.ReLU(inplace=True),SelfAttention(256))# (4) conv + BN + leaky Relu (parallel 3x3, 5x5, and 7x7 kernels)self.conv4_3x3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)),nn.ReLU(inplace=True),)self.conv4_5x5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2)),nn.ReLU(inplace=True),)self.conv4_7x7 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=7, stride=2, padding=3)),nn.ReLU(inplace=True),)# (5) conv + BN + leaky Relu (dilation rate 3)self.conv5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=1536, out_channels=1024, kernel_size=3, stride=1, padding=3, dilation=3)),nn.ReLU(inplace=True),)# (6) 用卷积层替换全连接层,输出1x1特征图,并使用sigmoid激活函数self.fc = nn.utils.spectral_norm(nn.Linear(1024 * 4 * 4, 1))self.sigmoid = nn.Sigmoid()def forward(self, x):x1 = self.conv1_3x3(x)x2 = self.conv1_5x5(x)x = torch.cat((x1, x2), dim=1)x = self.conv2(x)x = self.conv3(x)x1 = self.conv4_3x3(x)x2 = self.conv4_5x5(x)x3 = self.conv4_7x7(x)x = torch.cat((x1, x2, x3), dim=1)x = self.conv5(x)x = nn.AvgPool2d(2)(x)x = x.view(x.size(0), -1)  # Flatten the tensorx = self.fc(x)x = self.sigmoid(x)return x# 设置超参数
noise_dim = 100  # 噪声维度
label_dim = 58  # 标签维度
batch_size =64  # 批大小
learning_rate = 0.0001
num_epochs = 500  # 训练轮数
output_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/MMSGAN"  # 生成图像保存路径# 确保输出目录存在
if not os.path.exists(output_dir):os.makedirs(output_dir)# 创建生成器和判别器
G = Generator(noise_dim=noise_dim, label_dim=label_dim).to('cuda')
D = Discriminator(input_channels=3).to('cuda')# TrafficSignDataset类,用于数据加载
class TrafficSignDataset(Dataset):def __init__(self, root_dir, labels_file, transform=None):self.root_dir = root_dirself.transform = transformself.image_paths = []self.labels = []with open(labels_file, 'r') as f:lines = f.readlines()for line in lines:img_name, label = line.strip().split()img_path = os.path.join(root_dir, img_name)self.image_paths.append(img_path)self.labels.append(int(label))def __len__(self):return len(self.image_paths)def __getitem__(self, idx):img_path = self.image_paths[idx]image = Image.open(img_path).convert('RGB')label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 损失函数和优化器
criterion = nn.BCELoss()  # 二元交叉熵损失
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate*4,betas=(0.5, 0.999),weight_decay=1e-4)
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
# 设置学习率衰减参数
decay = 0.0001
num_epochs = 500# 训练循环
for epoch in range(num_epochs):# ... 训练过程 ...# 更新学习率lr_new_G = learning_rate * 4 / (1 + decay * num_epochs)lr_new_D = learning_rate / (1 + decay * num_epochs)for param_group in optimizer_G.param_groups:param_group['lr'] = lr_new_Gfor param_group in optimizer_D.param_groups:param_group['lr'] = lr_new_D# 定义图像预处理和数据增强
transform = transforms.Compose([transforms.Resize((128, 128)),  # 调整图像大小# 这个操作会将图像数据从0-255的整数值范围(如果是uint8类型)转换为0-1之间的浮点数范围,并且会将图像的形状从(H, W, C)转换为(C, H, W),其中H是高度,W是宽度,C是通道数。这样做是为了符合PyTorch模型的输入要求.transforms.ToTensor(),  # 转换为 Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到 [-1, 1]
])# 创建数据集和数据加载器
root_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct"
labels_file = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct/labels.txt"  # 标签文件路径
dataset = TrafficSignDataset(root_dir=root_dir, labels_file=labels_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 生成一批随机标签(整数)及其对应的独热编码(one-hot encoding),独热编码提供了一种方便的方式来表示真实标签,使得我们可以使用交叉熵损失等损失函数来计算预测值与真实值之间的差异。
def create_labels(batch_size, label_dim):labels = torch.randint(0, label_dim, (batch_size,))labels_one_hot = torch.zeros(batch_size, label_dim).scatter_(1, labels.view(-1, 1), 1)return labels.to('cuda'), labels_one_hot.to('cuda')def train():torch.cuda.empty_cache()# 初始化空列表,用于存储生成器和判别器的损失值
d_losses = []
g_losses = []# 训练循环
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):real_images = real_images.to('cuda')# 1. 训练判别器# 真实数据损失# 这行代码调用了一个名为create_labels的函数,该函数接收两个参数:real_images.size(0)表示真实图像的数量,label_dim表示标签的维度real_labels, real_labels_one_hot = create_labels(real_images.size(0), label_dim)real_outputs = D(real_images)noise_real = torch.rand_like(real_outputs) * -0.1real_loss = criterion(real_outputs, torch.full_like(real_outputs, 0.8) + noise_real)# 生成数据损失noise = torch.randn(real_images.size(0), noise_dim).to('cuda')fake_labels, fake_labels_one_hot = create_labels(real_images.size(0), label_dim)fake_images = G(noise, fake_labels_one_hot)fake_outputs = D(fake_images.detach())# 为假标签加入随机噪声(0, 0.1)noise_fake = torch.rand_like(fake_outputs) * 0.1fake_loss = criterion(fake_outputs, torch.full_like(fake_outputs, 0.2) + noise_fake)# 判别器总损失d_loss = real_loss + fake_lossoptimizer_D.zero_grad()d_loss.backward()optimizer_D.step()# 2. 训练生成器fake_outputs = D(fake_images)g_loss = criterion(fake_outputs, torch.ones_like(fake_outputs))optimizer_G.zero_grad()g_loss.backward()optimizer_G.step()# 追加损失值到列表中d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 打印损失值if i % 50 == 0:print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")# 每隔一定步保存生成的图像if i % 200 == 0:# 保存每一张生成的图像for idx in range(min(30, fake_images.size(0))):  # 遍历生成的每一张图像save_image(fake_images[idx],os.path.join(output_dir, f"epoch_{epoch + 1}_image_{idx + 1}.png"),normalize=True)  # 保存每一张图像,命名方式包括epoch, step, 和图像编号print("训练完成并保存生成图像。")# 绘制生成器和判别器的损失曲线
plt.figure(figsize=(10, 5))
plt.plot(d_losses, label='Discriminator Loss', color='blue')
plt.plot(g_losses, label='Generator Loss', color='red')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Loss During Training')
plt.legend()
plt.grid()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/882433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Palo Alto Networks Expedition 未授权SQL注入漏洞复现(CVE-2024-9465)

0x01 产品介绍: Palo Alto Networks Expedition 是一款强大的工具,帮助用户有效地迁移和优化网络安全策略,提升安全管理的效率和效果。它的自动化功能、策略分析和可视化报告使其在网络安全领域中成为一个重要的解决方案。 0x02 漏洞描述&am…

windows下安装、配置neo4j并服务化启动

第一步:下载Neo4j压缩包 官网下载地址:https://neo4j.com/download-center/ (官网下载真的非常慢,而且会自己中断,建议从以下链接下载) 百度网盘下载地址:链接:https://pan.baid…

周易解读:八卦02,八卦所代表的基本事物

八 卦02 上一节,我是讲完了八卦的卦象的画法的问题。这一节,我来尝试着去讲解八卦所代表的自然事物。 八卦是谁发明的呢?根据《周易说卦传》的说法,八卦是伏羲发明的。伏羲氏仰观天文,俯察地理,从中提取…

项目模块二:日志宏

一、代码展示 二、补充知识 1、LOG(level, format, ...) format 是用于宏识别格式化,类似于 printf("%s", str); 里面的 "%s" ... 不定参,传入宏的参数除了 level, format, 还有不确定个数的参数。 2、红色 \ 由于宏只能写在一…

链上相遇,节点之间的悸动与牵连

公主请阅 1. 返回倒数第 k 个节点1.1 题目说明1.2 题目分析1.3 解法一代码以及解释1.3 解法二代码以及解释 2.相交链表2.1 题目说明示例 1示例 2示例 3 2.2 题目分析2.3 代码部分2.4 代码分析 1. 返回倒数第 k 个节点 题目传送门 1.1 题目说明 题目名称: 面试题 02…

15分钟学 Go 第 10 天:函数参数和返回值

第10天:函数参数和返回值 目标:理解函数如何传递参数 在Go语言中,函数是程序的基本构建块。了解如何传递参数和返回值是编写高效、可复用代码的重要步骤。本文将详细讲解函数参数的类型、传递方式以及如何处理返回值,辅以代码示…

DP—子数组,子串系列 第一弹 -最大子数组和 -环形子数组的最大和 力扣

你好,欢迎阅读我的文章~ 个人主页:Mike 所属专栏:动态规划 ​ 53. 最大子数组和 最大子数组和 ​ 分析: 使用动态规划解决 状态表示: 1.以某个位置为结尾 2.以某个位置为起点 这里使用以某个位置为结尾,结合题目要求&#…

MySQL8.0主从同步报ERROR 13121错误解决方法

由于平台虚拟机宿主机迁移,导致一套MySQL主从库从节点故障,从节点服务终止,在服务启动后,恢复从节点同步服务,发现了如下报错: mysql> show slave status\G; *************************** 1. row *****…

GDAL+C#实现矢量多边形转栅格

1. 开发环境测试 参考C#配置GDAL环境,确保GDAL能使用,步骤简述如下: 创建.NET Framework 4.7.2的控制台应用 注意: 项目路径中不要有中文,否则可能报错:can not find proj.db 在NuGet中安装GDAL 3.9.1和G…

无人机之自主飞行关键技术篇

无人机自主飞行指的是无人机利用先进的算法和传感器,实现自我导航、路径规划、环境感知和自动避障等能力。这种飞行模式大大提升了无人机的智能化水平和操作的自动化程度。 一、传感器技术 传感器是无人机实现自主飞行和数据采集的关键组件,主要包括&a…

C语言复习第3章 函数

目录 一、函数介绍1.1 函数是什么1.2 C语言中函数的分类1.3 函数原型1.4 高内聚 低耦合1.5 C语言main函数的位置 二、函数的参数2.1 实参和形参2.2 函数的参数(实参)可以是表达式2.3 传值与传址(swap函数)2.4 明确形参是实参的临时拷贝2.5 void(如果不写函数返回值 默认是int)2…

python 爬虫 入门 三、登录以及代理。

目录 一、登录 (一)、登录4399 1.直接使用Cookie 2.使用账号密码进行登录 可选观看内容,使用python对密码进行加密(无结果代码,只有过程分析) 二、代理 免费代理 后续:协程,…

企业级调度器 LVS

集群和分布式基础知识 系统性能的扩展方式 当一个系统,或一个服务的请求量达到一定的数量级的时候,运行该服务的服务器的性能和资源上限, 很容易成为其性能瓶颈。除了性能问题之外,如果只部署在单台服务器上,在此服务…

gitee建立/取消关联仓库

目录 一、常用指令总结 二、建立关联具体操作 三、取消关联具体操作 一、常用指令总结 首先要选中要关联的文件,右击,选择Git Bash Here。 git remote -v //查看自己的文件有几个关联的仓库git init //初始化文件夹为git可远程建立链接的文件夹…

uniapp uni.uploadFile errMsg: “uploadFile:fail

uniapp 上传后一直显示加载中 1.检查前后端上传有无问题 2.检查失败信息 await uni.uploadFile({url,filePath,name,formData,header,timeout: 30000000, // 自定义上传超时时间fail: async function(err) {$util.hideAll()// 失败// err 返回 {errMsg: "uploadFile:fai…

SpringCloud学习:Openfeign组件实现服务调用和负载均衡

OpenFeign:服务调用与负载均衡(服务端接口) 是什么:通过OpenFeign可以实现服务调用和负载均衡 OpenFeign是一个声明性web服务客户端, 怎么用:服务提供者提取公共接口用FrignClient标注,服务调…

kernel32.dll下载地址:如何安全地恢复系统文件

关于从网络上寻找kernel32.dll的下载地址,这通常不是一个安全的做法,而且可能涉及到多种风险。kernel32.dll是Windows操作系统的核心组件之一,负责内存管理、进程和线程管理以及其他关键系统功能。因为kernel32.dll是系统的基础文件&#xff…

信息安全工程师(57)网络安全漏洞扫描技术与应用

一、网络安全漏洞扫描技术概述 网络安全漏洞扫描技术是一种可以自动检测计算机系统和网络设备中存在的漏洞和弱点的技术。它通过使用特定的方法和工具,模拟攻击者的攻击方式,从而检测存在的漏洞和弱点。这种技术可以帮助组织及时发现并修补漏洞&#xff…

【数据结构与算法】链表(上)

记录自己所学&#xff0c;无详细讲解 无头单链表实现 1.项目目录文件 2.头文件 Slist.h #include <stdio.h> #include <assert.h> #include <stdlib.h> struct Slist {int data;struct Slist* next; }; typedef struct Slist Slist; //初始化 void SlistI…

C++20中头文件ranges的使用

<ranges>是C20中新增加的头文件&#xff0c;提供了一组与范围(ranges)相关的功能&#xff0c;此头文件是ranges库的一部分。包括&#xff1a; 1.concepts: (1).std::ranges::range:指定类型为range&#xff0c;即它提供开始迭代器和结束标记(it provides a begin iterato…