变分自编码器生成新的手写数字图像

变分自编码器(Variational Autoencoder,VAE)是一种生成模型,通常用于学习数据的潜在表示,并用于生成新的数据样本。它由两部分组成:编码器和解码器。

  1. 编码器(Encoder):接收输入数据,并将其映射到潜在空间中的分布。这意味着编码器将数据转换为均值和方差参数的分布,通常假设为高斯分布。

  2. 解码器(Decoder):接收来自编码器的潜在表示,并将其映射回原始数据空间。解码器尝试从潜在空间中的样本中生成与输入数据尽可能接近的重建数据。

VAE的目标是学习一个能够生成与训练数据类似的数据分布。为了实现这一点,VAE采用了一种被称为变分推断的方法,其中引入了一个额外的损失项,即KL散度,用于度量生成的潜在分布与预先设定的先验分布之间的差异。

VAE将经过神经网络编码后的隐藏层假设为一个标准的高斯分布,然后再从这个分布中采样一个特征,再用这个特征进行解码,期望得到与原始输入相同的结果,损失和AE几乎一样,只是增加编码推断分布与标准高斯分布的KL散度的正则项,显然增加这个正则项的目的就是防止模型退化成普通的AE,因为网络训练时为了尽量减小重构误差,必然使得方差逐渐被降到0,这样便不再会有随机采样噪声,也就变成了普通的AE。

举例来说,假设我们有一组手写数字的图像作为输入数据。我们可以使用VAE来学习手写数字的潜在表示,并用此表示来生成新的手写数字图像。编码器将输入图像转换为潜在空间中的分布,解码器则将从该分布中采样的样本映射回原始图像空间。通过训练编码器和解码器,VAE可以生成与训练数据类似的手写数字图像,同时学习数据的潜在结构。

以下是使用 PyTorch 实现的简单示例代码,演示了如何使用变分自编码器(VAE)学习手写数字的潜在表示,并用此表示来生成新的手写数字图像:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt# 定义变分自编码器模型
class VAE(nn.Module):def __init__(self, input_dim, latent_dim):super(VAE, self).__init__()# 编码器部分self.encoder = nn.Sequential(nn.Linear(input_dim, 512),nn.ReLU(),nn.Linear(512, 256),nn.ReLU(),nn.Linear(256, latent_dim * 2)  # 输出均值和方差参数)# 解码器部分self.decoder = nn.Sequential(nn.Linear(latent_dim, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, input_dim),nn.Sigmoid()  # 输出范围在 0 到 1 之间)def reparameterize(self, mu, logvar):std = torch.exp(0.5 * logvar)eps = torch.randn_like(std)return mu + eps * stddef forward(self, x):# 编码z_mu_logvar = self.encoder(x)mu, logvar = torch.chunk(z_mu_logvar, 2, dim=1)# 重参数化z = self.reparameterize(mu, logvar)# 解码x_recon = self.decoder(z)return x_recon, mu, logvar# 计算重构损失和 KL 散度
def loss_function(x_recon, x, mu, logvar):recon_loss = nn.BCELoss(reduction='sum')(x_recon, x)  # 二进制交叉熵损失kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())return recon_loss + kl_divergence# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Lambda(lambda x: x.view(-1))  # 将图像展平成向量
])# 加载 MNIST 数据集
train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)# 初始化模型和优化器
latent_dim = 20
input_dim = 784  # 28x28
model = VAE(input_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=1e-3)# 训练模型
num_epochs = 20
for epoch in range(num_epochs):total_loss = 0for batch_idx, (x, _) in enumerate(train_loader):optimizer.zero_grad()x_recon, mu, logvar = model(x)loss = loss_function(x_recon, x, mu, logvar)loss.backward()optimizer.step()total_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader.dataset)}")# 使用训练好的模型生成新的手写数字图像
with torch.no_grad():z = torch.randn(10, latent_dim)  # 生成 10 个随机潜在向量generated_images = model.decoder(z)generated_images = generated_images.view(-1, 1, 28, 28)  # 将向量转换成图像形状# 可视化生成的图像
fig, axes = plt.subplots(1, 10, figsize=(10, 1))
for i, ax in enumerate(axes):ax.imshow(generated_images[i][0], cmap='gray')ax.axis('off')
plt.show()

这段代码首先定义了一个简单的变分自编码器模型,然后使用 MNIST 数据集训练该模型,最后使用训练好的模型生成新的手写数字图像。

参考 【PyTorch】变分自编码器/Variational Autoencoder(VAE)_variantautoencoder(vae)pytorch-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/801967.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用Echarts词云数据可视化热词表白​​

目录 1、使用前准备 2、准备工作 3、盒子搭建 4、整体展现 1、使用前准备 找到表白对象(重中之重!),不要一见钟情(个人觉得:一见钟情属于见色起意!),因为数据可视化需…

海外仓为何要做仓库管理系统?位像素海外仓系统的仓库管理功能有哪些?

在当今繁荣的跨境电商市场中,海外仓已经成为了许多电商企业的重要选择。但是,海外仓的成功与否并不仅仅取决于其位置和规模,同样重要的是其仓库管理系统的有效性。那么,海外仓为何要做仓库管理系统呢?让我们一起来探讨…

“成像光谱遥感技术中的AI革命:ChatGPT在遥感领域中的应用“

遥感技术主要通过卫星和飞机从远处观察和测量我们的环境,是理解和监测地球物理、化学和生物系统的基石。ChatGPT是由OpenAI开发的最先进的语言模型,在理解和生成人类语言方面表现出了非凡的能力。本文重点介绍ChatGPT在遥感中的应用,人工智能…

可视化大屏的应用(9):智慧旅游和智慧景区

可视化大屏在智慧旅游领域具有多种价值,可以为旅游管理者和游客提供更加便捷、优质的服务和体验。本期大千UI工场带来智慧旅游和智慧景区的可视化大屏界面,供大家欣赏。 可视化大屏在智慧旅游领域的价值如下: 提供全面的信息展示&#xff0…

数据结构:构建完全二叉查找树

文章目录 1、步骤 1: 对给定数组排序2、步骤 2: 递归构建完全二叉查找树3、注意4、在有序数组中寻找根结点位置5、代码实现6、其他方法?基本思路插入操作删除操作特别考虑 对于一个给定序列的二叉查找树,有很多种,但是完全二叉查找树只有一种…

【网站项目】医院核酸检测预约挂号小程序

🙊作者简介:拥有多年开发工作经验,分享技术代码帮助学生学习,独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。🌹赠送计算机毕业设计600个选题excel文件,帮助大学选题。赠送开题报告模板&#xff…

4.1-4.5算法刷题笔记(17道题)

4.1-4.5算法刷题笔记 1. 区间和2. 区间合并3. 用队列实现栈(queueMain queueTemp;)4. 最小栈 1. 单链表模板5. 单链表 2. 双链表模板6. 双链表 3. 模拟栈7. 模拟栈(一个数组即可)8. 表达式求值 4. 队列 tt -1,hh 0;9. 模拟队列 5. 单调栈10. 单调栈 6…

【接口自动化】参数化替换

在做接口测试时,除了测单个接口,还需要进行业务链路间的接口测试 比如[注册-登陆]需要token鉴权的业务流 当我们用使用postman/jmeter等工具时,将注册接口的一些响应信息提取出来,放到登陆接口的请求中,来完成某个业务…

在Gazebo中如何拯救翻车的机器人

Gazebo提供了一些交互工具,允许你直接通过图形界面操作模型: 启动Gazebo后,在右侧工具栏中,你会找到一个可以拖拽物体的图标(通常是一个手掌图标或者类似的)。点击这个图标。 随后,你可以用鼠标…

