AIGC笔记--U-ViT的简单代码实现

1--前言

原论文:All are Worth Words: A ViT Backbone for Diffusion Models

完整可debug的代码:

2--结构

3--简单代码

        以视频作为输入,实现上图红色框的计算:

import torch
import torch.nn as nn
from einops import rearrangeclass PatchEmbedding(nn.Module):def __init__(self, in_channels = 3, patch_size = 16, emb_dim = 768):super().__init__()self.patch_size = patch_size# 由于是视频作为输入,因此使用Conv3dself.proj = nn.Conv3d(in_channels = in_channels, out_channels = emb_dim, kernel_size = (1, patch_size, patch_size), stride = (1, patch_size, patch_size))def forward(self, x):# x.shape: [batch_size, channels, frames, height, width]x = self.proj(x) # (batch_size, emd_dim, frames, height // patch_size, width // patch_size)x = rearrange(x, 'b e t h w -> b (t h w) e') # flatten into patchesreturn xclass TransformerEncoder(nn.Module):def __init__(self, emb_dim = 768, num_heads = 8, ff_dim = 1024, dropout = 0.1, skip = False):super().__init__()self.layernorm1 = nn.LayerNorm(emb_dim)self.self_attn = nn.MultiheadAttention(embed_dim = emb_dim, num_heads = num_heads, dropout = dropout)self.dropout1 = nn.Dropout(dropout)self.layernorm2 = nn.LayerNorm(emb_dim)self.ffn = nn.Sequential(nn.Linear(emb_dim, ff_dim),nn.ReLU(),nn.Linear(ff_dim, emb_dim))self.dropout2 = nn.Dropout(dropout)if skip:self.skip_linear = nn.Linear(2 * emb_dim, emb_dim)def forward(self, x, skip = None):# x.shape: [b, thw, c]if(skip != None): # long skip connectionx = self.skip_linear(torch.cat([x, skip], dim=-1)) # [b, thw, 2c] -> [b, thw, c]x2 = self.layernorm1(x) # normx2, _ = self.self_attn(x2, x2, x2) # attentionx = x + self.dropout1(x2) # skipx2 = self.layernorm2(x) # normx2 = self.ffn(x2) # mlpx = x + self.dropout2(x2) # skipreturn xclass UViT(nn.Module):def __init__(self, in_channels = 3, patch_size = 16, emb_dim = 768, depth = 5, num_heads = 8, ff_dim = 1024, dropout = 0.1):super().__init__()self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_dim)self.in_blocks = nn.ModuleList([TransformerEncoder(emb_dim = emb_dim, num_heads = num_heads, ff_dim = ff_dim, dropout = dropout) for _ in range(depth // 2)])self.mid_block = TransformerEncoder(emb_dim = emb_dim, num_heads = num_heads, ff_dim = ff_dim, dropout = dropout)self.out_blocks = nn.ModuleList([TransformerEncoder(emb_dim = emb_dim, num_heads = num_heads, ff_dim = ff_dim, dropout = dropout, skip = True) for _ in range(depth // 2)])self.layernorm = nn.LayerNorm(emb_dim)def forward(self, x): # [b, c, t, h, w]x = self.patch_embed(x) # [b, t * (h//patch) * (w//patch), c']skips = [] # 存储in_blocks的输出, 作为long skip, 本质上是一个栈(后进先出)for blk in self.in_blocks:x = blk(x)skips.append(x)x = self.mid_block(x)for blk in self.out_blocks:x = blk(x, skips.pop()) # 栈的pop,作为long skipx = self.layernorm(x)return xif __name__ == "__main__":batch_size = 2in_channels = 3frames = 8height = 128width = 128video_tensor = torch.randn(batch_size, in_channels, frames, height, width).cuda() # 这里以视频作为输入,原论文是以图片作为输入# depth 必须是奇数,因为Long skip connection的存在model = UViT(in_channels = in_channels, patch_size = 16, emb_dim = 768, depth = 5, num_heads = 8, ff_dim = 1024, dropout = 0.1).cuda()output = model(video_tensor)print("output.shape: ", output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/34610.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux-笔记 OverlayFS文件系统入门

目录 前言 主要概念 工作原理 特点特性 1、上下合并 2、同名文件覆盖 3、同名目录合并 4、写时拷贝 实操入门 内核配置 挂载文件系统 验证 1、同名文件覆盖 2、同名目录合并 3、写时拷贝 1)验证新增文件或目录 2)验证修改文件 3&…

昇思25天学习打卡营第3天|张量Tensor

