VAE-pytorch代码

 

 

 

import osimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoaderfrom torchvision import transforms, datasets
from torchvision.utils import save_imagefrom tqdm import tqdmclass VAE(nn.Module):  # 定义VAE模型def __init__(self, img_size, latent_dim):  # 初始化方法super(VAE, self).__init__()  # 继承初始化方法self.in_channel, self.img_h, self.img_w = img_size  # 由输入图片形状得到图片通道数C、图片高度H、图片宽度Wself.h = self.img_h // 32  # 经过5次卷积后,最终特征层高度变为原图片高度的1/32self.w = self.img_w // 32  # 经过5次卷积后,最终特征层宽度变为原图片高度的1/32hw = self.h * self.w  # 最终特征层的尺寸hxwself.latent_dim = latent_dim  # 采样变量Z的长度self.hidden_dims = [32, 64, 128, 256, 512]  # 特征层通道数列表# 开始构建编码器Encoderlayers = []  # 用于存放模型结构for hidden_dim in self.hidden_dims:  # 循环特征层通道数列表layers += [nn.Conv2d(self.in_channel, hidden_dim, 3, 2, 1),  # 添加convnn.BatchNorm2d(hidden_dim),  # 添加bnnn.LeakyReLU()]  # 添加leakyreluself.in_channel = hidden_dim  # 将下次循环的输入通道数设为本次循环的输出通道数self.encoder = nn.Sequential(*layers)  # 解码器Encoder模型结构self.fc_mu = nn.Linear(self.hidden_dims[-1] * hw, self.latent_dim)  # linaer,将特征向量转化为分布均值muself.fc_var = nn.Linear(self.hidden_dims[-1] * hw, self.latent_dim)  # linear,将特征向量转化为分布方差的对数log(var)# 开始构建解码器Decoderlayers = []  # 用于存放模型结构self.decoder_input = nn.Linear(self.latent_dim, self.hidden_dims[-1] * hw)  # linaer,将采样变量Z转化为特征向量self.hidden_dims.reverse()  # 倒序特征层通道数列表for i in range(len(self.hidden_dims) - 1):  # 循环特征层通道数列表layers += [nn.ConvTranspose2d(self.hidden_dims[i], self.hidden_dims[i + 1], 3, 2, 1, 1),  # 添加transconvnn.BatchNorm2d(self.hidden_dims[i + 1]),  # 添加bnnn.LeakyReLU()]  # 添加leakyrelulayers += [nn.ConvTranspose2d(self.hidden_dims[-1], self.hidden_dims[-1], 3, 2, 1, 1),  # 添加transconvnn.BatchNorm2d(self.hidden_dims[-1]),  # 添加bnnn.LeakyReLU(),  # 添加leakyrelunn.Conv2d(self.hidden_dims[-1], img_size[0], 3, 1, 1),  # 添加convnn.Tanh()]  # 添加tanhself.decoder = nn.Sequential(*layers)  # 编码器Decoder模型结构def encode(self, x):  # 定义编码过程result = self.encoder(x)  # Encoder结构,(n,1,32,32)-->(n,512,1,1)result = torch.flatten(result, 1)  # 将特征层转化为特征向量,(n,512,1,1)-->(n,512)mu = self.fc_mu(result)  # 计算分布均值mu,(n,512)-->(n,128)log_var = self.fc_var(result)  # 计算分布方差的对数log(var),(n,512)-->(n,128)return [mu, log_var]  # 返回分布的均值和方差对数def decode(self, z):  # 定义解码过程y = self.decoder_input(z).view(-1, self.hidden_dims[0], self.h,self.w)  # 将采样变量Z转化为特征向量,再转化为特征层,(n,128)-->(n,512)-->(n,512,1,1)y = self.decoder(y)  # decoder结构,(n,512,1,1)-->(n,1,32,32)return y  # 返回生成样本Ydef reparameterize(self, mu, log_var):  # 重参数技巧std = torch.exp(0.5 * log_var)  # 分布标准差stdeps = torch.randn_like(std)  # 从标准正态分布中采样,(n,128)return mu + eps * std  # 返回对应正态分布中的采样值def forward(self, x):  # 前传函数mu, log_var = self.encode(x)  # 经过编码过程,得到分布的均值mu和方差对数log_varz = self.reparameterize(mu, log_var)  # 经过重参数技巧,得到分布采样变量Zy = self.decode(z)  # 经过解码过程,得到生成样本Yreturn [y, x, mu, log_var]  # 返回生成样本Y,输入样本X,分布均值mu,分布方差对数log_vardef sample(self, n, cuda):  # 定义生成过程z = torch.randn(n, self.latent_dim)  # 从标准正态分布中采样得到n个采样变量Z,长度为latent_dimif cuda:  # 如果使用cudaz = z.cuda()  # 将采样变量Z加载到GPUimages = self.decode(z)  # 经过解码过程,得到生成样本Yreturn images  # 返回生成样本Ydef loss_fn(y, x, mu, log_var):  # 定义损失函数recons_loss = F.mse_loss(y, x)  # 重建损失,MSEkld_loss = torch.mean(0.5 * torch.sum(mu ** 2 + torch.exp(log_var) - log_var - 1, 1), 0)  # 分布损失,正态分布与标准正态分布的KL散度return recons_loss + w * kld_loss  # 最终损失由两部分组成,其中分布损失需要乘上一个系数wif __name__ == "__main__":total_epochs = 100  # epochsbatch_size = 64  # batch sizelr = 5e-4  # lrw = 0.00025  # kld_loss的系数wnum_workers = 8  # 数据加载线程数image_size = 32  # 图片尺寸image_channel = 1  # 图片通道latent_dim = 128  # 采样变量Z长度sample_images_dir = "sample_images"  # 生成样本示例存放路径train_dataset_dir = "../dataset/mnist"  # 训练样本存放路径os.makedirs(sample_images_dir, exist_ok=True)  # 创建生成样本示例存放路径os.makedirs(train_dataset_dir, exist_ok=True)  # 创建训练样本存放路径cuda = True if torch.cuda.is_available() else False  # 如果cuda可用,则使用cudaimg_size = (image_channel, image_size, image_size)  # 输入样本形状(1,32,32)vae = VAE(img_size, latent_dim)  # 实例化VAE模型,传入输入样本形状与采样变量长度if cuda:  # 如果使用cudavae = vae.cuda()  # 将模型加载到GPU# dataset and dataloadertransform = transforms.Compose(  # 图片预处理方法[transforms.Resize(image_size),  # 图片resize,(28x28)-->(32,32)transforms.ToTensor(),  # 转化为tensortransforms.Normalize([0.5], [0.5])]  # 标准化)dataloader = DataLoader(  # 定义dataloaderdataset=datasets.MNIST(root=train_dataset_dir,  # 使用mnist数据集,选择数据路径train=True,  # 使用训练集transform=transform,  # 图片预处理download=True),  # 自动下载batch_size=batch_size,  # batch sizenum_workers=num_workers,  # 数据加载线程数shuffle=True  # 打乱数据)# optimizeroptimizer = torch.optim.Adam(vae.parameters(), lr=lr)  # 使用Adam优化器# train loopfor epoch in range(total_epochs):  # 循环epochtotal_loss = 0  # 记录总损失pbar = tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{total_epochs}", postfix=dict,miniters=0.3)  # 设置当前epoch显示进度for i, (img, _) in enumerate(dataloader):  # 循环iterif cuda:  # 如果使用cudaimg = img.cuda()  # 将训练数据加载到GPUvae.train()  # 模型开始训练optimizer.zero_grad()  # 模型清零梯度y, x, mu, log_var = vae(img)  # 输入训练样本X,得到生成样本Y,输入样本X,分布均值mu,分布方差对数log_varloss = loss_fn(y, x, mu, log_var)  # 计算lossloss.backward()  # 反向传播,计算当前梯度optimizer.step()  # 根据梯度,更新网络参数total_loss += loss.item()  # 累计losspbar.set_postfix(**{"Loss": loss.item()})  # 显示当前iter的losspbar.update(1)  # 步进长度pbar.close()  # 关闭当前epoch显示进度print("total_loss:%.4f" %(total_loss / len(dataloader)))  # 显示当前epoch训练完成后,模型的总损失vae.eval()  # 模型开始验证sample_images = vae.sample(25, cuda)  # 获得25个生成样本save_image(sample_images.data, "%s/ep%d.png" % (sample_images_dir, (epoch + 1)), nrow=5,normalize=True)  # 保存生成样本示例(5x5)

