生成式 AI:使用 Pytorch 通过 GAN 生成合成数据

导 读

生成对抗网络(GAN)因其生成图像的能力而变得非常受欢迎,而语言模型(例如 ChatGPT)在各个领域的使用也越来越多。这些 GAN 模型可以说是人工智能/机器学习目前主流的原因;

因为它向每个人(尤其是该领域之外的人)展示了机器学习所具有的巨大潜力。网上已经有很多关于 GAN 模型的资源,但其中大多数都集中在图像生成上。这些图像生成和语言模型需要复杂的空间或时间复杂性,这增加了额外的复杂性,使读者更难理解 GAN 的真正本质。

为了解决这个问题并使 GAN 更容易被更广泛的受众所接受,在本文的 GAN 模型示例中,我们将采取一种不同的、更实用的方法,重点关注生成数学函数的合成数据。

除了出于学习目的的简化之外,合成数据生成本身也变得越来越重要。数据不仅在业务决策中发挥着核心作用,而且数据驱动方法的用途也越来越多,比第一原理模型更受欢迎。

比如天气预报,第一个原理模型包括通过数值求解的纳维-斯托克斯方程的简化版本。然而,深度学习研究中进行天气预报的尝试在捕捉天气模式方面非常成功,并且一旦经过训练,运行起来会更容易、更快。

有需要的朋友关注公众号【小Z的科研日常】,获取更多内容

01、生成模型与判别模型

在机器学习中,理解判别模型和生成模型之间的区别非常重要,因为它们是 GAN 的关键组成部分:

判别模型:

判别模型侧重于将数据分类为预定义的类别,例如将狗和猫的图像分类为各自的类别。这些模型不是捕获整个分布,而是辨别不同类别的边界。它们输出 P(y|x)(类别概率,给定输入数据的 y,x),即它们回答给定数据点属于哪个类别的问题。

生成模型:

生成模型旨在理解数据的底层结构。与区分类别的判别模型不同,生成模型学习数据的整个分布。这些模型输出 p(x|y),即它们回答了给定指定类生成该特定数据点的可能性有多大的问题。

这两个模型之间的相互作用构成了 GAN 的基础。

02、GAN—结构和组件

GAN 的关键组件包括噪声向量、生成器和鉴别器。

生成器:生成真实数据

为了生成合成数据,生成器使用随机噪声向量作为输入。为了欺骗鉴别器,生成器的目的是学习真实数据的分布并生成无法与真实数据区分开的合成数据。这里的一个问题是,对于相同的输入,它总是会产生相同的输出(想象一个图像生成器产生真实的图像,但总是相同的图像,这不是很有用)。随机噪声向量将随机性注入到过程中,从而提供生成的输出的多样性。

鉴别器:辨别真假

鉴别器就像一位受过训练来区分真实数据和虚假数据的艺术评论家。它的作用是仔细检查收到的数据并为工作真实性分配概率分数。如果合成数据看起来与真实数据相似,则鉴别器分配高概率,否则分配低概率分数。

对抗性训练:动态决斗

生成器努力学习生成鉴别器无法与真实数据区分开的合成数据。同时,鉴别器还学习并提高区分真实与合成的能力。这种动态的训练过程促使两个模型提高技能。这两个模型总是相互竞争(因此被称为对抗性),并且通过这种竞争,两个模型都在各自的角色中变得非常出色。

03、Pytorch实现GAN

在此示例中,我们在 pytorch 中实现了一个可以生成合成数据的模型。对于训练,我们有一个具有以下形状的 6 参数数据集(所有参数都绘制为参数 1 的函数)。每个参数都经过精心选择,具有显着不同的分布和形状,以增加数据集的复杂性并模仿真实世界的数据。

定义 GAN 模型组件(生成器和判别器)

import torch
from torch import nn
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn.init as init
import pandas as pd
import numpy as np
from torch.utils.data import Dataset# 定义单块功能
def FC_Layer_blockGen(input_dim, output_dim):single_block = nn.Sequential(nn.Linear(input_dim, output_dim),nn.ReLU())return single_block# 定义 GENERATOR
class Generator(nn.Module):def __init__(self, latent_dim, output_dim):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(latent_dim, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, output_dim),nn.Tanh()  )def forward(self, x):return self.model(x)#定义单个判别块
def FC_Layer_BlockDisc(input_dim, output_dim):return nn.Sequential(nn.Linear(input_dim, output_dim),nn.ReLU(),nn.Dropout(0.4))# 定义判别器class Discriminator(nn.Module):def __init__(self, input_dim):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(input_dim, 512),nn.ReLU(),nn.Dropout(0.4),nn.Linear(512, 512),nn.ReLU(),nn.Dropout(0.4),nn.Linear(512, 256),nn.ReLU(),nn.Dropout(0.4),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):return self.model(x)#定义训练参数
batch_size = 128
num_epochs = 500
lr = 0.0002
num_features = 6
latent_dim = 20# 模型初始化
generator = Generator(noise_dim, num_features)
discriminator = Discriminator(num_features)# 损失函数和优化器
criterion = nn.BCELoss()
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)

模型初始化和数据处理

file_path = 'SamplingData7.xlsx'
data = pd.read_excel(file_path)
X = data.values
X_normalized = torch.FloatTensor((X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0)) * 2 - 1)
real_data = X_normalizedclass MyDataset(Dataset):def __init__(self, dataframe):self.data = dataframe.values.astype(float)self.labels = dataframe.values.astype(float)def __len__(self):return len(self.data)def __getitem__(self, idx):sample = {'input': torch.tensor(self.data[idx]),'label': torch.tensor(self.labels[idx])}return sample# 创建数据集实例
dataset = MyDataset(data)# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)def weights_init(m):if isinstance(m, nn.Linear):init.xavier_uniform_(m.weight)if m.bias is not None:init.constant_(m.bias, 0)pretrained = False
if pretrained:pre_dict = torch.load('pretrained_model.pth')generator.load_state_dict(pre_dict['generator'])discriminator.load_state_dict(pre_dict['discriminator'])
else:# 应用权重初始化generator = generator.apply(weights_init)discriminator = discriminator.apply(weights_init)

模型训练

model_save_freq = 100latent_dim =20
for epoch in range(num_epochs):for batch in dataloader:real_data_batch = batch['input']real_labels = torch.FloatTensor(np.random.uniform(0.9, 1.0, (batch_size, 1)))disc_optimizer.zero_grad()output_real = discriminator(real_data_batch)loss_real = criterion(output_real, real_labels)loss_real.backward()fake_labels = torch.FloatTensor(np.random.uniform(0, 0.1, (batch_size, 1)))noise = torch.FloatTensor(np.random.normal(0, 1, (batch_size, latent_dim)))generated_data = generator(noise)output_fake = discriminator(generated_data.detach())loss_fake = criterion(output_fake, fake_labels)loss_fake.backward()disc_optimizer.step()valid_labels = torch.FloatTensor(np.random.uniform(0.9, 1.0, (batch_size, 1)))gen_optimizer.zero_grad()output_g = discriminator(generated_data)loss_g = criterion(output_g, valid_labels)loss_g.backward()gen_optimizer.step()print(f"Epoch {epoch}, D Loss Real: {loss_real.item()}, D Loss Fake: {loss_fake.item()}, G Loss: {loss_g.item()}")

模型评估和可视化结果

