AIGC笔记--CVAE模型的搭建

目录

1--CVAE模型

2--代码实例


1--CVAE模型

简单介绍:

        与VAE类似,只不过模型的输入需要考虑图片和条件(condition)的融合,融合结果通过一个 encoder 映射到标准分布(均值和方差),从映射的标准分布中随机采样一个样本,样本也需要和条件进行融合,最后通过 decoder 重构图片;

        由于模型的输入是图片和条件的融合,因此模型学习了基于条件的图片生成;

        计算源图片和重构图片之间的损失,具体损失函数的推导可以参考:变分自编码器(VAE)

2--代码实例

        下面的 CVAE 中,用了最简单的融合方式(concat)将条件 Y 与输入 X 融合形成X_given_Y,同理条件 Y 与 X_given_Y 融合形成 z_given_Y;

import torch
import torch.nn as nnclass VAE(nn.Module):def __init__(self, in_features, latent_size, y_size=0):super(VAE, self).__init__()self.latent_size = latent_sizeself.encoder_forward = nn.Sequential( # encodernn.Linear(in_features + y_size, in_features),nn.LeakyReLU(),nn.Linear(in_features, in_features),nn.LeakyReLU(),nn.Linear(in_features, self.latent_size * 2))self.decoder_forward = nn.Sequential( # decodernn.Linear(self.latent_size + y_size, in_features),nn.LeakyReLU(),nn.Linear(in_features, in_features),nn.LeakyReLU(),nn.Linear(in_features, in_features),nn.Sigmoid())def encoder(self, X): # encodeout = self.encoder_forward(X) # 这里通过一个encoder生成均值和标准差mu = out[:, :self.latent_size] # 输出的前半部分作为均值log_var = out[:, self.latent_size:] # 后半部分作为标准差return mu, log_vardef decoder(self, z): # decodemu_prime = self.decoder_forward(z)return mu_primedef reparameterization(self, mu, log_var): # reparameterizationepsilon = torch.randn_like(log_var)z = mu + epsilon * torch.sqrt(log_var.exp())return zdef loss(self, X, mu_prime, mu, log_var): # cal lossreconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))return reconstruction_loss + latent_lossdef forward(self, X, *args, **kwargs):mu, log_var = self.encoder(X) # encodez = self.reparameterization(mu, log_var) # generate z by reparameterizationmu_prime = self.decoder(z) # decodereturn mu_prime, mu, log_varclass CVAE(VAE):def __init__(self, in_features, latent_size, y_size):super(CVAE, self).__init__(in_features, latent_size, y_size)def forward(self, X, y = None, *args, **kwargs):y = y.to(next(self.parameters()).device)X_given_Y = torch.cat((X, y.unsqueeze(1)), dim = 1)mu, log_var = self.encoder(X_given_Y)z = self.reparameterization(mu, log_var)z_given_Y = torch.cat((z, y.unsqueeze(1)), dim = 1)mu_prime_given_Y = self.decoder(z_given_Y)return mu_prime_given_Y, mu, log_var

简单的损失计算代码:

def loss(self, X, mu_prime, mu, log_var): # cal lossreconstruction_loss = torch.mean(torch.square(X - mu_prime).sum(dim=1))latent_loss = torch.mean(0.5 * (log_var.exp() + torch.square(mu) - log_var).sum(dim=1))return reconstruction_loss + latent_loss

完整代码参考:liujf69/VAE

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/625934.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于ssm百货中心供应链管理系统+jsp论文

摘 要 社会发展日新月异,用计算机应用实现数据管理功能已经算是很完善的了,但是随着移动互联网的到来,处理信息不再受制于地理位置的限制,处理信息及时高效,备受人们的喜爱。本次开发一套百货中心供应链管理系统有管理…

transfomer中Decoder和Encoder的base_layer的源码实现

简介 Encoder和Decoder共同组成transfomer,分别对应图中左右浅绿色框内的部分. Encoder: 目的:将输入的特征图转换为一系列自注意力的输出。 工作原理:首先,通过卷积神经网络(CNN)提取输入图像的特征。然…

构建未来教育:在线培训系统开发的技术探讨

随着远程学习的崛起和数字化教育的普及,在线培训系统的开发成为了现代教育的核心。本文将深入讨论在线培训系统的关键技术要点,涵盖前后端开发、数据库管理、以及安全性和身份验证等关键方面。 前端开发:提供交互性与用户友好体验 在构建在…

02 SpringMVC接收数据之访问路径设置+四种接参方式+@EnableWebMvc

1.1 访问路径设置 RequestMapping注解的作用就是将请求的 URL 地址和处理请求的方式(handler方法)关联起来,建立映射关系。 SpringMVC 接收到指定的请求,就会来找到在映射关系中对应的方法来处理这个请求。 1.1.1 精准路径匹配…

京东ES支持ZSTD压缩算法上线了:高性能,低成本 | 京东云技术团队

1 前言 在《ElasticSearch降本增效常见的方法》一文中曾提到过zstd压缩算法[1],一步一个脚印我们终于在京东ES上线支持了zstd;我觉得促使目标完成主要以下几点原因: Elastic官方原因:zstd压缩算法没有在Elastic官方的开发计划中&…

【Leetcode Sheet】Weekly Practice 24

Leetcode Test 447 回旋镖的数量(1.8) 给定平面上 n 对 互不相同 的点 points ,其中 points[i] [xi, yi] 。回旋镖 是由点 (i, j, k) 表示的元组 ,其中 i 和 j 之间的距离和 i 和 k 之间的欧式距离相等(需要考虑元组的顺序)。 …

最新智能AI系统ChatGPT网站程序源码+详细图文搭建部署教程,Midjourney绘画,GPT语音对话+ChatFile文档对话总结+DALL-E3文生图

一、前言 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作Ch…

如何增加服务器的高并发

随着互联网的快速发展和普及,越来越多的应用程序需要支持高并发的请求处理。在这种情况下增加服务器的高并发能力成为了一个热门的话题。下面简单的介绍如果提高服务器的高并发能力。 负载均衡 是把请求分发到多个服务器上,来实现请求的平衡和分担。负…

使用JavaScript实现实时在线协作编辑器:从设计到实现

一、引言 随着Web技术的发展,实现在线协作编辑文档已经成为一种常见的需求。通过在线协作,多位用户可以同时编辑同一个文档,并实时看到其他用户的更改。这样的功能需要复杂的技术实现,包括数据同步、冲突解决和实时通信。本篇博客…

(一)环境部署

Python虚拟环境 安装virtualenv pip install virtualenv 创建环境 virtualenv -p D:\python\python.exe(python解释器目录) env-py3.6(虚拟环境目录,名称随意) 在当前目录下生成env-py3.6目录。 激活环境 ...\env-py3.6\Scripts> .\activate 关闭&#xf…

应用架构演变过程、rpc及Dubbo简介

一、应用架构演变历史: 单一应用架构 -> 垂直应用架构 -> 分布式服务架构 -> 微服务架构。 单一应用架构 当网站流量很小时,只需一个应用,将所有功能都部署在一起,以减少部署节点和成本。 此时,用于简化增删…

STM32 CubeIDE 使用 CMSIS-DAP烧录 (方法2--外部小工具)

前言: 本篇所用方法,需要借助一个外部的工具小软件。 优点:烧录更稳定; 缺点:不能在线仿真调试。 下面链接,是另一种方法:修改CubeIDE调试文件。能在CubeIDE直接烧录、仿真,但不稳定。…

Bazel

简介: Bazel 是 google 研发的一款开源构建和测试工具,也是一种简单、易读的构建工具。 Bazel 支持多种编程语言的项目,并针对多个平台构建输出。 高级构建语言:Bazel 使用一种抽象的、人类可读的语言在高语义级别上描述项目的构建属性。与其…

uniapp 简易自定义日历

1、组件代码 gy-calendar-self.vue <template><view class"calendar"><view class"selsct-date">请选择预约日期</view><!-- 日历头部&#xff0c;显示星期 --><view class"weekdays"><view v-for"…

Linux常用命令大全(三)

系统权限 用户组 1. 创建组groupadd 组名 2. 删除组groupdel 组名 3. 查找系统中的组cat /etc/group | grep -n “组名”说明&#xff1a;系统每个组信息都会被存放在/etc/group的文件中1. 创建用户useradd -g 组名 用户名 2. 设置密码passwd 用户名 3. 查找系统账户说明&am…

蓝桥杯java基础

2. AB问题II 时间限制&#xff1a;1.000S 空间限制&#xff1a;32MB 题目描述 计算ab&#xff0c;但输入方式有所改变。 输入描述 第一行是一个整数N&#xff0c;表示后面会有N行a和b&#xff0c;通过空格隔开。 输出描述 对于输入的每对a和b&#xff0c;你需要在相应的…

openssl快速生成自签名证书

系统&#xff1a;Centos 7.6 确保已安装openssl openssl version生成私钥文件 private.key &#xff08;文件名自定义&#xff09; openssl genpkey -algorithm RSA -out private.key -pkeyopt rsa_keygen_bits:2048-out private.key&#xff1a;生成的私钥文件-algorithm RS…

探索设计模式的魅力:工厂方法模式

工厂方法模式是一种创建型设计模式&#xff0c;它提供了一种创建对象的接口&#xff0c;但将具体实例化对象的工作推迟到子类中完成。这样做的目的是创建对象时不用依赖于具体的类&#xff0c;而是依赖于抽象&#xff0c;这提高了系统的灵活性和可扩展性。 以下是工厂方法模式的…

MySQL 8.0中移除的功能(二)

PROCEDURE ANALYSE()​ 语法已被移除。客户端的 ​--ssl​ 和 ​--ssl-verify-server-cert​ 选项已被移除。使用 ​--ssl-modeREQUIRED​ 代替 ​--ssl1​ 或 ​--enable-ssl​。使用 ​--ssl-modeDISABLED​ 代替 ​--ssl0​、​--skip-ssl​ 或 ​--disable-ssl​。使用 ​-…

chatgpt的基本技术及其原理

ChatGPT是一种基于生成式预训练的语言模型&#xff0c;它的基本技术包括预训练和微调。下面我将为你解释这些技术及其原理。 1. 预训练&#xff08;Pre-training&#xff09;: ChatGPT的预训练阶段是在大规模的文本数据上进行的。模型通过对大量的互联网文本进行自监督学习来学…