Bottleneck
模块
首先定义了一个名为 Bottleneck
的 PyTorch 模块,它是 ResNet 架构中的一个瓶颈块(Bottleneck Block)。瓶颈块是 ResNet 中常用的一种层次结构,用于构建更深的网络。以下是对这段代码的详细注释:
类定义和初始化函数
class Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1):super().__init__()
expansion
: 一个类属性,表示通道数扩展倍数。对于标准的ResNet,expansion
通常为 4。__init__
: 初始化函数,用于定义该模块的所有层和参数。inplanes
: 输入通道数。planes
: 输出通道数的一部分。stride
: 卷积层的步幅。
定义卷积层和批归一化层
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)self.bn1 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(planes)self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)self.bn3 = nn.BatchNorm2d(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = Noneself.stride = stride
conv1
: 一个 1x1 的卷积层,用于减少通道数。bn1
: 第一个批归一化层。conv2
: 一个 3x3 的卷积层,用于提取特征。bn2
: 第二个批归一化层。avgpool
: 一个平均池化层,用于在步幅大于1时进行下采样;否则为恒等映射(nn.Identity()
)。conv3
: 一个 1x1 的卷积层,用于恢复通道数。bn3
: 第三个批归一化层。relu
: 一个 ReLU 激活函数。downsample
: 一个下采样层,默认为None
。stride
: 步幅,用于卷积层和池化层。
定义下采样层
if stride > 1 or inplanes != planes * Bottleneck.expansion:self.downsample = nn.Sequential(OrderedDict([("-1", nn.AvgPool2d(stride)),("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),("1", nn.BatchNorm2d(planes * self.expansion))]))
- 如果步幅大于1,或者输入通道数不等于输出通道数(考虑到扩展倍数),则需要进行下采样。
downsample
: 使用nn.Sequential
和OrderedDict
定义下采样层。"-1"
: 一个平均池化层,用于下采样。"0"
: 一个 1x1 的卷积层,用于调整通道数。"1"
: 一个批归一化层。
前向传播函数
def forward(self, x: torch.Tensor):identity = xout = self.relu(self.bn1(self.conv1(x)))out = self.relu(self.bn2(self.conv2(out)))out = self.avgpool(out)out = self.bn3(self.conv3(out))if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out
identity
: 保存输入张量,用于残差连接。- 前向传播步骤:
- 输入
x
经过第一个卷积层conv1
和批归一化层bn1
,然后应用 ReLU 激活函数。 - 输出经过第二个卷积层
conv2
和批归一化层bn2
,然后应用 ReLU 激活函数。 - 输出经过平均池化层
avgpool
(如果步幅大于1)。 - 输出经过第三个卷积层
conv3
和批归一化层bn3
。 - 如果存在下采样层,将输入
x
通过下采样层调整尺寸和通道数。 - 将输出
out
与identity
相加(残差连接)。 - 最后应用 ReLU 激活函数,并返回输出
out
。
- 输入
总结
这个 Bottleneck
类实现了ResNet 的瓶颈块结构,通过三个卷积层和批归一化层提取特征,并通过残差连接将输入直接添加到输出,从而缓解梯度消失问题,提高训练深度神经网络的效果。下采样层将avgpool添加到卷积之前用于在需要时调整输入的尺寸,随后用卷积调整通道数。在后面的ModifiedResNet
会用到该模块。
AttentionPool2d
模块
代码定义了一个名为 AttentionPool2d
的 PyTorch 模块,用于基于注意力机制的二维池化操作。以下是对这段代码的详细注释:
类定义和初始化函数
class AttentionPool2d(nn.Module):def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):super().__init__()
spacial_dim
:输入特征图的空间维度(例如,宽度或高度)。embed_dim
:嵌入维度,即输入特征图的通道数。num_heads
:多头注意力机制的头数。output_dim
:输出维度,如果未指定则默认与嵌入维度相同。
定义位置嵌入和线性投影层
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)self.k_proj = nn.Linear(embed_dim, embed_dim)self.q_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads
positional_embedding
:位置嵌入参数,形状为(spacial_dim + 1, embed_dim)
,用于为每个空间位置添加位置信息。k_proj
,q_proj
,v_proj
:分别为键、查询和值的线性投影层。c_proj
:输出的线性投影层。num_heads
:注意力头的数量。
前向传播函数
def forward(self, x):
x
:输入张量,形状为(N, C, H, W)
。
调整输入张量的形状并添加位置嵌入
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
- 将输入张量从
(N, C, H, W)
调整为(N, C, HW)
。 - 转置张量的维度,使其形状变为
(HW, N, C)
,其中HW
是展开后的空间维度。
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
``- 计算 `x` 在第一个维度(batch 维度)的均值,并将其添加到张量的开头。这样可以为注意力机制添加一个全局平均池化的位置。
- 最后的张量形状为 `(HW+1, N, C)`。```pythonx = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
- 将位置嵌入添加到每个位置的特征表示中,以引入位置信息。
- 使用广播机制,将
self.positional_embedding
的形状从(HW+1, C)
扩展到(HW+1, N, C)
。
计算多头自注意力
x, _ = F.multi_head_attention_forward(query=x, key=x, value=x,embed_dim_to_check=x.shape[-1],num_heads=self.num_heads,q_proj_weight=self.q_proj.weight,k_proj_weight=self.k_proj.weight,v_proj_weight=self.v_proj.weight,in_proj_weight=None,in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),bias_k=None,bias_v=None,add_zero_attn=False,dropout_p=0,out_proj_weight=self.c_proj.weight,out_proj_bias=self.c_proj.bias,use_separate_proj_weight=True,training=self.training,need_weights=False)
- 使用 PyTorch 提供的
multi_head_attention_forward
函数计算多头自注意力。 query
,key
,value
:输入张量x
。embed_dim_to_check
:检查的嵌入维度,等于x
的最后一个维度。num_heads
:注意力头的数量。q_proj_weight
,k_proj_weight
,v_proj_weight
:分别为查询、键和值的投影权重。in_proj_weight
:输入投影权重,为None
,因为我们使用了单独的投影权重。in_proj_bias
:输入投影偏置,拼接了查询、键和值的偏置。bias_k
,bias_v
:键和值的偏置,为None
。add_zero_attn
:是否添加全零注意力头,为False
。dropout_p
:dropout 概率,为 0(不使用 dropout)。out_proj_weight
,out_proj_bias
:输出投影的权重和偏置。use_separate_proj_weight
:是否使用单独的投影权重,为True
。training
:指示是否在训练模式下,使用self.training
。need_weights
:是否需要输出注意力权重,为False
。
返回结果
return x
- 返回多头注意力计算后的张量
x
。
总结
AttentionPool2d
类通过多头注意力机制实现了二维池化操作,具体步骤如下:
-
调整输入张量的形状:将输入张量从
(N, C, H, W)
转换为(HW, N, C)
,并添加全局平均池化位置。 -
添加位置嵌入:将位置嵌入添加到每个位置的特征表示中。
-
计算多头自注意力:使用 PyTorch 的
multi_head_attention_forward
函数计算多头自注意力。 -
返回结果:返回计算后的张量。返回的张量形状是 (H*W+1, N, C),其中:
(H*W+1):是原始空间维度加上一个全局平均池化的位置。
N:是批量大小。
C:是嵌入维度(embed_dim)或者输出维度(output_dim)。
对于输入形状为 (32, 2048, 7, 7) 的示例,返回的 x 的形状是 (50, 32, 2048)。
这种注意力池化机制可以更好地捕捉输入特征图中的全局信息,提高模型的表达能力和性能。整个过程是将输入张量进行维度转换和扩展,加上位置嵌入后,通过多头注意力机制进行处理,最终返回包含空间信息和全局信息的输出张量。
ModifiedResNet
函数
定义了一个名为 ModifiedResNet
的 PyTorch 神经网络模块,它是对标准 ResNet 的修改版。主要修改包括:
- 使用 3 个卷积层作为初始的“stem”层,而不是一个,并且使用平均池化替代最大池化。
- 使用反别名的跨步卷积(即在卷积层之前添加平均池化层)。
- 使用 QKV 注意力机制替代最后的平均池化层。
类定义和初始化函数
class ModifiedResNet(nn.Module):"""A ResNet class that is similar to torchvision's but contains the following changes:- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1- The final pooling layer is a QKV attention instead of an average pool"""def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):super().__init__()self.output_dim = output_dimself.input_resolution = input_resolution# the 3-layer stemself.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(width // 2)self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(width // 2)self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)self.bn3 = nn.BatchNorm2d(width)self.avgpool = nn.AvgPool2d(2)self.relu = nn.ReLU(inplace=True)# residual layersself._inplanes = width # this is a *mutable* variable used during constructionself.layer1 = self._make_layer(width, layers[0])self.layer2 = self._make_layer(width * 2, layers[1], stride=2)self.layer3 = self._make_layer(width * 4, layers[2], stride=2)self.layer4 = self._make_layer(width * 8, layers[3], stride=1) embed_dim = width * 32 # the ResNet feature dimensionself.attnpool = AttentionPool2d(input_resolution, embed_dim, heads, output_dim)
初始化部分
layers
: 每个残差层中包含的块数量的列表。output_dim
: 模型的输出维度。heads
: 多头注意力机制的头数量。input_resolution
: 输入图像的分辨率。width
: 初始卷积层的宽度。
stem 部分(初始卷积层)
# the 3-layer stemself.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(width // 2)self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(width // 2)self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)self.bn3 = nn.BatchNorm2d(width)self.avgpool = nn.AvgPool2d(2)self.relu = nn.ReLU(inplace=True)
- 定义了三个卷积层
conv1
,conv2
,conv3
,每个卷积层后接一个批量归一化层bn1
,bn2
,bn3
。 - 使用 ReLU 作为激活函数。
- 最后使用一个平均池化层
avgpool
。
残差层和 _make_layer
方法
# residual layersself._inplanes = width # this is a *mutable* variable used during constructionself.layer1 = self._make_layer(width, layers[0])self.layer2 = self._make_layer(width * 2, layers[1], stride=2)self.layer3 = self._make_layer(width * 4, layers[2], stride=2)self.layer4 = self._make_layer(width * 8, layers[3], stride=1) embed_dim = width * 32 # the ResNet feature dimensionself.attnpool = AttentionPool2d(input_resolution, embed_dim, heads, output_dim)
_make_layer
方法用于构建残差层。- 定义了四个残差层
layer1
,layer2
,layer3
,layer4
,每个残差层包含多个 Bottleneck 块,通道数依次增加。 embed_dim
是最终的嵌入维度,等于初始宽度的32倍。self.attnpool
是一个注意力池化层。
def _make_layer(self, planes, blocks, stride=1):layers = [Bottleneck(self._inplanes, planes, stride)]self._inplanes = planes * Bottleneck.expansionfor _ in range(1, blocks):layers.append(Bottleneck(self._inplanes, planes))return nn.Sequential(*layers)
_make_layer
方法:planes
:每个块的输出通道数。blocks
:块的数量。stride
:步幅。- 使用 Bottleneck 块构建层,每个块的输出通道数是输入通道数的 4 倍(由于
Bottleneck.expansion
)。
前向传播函数
def forward(self, x): def stem(x):for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:x = self.relu(bn(conv(x)))x = self.avgpool(x)return x# 将 x 的数据类型转换为与 conv1.weight 相同的数据类型x = x.type(self.conv1.weight.dtype) # 通过调用 stem 函数,可以简化 forward 方法的主逻辑,避免直接在 forward 方法中书写所有的初始卷积和池化操作。x = stem(x) x = self.layer1(x) x = self.layer2(x) x3 = self.layer3(x) x4 = self.layer4(x3) xproj = self.attnpool(x4) return x3, x4, xproj
forward
方法
-
stem
方法:def stem(x):for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:x = self.relu(bn(conv(x)))x = self.avgpool(x)return x
- 依次通过三个卷积层和批量归一化层,最后通过平均池化层。
-
前向传播步骤:
x = x.type(self.conv1.weight.dtype)
:确保输入的类型与卷积层权重的类型一致。x = stem(x)
:通过初始的 stem 部分。x = self.layer1(x)
:通过第一个残差层。x = self.layer2(x)
:通过第二个残差层。x3 = self.layer3(x)
:通过第三个残差层,输出为x3
。x4 = self.layer4(x3)
:通过第四个残差层,输出为x4
。xproj = self.attnpool(x4)
:通过注意力池化层,输出为xproj
。
-
返回
x3
,x4
,xproj
三个张量,分别表示不同层的输出。
总结
这个 ModifiedResNet
类实现了一个修改版的 ResNet 网络,具有三个初始卷积层、反别名的跨步卷积和一个使用 QKV 注意力机制的最终池化层。通过这种设计,可以更好地提取和处理图像特征,提高模型的表达能力和性能。
LayerNorm
和QuickGELU
函数
以下是对两个 PyTorch 类 LayerNorm
和 QuickGELU
的详细注释。这两个类分别处理 fp16
(半精度浮点数)计算的层归一化和快速近似的 GELU 激活函数。
LayerNorm
类
这个类继承自 PyTorch 的 nn.LayerNorm
,并对其进行了修改,以便在处理 fp16
(半精度浮点数)时能够正常工作。具体实现是将输入数据类型临时转换为 fp32
(单精度浮点数),进行层归一化计算后,再将结果转换回原始数据类型。
class LayerNorm(nn.LayerNorm):"""Subclass torch's LayerNorm to handle fp16."""def forward(self, x: torch.Tensor):orig_type = x.dtype # 保存输入张量的原始数据类型ret = super().forward(x.type(torch.float32)) # 将输入张量转换为float32类型并调用父类的forward方法进行计算return ret.type(orig_type) # 将计算结果转换回原始数据类型
注释解释
orig_type = x.dtype
: 保存输入张量x
的原始数据类型,以便稍后将结果转换回该数据类型。ret = super().forward(x.type(torch.float32))
: 将输入张量x
转换为float32
类型,并调用父类nn.LayerNorm
的forward
方法进行层归一化计算。super().forward(...)
调用父类的forward
方法,进行标准的层归一化计算。
return ret.type(orig_type)
: 将计算结果ret
的数据类型转换回输入张量x
的原始数据类型,并返回结果。
这种方法在使用 fp16
数据类型时,避免了由于数据类型精度较低而可能引起的数值不稳定问题。
QuickGELU
类
这个类实现了一个快速近似的 GELU(Gaussian Error Linear Units)激活函数。GELU 是一种常用的激活函数,在一些深度学习模型(如 Transformer)中表现良好。
class QuickGELU(nn.Module):def forward(self, x: torch.Tensor):return x * torch.sigmoid(1.702 * x)
注释解释
def forward(self, x: torch.Tensor):
: 定义forward
方法,它接受一个输入张量x
并返回激活后的输出。return x * torch.sigmoid(1.702 * x)
: 实现 QuickGELU 激活函数:torch.sigmoid(1.702 * x)
: 计算输入张量x
的缩放版本1.702 * x
的 Sigmoid 值。x * ...
: 将输入张量x
与 Sigmoid 值逐元素相乘,得到 QuickGELU 激活后的输出。
QuickGELU 的数学表达式为:
QuickGELU ( x ) = x ⋅ σ ( 1.702 ⋅ x ) \text{QuickGELU}(x) = x \cdot \sigma(1.702 \cdot x) QuickGELU(x)=x⋅σ(1.702⋅x)
其中 σ \sigma σ 是 Sigmoid 函数。
QuickGELU 的好处
这种近似计算比原始的 GELU 函数计算更快,但效果相似,特别适用于需要高效计算的模型。
示例代码
以下是如何使用这两个类的示例:
import torch# 示例使用 LayerNorm
x = torch.randn(2, 5, dtype=torch.float16) # 使用 fp16 数据类型
layer_norm = LayerNorm(x.size()[1:])
output = layer_norm(x) # 自动处理 fp16 数据类型
print(output)# 示例使用 QuickGELU
x = torch.randn(2, 5)
quick_gelu = QuickGELU()
output = quick_gelu(x) # 计算 QuickGELU 激活
print(output)
总结
LayerNorm
类:处理fp16
数据类型的层归一化,确保计算的数值稳定性。QuickGELU
类:实现快速近似的 GELU 激活函数,提供高效的激活计算。
这两个类通过重载 forward
方法,在处理不同的数据类型和计算需求时提供了便利和性能优化。
ResidualAttentionBlock
模块
以下是对 ResidualAttentionBlock
类的详细注释。这段代码定义了一个带有残差连接的注意力块,该模块包括一个多头注意力机制、一个前馈神经网络和层归一化操作。
ResidualAttentionBlock
类
class ResidualAttentionBlock(nn.Module):def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):super().__init__()self.attn = nn.MultiheadAttention(d_model, n_head)self.ln_1 = LayerNorm(d_model)self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)),("gelu", QuickGELU()),("c_proj", nn.Linear(d_model * 4, d_model))]))self.ln_2 = LayerNorm(d_model)self.attn_mask = attn_mask
初始化方法 __init__
d_model
: 输入特征的维度。n_head
: 多头注意力机制的头数量。attn_mask
: 可选的注意力掩码张量,用于屏蔽某些位置的注意力。
初始化各个子模块
-
多头注意力机制:
self.attn = nn.MultiheadAttention(d_model, n_head)
nn.MultiheadAttention
进行多头注意力计算。
-
第一层归一化层:
self.ln_1 = LayerNorm(d_model)
- 自定义的
LayerNorm
层,用于处理fp16
数据类型。
- 自定义的
-
前馈神经网络:
self.mlp = nn.Sequential(OrderedDict([("c_fc", nn.Linear(d_model, d_model * 4)),("gelu", QuickGELU()),("c_proj", nn.Linear(d_model * 4, d_model)) ]))
- 使用
nn.Sequential
和OrderedDict
构建一个前馈神经网络。 - 包括一个线性层
c_fc
,一个快速近似的 GELU 激活函数QuickGELU
,和另一个线性层c_proj
。
- 使用
-
第二层归一化层:
self.ln_2 = LayerNorm(d_model)
-
注意力掩码:
self.attn_mask = attn_mask
注意力方法 attention
def attention(self, x: torch.Tensor):self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else Nonereturn self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
self.attn_mask.to(dtype=x.dtype, device=x.device)
:- 将注意力掩码的 dtype 和 device 与输入张量
x
保持一致。
- 将注意力掩码的 dtype 和 device 与输入张量
self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
:- 调用多头注意力机制,使用
x
作为查询、键和值。 need_weights=False
表示不返回注意力权重。attn_mask=self.attn_mask
传入可选的注意力掩码。- 返回值的第一个元素是计算后的注意力输出。
- 调用多头注意力机制,使用
前向传播方法 forward
def forward(self, x: torch.Tensor):x = x + self.attention(self.ln_1(x))x = x + self.mlp(self.ln_2(x))return x
-
残差连接和第一部分:
x = x + self.attention(self.ln_1(x))
- 输入
x
先经过第一层归一化ln_1
。 - 然后通过
attention
方法进行多头注意力计算。 - 最后将注意力输出与原始输入
x
相加,形成残差连接。
- 输入
-
残差连接和第二部分:
x = x + self.mlp(self.ln_2(x))
- 输入
x
经过第二层归一化ln_2
。 - 然后通过前馈神经网络
mlp
进行计算。 - 最后将前馈神经网络的输出与输入
x
相加,形成残差连接。
- 输入
-
返回值:
return x
- 返回经过两个残差连接后的结果。
总结
ResidualAttentionBlock
类实现了一个带有残差连接的注意力块,包含以下关键部分:
- 多头注意力机制:用于计算输入特征的注意力。
- 层归一化:用于标准化输入特征,缓解训练中的数值不稳定问题。
- 前馈神经网络:包括两个线性层和一个激活函数,用于进一步处理特征。
- 残差连接:将输入特征与经过处理的特征相加,帮助信息在网络中更好地传播。
这种结构在现代深度学习模型(如 Transformer)中非常常见,能够有效处理长序列依赖关系,提高模型的表现力和训练效率。
Transformer模块
以下是对 Transformer
类的详细注释。这段代码实现了一个基于残差注意力块的 Transformer 模块。
Transformer
类
class Transformer(nn.Module):def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):super().__init__()self.width = widthself.layers = layersself.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
初始化方法 __init__
-
参数解释:
width
: 特征的宽度或嵌入维度。layers
: Transformer 中残差注意力块的层数。heads
: 多头注意力机制的头数量。attn_mask
: 注意力掩码,用于屏蔽特定位置的注意力计算。
-
实例属性:
self.width
: 保存特征宽度。self.layers
: 保存层数。self.resblocks
: 包含多个ResidualAttentionBlock
的序列,每个块由width
,heads
和attn_mask
参数定义。
初始化步骤
-
保存输入参数:
self.width = width self.layers = layers
-
定义残差注意力块序列:
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
- 使用列表生成式创建多个
ResidualAttentionBlock
实例。 - 使用
nn.Sequential
将这些块组合成一个顺序模块,按顺序依次执行每个块的前向传播。
- 使用列表生成式创建多个
前向传播方法 forward
def forward(self, x: torch.Tensor):return self.resblocks(x)
前向传播步骤
-
输入参数:
x
: 输入张量,形状为(batch_size, sequence_length, embedding_dim)
。
-
前向传播过程:
return self.resblocks(x)
- 直接将输入
x
传递给self.resblocks
。 self.resblocks
是一个nn.Sequential
模块,它包含多个ResidualAttentionBlock
。- 每个
ResidualAttentionBlock
按顺序处理输入张量x
。 - 最终返回经过所有残差注意力块处理后的张量。
- 直接将输入
示例代码
以下是如何使用这个 Transformer
类的示例:
import torch# 假设输入张量的形状为 (batch_size, sequence_length, embedding_dim)
batch_size = 2
sequence_length = 10
embedding_dim = 64# 创建一个示例输入张量
x = torch.randn(batch_size, sequence_length, embedding_dim)# 定义 Transformer 参数
width = embedding_dim
layers = 6
heads = 8# 创建 Transformer 实例
transformer = Transformer(width, layers, heads)# 调用 Transformer 模型
output = transformer(x)print(output.shape) # 输出的形状应与输入相同
总结
Transformer
类实现了一个多层残差注意力块的序列,每个残差注意力块通过多头注意力机制和前馈神经网络处理输入张量。其主要组成部分包括:
width
: 输入特征的宽度或嵌入维度。layers
: 残差注意力块的层数。heads
: 多头注意力机制的头数量。attn_mask
: 注意力掩码,用于屏蔽特定位置的注意力计算。
前向传播方法将输入张量依次传递给每个残差注意力块,最终返回处理后的输出张量。这种结构在处理序列数据(如自然语言处理、时间序列分析)时非常有效。
VisionTransformer
类
以下是对 VisionTransformer
类的详细注释。这个类实现了一个视觉变压器模型(Vision Transformer),用于图像处理任务。
class VisionTransformer(nn.Module):def __init__(self, h_resolution: int, w_resolution: int, patch_size: int, stride_size: int, width: int, layers: int, heads: int, output_dim: int):super().__init__()self.h_resolution = h_resolutionself.w_resolution = w_resolutionself.output_dim = output_dimself.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, bias=False)
初始化方法 __init__
- 参数解释:
h_resolution
: 输入图像的高度。w_resolution
: 输入图像的宽度。patch_size
: 每个补丁的大小。stride_size
: 补丁提取的步幅。width
: 嵌入维度或特征的宽度。layers
: Transformer 模块中的层数。heads
: 多头注意力机制的头数量。output_dim
: 模型的输出维度。
初始化步骤
-
保存输入参数:
self.h_resolution = h_resolution self.w_resolution = w_resolution self.output_dim = output_dim
-
定义卷积层:
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, bias=False)
- 使用卷积层提取图像补丁,并将每个补丁嵌入到
width
维度的特征空间中。
- 使用卷积层提取图像补丁,并将每个补丁嵌入到
-
定义嵌入和归一化层:
scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn(h_resolution * w_resolution + 1, width)) self.ln_pre = LayerNorm(width)
-
定义 Transformer 模块:
self.transformer = Transformer(width, layers, heads)
-
定义后处理层:
self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
前向传播方法 forward
def forward(self, x: torch.Tensor, cv_emb = None):x = self.conv1(x) # shape = [*, width, grid, grid]x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]if cv_emb != None: x[:,0] = x[:,0] + cv_embx = x + self.positional_embedding.to(x.dtype)x = self.ln_pre(x)x = x.permute(1, 0, 2) # NLD -> LNDx11 = self.transformer.resblocks[:11](x) x12 = self.transformer.resblocks[11](x11) x11 = x11.permute(1, 0, 2) # LND -> NLD x12 = x12.permute(1, 0, 2) # LND -> NLD x12 = self.ln_post(x12) if self.proj is not None:xproj = x12 @ self.proj return x11, x12, xproj
前向传播步骤
-
卷积层提取补丁:
x = self.conv1(x) # shape = [*, width, grid, grid]
- 输入图像
x
通过卷积层conv1
提取补丁,并转换为具有width
维度的特征表示。
- 输入图像
-
调整形状:
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
- 调整张量形状,将其变为
[batch_size, num_patches, width]
。
- 调整张量形状,将其变为
-
添加类别嵌入:
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
- 在特征表示的开头添加一个类别嵌入。
-
处理可选的嵌入:
if cv_emb != None: x[:,0] = x[:,0] + cv_emb
-
添加位置嵌入和归一化:
x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x)
-
转换维度并传递给 Transformer 模块:
x = x.permute(1, 0, 2) # NLD -> LND x11 = self.transformer.resblocks[:11](x) x12 = self.transformer.resblocks[11](x11) x11 = x11.permute(1, 0, 2) # LND -> NLD x12 = x12.permute(1, 0, 2) # LND -> NLD
-
后处理:
x12 = self.ln_post(x12) if self.proj is not None:xproj = x12 @ self.proj
-
返回结果:
return x11, x12, xproj
总结
VisionTransformer
类实现了一个视觉变压器模型,通过卷积层提取图像补丁,并使用 Transformer 模块进行特征处理。整个模型包括卷积层、类别嵌入、位置嵌入、多层 Transformer 模块和后处理层。前向传播过程依次经过这些模块,并返回多个层的输出结果。