demo_GAN

# 导入PyTorch库,这是一个用于深度学习的开源库
import torch
# 导入PyTorch的神经网络模块(nn),用于定义神经网络结构
import torch.nn as nn
# 导入PyTorch的函数式模块(functional),提供了一些常用的激活函数和损失函数等
import torch.nn.functional as F
# 导入PyTorch的优化器模块(optim),用于定义优化算法,如梯度下降等
import torch.optim as optim
# 从PyTorch的数据加载器模块中导入DataLoader和TensorDataset类,用于加载和处理数据集
from torch.utils.data import DataLoader, TensorDataset
# 从torchvision库的实用工具模块中导入save_image函数,用于保存生成的图像
from torchvision.utils import save_image
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
# 导入os模块,用于处理文件和目录操作
import os
import matplotlib.pyplot as plt# 自注意力机制模块定义
class SelfAttention(nn.Module):def __init__(self, in_dim):super(SelfAttention, self).__init__()self.query = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)self.key = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, C, width, height = x.size()query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)key = self.key(x).view(batch_size, -1, width * height)energy = torch.bmm(query, key)attention = F.softmax(energy, dim=-1)value = self.value(x).view(batch_size, -1, width * height)out = torch.bmm(value, attention.permute(0, 2, 1))out = out.view(batch_size, C, width, height)out = self.gamma * out + xreturn out# Generator Model定义了一个名为Generator的神经网络模型,它继承自PyTorch框架中的nn.Module类
class Generator(nn.Module):def __init__(self, noise_dim, label_dim):super(Generator, self).__init__()self.label_dim = label_dim# 定义了一个名为self.fc的神经网络层序列,含三个层,输入层:随机噪声和标签,批量归一化层,漏洞型relu层self.fc = nn.Sequential(nn.Linear(noise_dim + label_dim, 1024 * 2 * 2),nn.BatchNorm1d(1024 * 2 * 2),# 这是一个Leaky ReLU激活函数层,它的作用是将负数的输入值乘以一个小的常数(这里是0.2),然后将结果作为输出,在原始数据上进行操作nn.ReLU(inplace=True))# Hidden Layers: Deconv + BN + Leaky ReLUself.deconv_layers = nn.Sequential(nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),  # 2x2 -> 4x4nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # 4x4 -> 8x8nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 8x8 -> 16x16nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 16x16 -> 32x32nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),  # 32x32 -> 64x64nn.BatchNorm2d(32),nn.ReLU(inplace=True),# 增加一层,扩展到128x128nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),  # 64x64 -> 128x128nn.BatchNorm2d(16),nn.ReLU(inplace=True),nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1),  # 128x128 -> 128x128 (RGB)nn.Tanh()  # 输出层,范围[-1, 1])def forward(self, noise, labels):# 拼接噪声和标签x = torch.cat((noise, labels), dim=1)x = self.fc(x).view(-1, 1024, 2, 2)return self.deconv_layers(x)class Discriminator(nn.Module):def __init__(self, input_channels):super(Discriminator, self).__init__()# 第一层:并行卷积层(3×3和5×5卷积核),后续拼接self.conv1_3x3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=2, padding=1)),nn.ReLU(inplace=True))self.conv1_5x5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=5, stride=2, padding=2)),nn.ReLU(inplace=True),)# (2) conv + BN + leaky Relu (dilation rate 1)self.conv2 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, dilation=1)),nn.ReLU(inplace=True),)# (3) conv + BN + leaky Relu + self-attention mechanismself.conv3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)),nn.ReLU(inplace=True),SelfAttention(256))# (4) conv + BN + leaky Relu (parallel 3x3, 5x5, and 7x7 kernels)self.conv4_3x3 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)),nn.ReLU(inplace=True),)self.conv4_5x5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=5, stride=2, padding=2)),nn.ReLU(inplace=True),)self.conv4_7x7 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=7, stride=2, padding=3)),nn.ReLU(inplace=True),)# (5) conv + BN + leaky Relu (dilation rate 3)self.conv5 = nn.Sequential(nn.utils.spectral_norm(nn.Conv2d(in_channels=1536, out_channels=1024, kernel_size=3, stride=1, padding=3, dilation=3)),nn.ReLU(inplace=True),)# (6) 用卷积层替换全连接层,输出1x1特征图,并使用sigmoid激活函数self.fc = nn.utils.spectral_norm(nn.Linear(1024 * 4 * 4, 1))self.sigmoid = nn.Sigmoid()def forward(self, x):x1 = self.conv1_3x3(x)x2 = self.conv1_5x5(x)x = torch.cat((x1, x2), dim=1)x = self.conv2(x)x = self.conv3(x)x1 = self.conv4_3x3(x)x2 = self.conv4_5x5(x)x3 = self.conv4_7x7(x)x = torch.cat((x1, x2, x3), dim=1)x = self.conv5(x)x = nn.AvgPool2d(2)(x)x = x.view(x.size(0), -1)  # Flatten the tensorx = self.fc(x)x = self.sigmoid(x)return x# 设置超参数
noise_dim = 100  # 噪声维度
label_dim = 58  # 标签维度
batch_size =64  # 批大小
learning_rate = 0.0001
num_epochs = 500  # 训练轮数
output_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/MMSGAN"  # 生成图像保存路径# 确保输出目录存在
if not os.path.exists(output_dir):os.makedirs(output_dir)# 创建生成器和判别器
G = Generator(noise_dim=noise_dim, label_dim=label_dim).to('cuda')
D = Discriminator(input_channels=3).to('cuda')# TrafficSignDataset类,用于数据加载
class TrafficSignDataset(Dataset):def __init__(self, root_dir, labels_file, transform=None):self.root_dir = root_dirself.transform = transformself.image_paths = []self.labels = []with open(labels_file, 'r') as f:lines = f.readlines()for line in lines:img_name, label = line.strip().split()img_path = os.path.join(root_dir, img_name)self.image_paths.append(img_path)self.labels.append(int(label))def __len__(self):return len(self.image_paths)def __getitem__(self, idx):img_path = self.image_paths[idx]image = Image.open(img_path).convert('RGB')label = self.labels[idx]if self.transform:image = self.transform(image)return image, label# 损失函数和优化器
criterion = nn.BCELoss()  # 二元交叉熵损失
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate*4,betas=(0.5, 0.999),weight_decay=1e-4)
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate, betas=(0.5, 0.999))
# 设置学习率衰减参数
decay = 0.0001
num_epochs = 500# 训练循环
for epoch in range(num_epochs):# ... 训练过程 ...# 更新学习率lr_new_G = learning_rate * 4 / (1 + decay * num_epochs)lr_new_D = learning_rate / (1 + decay * num_epochs)for param_group in optimizer_G.param_groups:param_group['lr'] = lr_new_Gfor param_group in optimizer_D.param_groups:param_group['lr'] = lr_new_D# 定义图像预处理和数据增强
transform = transforms.Compose([transforms.Resize((128, 128)),  # 调整图像大小# 这个操作会将图像数据从0-255的整数值范围(如果是uint8类型)转换为0-1之间的浮点数范围,并且会将图像的形状从(H, W, C)转换为(C, H, W),其中H是高度,W是宽度,C是通道数。这样做是为了符合PyTorch模型的输入要求.transforms.ToTensor(),  # 转换为 Tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到 [-1, 1]
])# 创建数据集和数据加载器
root_dir = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct"
labels_file = "D:/PycharmProjects/GAN/traffic sign datasets/CTSDB/resized_images/BorderlineSMOTE -insepct/labels.txt"  # 标签文件路径
dataset = TrafficSignDataset(root_dir=root_dir, labels_file=labels_file, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)# 生成一批随机标签(整数)及其对应的独热编码(one-hot encoding),独热编码提供了一种方便的方式来表示真实标签,使得我们可以使用交叉熵损失等损失函数来计算预测值与真实值之间的差异。
def create_labels(batch_size, label_dim):labels = torch.randint(0, label_dim, (batch_size,))labels_one_hot = torch.zeros(batch_size, label_dim).scatter_(1, labels.view(-1, 1), 1)return labels.to('cuda'), labels_one_hot.to('cuda')def train():torch.cuda.empty_cache()# 初始化空列表,用于存储生成器和判别器的损失值
d_losses = []
g_losses = []# 训练循环
for epoch in range(num_epochs):for i, (real_images, _) in enumerate(dataloader):real_images = real_images.to('cuda')# 1. 训练判别器# 真实数据损失# 这行代码调用了一个名为create_labels的函数,该函数接收两个参数:real_images.size(0)表示真实图像的数量,label_dim表示标签的维度real_labels, real_labels_one_hot = create_labels(real_images.size(0), label_dim)real_outputs = D(real_images)noise_real = torch.rand_like(real_outputs) * -0.1real_loss = criterion(real_outputs, torch.full_like(real_outputs, 0.8) + noise_real)# 生成数据损失noise = torch.randn(real_images.size(0), noise_dim).to('cuda')fake_labels, fake_labels_one_hot = create_labels(real_images.size(0), label_dim)fake_images = G(noise, fake_labels_one_hot)fake_outputs = D(fake_images.detach())# 为假标签加入随机噪声(0, 0.1)noise_fake = torch.rand_like(fake_outputs) * 0.1fake_loss = criterion(fake_outputs, torch.full_like(fake_outputs, 0.2) + noise_fake)# 判别器总损失d_loss = real_loss + fake_lossoptimizer_D.zero_grad()d_loss.backward()optimizer_D.step()# 2. 训练生成器fake_outputs = D(fake_images)g_loss = criterion(fake_outputs, torch.ones_like(fake_outputs))optimizer_G.zero_grad()g_loss.backward()optimizer_G.step()# 追加损失值到列表中d_losses.append(d_loss.item())g_losses.append(g_loss.item())# 打印损失值if i % 50 == 0:print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(dataloader)}], "f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")# 每隔一定步保存生成的图像if i % 200 == 0:# 保存每一张生成的图像for idx in range(min(30, fake_images.size(0))):  # 遍历生成的每一张图像save_image(fake_images[idx],os.path.join(output_dir, f"epoch_{epoch + 1}_image_{idx + 1}.png"),normalize=True)  # 保存每一张图像,命名方式包括epoch, step, 和图像编号print("训练完成并保存生成图像。")# 绘制生成器和判别器的损失曲线
plt.figure(figsize=(10, 5))
plt.plot(d_losses, label='Discriminator Loss', color='blue')
plt.plot(g_losses, label='Generator Loss', color='red')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Loss During Training')
plt.legend()
plt.grid()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/882433.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

