VAE-pytorch代码

 

 

 

import osimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoaderfrom torchvision import transforms, datasets
from torchvision.utils import save_imagefrom tqdm import tqdmclass VAE(nn.Module):  # 定义VAE模型def __init__(self, img_size, latent_dim):  # 初始化方法super(VAE, self).__init__()  # 继承初始化方法self.in_channel, self.img_h, self.img_w = img_size  # 由输入图片形状得到图片通道数C、图片高度H、图片宽度Wself.h = self.img_h // 32  # 经过5次卷积后,最终特征层高度变为原图片高度的1/32self.w = self.img_w // 32  # 经过5次卷积后,最终特征层宽度变为原图片高度的1/32hw = self.h * self.w  # 最终特征层的尺寸hxwself.latent_dim = latent_dim  # 采样变量Z的长度self.hidden_dims = [32, 64, 128, 256, 512]  # 特征层通道数列表# 开始构建编码器Encoderlayers = []  # 用于存放模型结构for hidden_dim in self.hidden_dims:  # 循环特征层通道数列表layers += [nn.Conv2d(self.in_channel, hidden_dim, 3, 2, 1),  # 添加convnn.BatchNorm2d(hidden_dim),  # 添加bnnn.LeakyReLU()]  # 添加leakyreluself.in_channel = hidden_dim  # 将下次循环的输入通道数设为本次循环的输出通道数self.encoder = nn.Sequential(*layers)  # 解码器Encoder模型结构self.fc_mu = nn.Linear(self.hidden_dims[-1] * hw, self.latent_dim)  # linaer,将特征向量转化为分布均值muself.fc_var = nn.Linear(self.hidden_dims[-1] * hw, self.latent_dim)  # linear,将特征向量转化为分布方差的对数log(var)# 开始构建解码器Decoderlayers = []  # 用于存放模型结构self.decoder_input = nn.Linear(self.latent_dim, self.hidden_dims[-1] * hw)  # linaer,将采样变量Z转化为特征向量self.hidden_dims.reverse()  # 倒序特征层通道数列表for i in range(len(self.hidden_dims) - 1):  # 循环特征层通道数列表layers += [nn.ConvTranspose2d(self.hidden_dims[i], self.hidden_dims[i + 1], 3, 2, 1, 1),  # 添加transconvnn.BatchNorm2d(self.hidden_dims[i + 1]),  # 添加bnnn.LeakyReLU()]  # 添加leakyrelulayers += [nn.ConvTranspose2d(self.hidden_dims[-1], self.hidden_dims[-1], 3, 2, 1, 1),  # 添加transconvnn.BatchNorm2d(self.hidden_dims[-1]),  # 添加bnnn.LeakyReLU(),  # 添加leakyrelunn.Conv2d(self.hidden_dims[-1], img_size[0], 3, 1, 1),  # 添加convnn.Tanh()]  # 添加tanhself.decoder = nn.Sequential(*layers)  # 编码器Decoder模型结构def encode(self, x):  # 定义编码过程result = self.encoder(x)  # Encoder结构,(n,1,32,32)-->(n,512,1,1)result = torch.flatten(result, 1)  # 将特征层转化为特征向量,(n,512,1,1)-->(n,512)mu = self.fc_mu(result)  # 计算分布均值mu,(n,512)-->(n,128)log_var = self.fc_var(result)  # 计算分布方差的对数log(var),(n,512)-->(n,128)return [mu, log_var]  # 返回分布的均值和方差对数def decode(self, z):  # 定义解码过程y = self.decoder_input(z).view(-1, self.hidden_dims[0], self.h,self.w)  # 将采样变量Z转化为特征向量,再转化为特征层,(n,128)-->(n,512)-->(n,512,1,1)y = self.decoder(y)  # decoder结构,(n,512,1,1)-->(n,1,32,32)return y  # 返回生成样本Ydef reparameterize(self, mu, log_var):  # 重参数技巧std = torch.exp(0.5 * log_var)  # 分布标准差stdeps = torch.randn_like(std)  # 从标准正态分布中采样,(n,128)return mu + eps * std  # 返回对应正态分布中的采样值def forward(self, x):  # 前传函数mu, log_var = self.encode(x)  # 经过编码过程,得到分布的均值mu和方差对数log_varz = self.reparameterize(mu, log_var)  # 经过重参数技巧,得到分布采样变量Zy = self.decode(z)  # 经过解码过程,得到生成样本Yreturn [y, x, mu, log_var]  # 返回生成样本Y,输入样本X,分布均值mu,分布方差对数log_vardef sample(self, n, cuda):  # 定义生成过程z = torch.randn(n, self.latent_dim)  # 从标准正态分布中采样得到n个采样变量Z,长度为latent_dimif cuda:  # 如果使用cudaz = z.cuda()  # 将采样变量Z加载到GPUimages = self.decode(z)  # 经过解码过程,得到生成样本Yreturn images  # 返回生成样本Ydef loss_fn(y, x, mu, log_var):  # 定义损失函数recons_loss = F.mse_loss(y, x)  # 重建损失,MSEkld_loss = torch.mean(0.5 * torch.sum(mu ** 2 + torch.exp(log_var) - log_var - 1, 1), 0)  # 分布损失,正态分布与标准正态分布的KL散度return recons_loss + w * kld_loss  # 最终损失由两部分组成,其中分布损失需要乘上一个系数wif __name__ == "__main__":total_epochs = 100  # epochsbatch_size = 64  # batch sizelr = 5e-4  # lrw = 0.00025  # kld_loss的系数wnum_workers = 8  # 数据加载线程数image_size = 32  # 图片尺寸image_channel = 1  # 图片通道latent_dim = 128  # 采样变量Z长度sample_images_dir = "sample_images"  # 生成样本示例存放路径train_dataset_dir = "../dataset/mnist"  # 训练样本存放路径os.makedirs(sample_images_dir, exist_ok=True)  # 创建生成样本示例存放路径os.makedirs(train_dataset_dir, exist_ok=True)  # 创建训练样本存放路径cuda = True if torch.cuda.is_available() else False  # 如果cuda可用,则使用cudaimg_size = (image_channel, image_size, image_size)  # 输入样本形状(1,32,32)vae = VAE(img_size, latent_dim)  # 实例化VAE模型,传入输入样本形状与采样变量长度if cuda:  # 如果使用cudavae = vae.cuda()  # 将模型加载到GPU# dataset and dataloadertransform = transforms.Compose(  # 图片预处理方法[transforms.Resize(image_size),  # 图片resize,(28x28)-->(32,32)transforms.ToTensor(),  # 转化为tensortransforms.Normalize([0.5], [0.5])]  # 标准化)dataloader = DataLoader(  # 定义dataloaderdataset=datasets.MNIST(root=train_dataset_dir,  # 使用mnist数据集,选择数据路径train=True,  # 使用训练集transform=transform,  # 图片预处理download=True),  # 自动下载batch_size=batch_size,  # batch sizenum_workers=num_workers,  # 数据加载线程数shuffle=True  # 打乱数据)# optimizeroptimizer = torch.optim.Adam(vae.parameters(), lr=lr)  # 使用Adam优化器# train loopfor epoch in range(total_epochs):  # 循环epochtotal_loss = 0  # 记录总损失pbar = tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{total_epochs}", postfix=dict,miniters=0.3)  # 设置当前epoch显示进度for i, (img, _) in enumerate(dataloader):  # 循环iterif cuda:  # 如果使用cudaimg = img.cuda()  # 将训练数据加载到GPUvae.train()  # 模型开始训练optimizer.zero_grad()  # 模型清零梯度y, x, mu, log_var = vae(img)  # 输入训练样本X,得到生成样本Y,输入样本X,分布均值mu,分布方差对数log_varloss = loss_fn(y, x, mu, log_var)  # 计算lossloss.backward()  # 反向传播,计算当前梯度optimizer.step()  # 根据梯度,更新网络参数total_loss += loss.item()  # 累计losspbar.set_postfix(**{"Loss": loss.item()})  # 显示当前iter的losspbar.update(1)  # 步进长度pbar.close()  # 关闭当前epoch显示进度print("total_loss:%.4f" %(total_loss / len(dataloader)))  # 显示当前epoch训练完成后,模型的总损失vae.eval()  # 模型开始验证sample_images = vae.sample(25, cuda)  # 获得25个生成样本save_image(sample_images.data, "%s/ep%d.png" % (sample_images_dir, (epoch + 1)), nrow=5,normalize=True)  # 保存生成样本示例(5x5)

