[PyTorch][chapter 54][Variational Auto-Encoder 实战]

前言:

   
 

这里主要实现: Variational Autoencoders (VAEs) 变分自动编码器
其训练效果如下

 

训练的过程中要注意调节forward 中的kle ,调参。

整个工程两个文件:

    vae.py

   main.py

目录:

  1.      vae
  2.       main

一  vae

  文件名: vae.py

   作用:   Variational Autoencoders (VAE)

 训练的过程中加入一些限制,使它的latent space规则一点呢。于是就引入了variational autoencoder(VAE),它被定义为一个有规律地训练以避免过度拟合的Autoencoder,可以确保潜在空间具有良好的属性从而实现内容的生成。
variational autoencoder的架构和Autoencoder差不多,区别在于不再是把输入当作一个点,而是把输入当成一个分布。

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:19:19 2023@author: chengxf2
"""import torch
from torch import nn#ae: AutoEncoderclass VAE(nn.Module):def __init__(self,hidden_size=20):super(VAE, self).__init__()self.encoder = nn.Sequential(nn.Linear(in_features=784, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=hidden_size),nn.ReLU())# hidden [batch_size, 10]h_dim = int(hidden_size/2)self.hDim = h_dimself.decoder = nn.Sequential(nn.Linear(in_features=h_dim, out_features=64),nn.ReLU(),nn.Linear(in_features=64, out_features=128),nn.ReLU(),nn.Linear(in_features=128, out_features=256),nn.ReLU(),nn.Linear(in_features=256, out_features=784),nn.Sigmoid())def forward(self, x):'''param x:[batch, 1,28,28]return '''batchSz= x.size(0)#flattenx = x.view(batchSz, 784)#encoderh= self.encoder(x)#在给定维度上对所给张量进行分块,前一半的神经元看作u, 后一般的神经元看作sigmau, sigma = h.chunk(2,dim=1)#Reparameterize trick:#randn_like:产生一个正太分布 ~ N(0,1)#h.shape [batchSize,self.hDim]h = u+sigma* torch.randn_like(sigma)#kld :1e-8 防止sigma 平方为0kld = 0.5*torch.sum(torch.pow(u,2)+torch.pow(sigma,2)-torch.log(1e-8+torch.pow(sigma,2))-1)#MSE loss 是平均loss, 所以kld 也要算一个平均值kld = kld/(batchSz*32*32)xHat =   self.decoder(h)#reshapexHat = xHat.view(batchSz,1,28,28)return xHat,kld

二 main

文件名: main.py

作用: 训练,测试数据集

 

# -*- coding: utf-8 -*-
"""
Created on Wed Aug 30 14:24:10 2023@author: chengxf2
"""import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import time
from torch import optim,nn
from vae import VAE
import visdomdef main():batchNum = 32lr = 1e-3epochs = 20device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")torch.manual_seed(1234)viz = visdom.Visdom()viz.line([0],[-1],win='train_loss',opts =dict(title='train acc'))tf= transforms.Compose([ transforms.ToTensor()])mnist_train = datasets.MNIST('mnist',True,transform= tf,download=True)train_data = DataLoader(mnist_train, batch_size=batchNum, shuffle=True)mnist_test = datasets.MNIST('mnist',False,transform= tf,download=True)test_data = DataLoader(mnist_test, batch_size=batchNum, shuffle=True)global_step =0model =VAE().to(device)criteon = nn.MSELoss().to(device) #损失函数optimizer = optim.Adam(model.parameters(),lr=lr) #梯度更新规则print("\n ----main-----")for epoch in range(epochs):start = time.perf_counter()for step ,(x,y) in enumerate(train_data):#[b,1,28,28]x = x.to(device)x_hat,kld = model(x)loss = criteon(x_hat, x)if kld is not None:elbo = -loss -1.0*kldloss = -elbo#backpropoptimizer.zero_grad()loss.backward()optimizer.step()viz.line(Y=[loss.item()],X=[global_step],win='train_loss',update='append')global_step +=1end = time.perf_counter()    interval = int(end - start)print("epoch: %d"%epoch, "\t 训练时间 %d"%interval, '\t 总loss: %4.7f'%loss.item(),"\t KL divergence: %4.7f"%kld.item())x,target = iter(test_data).next()x = x.to(device)with torch.no_grad():x_hat,kld = model(x)tip = 'hat'+str(epoch)viz.images(x,nrow=8, win='x',opts=dict(title='x'))viz.images(x_hat,nrow=8, win='x_hat',opts=dict(title=tip))if __name__ == '__main__':main()

 参考:

 课时118 变分Auto-Encoder实战-2_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/61522.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

typora使用

1.主题配置 先打开主题文件夹, 文件–>>偏好设置–>>外观–>>打开主题文件夹 1.1字体 修改字体需要修改css文件,确定当前所用主题,可以在typora菜单点击主题,看看当前勾选的是哪个主题,比如gith…

iPhone 15 Pro与iPhone 13 Pro:最大的预期升级

如果你在2021年首次发布iPhone 13 Pro时就抢到了它,那么你的合同很可能即将到期。虽然距离iPhone 15系列还有几周的时间,但你可能已经在想:是时候把你的旧iPhone升级为iPhone 15 Pro了吗? 我们认为iPhone 13 Pro是你现在能买到的最好的手机之一。但如果你想在2023年晚些时…

微信小程序 趣味学习与益智游戏系统APP

管理员、用户可通过HBuilder系统手机打开系统,注册登录后可进行管理员后端;首页、个人中心、用户管理、学生分类管理、学一学管理、玩一玩管理、听一听管理、试题管理、练一练管理、系统管理、考试管理,用户前端;首页、学一学、玩…

音视频入门基础理论知识

文章目录 前言一、视频1、视频的概念2、常见的视频格式3、视频帧4、帧率5、色彩空间6、采用 YUV 的优势7、RGB 和 YUV 的换算 二、音频1、音频的概念2、采样率和采样位数①、采样率②、采样位数 3、音频编码4、声道数5、码率6、音频格式 三、编码1、为什么要编码2、视频编码①、…

同一台电脑测.Net和Mono平台浮点运算的差异

float speed 0.1f;float distance 2.0f;long needTime (long)(distance / speed);Log.Debug($"needTime{needTime}"); 结果: .Net平台算出20 Mono平台算出19

【传输层】网络基础 -- UDP协议 | TCP协议

再谈端口号端口号范围划分netstatpidof UDPUDP的特点面向数据报UDP的缓冲区 基于UDP的应用层协议 TCP认识TCP协议的报头理解封装解包理解可靠性TCP工作模式16位窗口大小6位标志位URGACKPSHRSTSYNFIN 再谈端口号 端口号(Port)标识了一个主机上进行通信的不同的应用程序 在TCP/I…

力扣92. 局部反转链表

92. 反转链表 II 给你单链表的头指针 head 和两个整数 left 和 right &#xff0c;其中 left < right 。请你反转从位置 left 到位置 right 的链表节点&#xff0c;返回 反转后的链表 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5], left 2, right 4 输出&am…

计算机网络 | TCP 三次握手四次挥手 |半关闭连接

本来是不愿意写的&#xff0c;可是在实际场景&#xff0c;对具体的描述标志还是模糊不清&#xff0c;基础不扎实&#xff0c;就得承认&#xff01;&#xff01;&#xff01; TCP 连接建立需要解决三大问题&#xff1a; 知道双方存在约定一些参数&#xff0c;如最大滑动窗口值、…

Ubuntu 22.04安装 —— Win11 22H2

目录 Ubuntu使用下载UbuntuVmware 安装图示安装步骤图示 Ubuntu使用 系统环境&#xff1a; Windows 11 22H2Vmware 17 ProUbutun 22.04.3 Server Ubuntu Server documentation | Ubuntu 下载 Ubuntu 官网下载 建议安装长期支持版本 ——> 可以选择桌面版或服务器版(仅包…

【Leetcode】130.被围绕的区域

一、题目 1、题目描述 给你一个 m x n 的矩阵 board ,由若干字符 X 和 O ,找到所有被 X 围绕的区域,并将这些区域里所有的 O 用 X 填充。 示例1: 输入:board = [[“X”,“X”,“X”,“X”],[“X”,“O”,“O”,“X”],[“X”,“X”,“O”,“X”],[“X”,“O”,“X”,“…

高德地图jsapi报错INVALID_USER_SCODE

看了各种网上方法&#xff0c;还是搞不定。无奈在高德开放平台提了工单。 很快高德的技术人员就给出了答复“您好&#xff0c;您检查一下您的安全密钥是否在key之前&#xff0c;安全密钥设置必须是在JS API 脚本加载之前进行设置&#xff0c;否则设置无效。” 并给出了推荐的…

【数据结构】C语言队列(详解)

前言: &#x1f4a5;&#x1f388;个人主页:​​​​​​Dream_Chaser&#xff5e; &#x1f388;&#x1f4a5; ✨✨专栏:http://t.csdn.cn/oXkBa ⛳⛳本篇内容:c语言数据结构--C语言实现队列 目录 一.队列概念及结构 1.1队列的概念 1.2队列的结构 二.队列的实现 2.1头文…

【python爬虫】6.爬虫实操(带参数请求数据)

文章目录 前言项目&#xff1a;狂热粉丝分析过程什么是带参数请求数据如何带参数请求数据 代码实现被隐藏的歌曲清单什么是Request Headers如何添加Request Headers 复习 前言 先来复习一下上一关的主要知识吧&#xff0c;先热个身。 Network能够记录浏览器的所有请求。我们最…

AZ900备考

文章目录 云服务的概念云服务模型云服务类型消费的模型云服务的好处可靠性和可预测性的优势云中的管理 Azure 体系结构和服务核心结构组件物理基础结构组件 Azure计算和网络服务Azure 存储服务身份认证AD身份认证 Azure 管理和治理成本管理治理合规性的功能和工具管理和部署Azu…

喷泉码浅谈

01、喷泉码简介 喷泉码&#xff08;Fountain Code&#xff09;是一种在无线通信、数据传输和网络编码领域中使用的错误纠正技术。它与传统的纠错码和编码方法有所不同&#xff0c;喷泉码被设计用于在不确定信道条件下的高效数据传输。传统的纠错码&#xff08;如海明码、RS码等…

Linux服务器部署JavaWeb后端项目

适用于&#xff1a;MVVM前后台分离开发、部署、域名配置 前端&#xff1a;Vue 后端&#xff1a;Spring Boot 这篇文章只讲后端部署&#xff0c;前端部署戳这里 目录 Step1&#xff1a;服务器上搭建后端所需环境1、更新服务器软件包2、安装JDK83、安装MySQL4、登录MySQL5、修…

钉钉消息已读、未读咋实现的嘞?

前言 一款app&#xff0c;消息页面有&#xff1a;钱包通知、最近访客等各种通知类别&#xff0c;每个类别可能有新的通知消息&#xff0c;实现已读、未读功能&#xff0c;包括多少个未读&#xff0c;这个是怎么实现的呢&#xff1f;比如用户A访问了用户B的主页&#xff0c;难道…

Mac下Docker Desktop安装命令行工具、开启本地远程访问

Mac系统下&#xff0c;为了方便在terminal和idea里使用docker&#xff0c;需要安装docker命令行工具&#xff0c;和开启Docker Desktop本地远程访问。 具体方法是在设置-高级下&#xff0c; 1.将勾选的User调整为System&#xff0c;这样不用手动配置PATH即可使用docker命令 …

Nuxt 菜鸟入门学习笔记五:CSS 样式

文章目录 本地样式表在组件内导入通过 Nuxt 配置 CSS 属性导入使用字体导入通过 NPM 发布的样式表 外部样式表动态添加样式表【高级】使用 Nitro 插件修改渲染的头部 使用预处理器单文件组件 SFC 样式类和样式绑定使用 v-bind 的动态样式Scoped StylesCSS Modules预处理器支持 …

基于Python+OpenCV智能答题卡识别系统——深度学习和图像识别算法应用(含Python全部工程源码)+训练与测试数据集

目录 前言总体设计系统整体结构图系统流程图 运行环境Python 环境PyCharm安装OpenCV环境 模块实现1. 信息识别2. Excel导出模块3. 图形用户界面模块4. 手写识别模块 系统测试1. 系统识别准确率2. 系统识别应用 工程源代码下载其它资料下载 前言 本项目基于Python和OpenCV图像处…