Linux/Lame

Lame 今天随便乱逛发现这台机器貌似是 HackTheBox 平台的第一台机器,而且我还没做过,从简介上来看的话是一台很简单的机器,快快的玩一下 Enumeration nmap 首先用 nmap 扫描一下常见的端口,发现系统对外开放了 21,22,139,445 端…

面试算法-160-合并两个有序链表

题目 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1: 输入:l1 [1,2,4], l2 [1,3,4] 输出:[1,1,2,3,4,4] 解 class Solution {public ListNode mergeTwoLists(ListNode li…

NineData创始人CEO叶正盛受邀参加『数据技术嘉年华』的技术大会

4月13日,NineData 创始人&CEO叶正盛受邀参加第13届『数据技术嘉年华』的技术大会。将和数据领域的技术爱好者一起相聚,并分享《NineData在10000公里跨云数据库间实时数据复制技术原理与实践》主题内容。 分享嘉宾 叶正盛,NineData CEO …

多线程同步计数器CountDownLatch,CyclicBarrier,Semaphore

系列文章目录 文章目录 系列文章目录前言前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 CountDownLatch CountDownLatch是一个同步工具类,它允许一个或多个线程等…

【学习】软件验收测试,能否选择第三方检测机构进行测试?

随着信息技术的快速发展,软件已经成为各行各业中不可或缺的一部分。为了保证软件的质量和稳定性,验收测试成为了软件开发过程中至关重要的一环。那么,第三方软件测试机构可以做验收测试吗?我们一起来看下今日的分享。 一、验收测…

MySQL操作DML

目录 1.概述 2.插入 3.更新 4.删除 5.查询 6.小结 1.概述 数据库DML是数据库操作语言(Data Manipulation Language)的简称,主要用于对数据库中的数据进行增加、修改、删除等操作。它是SQL语言的一部分,用于实现对数据库中数…

diffusion model(十五) : IP-Adapter技术小结

infopaperhttps://arxiv.org/pdf/2308.06721.pdfcodehttps://github.com/tencent-ailab/IP-Adapterorg.Tencent AI Lab个人博客地址http://myhz0606.com/article/ip_adapter 1 Motivation 为了对文生图diffusion model进行特定概念的定制,常用LoRA[1]、textual in…

Android Studio 生成 keystore 签名文件及打包验证流程

一、创建keystore签名文件 1、在菜单栏中,依次点击 Build - Generate Signed Bundle/Apk...(生成签名) 2、选择 APK 选项,点击按钮 Next 到下一步 3、新建key store秘钥文件,点击按钮 Next 到下一步 4、按如下提示填写信息,点击按…

JAVA POI Excel 使用数组公式 FREQUENCY

平台及依赖 JAVA 17POI版本 <dependency><groupId>org.apache.poi</groupId><artifactId>poi</artifactId><version>5.2.5</version></dependency><dependency><groupId>org.apache.poi</groupId><art…

【RISC-V 指令集】RISC-V 向量V扩展指令集介绍(四)- 配置和设置指令(vsetvli/vsetivli/vsetvl)

1. 引言 以下是《riscv-v-spec-1.0.pdf》文档的关键内容&#xff1a; 这是一份关于向量扩展的详细技术文档&#xff0c;内容覆盖了向量指令集的多个关键方面&#xff0c;如向量寄存器状态映射、向量指令格式、向量加载和存储操作、向量内存对齐约束、向量内存一致性模型、向量…

蓝桥杯模拟赛练习题—— 燃烧你的卡路里

目标 请在 js/index.js 和 index.html 文件中补全代码&#xff0c;完成以下目标&#xff1a; 点击“定制方案”按钮后&#xff0c;弹出侧滑页面&#xff0c;所使用的组件为 el-drawer&#xff0c;相关属性如下&#xff1a; 参数说明类型默认值v-model是否显示 Drawerboolean…