其中计算KLloss的代码的解释如下:

代码的目标是计算变分自编码器(VAE)中近似后验分布q(z∣x) 和标准正态分布 p(z) 之间的KL散度。KL散度公式的具体计算步骤如下:

1. mu ** 2

计算均值的平方项: μ2 这个项是为了衡量均值偏离零的程度。

2. torch.exp(log_var)

对数方差取指数,以获得实际的方差: exp⁡(log⁡(σ2))=σ2 这个项衡量方差的大小。

3. - log_var

减去对数方差: −log⁡(σ2) 这个项衡量分布的扩展程度。

4. - 1

减去 1,是KL散度公式中的常数项,用于归一化。

将这些项加在一起:

μ2+exp⁡(log⁡(σ2))−log⁡(σ2)−1

5. torch.sum(..., 1)

对所有维度求和,计算单个样本的KL散度: ∑(μ2+σ2−log⁡(σ2)−1) 这一步是将每个样本的所有维度的KL散度加起来。

6. 0.5 * ...

乘以 0.5,因KL散度公式中有系数 0.5: 0.5×∑(μ2+σ2−log⁡(σ2)−1)

7. torch.mean(..., 0)

对所有样本取平均,得到最终的KL散度损失: mean(0.5×∑(μ2+σ2−log⁡(σ2)−1))