把其他.ui文件拿到我的工程中使用

在Qt工程中使用工程外的ui文件的方式:作为一个类直接使用、包含到自己的类中或继承使用 将ui文件添加到工程中,作为一个类以直接使用 注:这里指使用原本不属于该工程的ui文件第一步:在工程文件.pro中添加UI文件 在.proj文件中添加…

每日一题——第一百一十七题

题目&#xff1a;使用二分查找&#xff0c;查找一个数是否存在于一个升序数组中 #include <stdio.h>int binarySearch(int arr[], int length, int elem);int main() {int arr[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 };int key;int length sizeof(arr) / sizeof(arr[0]);pri…

Palo Alto Networks Expedition 未授权SQL注入漏洞复现(CVE-2024-9465)

0x01 产品介绍&#xff1a; Palo Alto Networks Expedition 是一款强大的工具&#xff0c;帮助用户有效地迁移和优化网络安全策略&#xff0c;提升安全管理的效率和效果。它的自动化功能、策略分析和可视化报告使其在网络安全领域中成为一个重要的解决方案。 0x02 漏洞描述&am…

windows下安装、配置neo4j并服务化启动

第一步&#xff1a;下载Neo4j压缩包 官网下载地址&#xff1a;https://neo4j.com/download-center/ &#xff08;官网下载真的非常慢&#xff0c;而且会自己中断&#xff0c;建议从以下链接下载&#xff09; 百度网盘下载地址&#xff1a;链接&#xff1a;https://pan.baid…

