TansUNet代码理解

首先通过论文中所给的图片了解网络的整体架构:
在这里插入图片描述

vit_seg_modeling部分

模块引入和定义相关量:

# coding=utf-8
# __future__ 在老版本的Python代码中兼顾新特性的一种方法
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport copy
import logging
import mathfrom os.path import join as pjoinimport torch
import torch.nn as nn
import numpy as npfrom torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from . import vit_seg_configs as configs
from .vit_seg_modeling_resnet_skip import ResNetV2logger = logging.getLogger(__name__)ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"# 获取超参
CONFIGS = {'ViT-B_16': configs.get_b16_config(),'ViT-B_32': configs.get_b32_config(),'ViT-L_16': configs.get_l16_config(),'ViT-L_32': configs.get_l32_config(),'ViT-H_14': configs.get_h14_config(),'R50-ViT-B_16': configs.get_r50_b16_config(),'R50-ViT-L_16': configs.get_r50_l16_config(),'testing': configs.get_testing(),
}

工具函数的定义:
np2th用于将numpy格式的数据改为tensor。

def np2th(weights, conv=False):"""Possibly convert HWIO to OIHW."""if conv:weights = weights.transpose([3, 2, 0, 1])return torch.from_numpy(weights)

swish时由谷歌团队提出来的激活函数,他们实验表明,在一些具有挑战性的数据集上,它的效果比relu更好。

def swish(x):return x * torch.sigmoid(x)ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}

采用自顶向下的结构来理解代码
VisionTransformer就是模型的整个结构,其中调用了Transformer,DecoderCup,SegmentationHead,load_from用于加载训练好的参数。

class VisionTransformer(nn.Module):def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):super(VisionTransformer, self).__init__()self.num_classes = num_classesself.zero_head = zero_headself.classifier = config.classifierself.transformer = Transformer(config, img_size, vis)self.decoder = DecoderCup(config)self.segmentation_head = SegmentationHead(in_channels=config['decoder_channels'][-1],out_channels=config['n_classes'],kernel_size=3,)self.config = configdef forward(self, x):if x.size()[1] == 1:x = x.repeat(1, 3, 1, 1)x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)x = self.decoder(x, features)logits = self.segmentation_head(x)return logitsdef load_from(self, weights):# with torch.no_grad()将所有require_grad临时设置为False,这样可以只更新变量的值with torch.no_grad():res_weight = weightsself.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])posemb_new = self.transformer.embeddings.position_embeddingsif posemb.size() == posemb_new.size():self.transformer.embeddings.position_embeddings.copy_(posemb)elif posemb.size()[1] - 1 == posemb_new.size()[1]:posemb = posemb[:, 1:]self.transformer.embeddings.position_embeddings.copy_(posemb)else:logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))ntok_new = posemb_new.size(1)if self.classifier == "seg":_, posemb_grid = posemb[:, :1], posemb[0, 1:]gs_old = int(np.sqrt(len(posemb_grid)))gs_new = int(np.sqrt(ntok_new))print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)zoom = (gs_new / gs_old, gs_new / gs_old, 1)posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2npposemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)posemb = posemb_gridself.transformer.embeddings.position_embeddings.copy_(np2th(posemb))# Encoder wholefor bname, block in self.transformer.encoder.named_children():for uname, unit in block.named_children():unit.load_from(weights, n_block=uname)if self.transformer.embeddings.hybrid:self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))# .view(-1)将tensor展开为一维张量,但不改变该对象本身的形状gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():for uname, unit in block.named_children():unit.load_from(res_weight, n_block=bname, n_unit=uname)

接下来是Transformer的代码:
Transformer包括了Embeddings和Encoder:

class Transformer(nn.Module):def __init__(self, config, img_size, vis):super(Transformer, self).__init__()self.embeddings = Embeddings(config, img_size=img_size)self.encoder = Encoder(config, vis)def forward(self, input_ids):embedding_output, features = self.embeddings(input_ids)encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)return encoded, attn_weights, features

