[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言:

   
 

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

 

训练的过程中要注意调节forward 中的kle ,调参。

整个工程两个文件:

    vae.py

   main.py

目录:

  1.      vae
  2.       main

一  vae

  文件名: vae.py

   作用:   Variational Autoencoders (VAE)

 训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023@author: chengxf2
"""import torch
from torch import nn#ae: AutoEncoderclass VAE(nn.Module):def __init__(self,hidden_size=20):super(VAE, self).__init__()self.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=hidden_size),nn.ReLU())# hidden [batch_size, 10]h_dim = int(hidden_size/2)self.hDim = h_dimself.decoder = nn.Sequential(nn.Linear(in_features=h_dim, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def forward(self, x):'''param x:[batch, 1,28,28]return '''batchSz= x.size(0)#flattenx = x.view(batchSz, 784)#encoderh= self.encoder(x)#在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigmau, sigma = h.chunk(2,dim=1)#Reparameterize trick:#randn_like:产生一个正太分布 ~ N(0,1)#h.shape [batchSize,self.hDim]h = u+sigma* torch.randn_like(sigma)#kld :1e-8 防止sigma 平方为0kld = 0.5*torch.sum(torch.pow(u,2)+torch.pow(sigma,2)-torch.log(1e-8+torch.pow(sigma,2))-1)#MSE loss 是平均loss, 所以kld 也要算一个平均值kld = kld/(batchSz*32*32)xHat =   self.decoder(h)#reshapexHat = xHat.view(batchSz,1,28,28)return xHat,kld

二 main

文件名: main.py

作用: 训练,测试数据集

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023@author: chengxf2
"""import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdomdef main():batchNum = 32lr = 1e-3epochs = 20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")torch.manual_seed(1234)viz = visdom.Visdom()viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))tf= transforms.Compose([ transforms.ToTensor()])mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)global_step =0model =VAE().to(device)criteon = nn.MSELoss().to(device) #损失函数optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则print("\n ----main-----")for epoch in range(epochs):start = time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x = x.to(device)x_hat,kld = model(x)loss = criteon(x_hat, x)if kld is not None:elbo = -loss -1.0*kldloss = -elbo#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1end = time.perf_counter()    interval = int(end - start)print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())x,target = iter(test_data).next()x = x.to(device)with torch.no_grad():x_hat,kld = model(x)tip = 'hat'+str(epoch)viz.images(x,nrow=8, win='x',opts=dict(title='x'))viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))if __name__ == '__main__':main()

 参考:

 课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/61522.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

typora使用

1.主题配置 先打开主题文件夹, 文件–>>偏好设置–>>外观–>>打开主题文件夹 1.1字体 修改字体需要修改css文件,确定当前所用主题,可以在typora菜单点击主题,看看当前勾选的是哪个主题,比如gith…

iPhone 15 Pro与iPhone 13 Pro:最大的预期升级

如果你在2021年首次发布iPhone 13 Pro时就抢到了它,那么你的合同很可能即将到期。虽然距离iPhone 15系列还有几周的时间,但你可能已经在想:是时候把你的旧iPhone升级为iPhone 15 Pro了吗? 我们认为iPhone 13 Pro是你现在能买到的最好的手机之一。但如果你想在2023年晚些时…

微信小程序 趣味学习与益智游戏系统APP

管理员、用户可通过HBuilder系统手机打开系统,注册登录后可进行管理员后端;首页、个人中心、用户管理、学生分类管理、学一学管理、玩一玩管理、听一听管理、试题管理、练一练管理、系统管理、考试管理,用户前端;首页、学一学、玩…

音视频入门基础理论知识

文章目录 前言一、视频1、视频的概念2、常见的视频格式3、视频帧4、帧率5、色彩空间6、采用 YUV 的优势7、RGB 和 YUV 的换算 二、音频1、音频的概念2、采样率和采样位数①、采样率②、采样位数 3、音频编码4、声道数5、码率6、音频格式 三、编码1、为什么要编码2、视频编码①、…

同一台电脑测.Net和Mono平台浮点运算的差异

float speed 0.1f;float distance 2.0f;long needTime (long)(distance / speed);Log.Debug($"needTime{needTime}"); 结果: .Net平台算出20 Mono平台算出19

【传输层】网络基础 -- UDP协议 | TCP协议

再谈端口号端口号范围划分netstatpidof UDPUDP的特点面向数据报UDP的缓冲区 基于UDP的应用层协议 TCP认识TCP协议的报头理解封装解包理解可靠性TCP工作模式16位窗口大小6位标志位URGACKPSHRSTSYNFIN 再谈端口号 端口号(Port)标识了一个主机上进行通信的不同的应用程序 在TCP/I…

力扣92. 局部反转链表

92. 反转链表 II 给你单链表的头指针 head 和两个整数 left 和 right &#xff0c;其中 left < right 。请你反转从位置 left 到位置 right 的链表节点&#xff0c;返回 反转后的链表 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], left 2, right 4 输出&am…

计算机网络 | TCP 三次握手四次挥手 |半关闭连接

本来是不愿意写的&#xff0c;可是在实际场景&#xff0c;对具体的描述标志还是模糊不清&#xff0c;基础不扎实&#xff0c;就得承认&#xff01;&#xff01;&#xff01; TCP 连接建立需要解决三大问题&#xff1a; 知道双方存在约定一些参数&#xff0c;如最大滑动窗口值、…

Kotlin,解决调用了函数但是函数体内没有执行的问题,什么时候使用invoke

fun main() {listOf(1,2,3).forEach{ foo(it)} } fun foo(a:Int) {print(a) }这段代码按照代码逻辑来说打印的是 123 但是没有这个打印 把foo函数转成java的代码如下 JvmStaticNotNullpublic static final Function0 foo(final int var0) {return (Function0)(new Function0 ()…

Ubuntu 22.04安装 —— Win11 22H2

目录 Ubuntu使用下载UbuntuVmware 安装图示安装步骤图示 Ubuntu使用 系统环境&#xff1a; Windows 11 22H2Vmware 17 ProUbutun 22.04.3 Server Ubuntu Server documentation | Ubuntu 下载 Ubuntu 官网下载 建议安装长期支持版本 ——> 可以选择桌面版或服务器版(仅包…

UI界面自动化BagePage

常用basepage模块代码 # -*- coding: utf-8 -*- # Desc: UI自动化测试的一些基础浏览器操作方法# 第三方库导入 import time from logging import config import randomimport allure from selenium.webdriver.common.alert import Alert from selenium.webdriver.remote.webe…

【Leetcode】130.被围绕的区域

一、题目 1、题目描述 给你一个 m x n 的矩阵 board ,由若干字符 X 和 O ,找到所有被 X 围绕的区域,并将这些区域里所有的 O 用 X 填充。 示例1: 输入:board = [[“X”,“X”,“X”,“X”],[“X”,“O”,“O”,“X”],[“X”,“X”,“O”,“X”],[“X”,“O”,“X”,“…

高德地图jsapi报错INVALID_USER_SCODE

看了各种网上方法&#xff0c;还是搞不定。无奈在高德开放平台提了工单。 很快高德的技术人员就给出了答复“您好&#xff0c;您检查一下您的安全密钥是否在key之前&#xff0c;安全密钥设置必须是在JS API 脚本加载之前进行设置&#xff0c;否则设置无效。” 并给出了推荐的…

【数据结构】C语言队列(详解)

前言: &#x1f4a5;&#x1f388;个人主页:​​​​​​Dream_Chaser&#xff5e; &#x1f388;&#x1f4a5; ✨✨专栏:http://t.csdn.cn/oXkBa ⛳⛳本篇内容:c语言数据结构--C语言实现队列 目录 一.队列概念及结构 1.1队列的概念 1.2队列的结构 二.队列的实现 2.1头文…

【linux命令讲解大全】035.文件删除命令:rm 和 rmdir 的用法详解

文章目录 rm补充说明语法选项参数实例 rmdir补充说明语法选项参数实例 从零学 python rm 用于删除给定的文件和目录 补充说明 rm命令可以删除一个目录中的一个或多个文件或目录&#xff0c;也可以将某个目录及其下属的所有文件及其子目录均删除掉。对于链接文件&#xff0c;…

【python爬虫】6.爬虫实操(带参数请求数据)

文章目录 前言项目&#xff1a;狂热粉丝分析过程什么是带参数请求数据如何带参数请求数据 代码实现被隐藏的歌曲清单什么是Request Headers如何添加Request Headers 复习 前言 先来复习一下上一关的主要知识吧&#xff0c;先热个身。 Network能够记录浏览器的所有请求。我们最…

React Navigation 使用导航

在 Web 浏览器中&#xff0c;您可以使用锚标记链接到不同的页面。当用户单击链接时&#xff0c;URL 会被推送到浏览器历史记录堆栈中。当用户按下后退按钮时&#xff0c;浏览器会从历史堆栈顶部弹出该项目&#xff0c;因此活动页面现在是以前访问过的页面。React Native 不像 W…

AZ900备考

文章目录 云服务的概念云服务模型云服务类型消费的模型云服务的好处可靠性和可预测性的优势云中的管理 Azure 体系结构和服务核心结构组件物理基础结构组件 Azure计算和网络服务Azure 存储服务身份认证AD身份认证 Azure 管理和治理成本管理治理合规性的功能和工具管理和部署Azu…

代码随想录算法训练营第17期第34天 | 1005. K 次取反后最大化的数组和、134. 加油站、135. 分发糖果

1005. K 次取反后最大化的数组和 这里说一下卡哥和我的区别&#xff0c;基本思路是一样的&#xff0c; 只是卡哥这里只需要一次排序&#xff0c;而我这边排了两次&#xff1b; 卡哥思路&#xff1a; 1.按照绝对值大小从大到小排序 2.从前往后遍历&#xff0c;遇到负数将其转…

喷泉码浅谈

01、喷泉码简介 喷泉码&#xff08;Fountain Code&#xff09;是一种在无线通信、数据传输和网络编码领域中使用的错误纠正技术。它与传统的纠错码和编码方法有所不同&#xff0c;喷泉码被设计用于在不确定信道条件下的高效数据传输。传统的纠错码&#xff08;如海明码、RS码等…