FFMPEG录屏(17)--- 使用 DwmRegisterThumbnail 捕获指定窗口图像数据

使用 DwmRegisterThumbnail 捕获指定窗口图像数据 在 Windows 平台上&#xff0c;捕获指定窗口的图像数据可以通过多种方法实现&#xff0c;其中一种高效的方法是使用 [DwmRegisterThumbnail] 本文将介绍如何使用 [DwmRegisterThumbnail] 捕获窗口图像数据&#xff0c;并提供一…

Mysql中表字段VARCHAR(N)类型及长度的解释

本文将针对MySQL 中 varchar (N)类型字段的存储方式进行解释&#xff0c;主要是对字符和字节的关系的理解。 1. varchar (N) 中的 N varchar (N) 中的 N 表示字符数&#xff0c;而不是字节数。这意味着 N 表示你可以存储多少个字符。 字符数&#xff1a;指的是字符的个数&…

计算机视觉在疲劳检测中的应用

计算机视觉在疲劳检测中的应用 引言 随着科技的飞速发展&#xff0c;计算机视觉技术已经广泛应用于各个领域&#xff0c;其中疲劳检测是近年来备受关注的一个研究方向。疲劳检测旨在通过计算机视觉技术&#xff0c;实时分析个体的面部特征、动作以及生理信号等&#xff0c;判…

周易解读:八卦02,八卦所代表的基本事物

八 卦02 上一节&#xff0c;我是讲完了八卦的卦象的画法的问题。这一节&#xff0c;我来尝试着去讲解八卦所代表的自然事物。 八卦是谁发明的呢&#xff1f;根据《周易说卦传》的说法&#xff0c;八卦是伏羲发明的。伏羲氏仰观天文&#xff0c;俯察地理&#xff0c;从中提取…

项目模块二:日志宏

一、代码展示 二、补充知识 1、LOG(level, format, ...) format 是用于宏识别格式化&#xff0c;类似于 printf("%s", str); 里面的 "%s" ... 不定参&#xff0c;传入宏的参数除了 level, format, 还有不确定个数的参数。 2、红色 \ 由于宏只能写在一…