Embeddings的功能对应于图片中的:
在这里插入图片描述
ResNetV2(这部分的代码放在最后一个部分)对图片通过卷积操作提取特征,然后将提取到的各层特征返回到Embeddings。
拿到ResNetV2返回的特征后,将最后一层的特征分割为多个切片,并将各个切片映射成长度为patch_size*patch_size*channels的向量,并且加上位置序列信息,对应于图片的这个部分:
在这里插入图片描述

class Embeddings(nn.Module):"""Construct the embeddings from patch, position embeddings."""def __init__(self, config, img_size, in_channels=3):super(Embeddings, self).__init__()self.hybrid = Noneself.config = config# 应该是把参数中的img_size,转换为元组形式即:img_size = (value,value)这里的value即为参数的img_size。img_size = _pair(img_size)if config.patches.get("grid") is not None:  # ResNetgrid_size = config.patches["grid"]  # grid 是一个元组,值为:输入图片大小//切片大小patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])self.hybrid = Trueelse:patch_size = _pair(config.patches["size"])n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])self.hybrid = Falseif self.hybrid:self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)in_channels = self.hybrid_model.width * 16# patch_embeddings通过卷积操作将输入转变为(B, hidden_size, n_patches^(1/2), n_patches^(1/2))# hidden_size是一个token(相当于输入的一个词)的长度self.patch_embeddings = Conv2d(in_channels=in_channels,out_channels=config.hidden_size,kernel_size=patch_size,stride=patch_size)# 各个向量的位置序列self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))self.dropout = Dropout(config.transformer["dropout_rate"])def forward(self, x):if self.hybrid:x, features = self.hybrid_model(x)else:features = Nonex = self.patch_embeddings(x)  # (B, hidden, n_patches^(1/2), n_patches^(1/2))x = x.flatten(2)  # 表示从2维开始压缩,得到(B, hidden, n_patches)x = x.transpose(-1, -2)  # 对最后两个维度进行转置(B, n_patches, hidden)embeddings = x + self.position_embeddings  # 加上位置序列embeddings = self.dropout(embeddings)return embeddings, features

Encoder是图像的编码部分,根据num_layers生成多个Block模块

class Encoder(nn.Module):def __init__(self, config, vis):super(Encoder, self).__init__()self.vis = vis# nn.ModuleList()一个module列表,与普通的list相比,它继承了nn.Module的网络模型class,因此可以识别其中的parameters,# 即该列表中记录的module可以被主module识别,但它只是一个list,不会自动实现forward方法。self.layer = nn.ModuleList()self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)for _ in range(config.transformer["num_layers"]):layer = Block(config, vis)self.layer.append(copy.deepcopy(layer))def forward(self, hidden_states):attn_weights = []for layer_block in self.layer:hidden_states, weights = layer_block(hidden_states)if self.vis:attn_weights.append(weights)encoded = self.encoder_norm(hidden_states)return encoded, attn_weights

Block包括了MSA(Multihead Self-Attention)和MSA(Multi-Layer Perceptron)两个结构,对应于图像中的:
在这里插入图片描述

class Block(nn.Module):def __init__(self, config, vis):super(Block, self).__init__()self.hidden_size = config.hidden_sizeself.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn = Mlp(config)self.attn = Attention(config, vis)def forward(self, x):h = xx = self.attention_norm(x)x, weights = self.attn(x)x = x + hh = xx = self.ffn_norm(x)x = self.ffn(x)x = x + hreturn x, weightsdef load_from(self, weights, n_block):ROOT = f"Transformer/encoderblock_{n_block}"with torch.no_grad():query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size,self.hidden_size).t()key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size,self.hidden_size).t()out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size,self.hidden_size).t()query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)self.attn.query.weight.copy_(query_weight)self.attn.key.weight.copy_(key_weight)self.attn.value.weight.copy_(value_weight)self.attn.out.weight.copy_(out_weight)self.attn.query.bias.copy_(query_bias)self.attn.key.bias.copy_(key_bias)self.attn.value.bias.copy_(value_bias)self.attn.out.bias.copy_(out_bias)mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()self.ffn.fc1.weight.copy_(mlp_weight_0)self.ffn.fc2.weight.copy_(mlp_weight_1)self.ffn.fc1.bias.copy_(mlp_bias_0)self.ffn.fc2.bias.copy_(mlp_bias_1)self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))