张量Tensor 概念创建张量(4种方式)张量的属性张量索引张量运算Tensor与NumPy转换 概念 张量(Tensor)是一种特殊的数据结构,与数组和矩阵非常相似。张量是MindSpore网络运算中的基本数据结构。 创建张量(4…

MySQL 7种Join的定义图解示范结果(所有join类型)

文章目录 MySQL 7种Join的定义&图解&示范&结果(所有join类型)基本知识笛卡尔积 建表&填充数据1-Join不带条件account筛选 1-Inner Join 内连接不带条件account相同where筛选玩点特殊的 2-Left Join 左连接不带条件account筛选 3-Right J…

安全技术和防火墙(iptables)

安全技术 入侵检测系统:特点是不阻断网络访问,主要是提供报警和事后监督,不主动介入,类似于监控。 入侵防御系统:透明模式工作,对数据包,网络监控,服务攻击,木马&#…

HTTP协议中的各种请求头、请求类型的作用以及用途

目录 一、http协议介绍二、http协议的请求头三、http协议的请求类型四、http协议中的各种请求头、请求类型的作用以及用途 一、http协议介绍 HTTP(HyperText Transfer Protocol,超文本传输协议)是一种用于分布式、协作式和超媒体信息系统的应…

python flask 入门-helloworld

学习视频链接: 01-【前奏】课程介绍_哔哩哔哩_bilibili 1.安装flask pip install flask 踩坑记:本机不要连代理,否则无法install 提示报错valueError: check_hostname requires server_hostname 2.程序编写 在根目录下创建 app.py fr…

React实现二级评论

1. 什么是二级评论 图片来源–blackfrog的掘金文章 口语化的讲当我发布一个评论的时候就是一级评论,当我回复我发布的评论的时候就是二级评论并且将所有回复二级评论的评论也归于二级评论。 2. 二级评论功能的实现逻辑 在这里后端设计了四个接口分别是 获取所有…

jdk1.8升级到jdk11遇到的各种问题

一、第三方依赖使用了BASE64Decoder 如果项目中使用了这个类 sun.misc.BASE64Decoder,就会导致错误,因为再jdk11中,该类已经被删除。 Caused by: java.lang.NoClassDefFoundError: sun/misc/BASE64Encoder 当然这个类也有替换方式&#xf…

第 27 篇 : 搭建maven私服nexus

官网文档 1. 下载应该很慢, 最好是能翻墙 nexus-3.69.0-02-java8-unix.tar.gz 2. 上传到/usr/local/src, 解压及重命名 tar -zxvf nexus-3.69.0-02-java8-unix.tar.gz rm -rf nexus-3.69.0-02-java8-unix.tar.gz mv nexus-3.69.0-02 nexus ls3. 修改配置 cd /usr/local/sr…

作 业 二

cs与msf权限传递 1、进入cs界面,首先来到 Cobalt Strike 目录下,启动 Cobalt Strike 服务端 2、用客户端进 3、建立监听 4、生成脚本文件 5、开启服务,让win_2012 下载木马文件并运行 6、显示已经获取到了win的权限 转到Metasploit Framework 7、进去m…

智慧仓储的秘密武器:数据可视化的应用

智慧仓储中数据可视化是如何应用的?在现代物流和供应链管理中,智慧仓储已成为企业提升效率、降低成本和优化运营的重要手段。而数据可视化作为智慧仓储的重要工具,通过将复杂的数据转化为直观、易理解的图表和图形,极大地提升了仓…

MySQL实训--原神数据库

原神数据库 er图DDL/DML语句查询语句存储过程/触发器 er图 DDL/DML语句 SET NAMES utf8mb4; SET FOREIGN_KEY_CHECKS 0;DROP TABLE IF EXISTS artifacts; CREATE TABLE artifacts (id int NOT NULL AUTO_INCREMENT,artifacts_name varchar(255) CHARACTER SET utf8 COLLATE …

玩机进阶教程----MTK芯片使用Maui META修复基带 改写参数详细教程步骤解析

目前mtk芯片与高通芯片在主流机型 上使用比较普遍。但有时候版本更新或者误檫除分区等等原因会导致手机基带和串码丢失的故障。mtk芯片区别与高通。在早期mtk芯片中可以使用工具SN_Writer_Tool读写参数。但一些新版本机型兼容性不太好。今天使用另外一款工具来演示mtk芯片改写参…

Cesium 基本概念:创建实体和相机控制

基本概念 Entity // 创建一个实体 const entity_1 viewer.entities.add({position: new Cesium.Cartesian3(0, 0, 10000000),point: {pixelSize: 10,color: Cesium.Color.BLUE} });// 通过经纬度创建实体 const position Cesium.Cartesian3.fromDegrees(180.0, 0.0); // 创…

MySQL——自连接及联表查询练习

自连接 自己的表和自己的表连接,核心:一张表拆为两张一样的表即可。 父类: categoryidcategoryName2信息技术3软件开发5美术设计 子类: pidcategoryidcategoryName34数据库28办公信息36web开发57ps技术 子类的pid 父类的cate…

计算机缺失d3dx9_43.dll的多种解决方法,哪种更推荐使用

我在使用计算机时遇到了一个问题,系统提示我丢失了d3dx9_43.dll文件。丢失d3dx9_43.dll文件通常是由于DirectX组件未正确安装或损坏所致,这直接影响到依赖于DirectX的游戏和应用的运行。经过一番搜索和尝试,我找到了多种修复这个问题的方法&a…

2024最新SCI期刊影响因子发布(JCR2023)(含Top100榜单)

Clarivate Analytics(科睿唯安)2024年度《期刊引证报告》(Journal Citation Reports,简称JCR)发布了SCI期刊2023年影响因子(IF)。该指数备受访问学者、联培博士及博士后研究者关注。今天知识人网小编就简要介绍最新SCI…

【STM32 RTC实时时钟如何配置!超详细的解析和超简单的配置,附上寄存器操作】

STM32 里面RTC模块和时钟配置系统(RCC_BDCR寄存器)处于后备区域,即在系统复位或从待机模式唤醒后,RTC的设置和时间维持不变。因为系统对后备寄存器和RTC相关寄存器有写保护,所以如果想要对后备寄存器和RTC进行访问,则需要通过操作…

一文详解:什么是企业邮箱?最全百科

什么是企业邮箱?企业邮箱即绑定企业自有域名作为邮箱后缀的邮箱,是企业用于内部成员沟通和客户沟通的邮箱系统。 一、企业邮箱概念拆解 1.什么是企业邮箱? 企业邮箱即使用企业域名作为后缀的邮箱系统。它不仅提供专业的电子邮件收发功能&a…

【学习】使用PyTorch训练与评估自己的ResNet网络教程

参考:保姆级使用PyTorch训练与评估自己的ResNet网络教程_训练自己的图像分类网络resnet101 pytorch-CSDN博客 项目地址:GitHub - Fafa-DL/Awesome-Backbones: Integrate deep learning models for image classification | Backbone learning/comparison…