Transformer中的类别嵌入

类别嵌入

self.class_embedding = nn.Parameter(scale * torch.randn(width))

这一行代码的作用是在 VisionTransformer 类中创建并初始化一个类别嵌入向量(class embedding vector),用于表示输入序列的类别信息。

详细解释

类别嵌入

在 Transformer 模型中,类别嵌入(class embedding)是一种特殊的嵌入向量,通常用于表示整个输入序列的全局信息。它在视觉变压器(Vision Transformer)模型中起到了类似于 [CLS](分类)标记在 BERT 中的作用。

代码解释
self.class_embedding = nn.Parameter(scale * torch.randn(width))
1. torch.randn(width)
  • 作用:生成一个形状为 (width,) 的张量,其元素从标准正态分布(均值为0,标准差为1)中随机采样。
  • 示例:假设 width = 768,那么 torch.randn(width) 将生成一个包含 768 个元素的一维张量,每个元素为从标准正态分布中采样的随机数。
2. scale *
  • 作用:对生成的随机张量进行缩放。
  • scale 的定义:在代码中,scale 被定义为 width ** -0.5,即 1 / sqrt(width)。这个缩放因子通常用于标准化初始化值,使其具有适当的尺度。
  • 示例:假设 width = 768,那么 scale = 768 ** -0.5 = 1 / sqrt(768)
3. nn.Parameter(...)
  • 作用:将张量封装为一个 nn.Parameter 对象,这意味着这个张量将被视为模型的参数,会在训练过程中进行优化。
  • nn.Parameter 的使用nn.Parametertorch.Tensor 的子类,当其被赋值给 nn.Module(如 nn.Linearnn.Conv2d 等)的属性时,会被自动添加到模块的参数列表中,并在调用 model.parameters() 时返回。

综合解释

self.class_embedding = nn.Parameter(scale * torch.randn(width))
  • 生成随机张量:首先,通过 torch.randn(width) 生成一个形状为 (width,) 的随机张量,其元素从标准正态分布中采样。
  • 缩放张量:然后,将这个张量乘以 scale(即 width ** -0.5),以标准化其值。
  • 创建参数:最后,将这个缩放后的张量封装为一个 nn.Parameter 对象,并赋值给 self.class_embedding

作用

  • 类别嵌入self.class_embedding 是一个可训练的参数,用于表示输入序列的类别信息。
  • 模型训练:在模型训练过程中,self.class_embedding 会被优化,以适应训练数据,帮助模型更好地捕捉全局信息。

示例

假设我们有一个 width = 768VisionTransformer 模型实例:

import torch
import torch.nn as nnclass VisionTransformer(nn.Module):def __init__(self, width):super().__init__()scale = width ** -0.5self.class_embedding = nn.Parameter(scale * torch.randn(width))# 创建一个 width 为 768 的 VisionTransformer 实例
vit = VisionTransformer(768)# 打印 class_embedding 的形状和部分值
print(vit.class_embedding.shape)  # 输出: torch.Size([768])
print(vit.class_embedding[:5])  # 输出: 类别嵌入的前 5 个元素的值

这段代码会输出 class_embedding 的形状(应为 [768])和前 5 个元素的值。

总结

这一行代码 self.class_embedding = nn.Parameter(scale * torch.randn(width))VisionTransformer 类中创建并初始化了一个类别嵌入向量。该向量是一个可训练的参数,用于表示输入序列的全局信息,并在模型训练过程中进行优化。

为什么它用于表示输入序列的全局信息?

类别嵌入(class embedding)用于表示输入序列的全局信息是基于其设计和应用目的,尤其是在 Transformer 模型中的使用方法。让我们详细解释为什么类别嵌入在 Vision Transformer 中起到表示全局信息的作用。

类别嵌入的作用和背景

在 Transformer 模型中,类别嵌入(class embedding 或者 class token)是一种特殊的嵌入向量,通常添加到输入序列的开头。它的设计灵感来自于 BERT 模型中的 [CLS] token,用于捕获整个输入序列的全局信息,并且常用于分类任务。

Vision Transformer (ViT) 中的类别嵌入

在 Vision Transformer 模型中,输入图像首先被划分成若干个固定大小的图像补丁(patches),然后这些补丁被展平并嵌入到一个高维特征空间中。这些补丁嵌入被视为一个序列,类似于自然语言处理任务中的单词序列。