Attention对应图中的MSA部分,num_heads即为多头注意力机制的数量,attention_head_size为每个注意力机制的输出大小。Multihead self-attention 就是采用多个注意力机制来预测,但实现时并不是采用循环来实现多次,由于每个注意力机制采用相同的策略,他们只存在学习到的参数的差异,所以可以直接学习一个大的参数矩阵,我的理解如下图所示:
在这里插入图片描述

class Attention(nn.Module):def __init__(self, config, vis):super(Attention, self).__init__()self.vis = visself.num_attention_heads = config.transformer["num_heads"]self.attention_head_size = int(config.hidden_size / self.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_sizeself.query = Linear(config.hidden_size, self.all_head_size)self.key = Linear(config.hidden_size, self.all_head_size)self.value = Linear(config.hidden_size, self.all_head_size)self.out = Linear(config.hidden_size, config.hidden_size)self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])self.softmax = Softmax(dim=-1)def transpose_for_scores(self, x):# new_x_shape (B, n_patch, num_attention_heads, attention_head_size)new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)# view()方法主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensorx = x.view(*new_x_shape)# permute可以对任意高维矩阵进行转置,transpose只能操作2D矩阵的转置return x.permute(0, 2, 1, 3)  # return (B, num_attention_heads, n_patch, attention_head_size)def forward(self, hidden_states):# hidden_states (B, n_patch, hidden)# mixed_*  (B, n_patch, all_head_size)mixed_query_layer = self.query(hidden_states)mixed_key_layer = self.key(hidden_states)mixed_value_layer = self.value(hidden_states)query_layer = self.transpose_for_scores(mixed_query_layer)key_layer = self.transpose_for_scores(mixed_key_layer)value_layer = self.transpose_for_scores(mixed_value_layer)# torch.matmul矩阵相乘# key_layer.transpose(-1, -2): (B, num_attention_heads, attention_head_size, n_patch)# attention_scores: (B, num_attention_heads, n_patch, n_patch)attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))attention_scores = attention_scores / math.sqrt(self.attention_head_size)attention_probs = self.softmax(attention_scores)weights = attention_probs if self.vis else Noneattention_probs = self.attn_dropout(attention_probs)# context_layer (B, num_attention_heads, n_patch, attention_head_size)context_layer = torch.matmul(attention_probs, value_layer)# context_layer (B, n_patch, num_attention_heads, attention_head_size)# contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形context_layer = context_layer.permute(0, 2, 1, 3).contiguous()# new_context_layer_shape (B, n_patch,all_head_size)new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)attention_output = self.out(context_layer)# attention_output (B, n_patch,hidden_size)# 小细节 attention_head_size = int(hidden_size / num_attention_heads),all_head_size = num_attention_heads * attention_head_size# 所以应该满足hidden_size能被num_attention_heads整除attention_output = self.proj_dropout(attention_output)return attention_output, weights

Mlp也就是一个前馈神经网络

class Mlp(nn.Module):"""Multi-Layer Perceptron: 多层感知器"""def __init__(self, config):super(Mlp, self).__init__()self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)self.act_fn = ACT2FN["gelu"]self.dropout = Dropout(config.transformer["dropout_rate"])self._init_weights()def _init_weights(self):# nn.init.xavier_uniform_初始化权重,避免深度神经网络训练过程中的梯度消失和梯度爆炸问题nn.init.xavier_uniform_(self.fc1.weight)nn.init.xavier_uniform_(self.fc2.weight)# nn.init.normal_是正态初始化函数nn.init.normal_(self.fc1.bias, std=1e-6)nn.init.normal_(self.fc2.bias, std=1e-6)def forward(self, x):x = self.fc1(x)x = self.act_fn(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return x