PyTorch深度学习入门汇总

PyTorch 是由 Facebook 的人工智能研究小组开发的深度学习框架&#xff0c;可以基于PyTorch开发和训练各种深度学习模型。自 2016 年问世以来&#xff0c;PyTorch 因其灵活性和易用性而受到深度学习从业者的极大关注。 汇总目录 基于conda包的环境创建、激活、管理与删除 Pyt…

链上相遇,节点之间的悸动与牵连

公主请阅 1. 返回倒数第 k 个节点1.1 题目说明1.2 题目分析1.3 解法一代码以及解释1.3 解法二代码以及解释 2.相交链表2.1 题目说明示例 1示例 2示例 3 2.2 题目分析2.3 代码部分2.4 代码分析 1. 返回倒数第 k 个节点 题目传送门 1.1 题目说明 题目名称&#xff1a; 面试题 02…

15分钟学 Go 第 10 天:函数参数和返回值

第10天&#xff1a;函数参数和返回值 目标&#xff1a;理解函数如何传递参数 在Go语言中&#xff0c;函数是程序的基本构建块。了解如何传递参数和返回值是编写高效、可复用代码的重要步骤。本文将详细讲解函数参数的类型、传递方式以及如何处理返回值&#xff0c;辅以代码示…

用C++编写一个简单的游戏引擎:从游戏循环到物理与渲染的全面解析

解锁Python编程的无限可能:《奇妙的Python》带你漫游代码世界 构建一个基础的2D游戏引擎是一项富有挑战性但极具学习价值的任务。本文将通过从零开始的方式,逐步讲解如何使用C++开发一个简单的游戏引擎。内容涵盖了游戏引擎的核心架构设计,包括游戏循环、物理引擎和图形渲染…

DP—子数组,子串系列 第一弹 -最大子数组和 -环形子数组的最大和 力扣

你好&#xff0c;欢迎阅读我的文章~ 个人主页&#xff1a;Mike 所属专栏&#xff1a;动态规划 ​ 53. 最大子数组和 最大子数组和 ​ 分析: 使用动态规划解决 状态表示: 1.以某个位置为结尾 2.以某个位置为起点 这里使用以某个位置为结尾&#xff0c;结合题目要求&#…

MySQL8.0主从同步报ERROR 13121错误解决方法

由于平台虚拟机宿主机迁移&#xff0c;导致一套MySQL主从库从节点故障&#xff0c;从节点服务终止&#xff0c;在服务启动后&#xff0c;恢复从节点同步服务&#xff0c;发现了如下报错&#xff1a; mysql> show slave status\G; *************************** 1. row *****…

GDAL+C#实现矢量多边形转栅格

1. 开发环境测试 参考C#配置GDAL环境&#xff0c;确保GDAL能使用&#xff0c;步骤简述如下&#xff1a; 创建.NET Framework 4.7.2的控制台应用 注意&#xff1a; 项目路径中不要有中文&#xff0c;否则可能报错&#xff1a;can not find proj.db 在NuGet中安装GDAL 3.9.1和G…

OSI参考模型详解:初学者指南与实践案例

OSI参考模型详解&#xff1a;初学者指南与实践案例 OSI&#xff08;Open System Interconnect&#xff09;参考模型是一个由国际标准化组织&#xff08;ISO&#xff09;提出的七层网络分层模型&#xff0c;它为全球所有互联计算机系统提供了一个通用的通信框架&#xff0c;解决…

【Mysql】-锁机制-GAP锁

在 MySQL 的 InnoDB 存储引擎中&#xff0c;Gap 锁&#xff08;间隙锁&#xff09;是一种用于防止幻读的锁机制。幻读是指在一个事务中&#xff0c;多次执行相同的查询&#xff0c;结果集却不同&#xff0c;通常是由于其他事务插入了新的行。为了防止这种情况&#xff0c;InnoD…

无人机之自主飞行关键技术篇

无人机自主飞行指的是无人机利用先进的算法和传感器&#xff0c;实现自我导航、路径规划、环境感知和自动避障等能力。这种飞行模式大大提升了无人机的智能化水平和操作的自动化程度。 一、传感器技术 传感器是无人机实现自主飞行和数据采集的关键组件&#xff0c;主要包括&a…

软考-软件设计师(10)-专业英语词汇汇总与新技术知识点

场景 以下为高频考点、知识点汇总。 软件设计师上午选择题知识点、高频考点、口诀记忆技巧、经典题型汇总: 软考-软件设计师(1)-计算机基础知识点:进制转换、数据编码、内存编址、串并联可靠性、海明校验码、吞吐率、多媒体等: 软考-软件设计师(1)-计算机基础知识点:进制…