1,首先,我们需要知道的是,想要调用预训练的Swin Transformer模型,必须要安装pytorch2,因为pytorch1对应的torchvision中不包含Swin Transformer。
2,pytorch2调用预训练模型时,不建议使用pretrained=True,这个用法即将淘汰,会报警告。最好用如下方式:
from torchvision.models.swin_transformer import swin_b, Swin_B_Weights model = swin_b(weights=Swin_B_Weights.DEFAULT)
这里调用的就是swin_b在imagenet上的预训练模型
3,swin_b的模型结构如下(仅展示到第一个patch merging部分),在绝大部分情况下,我们可能需要的不是整个模型,而是其中的一个模块,比如SwinTransformerBlock。
SwinTransformer((features): Sequential((0): Sequential((0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))(1): Permute()(2): LayerNorm((128,), eps=1e-05, elementwise_affine=True))(1): Sequential((0): SwinTransformerBlock((norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(attn): ShiftedWindowAttention((qkv): Linear(in_features=128, out_features=384, bias=True)(proj): Linear(in_features=128, out_features=128, bias=True))(stochastic_depth): StochasticDepth(p=0.0, mode=row)(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(mlp): MLP((0): Linear(in_features=128, out_features=512, bias=True)(1): GELU(approximate='none')(2): Dropout(p=0.0, inplace=False)(3): Linear(in_features=512, out_features=128, bias=True)(4): Dropout(p=0.0, inplace=False)))(1): SwinTransformerBlock((norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(attn): ShiftedWindowAttention((qkv): Linear(in_features=128, out_features=384, bias=True)(proj): Linear(in_features=128, out_features=128, bias=True))(stochastic_depth): StochasticDepth(p=0.021739130434782608, mode=row)(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(mlp): MLP((0): Linear(in_features=128, out_features=512, bias=True)(1): GELU(approximate='none')(2): Dropout(p=0.0, inplace=False)(3): Linear(in_features=512, out_features=128, bias=True)(4): Dropout(p=0.0, inplace=False))))(2): PatchMerging((reduction): Linear(in_features=512, out_features=256, bias=False)(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True))
那么如何调用其中的SwinTransformerBlock呢。
由于该模型是个嵌套结构,而不是类似vgg一样简单的结构,所以不能直接用layer0=model.SwinTransformerBlock调用。
因为SwinTransformerBlock是Sequential下的子模块,故正确的调用代码如下:
swinblock = model.features[1][0]
结果如下,调用成功: