PolyGen: An Autoregressive Generative Model of 3D Meshes代码polygen_encoder.py解读

论文:PolyGen: An Autoregressive Generative Model of 3D Meshes

首先阅读transformer铺垫知识《Torch中Transformer的中文注释》。

以下为Encoder部分,很简单,小学生都会:

from typing import Optional
import pdbimport torch
import torch.nn as nn
from torch.nn import (MultiheadAttention,Linear,Dropout,LayerNorm,ReLU,Parameter,TransformerEncoderLayer,
)
import pytorch_lightning as plfrom .utils import embedding_to_paddingclass PolygenEncoderLayer(TransformerEncoderLayer):"""Polygen论文中描述的编码器模块"""def __init__(self,d_model: int = 256,nhead: int = 4,dim_feedforward: int = 1024,dropout: float = 0.2,re_zero: bool = True,) -> None:"""初始化PolygenEncoderLayer参数:d_model: 嵌入向量的大小,即模型的隐藏状态维度。nhead: 多头注意力机制的头数。dim_feedforward: 前馈网络(Feed Forward Network)的中间层维度。dropout: 每个连接层后ReLU激活函数后应用的dropout率,用于防止过拟合。re_zero: 如果为 True,使用零初始化对残差进行Alpha缩放,这是一种正则化技术,旨在改善模型的收敛速度和泛化能力。初始化:self_attn: 多头注意力层,使用 MultiheadAttention 实现。linear1 和 linear2: 两层线性变换,用于前馈网络的构建。dropout: 用于添加dropout的 Dropout 层。norm1 和 norm2: 用于层规范化(Layer Normalization)的 LayerNorm 层。activation: 激活函数,这里使用的是 ReLU。re_zero: 一个布尔变量,指示是否使用ReZero技术。alpha 和 beta: 如果使用ReZero技术,这两个参数用于缩放残差连接的输出,初始值为0。"""super(PolygenEncoderLayer, self).__init__(d_model, nhead, dim_feedforward=dim_feedforward, dropout=dropout)self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)self.linear1 = Linear(d_model, dim_feedforward)self.linear2 = Linear(dim_feedforward, d_model)self.dropout = Dropout(dropout)self.norm1 = LayerNorm(d_model)self.norm2 = LayerNorm(d_model)self.activation = ReLU()self.re_zero = re_zeroself.alpha = Parameter(data=torch.Tensor([0.0]))self.beta = Parameter(data=torch.Tensor([0.0]))def forward(self,src: torch.Tensor,src_mask: Optional[torch.Tensor] = None,src_key_padding_mask: Optional[torch.Tensor] = None,) -> torch.Tensor:"""PolygenEncoderLayer的前向传播方法参数:src: 形状为 [sequence_length, batch_size, embed_size] 的张量。传入TransformerEncoder的输入张量src_mask: 形状为 [sequence_length, sequence_length] 的张量。输入序列的掩码src_key_padding_mask: 形状为 [sequence_length, batch_size] 的张量。告诉注意力机制哪些输入序列的部分应该被忽略,因为它们是填充的返回:src: 形状为 [sequence_length, batch_size, embed_size] 的张量计算流程:自我注意力计算:首先,对输入张量 src 进行层规范化,然后将其作为查询、键和值传入多头注意力层 self_attn。注意力计算中,可以使用 src_mask 和 src_key_padding_mask 控制哪些位置的注意力应该被屏蔽。残差连接与Dropout:如果使用ReZero技术,将注意力层的输出乘以 alpha 参数。应用dropout,并将结果与输入张量 src 进行残差连接。前馈网络计算:再次对 src 进行层规范化,然后通过两层线性变换和激活函数。如果使用ReZero技术,将前馈网络的输出乘以 beta 参数。应用dropout,并将结果与 src 进行残差连接。"""src2 = self.norm1(src)src2 = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]# ReZero is All You Need: Fast Convergence at Large Depth https://arxiv.org/abs/2003.04887# The parameter α is initialized to 0 at the beginning of training,# and the output of the residual block is almost entirely dependent on its input,# thus avoiding the problem of the gradient disappearing or exploding.# As the training progresses, alpha gradually learns the optimal value,# allowing the internal representation of the residual block to gradually influence the final output.if self.re_zero:src2 = src2 * self.alphasrc2 = self.dropout(src2)src = src + src2src2 = self.norm2(src)src2 = self.linear1(src2)src2 = self.linear2(src2)if self.re_zero:src2 = src2 * self.betasrc2 = self.dropout(src2)src = src + src2return srcclass PolygenEncoder(pl.LightningModule):"""A modified version of the traditional Transformer Encoder suited for Polygen input sequences"""def __init__(self,hidden_size: int = 256,fc_size: int = 1024,num_heads: int = 4,layer_norm: bool = True,num_layers: int = 8,dropout_rate: float = 0.2,) -> None:"""Initializes the PolygenEncoderArgs:hidden_size: Size of the embedding vectors.fc_size: Size of the fully connected layer.num_heads: Number of multihead attention heads.layer_norm: Boolean variable that signifies if layer normalization should be used.num_layers: Number of decoder layers in the decoder.dropout_rate: Dropout rate applied immediately after the ReLU in each fully connected layer."""super(PolygenEncoder, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.encoder = nn.TransformerEncoder(PolygenEncoderLayer(d_model=hidden_size,nhead=num_heads,dim_feedforward=fc_size,dropout=dropout_rate,),num_layers=num_layers,)self.norm = LayerNorm(hidden_size)def forward(self, inputs: torch.Tensor) -> torch.Tensor:"""Forward method for the Transformer EncoderArgs:inputs: A Tensor of shape [sequence_length, batch_size, embed_size]. Represents the input sequence.Returns:outputs: A Tensor of shape [sequence_length, batch_size, embed_size]. Represents the result of the TransformerEncoder"""padding_mask = embedding_to_padding(inputs)out = self.encoder(inputs, src_key_padding_mask=padding_mask)return self.norm(out)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/40004.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mxd的地图文件 像百度地图那样在vue页面上展示出来