其中计算KLloss的代码的解释如下:

代码的目标是计算变分自编码器(VAE)中近似后验分布q(z∣x) 和标准正态分布 p(z) 之间的KL散度。KL散度公式的具体计算步骤如下:

1. mu ** 2

计算均值的平方项: μ2 这个项是为了衡量均值偏离零的程度。

2. torch.exp(log_var)

对数方差取指数,以获得实际的方差: exp⁡(log⁡(σ2))=σ2 这个项衡量方差的大小。

3. - log_var

减去对数方差: −log⁡(σ2) 这个项衡量分布的扩展程度。

4. - 1

减去 1,是KL散度公式中的常数项,用于归一化。

将这些项加在一起:

μ2+exp⁡(log⁡(σ2))−log⁡(σ2)−1

5. torch.sum(..., 1)

对所有维度求和,计算单个样本的KL散度: ∑(μ2+σ2−log⁡(σ2)−1) 这一步是将每个样本的所有维度的KL散度加起来。

6. 0.5 * ...

乘以 0.5,因KL散度公式中有系数 0.5: 0.5×∑(μ2+σ2−log⁡(σ2)−1)

7. torch.mean(..., 0)

对所有样本取平均,得到最终的KL散度损失: mean(0.5×∑(μ2+σ2−log⁡(σ2)−1))