具体步骤
  1. 图像划分为补丁

    • 原始图像被划分为若干个不重叠的图像补丁。
    • 每个补丁被展平成一个向量,并通过线性变换(卷积操作)嵌入到高维空间。
  2. 添加类别嵌入

    • 类别嵌入(class embedding)被添加到补丁序列的开头。这个类别嵌入是一个可训练的参数,用于捕获整个序列的全局信息。
    • 序列的第一个位置(位置 0)被保留给类别嵌入,后续位置由图像补丁嵌入填充。
  3. 位置嵌入

    • 位置嵌入(positional embedding)被加到每个补丁嵌入和类别嵌入上,用于保留序列中的位置信息。
  4. 通过 Transformer 模块

    • 这个序列被传递给 Transformer 模块进行处理。Transformer 模块会对整个序列(包括类别嵌入和补丁嵌入)进行多层注意力和前馈神经网络的计算。
  5. 提取全局信息

    • 在处理完成后,类别嵌入会包含整个输入序列的全局信息,因为它通过多头自注意力机制与所有补丁嵌入进行了交互。
    • 通常在分类任务中,类别嵌入被传递给一个分类头(classification head)以进行最终的分类预测。

代码解释

在你的 VisionTransformer 代码中,这一行定义了类别嵌入:

self.class_embedding = nn.Parameter(scale * torch.randn(width))

在前向传播中,类别嵌入被添加到补丁序列的开头:

x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
具体步骤解释
  1. 初始化类别嵌入

    • self.class_embedding 是一个形状为 (width,) 的可训练参数,用于表示类别嵌入。
  2. 添加到序列开头

    • 创建一个与输入张量 x 同 dtype 和设备的全零张量,其形状为 [batch_size, 1, width]
    • 将类别嵌入加到这个全零张量上,并与输入张量 x 在第一个维度上进行拼接。
    • 这样,类别嵌入被添加到每个输入样本的序列开头,形成一个新的张量,其形状为 [batch_size, grid ** 2 + 1, width]
  3. 通过 Transformer 模块

    • 类别嵌入与图像补丁嵌入一起传递给 Transformer 模块,进行多层自注意力计算。
    • Transformer 模块会将类别嵌入与所有补丁嵌入交互,从而使类别嵌入逐步聚合整个序列的信息。

总结

类别嵌入(class embedding)在 Vision Transformer 中用于表示输入序列的全局信息,因为它在前向传播中通过自注意力机制与所有图像补丁嵌入进行交互,从而捕获整个序列的全局特征。在分类任务中,这个类别嵌入最终用于做出分类决策,因此它被设计为包含全局信息。

添加类别嵌入

x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]

这行代码的作用是将类别嵌入(class embedding)添加到输入张量的序列开头。具体步骤如下:

  1. 创建一个类别嵌入向量并调整其数据类型和设备

    • self.class_embedding.to(x.dtype) 将类别嵌入向量的类型转换为输入张量 x 的数据类型。
    • torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) 创建一个形状为 [batch_size, 1, width] 的全零张量,其数据类型和设备与输入张量 x 一致。
    • self.class_embedding.to(x.dtype) + torch.zeros(...) 创建一个新的张量,其第一维度是类别嵌入,形状为 [batch_size, 1, width]
  2. 将类别嵌入添加到输入张量 x 的序列开头

    • x 是输入张量,其形状为 [batch_size, num_patches, width]
    • torch.cat([...], dim=1) 在第一个维度(即序列维度)上拼接类别嵌入和输入张量 x
  3. 更新后的张量形状

    • 拼接后的张量形状为 [batch_size, num_patches + 1, width],其中 num_patches + 1 表示原始的 num_patches 加上一个类别嵌入。

代码详解

让我们逐步分解这行代码:

1. 调整类别嵌入的数据类型
self.class_embedding.to(x.dtype)
  • self.class_embedding 是一个形状为 [width] 的向量。
  • .to(x.dtype) 将其数据类型转换为输入张量 x 的数据类型。
2. 创建一个全零张量
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
  • x.shape[0] 是批次大小(batch size)。
  • 1 是类别嵌入的位置。
  • x.shape[-1] 是特征宽度(width)。
  • dtype=x.dtype 设置全零张量的数据类型与输入张量 x 一致。
  • device=x.device 设置全零张量的设备与输入张量 x 一致。
3. 添加类别嵌入
self.class_embedding.to(x.dtype) + torch.zeros(...)
  • 将类别嵌入向量与全零张量相加,形成一个形状为 [batch_size, 1, width] 的张量。