要在Vue页面上展示MXD地图文件,可以使用一些开源的JavaScript库来实现。以下是一种可能的方法: 1. 安装ArcGIS API for JavaScript:在Vue项目中使用ArcGIS API for JavaScript可以轻松地加载和展示地图。在命令行中运行以下命令来安装该库&a…

HexPlane: A Fast Representation for Dynamic Scenes(总结图)

图1。用于动态三维场景的 Hex刨面。我们没有从深度 MLP 中回归颜色和不透明度,而是通过 HexPlann 显式地计算时空点的特征。配对一个微小的 MLP,它允许以上100倍加速匹配的质量。 图2。方法概述。Hex刨包含六个特征平面,跨越每对坐标轴(例如…

PyTorch计算机视觉实战:目标检测、图像处理与深度学习

本书基于真实数据集,全面系统地阐述现代计算机视觉实用技术、方法和实践,涵盖50多个计算机视觉问题。全书分为四部分:一部分介绍神经网络和PyTorch的基础知识,以及如何使用PyTorch构建并训练神经网络,包括输入数据缩放…

【前端VUE】VUE3第一节—vite创建vue3工程

什么是VUE Vue (发音为 /vjuː/,类似 view) 是一款用于构建用户界面的 JavaScript 框架。它基于标准 HTML、CSS 和 JavaScript 构建,并提供了一套声明式的、组件化的编程模型,帮助你高效地开发用户界面。无论是简单还是复杂的界面&#xff0…

深入了解自动化:聊聊什么项目适合做自动化测试?

自动化测试 什么是自动化测 什么是自动化测试? 随着软件产业的不断发展,市场对软件周期的要求越来越高,于是催生了各种开发模式,如大家熟知的敏捷开发,从而对测试提出了更高的要求。此时,产生了自动化测试…

启航IT之旅:高考假期预习指南

标题:启航IT之旅:高考假期预习指南 随着高考的落幕,许多有志于IT领域的学子们即将踏上新的学习旅程。这个假期,是他们探索IT世界的黄金时期。本文将为准IT新生们提供一份全面的预习指南,帮助他们为未来的学习和职业生…

008 数组队列(lua)

文章目录 初步array.luaarrayqueue.lua 修改(封装)array.luaarrayqueue.lua测试(直接在 arrayqueue.lua 文件的末尾添加) 修改(本身就是动态扩容)array.luaarrayqueue.lua 循环队列LoopQueue.lua 初步 array.lua Java是一种静态类型、面向对象的编程语言…

Linux高并发服务器开发(十)反应堆模型和线程池模型

文章目录 1 epoll反应堆2 线程池流程代码 3 复杂版本线程池代码 1 epoll反应堆 文件描述符 监听事件 回调函数 进行封装 创建socket设置端口复用绑定监听创建epoll树将监听文件描述符lfd上epoll树,对应的事件节点包括:文件描述符,事件epoll…

Taogogo Taocms v3.0.2 远程代码执行漏洞(CVE-2022-25578)

前言 CVE-2022-25578 是一个存在于 Taogogo Taocms v3.0.2 中的代码注入漏洞。此漏洞允许攻击者通过任意编辑 .htaccess 文件来执行代码注入。 漏洞详情 漏洞描述:攻击者可以利用此漏洞上传一个 .htaccess 文件到网站,并在文件中注入恶意代码&#xf…

Memcached缓存键命名规范:最佳实践与技巧

引言 Memcached是一个广泛使用的高性能分布式内存缓存系统,它通过键值对的方式存储数据,以提高数据检索速度。正确的缓存键命名对于维护Memcached缓存的效率和可管理性至关重要。本文将详细介绍Memcached缓存键的命名规范和最佳实践。 Memcached缓存键…

苹果手机怎么刷机?适合小白的刷机办法!

自己的苹果手机用时间长了,有些人想要为自己的手机重新刷新一下,但又不知道怎么刷机。不要慌现在就来给大家详细介绍一下苹果手机怎么刷机,希望可以帮助到大家。 iPhone常见的刷机方式,分为iTunes官方和第三方软件两种刷机方式。 …

【elementui】记录解决el-tree开启show-checkbox后,勾选一个叶结点后会自动折叠的现象

第一种解决方案&#xff1a;设置default-expand-keys的值为当前选中的key值即可 <el-treeref"tree"class"checkboxSelect-wrap":data"treeData"show-checkboxnode-key"id":expand-on-click-node"true":props"defau…

游戏云服务器为什么经常卡顿不流畅?

游戏云服务器经常出现卡顿或不流畅的情况可能是由多种因素造成的。以下是一些常见的原因&#xff1a; 网络问题 - 带宽不足&#xff1a;如果服务器的带宽不足以支持高峰时段的所有玩家同时在线&#xff0c;就会导致数据传输缓慢&#xff0c;引起卡顿。 - 网络延迟&#xff1a;网…

第T3周:天气识别

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 一、前期工作 本文将采用CNN实现多云、下雨、晴、日出四种天气状态的识别。较上篇文章&#xff0c;本文为了增加模型的泛化能力&#xff0c;新增了Dropout层并…

三、文件操作、错误与异常处理等(爬虫及数据可视化)

三、文件操作、错误与异常处理等&#xff08;爬虫及数据可视化&#xff09; 1&#xff0c;文件操作2&#xff0c;错误与异常 1&#xff0c;文件操作 学习文件操作的相关知识&#xff0c;将一些数据存起来&#xff0c;打开、关闭、读取、写入&#xff0c;重命名、删除等操作在o…

拉曼光谱入门:1.光谱的分类与散射光谱发展史

一、光谱是什么&#xff1f; 在一个宁静的午后&#xff0c;年轻的艾萨克牛顿坐在他母亲花园里的一棵苹果树下&#xff0c;手握一块精致的三棱镜。他沉思着光的奥秘&#xff0c;意识到光并非单一的白色&#xff0c;而是一种由多彩色组成的复杂结构。 他决心进行一次实验&#xf…

浅析C++函数重载

浅析C函数重载 C语言和C函数调用的不同 C语言会进行报错 C能成功运行并且自动识别类型 由此可以看出&#xff0c;C在函数调用时进行了调整&#xff0c;使其支持函数重载&#xff0c;那么我们就来看看进行了哪些调整吧&#x1f60e; 分析函数调用 首先我们要知道&#xff0c…

SQL中常用的内置函数

SQL中常用的内置函数 在SQL&#xff08;结构化查询语言&#xff09;中&#xff0c;有许多内置函数可用于各种数据操作和计算。以下是SQL中常用的函数。 一.字符串操作、数值计算、日期处理 COUNT(): 统计行数。 SELECT COUNT(*) FROM employees;SUM(): 计算数值列的总和。 S…

2024企业数据资产化及数据资产入表方案梳理

01 数据资产入表&#xff1a;是一个将组织的各类数据资产进行登记、分类、评估和管理的流程。 数据资产包括&#xff1a;客户信息、交易记录、产品数据、财务数据等。 做个比喻吧&#xff1a;数据资产入表就像是给公司的数据资产做“人口普查”—— ①找出公司有哪些数据找…

macos m2 百度paddleocr文字识别 python

创建了一个虚拟环境&#xff1a;conda create -n orc python3.11.7 进入虚拟环境后执行2条命令 pip install paddleocr -i https://pypi.tuna.tsinghua.edu.cn/simple pip install paddlepaddle -i https://pypi.tuna.tsinghua.edu.cn/simple​ ​ 安装好后&#xff0c;在网…