整个公式的作用是计算出近似后验分布 q(z∣x) 和标准正态分布 p(z) 之间的KL散度,该散度表示了两个分布之间的差异。这种损失通常用于变分自编码器(VAE)训练中,确保生成的潜在变量分布接近标准正态分布。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/36087.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一步步带你解锁Stable Diffusion:老外都眼馋的 SD 中文提示词插件分享

大家好我是极客菌!今天我们继续来分享一个外国人都眼馋的 SD 中文提示词插件。 那我们废话不多说,直接开整。 SD 的插件安装,小伙伴们应该都会了吧,我这里再简单讲下哦,到「扩展」中的「可下载」中点击「加载扩展列表…

国标GB/T 28181详解:国标GBT28181-2022的目录通知流程

目录 一、定义 二、作用 1、实时同步设备目录状态 2、优化资源管理和调度 3、增强系统的可扩展性和灵活性 4、提高系统的可靠性和稳定性 5、支持多级级联和分布式部署 6、便于用户管理和监控 三、基本要求 1、目录通知满足以下基本要求 2、关键要素 (1…

python-立方和不等式

[题目描述] 试求满足下述立方和不等式的 m 的整数解。 1^32^3...m^3≤n。本题算法如下: 对指定的 n,设置求和循环,从 i1 开始,i 递增1取值,把 i^3 (或 i∗i∗i)累加到 s,直至 s>n,脱离循环作…

docker配置redis主从复制

下载redis,复制redis.conf 主节点(6379) 修改redis.conf # bind 127.0.0.1 # 注释掉这里 protected-mode no # 改为no port 6379从节点(6380) 修改redis.conf bind 127.0.0.1 protected-mode no # 改为no port 6380 replicaof 172.17.0.2 6379 # 这里的ip为主节点容器的i…

Oracle数据库使用指南基本概念

学习总结 1、掌握 JAVA入门到进阶知识(持续写作中……) 2、学会Oracle数据库入门到入土用法(创作中……) 3、手把手教你开发炫酷的vbs脚本制作(完善中……) 4、牛逼哄哄的 IDEA编程利器技巧(编写中……) 5、面经吐血整理的 面试技…

介绍几种 MySQL 官方高可用方案

前言: MySQL 官方提供了多种高可用部署方案,从最基础的主从复制到组复制再到 InnoDB Cluster 等等。本篇文章以 MySQL 8.0 版本为准,介绍下不同高可用方案架构原理及使用场景。 1.MySQL Replication MySQL Replication 是官方提供的主从同…

HarmonyOS--数据持久化

用户首选项 场景介绍 1、用户首选项为应用提供Key-Value键值型的数据处理能力,支持应用持久化轻量级数据,并对其修改和查询。当用户希望有一个全局唯一存储的地方,可以采用用户首选项来进行存储。 2、Preferences会将该数据缓存在内存中&a…

模型情景制作-制作一棵树

情景模型中,最常用到的也是最能提升情景中生气的就是树。然而,自然的生长和环境的影响使得树的制作变成了考验制作者观察力的一道考题。制作一棵逼真的树,我们可以参考下面的这种方法。 铁丝制树 您需要准备9—12根铁丝,每根的长…

SuperCopy解决文档不能复制问题

有一些文档,我们要使用时,总是面临收费的情况,让我们不能复制,让人头疼不已!!! SuperCopy就可以解决这个问题。 获取SuperCopy步骤 1. 打开浏览器,点击右上角的三个点 2. 找到扩…

老板电器 45 年的烹饪经验,浓缩在这款烹饪大模型中

在科技不断进步的时代,人工智能(AI)迅速成为推动各行各业发展的重要力量。家电行业也不例外,根据 Gartner 的报告预测,到 2024 年,AI 家电市场的规模将达到万亿美元级别。这一预估凸显了智能化在家电行业中…

网络安全 DVWA通关指南 Cross Site Request Forgery (CSRF)

DVWA Cross Site Request Forgery (CSRF) 文章目录 DVWA Cross Site Request Forgery (CSRF)DVWA Low 级别 CSRFDVWA Medium 级别 CSRFDVWA High 级别 CSRFDVWA Impossible 级别 CSRF CSRF是跨站请求伪造攻击,由客户端发起,是由于没有在执行关键操作时&a…

【黑龙江哪些行业需要做等保?】

黑龙江等保测评是衡量企业网络安全水平的一项主要指标,包括:金融,能源,电信,医疗,教育,交通,制造,电商等。 等保测评是黑龙江省信息化建设的重要组成部分,也…

旅游管理系统源码小程序

便捷旅行,尽在掌握 旅游管理系统是一款基于FastAdminElementUNIAPP开发的多端(微信小程序、公众号、H5)旅游管理系统,拥有丰富的装修组件、多端分享、模板消息、电子合同、旅游攻略、旅游线路及相关保险预订等功能,提…

1961 Springboot自习室预约系统idea开发mysql数据库web结构java编程计算机网页源码maven项目

一、源码特点 springboot 自习室预约管理系统是一套完善的信息系统,结合springboot框架和bootstrap完成本系统,对理解JSP java编程开发语言有帮助系统采用springboot框架(MVC模式开发),系统具有完整的源代码和数据库…

大型企业组网如何规划网络

大型企业组网是一个复杂的过程,它需要细致的规划和设计,以确保网络能够满足企业的业务需求,同时保证性能、安全性和可扩展性。以下是规划大型企业网络的一些关键步骤和考虑因素: 1. 需求分析 业务需求:与各个业务部门…

EWM学习之旅-1-EWM100

系统学习一个业务模块已经变得越来越重要,开始吧,EWM! EWM的Learning Journey中包括7本 ebook,100/110/115/120/125/130/140,一本一本的啃吧,相信很多内容是重复的。 EWM100很适合初学者,了解概念术语&…

Lesson 40 What are you going to do?

Lesson 40 What are you going to do? 词汇 show v. 展示,展出 搭配:show人东西    show东西to人 口语:Show me your hands! 拿出来! n. 秀,表演 搭配:talk show 脱口秀    show room 展厅 take v…

精益生产推进时如何营造持续变革的氛围?

在快速变化的市场环境中,企业如何保持竞争力?精益生产无疑为众多企业提供了一个强大的战略工具。但是,单纯的引入精益生产理念和方法并不能保证企业的持续成功。关键在于如何营造一种持续变革的氛围,让精益生产成为推动企业不断前…

智慧校园-宿舍管理系统总体概述

在教育信息化的不断推动下,智慧校园宿舍管理系统脱颖而出,它以一种全新的视角和方式,重塑了高校宿舍管理的传统模式。该系统深度融合了云计算、物联网、大数据等先进科技,旨在为学生提供一个既安全可靠又充满便捷与温馨的居住体验…

Node.js全栈指南:认识MIME和HTTP

MIME,全称 “多用途互联网邮件扩展类型”。 这名称相当学术,用人话来说就是: 我们浏览一个网页的时候,之所以能看到 html 文件展示成网页,图片可以正常显示,css 样式能正常影响网页效果,js 脚…