整个公式的作用是计算出近似后验分布 q(z∣x) 和标准正态分布 p(z) 之间的KL散度,该散度表示了两个分布之间的差异。这种损失通常用于变分自编码器(VAE)训练中,确保生成的潜在变量分布接近标准正态分布。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/36087.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

增值电信业务经营许可证详解

最近有很多人问小编,什么是增值电信业务经营许可证?可能大家对这个名词还很陌生,不着急!今天小编就给大家好好讲讲这个证到底是什么?为什么需要办这个这个证?它可以给企业带来什么? 带着疑问我们继续往下看。 什么…

一步步带你解锁Stable Diffusion:老外都眼馋的 SD 中文提示词插件分享

大家好我是极客菌!今天我们继续来分享一个外国人都眼馋的 SD 中文提示词插件。 那我们废话不多说,直接开整。 SD 的插件安装,小伙伴们应该都会了吧,我这里再简单讲下哦,到「扩展」中的「可下载」中点击「加载扩展列表…

LangChain资料总结

1、LangChain介绍 - 莫尔索随笔 2、LangChain 介绍 | LangChain中文网:500页中文文档教程,助力大模型LLM应用开发从入门到精通 3、🌈 Spring AI 语雀 4、AI全栈「AGI」 专栏 语雀 5、GitHub - langchain4j/langchain4j: Java version of LangChain…

国标GB/T 28181详解:国标GBT28181-2022的目录通知流程