4. 拼接张量
torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)
  • 使用 torch.cat 在第一个维度(序列维度)上拼接类别嵌入和输入张量 x
  • 最终形状为 [batch_size, num_patches + 1, width]

示例代码

以下是一个完整的示例,展示了上述步骤:

import torch
import torch.nn as nnclass VisionTransformer(nn.Module):def __init__(self, width):super().__init__()scale = width ** -0.5self.class_embedding = nn.Parameter(scale * torch.randn(width))def forward(self, x: torch.Tensor):# 模拟输入张量 x,假设其形状为 [batch_size, num_patches, width]batch_size = x.shape[0]num_patches = x.shape[1]width = x.shape[2]# 添加类别嵌入x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(batch_size, 1, width, dtype=x.dtype, device=x.device), x], dim=1)return x# 创建 VisionTransformer 实例
width = 768
vit = VisionTransformer(width)# 创建一个模拟输入张量 x
batch_size = 2
num_patches = 196
x = torch.randn(batch_size, num_patches, width)# 调用 forward 方法
output = vit(x)# 输出张量形状
print(output.shape)  # 输出: torch.Size([2, 197, 768])

总结

这行代码通过在序列开头添加一个类别嵌入向量,将输入张量的形状从 [batch_size, num_patches, width] 扩展到 [batch_size, num_patches + 1, width]。这个类别嵌入向量在 Transformer 模型中用于捕获整个序列的全局信息,通常用于分类任务的最终决策。

什么是一个序列?

在计算机科学和数据处理的上下文中,一个序列(sequence)通常是指一个有序的元素集合,这些元素按照一定的顺序排列,可以是数字、字符、图像补丁等。在不同的应用场景中,序列的具体形式和内容可能有所不同,但它们都有一个共同点,即元素之间存在顺序关系。

序列的概念

一般概念
  • 序列 是一个元素的有序集合,元素可以是任何类型的数据(如数字、字符、图像块等)。
  • 顺序关系 是指序列中的元素按照特定的顺序排列,顺序信息通常是重要的,影响对序列的处理和理解。
具体示例
  1. 数值序列:如 1, 2, 3, 4, 5。这些数字按照从小到大的顺序排列。
  2. 字符序列:如 "hello"。字符 'h', 'e', 'l', 'l', 'o' 按照它们在字符串中的顺序排列。
  3. 图像补丁序列:如将一张图像划分为多个小的图像块,这些块按照它们在图像中的位置顺序排列。

在深度学习中的序列

在深度学习和神经网络中,序列数据是非常常见的输入类型。以下是几个典型的例子:

自然语言处理(NLP)

在 NLP 中,输入通常是单词序列或字符序列。比如:

  • 单词序列:一个句子 “The cat sat on the mat” 可以被表示为一个单词序列 ["The", "cat", "sat", "on", "the", "mat"]
  • 字符序列:同一个句子可以被表示为一个字符序列 ["T", "h", "e", " ", "c", "a", "t", " ", "s", "a", "t", " ", "o", "n", " ", "t", "h", "e", " ", "m", "a", "t"]
时间序列数据

在时间序列分析中,数据通常是按时间顺序排列的观测值。例如:

  • 股票价格序列:记录某只股票在不同时间点的价格。
  • 传感器数据序列:记录传感器在不同时间点的读数。
图像处理

在图像处理任务中,图像可以被划分为多个小块,每个小块作为序列中的一个元素进行处理。这种方法在 Vision Transformer (ViT) 中得到了应用。

Vision Transformer (ViT) 中的序列

在 Vision Transformer 中,图像被划分为固定大小的补丁(patches),这些补丁被视为一个序列来处理。具体步骤如下:

  1. 图像划分为补丁

    • 将输入图像划分为固定大小的非重叠补丁。例如,一个 224x224 的图像可以划分为 16x16 的补丁,总共有 (224/16)^2 = 196 个补丁。
  2. 将补丁展平并嵌入

    • 每个补丁被展平成一个向量,并通过线性变换嵌入到一个高维特征空间(例如,768 维)。
  3. 形成补丁序列

    • 将所有补丁嵌入向量连接起来,形成一个补丁嵌入序列,其形状为 [batch_size, num_patches, embedding_dim]
  4. 添加类别嵌入和位置嵌入

    • 类别嵌入(class embedding)被添加到序列的开头,用于捕获整个图像的全局信息。
    • 位置嵌入(positional embedding)被加到每个补丁嵌入上,以保留序列中的位置信息。

示例代码

以下是一个简单的 Vision Transformer 示例代码:

import torch
import torch.nn as nnclass VisionTransformer(nn.Module):def __init__(self, patch_size: int, stride_size: int, width: int, num_patches: int, output_dim: int):super().__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, bias=False)scale = width ** -0.5self.class_embedding = nn.Parameter(scale * torch.randn(width))self.positional_embedding = nn.Parameter(scale * torch.randn(num_patches + 1, width))self.ln_pre = LayerNorm(width)self.transformer = Transformer(width, layers=12, heads=12)self.ln_post = LayerNorm(width)self.proj = nn.Parameter(scale * torch.randn(width, output_dim))def forward(self, x: torch.Tensor):# 提取补丁嵌入x = self.conv1(x)  # [batch_size, width, num_patches_h, num_patches_w]x = x.flatten(2)   # [batch_size, width, num_patches]x = x.transpose(1, 2)  # [batch_size, num_patches, width]# 添加类别嵌入cls_embed = self.class_embedding.expand(x.shape[0], -1).unsqueeze(1)  # [batch_size, 1, width]x = torch.cat((cls_embed, x), dim=1)  # [batch_size, num_patches + 1, width]# 添加位置嵌入x = x + self.positional_embedding# 前向传播x = self.ln_pre(x)x = x.permute(1, 0, 2)  # 转换为 [sequence_length, batch_size, width]x = self.transformer(x)x = x.permute(1, 0, 2)  # 转换为 [batch_size, sequence_length, width]x = self.ln_post(x[:, 0, :])  # 取出类别嵌入# 投影到输出维度x = x @ self.projreturn x# 模拟输入图像
input_image = torch.randn(1, 3, 224, 224)  # [batch_size, channels, height, width]# 创建 VisionTransformer 实例
vit = VisionTransformer(patch_size=16, stride_size=16, width=768, num_patches=(224 // 16) ** 2, output_dim=1000)# 前向传播
output = vit(input_image)
print(output.shape)  # [1, 1000]

总结

序列是一个有序的元素集合,在深度学习中,序列数据(如文本、时间序列、图像补丁等)可以通过 Transformer 模型进行有效处理。Vision Transformer 通过将图像划分为补丁并视为序列,使得图像处理任务可以直接应用于序列建模技术,从而利用 Transformer 在捕捉长距离依赖关系方面的优势。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/32674.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

昇思25天学习打卡营第4天|数据变换Transforms

学习内容复盘 1.1 数据变换 什么是数据变换、为何要数据变换 通常情况下,直接加载的原始数据并不能直接送入神经网络进行训练,此时我们需要对其进行数据预处理。MindSpore提供不同种类的数据变换(Transforms),配合数…

学习VXLAN -- 报文结构、原理和配置

目录 VXLAN背景什么是VXLANVXLAN的优势VXLAN报文结构一些特定名词BDVBDIFVAPVSIVSI-InterfaceAC VXLAN的实现原理图VXLAN MAC地址表项MAC地址动态学习 VXLAN隧道VXLAN隧道工作模式L2 GatewayIP Gateway VXLAN隧道的建立与关联VXLAN隧道建立的方式VXLAN对到与VXLAN关联的方式 配…

低成本STC32G8K64驱动控制BLDC开源入门学习方案

低成本STC32G8K64驱动控制BLDC开源入门学习方案 ✨采用STC32G8K64单片机,参考梁工的STC32G12K128-LQFP48驱动方案制作,梁工BLDC相关的资料:https://www.stcaimcu.com/forum.php?modviewthread&tid7472&extrapage%3D1,在此…

python tarfile解压失败怎么解决

问题原因 在使用tarfile模块解压一份Linux服务器上的打包文件时,出现了错误提示:IOError:[Errno 22] invalid mode (wb) or filename. 经过检查,发现是因为打包文件中有文件名存在“:”符号,而window下的…

react中如何获取并使用usestate声明的变量的值

1. 函数式更新 当需要根据当前状态来更新状态时,可以使用函数式更新。setState(在类组件中)和setCount(在useState中)都可以接受一个函数作为参数,这个函数接收当前的状态作为参数,并返回新的状…

python rename报错怎么解决

刚接触python,写了一段简单的代码,功能就是重命名一个文件,代码如下: list_1os.listdir(".") for files in list_1:fopen(files)if f.name"01.txt":os.rename(01.txt,001.txt)elif f.name"05.txt":…

【Python机器学习】k均值聚类——k均值的失败案例

k均值可能不总能找到“正确”的簇个数,每个簇仅由其中心定义,这意味着每个簇都是凸形。因此,k均值只能找到相对简单的形状。k均值还假设所有簇在某种程度上具有相同的“直径”,它总是将簇之间的边界刚好画在簇中心的之间位置。有时…

找不到msvcr120.dll怎么办,msvcr120.dll丢失的多种解决方法

msvcr120.dll是微软Visual C 2013的可再发行组件包中的一个文件,它是许多程序运行所必需的。这个文件包含了Visual C库,这些库为使用C编写的软件提供支持。如果你的电脑中缺少msvcr120.dll文件,那么依赖这个文件运行的应用程序可能无法启动或…

WPF文本绑定显示格式StringFormat设置-数值类型处理

绑定显示格式设置 在Textblock等文本控件中,我们经常要绑定一些数据类型,但是我们希望显示的时候能够按照我们想要的格式去显示,比如增加文本前缀,后面加单位,显示百分号等等,这种就需要对绑定格式进行处理…

时序设计中的“打拍”

“打拍”:在数字系统和时序设计中,打拍(Double Flopping / Two-Stage Registering)是指通过两个级联的寄存器(flip-flops)将输入信号同步到系统时钟域内的过程,常用于解决跨时钟域信号的亚稳态问…

智能淘客返利系统架构解析

智能淘客返利系统架构解析 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 随着电子商务行业的迅速发展,淘宝、天猫等电商平台成为了人们购物的主要…

3. kubernetes客户端crictl命令

kubernetes客户端crictl命令 crictl 是一个命令行工具,用于与容器运行时接口(CRI)兼容的容器运行时(如 containerd 和 CRI-O)进行交互。crictl 提供了许多有用的命令来管理容器、镜像和 sandboxes。 官方仓库地址&am…

Rust:Future、async 异步代码机制示例与分析

0. 异步、并发、并行、进程、协程概念梳理 Rust 的异步机制不是多线程或多进程,而是基于协程(或称为轻量级线程、微线程)的模型,这些协程可以在单个线程内并发执行。这种模型允许在单个线程中通过非阻塞的方式处理多个任务&#…

关于微信没有接入鸿蒙NEXT的思考

6月21日,纯血鸿蒙发布,国内的质疑声终于停止,不再被人喊叫换皮 Android 了.就连编程语言都是华为自研的。 可是发布会后微信却成了热点,因为余承东在感谢了一圈互联网企业,如:淘宝、支付宝、美团、京东、抖音、今日头条、钉钉、小红书、微博、B站、高德、WPS等等. 唯独没有感…

CSS基础学习记录(5)

目录 1、CSS语法 2、实例 3、CSS注释 4、id 选择器 5、class 类选择器 6、标签选择器 7、内联选择器 1、CSS语法 CSS 规则由两个主要的部分构成:选择器,以及一条或多条声明: 选择器(Selector)通常是您需要改变样式的 HTML …

Altera不同系列的型号命名规则

Altera芯片型号:10AX07H4F34I3SG 20nm工艺 资源: 大数据 云计算 人工智能 图像处理 MSEL

高级人工智能复习 中科大

参考: 中科大2023春季【高级人工智能】试题回顾 中国科学技术大学《高级人工智能》课程 重要知识点提纲 高级人工智能复习提纲 1.搜索 1.1 搜索问题的概念 搜索问题的五个要素:状态空间、后继函数、初始状态、目标测试和路径耗散。 用状态图描述搜索…

Codeforces Round 953 (Div. 2) A~F

A.Alice and Books(思维) 题意: 爱丽丝有 n n n本书。第 1 1 1本书包含 a 1 a_1 a1​页,第 2 2 2本书包含 a 2 a_2 a2​页, … \ldots …第 n n n本书包含 a n a_n an​页。爱丽丝的操作如下: 她把所有的…

C语言之常用标准库介绍

文章目录 1 标准库1.1 诊断assert.h1.2 字符类别测试ctype.h1.3 错误处理errno.h1.4 整型常量limits.h1.5 地域环境locale.h1.6 数学函数math.h1.7 非局部跳转setjmp.h1.8 可变参数表stdarg.h1.9 公共定义stddef.h1.10 输入输出stdio.h1.11 实用函数stdlib.h1.12 日期与时间函数…

L57---112.路径总和(广搜)---Java版

1.题目描述 给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。判断该树中是否存在 根节点到叶子节点 的路径,这条路径上所有节点值相加等于目标和 targetSum 。如果存在,返回 true ;否则,返回 false 。叶子节点 是指…