至此,Transformer所调用的模块结束了。


DecoderCup 对对应图片向上解码的部分:
在这里插入图片描述

在forward函数中的

B, n_patch, hidden = hidden_states.size()  # hidden_states: (B, n_patch, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)  # x: (B, hidden, n_patch)
x = x.contiguous().view(B, hidden, h, w)  # x: (B, hidden, h, w)
x = self.conv_more(x)  # (B, hidden, h, w) ===> (B, 512, h', w')

将Transformer的输出(B, n_patch, hidden),先转化为(B, hidden, h, w),其中 h , w = n _ p a t c h = H 16 = W 16 h,w = \sqrt{n\_patch} = \frac{H}{16}= \frac{W}{16} h,w=n_patch =16H=16W ,即:
在这里插入图片描述
然后通过卷积操作conv_more得到(512, hidden, h, w):
在这里插入图片描述

class DecoderCup(nn.Module):def __init__(self, config):super().__init__()self.config = confighead_channels = 512self.conv_more = Conv2dReLU(config.hidden_size,head_channels,kernel_size=3,padding=1,use_batchnorm=True,)decoder_channels = config.decoder_channels  # decoder_channels (256, 128, 64, 16)in_channels = [head_channels] + list(decoder_channels[:-1])  # in_channels = [512, 256, 128, 64]out_channels = decoder_channels# config.n_skip = 3if self.config.n_skip != 0:skip_channels = self.config.skip_channels  # config.skip_channels = [512, 256, 64, 16]for i in range(4 - self.config.n_skip):  # re-select the skip channels according to n_skipskip_channels[3 - i] = 0  # ===》skip_channels = [512, 256, 64, 0]else:skip_channels = [0, 0, 0, 0]# in_channels = [512, 256, 128, 64] out_channels = (256, 128, 64, 16)blocks = [DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)]self.blocks = nn.ModuleList(blocks)def forward(self, hidden_states, features=None):B, n_patch, hidden = hidden_states.size()  # hidden_states: (B, n_patch, hidden)h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))x = hidden_states.permute(0, 2, 1)  # x: (B, hidden, n_patch)x = x.contiguous().view(B, hidden, h, w)  # x: (B, hidden, h, w)x = self.conv_more(x)  # (B, hidden, h, w) ===> (B, 512, h, w)for i, decoder_block in enumerate(self.blocks):if features is not None:skip = features[i] if (i < self.config.n_skip) else Noneelse:skip = Nonex = decoder_block(x, skip=skip)return x

DecoderBlock就是逐层向上解码的过程,首先通过插值上采样UpsamplingBilinear2d扩大H和W,随后与对应的feature进行拼接后进行卷积,即:
在这里插入图片描述

class DecoderBlock(nn.Module):def __init__(self,in_channels,out_channels,skip_channels=0,use_batchnorm=True,):super().__init__()self.conv1 = Conv2dReLU(in_channels + skip_channels,out_channels,kernel_size=3,padding=1,use_batchnorm=use_batchnorm,)self.conv2 = Conv2dReLU(out_channels,out_channels,kernel_size=3,padding=1,use_batchnorm=use_batchnorm,)self.up = nn.UpsamplingBilinear2d(scale_factor=2)def forward(self, x, skip=None):x = self.up(x)if skip is not None:x = torch.cat([x, skip], dim=1)x = self.conv1(x)x = self.conv2(x)return x

SegmentationHead对应于图像分割部分:
在这里插入图片描述
nn.Identity()不对输入进行任何操作,常在分类任务中替换最后一层,得到分类前得到的特征,常用于迁移学习,用法举例:

model = models.resnet18()
# replace last linar layer with nn.Identity
model.fc = nn.Identity()# get features for input
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)
> torch.Size([1, 512])

SegmentationHead模块:

class SegmentationHead(nn.Sequential):def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()super().__init__(conv2d, upsampling)

最后是ResNetV2模块,该模块在vit_seg_modeling_resnet_skip文件中,对应图片中的:
在这里插入图片描述
该模块的相关包及其工具函数:

import mathfrom os.path import join as pjoin
from collections import OrderedDictimport torch
import torch.nn as nn
import torch.nn.functional as Fdef np2th(weights, conv=False):"""Possibly convert HWIO to OIHW."""if conv:weights = weights.transpose([3, 2, 0, 1])return torch.from_numpy(weights)class StdConv2d(nn.Conv2d):def forward(self, x):w = self.weightv, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)w = (w - m) / torch.sqrt(v + 1e-5)return F.conv2d(x, w, self.bias, self.stride, self.padding,self.dilation, self.groups)def conv3x3(cin, cout, stride=1, groups=1, bias=False):return StdConv2d(cin, cout, kernel_size=3, stride=stride,padding=1, bias=bias, groups=groups)def conv1x1(cin, cout, stride=1, bias=False):return StdConv2d(cin, cout, kernel_size=1, stride=stride,padding=0, bias=bias)
class ResNetV2(nn.Module):"""Implementation of Pre-activation (v2) ResNet mode."""def __init__(self, block_units, width_factor):super().__init__()width = int(64 * width_factor)self.width = widthself.root = nn.Sequential(OrderedDict([('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),('gn', nn.GroupNorm(32, width, eps=1e-6)),('relu', nn.ReLU(inplace=True)),# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))]))self.body = nn.Sequential(OrderedDict([('block1', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +[(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],))),('block2', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +[(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],))),('block3', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +[(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],))),]))def forward(self, x):features = []b, c, in_size, _ = x.size()x = self.root(x)features.append(x)x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)for i in range(len(self.body)-1):x = self.body[i](x)right_size = int(in_size / 4 / (i+1))if x.size()[2] != right_size:pad = right_size - x.size()[2]assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]else:feat = xfeatures.append(feat)x = self.body[-1](x)return x, features[::-1]
class PreActBottleneck(nn.Module):"""Pre-activation (v2) bottleneck block."""def __init__(self, cin, cout=None, cmid=None, stride=1):super().__init__()cout = cout or cincmid = cmid or cout//4self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)self.conv1 = conv1x1(cin, cmid, bias=False)self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)self.conv2 = conv3x3(cmid, cmid, stride, bias=False)  # Original code has it on conv1!!self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)self.conv3 = conv1x1(cmid, cout, bias=False)self.relu = nn.ReLU(inplace=True)if (stride != 1 or cin != cout):# Projection also with pre-activation according to paper.self.downsample = conv1x1(cin, cout, stride, bias=False)self.gn_proj = nn.GroupNorm(cout, cout)def forward(self, x):# Residual branchresidual = xif hasattr(self, 'downsample'):residual = self.downsample(x)residual = self.gn_proj(residual)# Unit's branchy = self.relu(self.gn1(self.conv1(x)))y = self.relu(self.gn2(self.conv2(y)))y = self.gn3(self.conv3(y))y = self.relu(residual + y)return ydef load_from(self, weights, n_block, n_unit):conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])self.conv1.weight.copy_(conv1_weight)self.conv2.weight.copy_(conv2_weight)self.conv3.weight.copy_(conv3_weight)self.gn1.weight.copy_(gn1_weight.view(-1))self.gn1.bias.copy_(gn1_bias.view(-1))self.gn2.weight.copy_(gn2_weight.view(-1))self.gn2.bias.copy_(gn2_bias.view(-1))self.gn3.weight.copy_(gn3_weight.view(-1))self.gn3.bias.copy_(gn3_bias.view(-1))if hasattr(self, 'downsample'):proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])self.downsample.weight.copy_(proj_conv_weight)self.gn_proj.weight.copy_(proj_gn_weight.view(-1))self.gn_proj.bias.copy_(proj_gn_bias.view(-1))

由于只有在hybrid模式下才用到这部分的代码,所以目前并没有去了解为什么采用StdConv2d和GroupNorm,后面再去ViT里面找答案吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/36653.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新基建助推数字经济,CosmosAI率先布局AI超算租赁新纪元

伦敦, 8月14日 - 在英国伦敦隆重的Raffles OWO举办的欧盟数字超算新时代战略合作签约仪式&#xff0c;CosmosAI、Infinite Money Fund与Internet Research Lab三方强强联手&#xff0c;达成了历史性的合作协议&#xff0c;共同迈向超算租赁新纪元。 ​ 这次跨界的合作昭示了全球…

VR家装提升用户信任度,线上体验家装空间感

近些年&#xff0c;VR家装逐渐被各大装修公司引入&#xff0c;VR全景装修的盛行&#xff0c;大大增加了客户“所见即所得”的沉浸式体验感&#xff0c;不再是传统二维平面的看房模式&#xff0c;而是让客户通过视觉、听觉、交互等功能更加真实的体验家装后的效果。 对于传统家装…

BUUCTF 还原大师 1

题目描述&#xff1a; 我们得到了一串神秘字符串&#xff1a;TASC?O3RJMV?WDJKX?ZM,问号部分是未知大写字母&#xff0c;为了确定这个神秘字符串&#xff0c;我们通过了其他途径获得了这个字串的32位MD5码。但是我们获得它的32位MD5码也是残缺不全&#xff0c;E903???4D…

【Vue3】自动引入插件-`unplugin-auto-import`

Vue3自动引入插件-unplugin-auto-import&#xff0c;不必再手动 import 。 自动导入 api 按需为 Vite, Webpack, Rspack, Rollup 和 esbuild 。支持TypeScript。由unplugin驱动。 插件安装&#xff1a;unplugin-auto-import 配置vite.config.ts&#xff08;配置完后需要重启…

迪瑞克斯拉算法 — 优化

在上一篇迪瑞克斯拉算法中将功能实现了出来&#xff0c;完成了图集中从源点出发获取所有可达的点的最短距离的收集。 但在代码中getMinDistanceAndUnSelectNode()方法的实现并不简洁&#xff0c;每次获取minNode时&#xff0c;都需要遍历整个Map&#xff0c;时间复杂度太高。这…

HTML详解连载(5)

HTML详解连载&#xff08;5&#xff09; 专栏链接 [link](http://t.csdn.cn/xF0H3)下面进行专栏介绍 开始喽行高&#xff1a;设置多行文本的间距属性名属性值行高的测量方法 行高-垂直居中技巧 字体族属性名属性值示例扩展 font 复合属性使用场景复合属性示例注意 文本缩进属性…

UG NX二次开发(C#)-CAM自定义铣加工的出口环境

文章目录 1、前言2、自定义铣削加工操作3、出错原因4、解决方案4.1 MILL_USER的用户参数4.2 采用自定义铣削的方式生成自定义的dll4.2 配置加工的出口环境4.3 调用dll5、结论1、前言 作为一款大型的CAD/CAM软件, UG NX为我们提供了丰富的加工模板,通过加工模板能直接用于生成…

DTC服务(0x14 0x19 0x85)

DTC相关的服务有ReadDTCInformation (19) service&#xff0c;ControlDTCSetting (85) service和ReadDTCInformation (19) service ReadDTCInformation (19) service 该服务允许客户端从车辆内任意一台服务器或一组服务器中读取驻留在服务器中的诊断故障代码( DTC )信息的状态…

Web菜鸟教程 - Radis实现高性能数据库

Redis是用C语言开发的一个高性能键值对数据库&#xff0c;可用于数据缓存&#xff0c;主要用于处理大量数据的高访问负载。 也就是说&#xff0c;如果你对性能要求不高&#xff0c;不用Radis也是可以的。不过作为最自己写的程序有高要求的程序员&#xff0c;自然是要学一下的&a…

PHP Mysql查询全部全部返回字符串类型

设置pdo属性 $pdo->setAttribute(PDO::ATTR_EMULATE_PREPARES, true);

08-1_Qt 5.9 C++开发指南_QPainter绘图

文章目录 前言1. QPainter 绘图系统1.1 QPainter 与QPaintDevice1.2 paintEvent事件和绘图区1.3 QPainter 绘图的主要属性 2. QPen的主要功能3. QBrush的主要功能4. 渐变填充5. QPainter 绘制基本图形元件5.1 基本图像元件5.2 QpainterPath的使用 前言 本章所介绍内容基本在《…

chatserver服务器开发笔记

chatserver服务器开发笔记 1 chatserver2 开发环境3 编译 1 chatserver 集群聊天服务器和客户端代码&#xff0c;基于muduo、redis、mysql实现。 学习于https://fixbug.ke.qq.com/ 本人已经挂github&#xff1a;https://github.com/ZixinChen-S/chatserver/tree/main 需要该项…

kubernetes pod 资源限制 探针

资源限制 当定义 Pod 时可以选择性地为每个容器设定所需要的资源数量。 最常见的可设定资源是 CPU 和内存大小&#xff0c;以及其他类型的资源。 当为 Pod 中的容器指定了 request 资源时&#xff0c;代表容器运行所需的最小资源量&#xff0c;调度器就使用该信息来决定将 Pod …

Java课题笔记~ JSP开发模型

MVC 1.JSP演化历史 1. 早期只有servlet&#xff0c;只能使用response输出标签数据&#xff0c;非常麻烦 2. 后来有了jsp&#xff0c;简化了Servlet的开发&#xff0c;如果过度使用jsp&#xff0c;在jsp中即写大量的java代码&#xff0c;有写html表&#xff0c;造成难于维护&…

计算机网络实验4:HTTP、DNS协议分析

文章目录 1. 主要教学内容2. HTTP协议3. HTTP分析实验【实验目的】【实验原理】【实验内容】【实验思考】 4. HTTP分析实验可能遇到的问题4.1 捕捉不到http报文4.2 百度是使用HTTPS协议进行传输4.3 Wireshark获得数据太多如何筛选4.4 http报文字段含义不清楚General&#xff08…

Git和GitHub

文章目录 1.Git介绍2. 常用命令3. Git分支操作4. Git团队协作机制5. GitHub操作6. IDEA集成Git7.IDEA操作GitHub8. Gitee 1.Git介绍 Git免费的开源的分布式版本控制系统&#xff0c;可以快速高效从小到大的各种项目 Git易于学习&#xff0c;占地面积小&#xff0c;性能快。它…

【日常积累】RPM包依赖下载及私有yum仓库搭建

概述 某些时候&#xff0c;我们需要下载某个RPM包依赖的依赖。如某些内网环境&#xff0c;就需要自行准备rpm包。可以通过能上互联网的服务器进行相应的rpm包下载&#xff0c;然后在拷贝到相应的服务器安装&#xff0c;或者搭建自己的内容rpm包仓库。 查看*.rpm 包依赖&#…

Flink多流处理之Broadcast(广播变量)

写过Spark批处理的应该都知道,有一个广播变量broadcast这样的一个算子,可以优化我们计算的过程,有效的提高效率;同样在Flink中也有broadcast,简单来说和Spark中的类似,但是有所区别,首先Spark中的broadcast是静态的数据,而Flink中的broadcast是动态的,也就是源源不断的数据流.在…

什么是微服务?

2.微服务的优缺点 优点 单一职责原则每个服务足够内聚&#xff0c;足够小&#xff0c;代码容易理解&#xff0c;这样能聚焦一个指定的业务功能或业务需求&#xff1b;开发简单&#xff0c;开发效率提高&#xff0c;一个服务可能就是专一的只干一件事&#xff1b;微服务能够被小…

命令提示符之操作基础(Windows)

打开命令提示符 方法一 打开指定文件的文件夹&#xff0c;在路径栏里输入“cmd”&#xff0c;回车&#xff0c;就进入控制台了。默认路径就是指定文件夹的路径。 方法二 打开指定的文件夹&#xff0c;按住shift键&#xff0c;在空白处右击&#xff0c;在菜单栏中选择“在此处打…