目录 一、定义 二、作用 1、实时同步设备目录状态 2、优化资源管理和调度 3、增强系统的可扩展性和灵活性 4、提高系统的可靠性和稳定性 5、支持多级级联和分布式部署 6、便于用户管理和监控 三、基本要求 1、目录通知满足以下基本要求 2、关键要素 (1…

码农在低空经济领域的就业机遇与技能准备

在当今快速发展的经济格局中,低空经济正逐渐崭露头角,成为一个充满潜力和创新的领域。对于广大码农来说,这无疑是一片有待开拓的职业新蓝海。那么,码农们如何在低空经济中找到属于自己的就业机会呢? 首先,…

python-立方和不等式

[题目描述] 试求满足下述立方和不等式的 m 的整数解。 1^32^3...m^3≤n。本题算法如下: 对指定的 n,设置求和循环,从 i1 开始,i 递增1取值,把 i^3 (或 i∗i∗i)累加到 s,直至 s>n,脱离循环作…

docker配置redis主从复制

下载redis,复制redis.conf 主节点(6379) 修改redis.conf # bind 127.0.0.1 # 注释掉这里 protected-mode no # 改为no port 6379从节点(6380) 修改redis.conf bind 127.0.0.1 protected-mode no # 改为no port 6380 replicaof 172.17.0.2 6379 # 这里的ip为主节点容器的i…

Oracle数据库使用指南基本概念

学习总结 1、掌握 JAVA入门到进阶知识(持续写作中……) 2、学会Oracle数据库入门到入土用法(创作中……) 3、手把手教你开发炫酷的vbs脚本制作(完善中……) 4、牛逼哄哄的 IDEA编程利器技巧(编写中……) 5、面经吐血整理的 面试技…

代码随想录算法训练营Day36| 62.不同路径 , 63. 不同路径 II

由于我最近临近期末考试所以后面两题就先暂时跳过,但是并不是代表我不写,等我暑假会全部补起来,那么来看今天的第一题 62.不同路径:代码随想录 这道题目就是说让你求出到达终点有几种不同的路径,你只能向下或者向右走&…

介绍几种 MySQL 官方高可用方案

前言: MySQL 官方提供了多种高可用部署方案,从最基础的主从复制到组复制再到 InnoDB Cluster 等等。本篇文章以 MySQL 8.0 版本为准,介绍下不同高可用方案架构原理及使用场景。 1.MySQL Replication MySQL Replication 是官方提供的主从同…

2024.06.17校招 实习 内推 面经

绿*泡*泡VX: neituijunsir 交流*裙 ,内推/实习/校招汇总表格 1、提前批 | 中国电科38所2024年暑期开放日暨2025届提前批招聘正式启动! 提前批 | 中国电科38所2024年暑期开放日暨2025届提前批招聘正式启动! 2、实习 | 舍弗勒实…

HarmonyOS--数据持久化

用户首选项 场景介绍 1、用户首选项为应用提供Key-Value键值型的数据处理能力,支持应用持久化轻量级数据,并对其修改和查询。当用户希望有一个全局唯一存储的地方,可以采用用户首选项来进行存储。 2、Preferences会将该数据缓存在内存中&a…

Java之使用策略模式替代 if-else

在Java中,通常情况下 if-else 语句用于根据不同条件执行不同的逻辑。而策略模式则是一种设计模式,它允许在运行时选择算法的行为。 策略模式的主要思想是将算法封装成独立的对象,使得它们可以相互替换,使得算法的变化独立于使用算…

模型情景制作-制作一棵树

情景模型中,最常用到的也是最能提升情景中生气的就是树。然而,自然的生长和环境的影响使得树的制作变成了考验制作者观察力的一道考题。制作一棵逼真的树,我们可以参考下面的这种方法。 铁丝制树 您需要准备9—12根铁丝,每根的长…

SuperCopy解决文档不能复制问题

有一些文档,我们要使用时,总是面临收费的情况,让我们不能复制,让人头疼不已!!! SuperCopy就可以解决这个问题。 获取SuperCopy步骤 1. 打开浏览器,点击右上角的三个点 2. 找到扩…

老板电器 45 年的烹饪经验,浓缩在这款烹饪大模型中

在科技不断进步的时代,人工智能(AI)迅速成为推动各行各业发展的重要力量。家电行业也不例外,根据 Gartner 的报告预测,到 2024 年,AI 家电市场的规模将达到万亿美元级别。这一预估凸显了智能化在家电行业中…

网络安全 DVWA通关指南 Cross Site Request Forgery (CSRF)

DVWA Cross Site Request Forgery (CSRF) 文章目录 DVWA Cross Site Request Forgery (CSRF)DVWA Low 级别 CSRFDVWA Medium 级别 CSRFDVWA High 级别 CSRFDVWA Impossible 级别 CSRF CSRF是跨站请求伪造攻击,由客户端发起,是由于没有在执行关键操作时&a…

selenium网页自动化使用教程

Selenium 是一个流行的自动化测试工具,它支持多种编程语言,包括 Python。以下是关于 Selenium 安装和使用的一些详细步骤: 安装 Selenium: 确保 Python 环境已经安装在你的计算机上。Selenium 支持 Python 2.7 以及 Python 3.2 …

【黑龙江哪些行业需要做等保?】

黑龙江等保测评是衡量企业网络安全水平的一项主要指标,包括:金融,能源,电信,医疗,教育,交通,制造,电商等。 等保测评是黑龙江省信息化建设的重要组成部分,也…

旅游管理系统源码小程序

便捷旅行,尽在掌握 旅游管理系统是一款基于FastAdminElementUNIAPP开发的多端(微信小程序、公众号、H5)旅游管理系统,拥有丰富的装修组件、多端分享、模板消息、电子合同、旅游攻略、旅游线路及相关保险预订等功能,提…