import seaborn as snssynthetic_data = generator(torch.FloatTensor(np.random.normal(0, 1, (real_data.shape[0], noise_dim))))# 绘制结果
fig, axs = plt.subplots(2, 3, figsize=(12, 8))
fig.suptitle('Real and Synthetic Data Distributions', fontsize=16)for i in range(2):for j in range(3):sns.histplot(synthetic_data[:, i * 3 + j].detach().numpy(), bins=50, alpha=0.5, label='Synthetic Data', ax=axs[i, j], color='blue')sns.histplot(real_data[:, i * 3 + j].numpy(), bins=50, alpha=0.5, label='Real Data', ax=axs[i, j], color='orange')axs[i, j].set_title(f'Parameter {i * 3 + j + 1}', fontsize=12)axs[i, j].set_xlabel('Value')axs[i, j].set_ylabel('Frequency')axs[i, j].legend()plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()#创建 2x3 网格的子绘图
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Comparison of Real and Synthetic Data', fontsize=16)# Define parameter names
param_names = ['Parameter 1', 'Parameter 2', 'Parameter 3', 'Parameter 4', 'Parameter 5', 'Parameter 6']# 各参数的散点图
for i in range(2):for j in range(3):param_index = i * 3 + jsns.scatterplot(real_data[:, 0].numpy(), real_data[:, param_index].numpy(), label='Real Data', alpha=0.5, ax=axs[i, j])sns.scatterplot(synthetic_data[:, 0].detach().numpy(), synthetic_data[:, param_index].detach().numpy(), label='Generated Data', alpha=0.5, ax=axs[i, j])axs[i, j].set_title(param_names[param_index], fontsize=12)axs[i, j].set_xlabel(f'Real Data - {param_names[param_index]}')axs[i, j].set_ylabel(f'Real Data - {param_names[param_index]}')axs[i, j].legend()plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/737982.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows下IntelliJ IDEA远程连接服务器中Hadoop运行WordCount(详细版)

使用IDEA直接运行Hadoop项目,有两种方式,分别是本地式:本地安装HadoopIDEA;远程式:远程部署Hadoop,本地安装IDEA并连接, 本文介绍第二种。 一、安装配置Hadoop (1)虚拟机伪分布式 见上才艺&a…

机器学习-04-分类算法-01决策树

总结 本系列是机器学习课程的系列课程,主要介绍机器学习中分类算法,本篇为分类算法开篇与决策树部分。 本门课程的目标 完成一个特定行业的算法应用全过程: 懂业务会选择合适的算法数据处理算法训练算法调优算法融合 算法评估持续调优工程…

酷开科技发力研发酷开系统,让家庭娱乐生活更加丰富多彩

在这个快节奏的社会,家庭娱乐已成为我们日常生活中不可或缺的一部分,为了给家庭带来更多欢笑与感动,酷开科技发力研发出拥有丰富内容和技术的智能电视操作系统——酷开系统,它集合了电影、电视剧、综艺、游戏、音乐等海量内容&…

Python 导入Excel三维坐标数据 生成三维曲面地形图(面) 4-3、线条平滑曲面(原始颜色)去除无效点

环境和包: 环境 python:python-3.12.0-amd64包: matplotlib 3.8.2 pandas 2.1.4 openpyxl 3.1.2 scipy 1.12.0 代码: import pandas as pd import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from scipy.interpolate import griddata fr…

Ubuntu 14.04:安装PaddlePaddle(Conda安装)

目录 一、PaddlePaddle 概要 二、PaddlePaddle安装要求 三、PaddlePaddle安装 3.1 安装 Anaconda3 3.2 创建Anaconda虚拟环境(python 3.8) 3.3 进入Anaconda虚拟环境 3.4 检测 Anaconda 虚拟环境配置是否符合PaddlePaddle安装要求 3.4.1 确认 py…

Python中的异常处理及最佳实践【第125篇—异常处理】

Python中的异常处理及最佳实践 异常处理是编写健壮、可靠和易于调试的Python代码中不可或缺的一部分。在本文中,我们将深入探讨Python中的异常处理机制,并分享一些最佳实践和代码示例,以帮助您更好地处理错误情况和提高代码的稳定性。 异常…

山姆・阿尔特曼重返OpenAI董事会;Car-GPT:LLMs能否最终实现自动驾驶?

🦉 AI新闻 🚀 山姆・阿尔特曼重返OpenAI董事会 摘要:经历长达数月的审查后,山姆・阿尔特曼已重返OpenAI董事会,并作为返回条件之一,OpenAI还新增了三名外部女性董事会成员。这标志着公司正努力摆脱去年11…

电子价签前景璀璨,汉朔科技革新零售行业的数字化新篇章

新型商超模式数字化“秘密武器”——电子价签 传统纸质价签,只要商品价格、日期等信息发生变化,就必须重新打印进行手动替换。电子价签的应用使传统的人工申请、调价、打印、营业员去货架前端更换等变价流程均可省略,所有门店的价格由后台统…

【c 语言】算术操作符详解

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:C语言 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步&…

Transformer模型引领NLP革新之路

在不到4 年的时间里,Transformer 模型以其强大的性能和创新的思想,迅速在NLP 社区崭露头角,打破了过去30 年的记录。BERT、T5 和GPT 等模型现在已成为计算机视觉、语音识别、翻译、蛋白质测序、编码等各个领域中新应用的基础构件。因此&#…

SpringMVC | SpringMVC中的 “数据绑定”

目录: “数据绑定” 介绍1.简单数据绑定 :绑定 “默认数据” 类型绑定 “简单数据类型” 类型 (绑定Java“基本数据类型”)绑定 “POJO类型”绑定 “包装 POJO”“自定义数据” 绑定 :Converter (自定义转换器) 作者简介 :一只大皮卡丘&#…

【Linux】Linux上的一些软件安装与环境配置(Centos7配置JDK、Hadoop)

文章目录 安装JDK配置环境变量1. 卸载已安装的JDK查询已安装的 jdk 列表删除已经安装的 jdk 2. 上传安装包3. 创建 /usr/local/java 文件夹4. 将 jdk 压缩包解压到 /usr/local/java 目录下5. 配置 jdk 的环境变量6. 让配置文件生效7. 校验8.拍个快照吧,免得后面哪里…

机器学习概论—正则化

机器学习概论—正则化 在开发机器学习模型的过程中,大家一定遇到过模型在训练集上表现不错,但验证精度或测试精度过低的情况。这种情况在机器学习领域通常被称为过度拟合,这也是机器学习从业者最不希望在他的模型中出现的情况。 在本文中,我们将学习一种称为正则化的方法…

ETAS工具链ISOLAR-AB重要概念,RTE配置,ECU抽取

RTE配置界面,包含ECU抽取关联 首次配置RTE,出现需要勾选的抽取EXTRACT 创建System System制作SWC到ECU的Mapping System制作System Data 的Mapping

如何解决ChatGPT消息发不出问题,GPT消息无法发出去,没有响应的问题

前言 今天工作到一半,登陆ChatGPT想咨询一些代码上的问题,结果发现发不了消息了。 ChatGPT 无法发送消息,但是能查看历史的对话。不过首先可以先打开官方的网站:https://status.openai.com/ 。 查看当前Open AI的状态&#xff0…

【动态规划】代码随想录算法训练营第五十一天 | 309.最佳买卖股票时机含冷冻期, 714.买卖股票的最佳时机含手续费,总结(待补充)

309.最佳买卖股票时机含冷冻期 1、题目链接:. - 力扣(LeetCode) 2、文章讲解:代码随想录 3、题目: 给定一个整数数组,其中第 i 个元素代表了第 i 天的股票价格 。 设计一个算法计算出最大利润。在满足…

Python安装第三方库

前言:大部分时候我们都是使用pip install去安装一些第三方库,但是偶尔也会有部分库无法安装(最典型的就是dlib这个库),需要采取别的方法解决,这里做笔记记录一下。 使用国内镜像源安装 因为pypi的服务器在…

最新android icon和splashScreen适配兼容至2024android

android在12做了splashScreen的变动,即,android12有自带的screenSplash过渡,不论你是否自己有变化,都会插入该动画。 android8做了icon的巨大变动。13做了图标的主题兼容。 一、icon制作 制作 使用android自带的工具&#xff0…

03:HAL---中断

目录 一:中断 1:简历 2:AFIO 3:EXTI 4:NVIC基本结构 5:使用步骤 6:设计中断函数 二:中断的应用 A:对外式红外传感计数器 1:硬件介绍 2:计数代码 B:旋转编码计数器 1:硬件介绍 2:旋转编码器代码 C:按键控制LED D:代码总结 一:中断 1:简历 中断:在主程序…

第二十二周周报

论文研读:Camera Distance-aware Top-down Approach for 3D Multi-person Pose Estimation from a Single RGB Image 粗读10篇文献。 通过图2 我可以知道这个论文大概实现的这个姿态估计效果的方法,首先是把图片输入到DetectNet网络,该网络…