【ECCV2022】Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation

Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation

论文:https://arxiv.org/abs/2105.05537

代码:https://github.com/HuCaoFighting/Swin-Unet

解读:Swin-UNet:基于纯 Transformer 结构的语义分割网络 - 知乎 (zhihu.com)

介绍

文如题名,本文使用纯Transformer构建Unet网络,用于医学图像分割。

本文用Swin Transformer替换Unet全部结构,构建出Swin-UNet。

方法

Swin-UNet 

Swin-UNet由Encoder、Bottleneck、Decoder和跳跃连接组成。

  • Encoder,输入图像先进行patch partition,减小长宽,再经过linear embedding和两个Swin Transformer block,再通过patch merging进行下采样。经过多次,每次按2倍进行下采样。
  • Bottleneck,用了两个连续的Swin Transformer block.
  • Decoder。主要由patch expanding来实现上采样,作为一个完全对称的网络结构,解码器也是每次扩大2倍进行上采样,核心模块由Swin Transformer block和patch expanding组成。

 Swin transformer block

一个Swin Transformer block由一个W-MSA和一个SW-MSA组成。 

使用带有偏移窗口的分层Swin Transformer作为编码器来提取上下文特征。 

并设计了具有补丁扩展层的基于对称Swin Transformer的解码器来执行上采样操作,以恢复特征图的空间分辨率。

关键代码

Swin transformer block

# https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.pyclass SwinTransformerBlock(nn.Module):r""" Swin Transformer Block.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resulotion.num_heads (int): Number of attention heads.window_size (int): Window size.shift_size (int): Shift size for SW-MSA.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioif min(self.input_resolution) <= self.window_size:# if window size is larger than input resolution, we don't partition windowsself.shift_size = 0self.window_size = min(self.input_resolution)assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)if self.shift_size > 0:# calculate attention mask for SW-MSAH, W = self.input_resolutionimg_mask = torch.zeros((1, H, W, 1))  # 1 H W 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))else:attn_mask = Noneself.register_buffer("attn_mask", attn_mask)def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return xdef extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"def flops(self):flops = 0H, W = self.input_resolution# norm1flops += self.dim * H * W# W-MSA/SW-MSAnW = H * W / self.window_size / self.window_sizeflops += nW * self.attn.flops(self.window_size * self.window_size)# mlpflops += 2 * H * W * self.dim * self.dim * self.mlp_ratio# norm2flops += self.dim * H * Wreturn flops

Swin-UNet

# https://github.com/HuCaoFighting/Swin-Unet/blob/main/networks/swin_transformer_unet_skip_expand_decoder_sys.pyimport torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from timm.models.layers import DropPath, to_2tuple, trunc_normal_class Mlp(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.fc2 = nn.Linear(hidden_features, out_features)self.drop = nn.Dropout(drop)def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop(x)x = self.fc2(x)x = self.drop(x)return xdef window_partition(x, window_size):"""Args:x: (B, H, W, C)window_size (int): window sizeReturns:windows: (num_windows*B, window_size, window_size, C)"""B, H, W, C = x.shapex = x.view(B, H // window_size, window_size, W // window_size, window_size, C)windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)return windowsdef window_reverse(windows, window_size, H, W):"""Args:windows: (num_windows*B, window_size, window_size, C)window_size (int): Window sizeH (int): Height of imageW (int): Width of imageReturns:x: (B, H, W, C)"""B = int(windows.shape[0] / (H * W / window_size / window_size))x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)return xclass WindowAttention(nn.Module):r""" Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.Args:dim (int): Number of input channels.window_size (tuple[int]): The height and width of the window.num_heads (int): Number of attention heads.qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if setattn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_size  # Wh, Wwself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5# define a parameter table of relative position biasself.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH# get pair-wise relative position index for each token inside the windowcoords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Wwcoords_flatten = torch.flatten(coords, 1)  # 2, Wh*Wwrelative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Wwrelative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Wwself.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)def forward(self, x, mask=None):"""Args:x: input features with shape of (num_windows*B, N, C)mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None"""B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return xdef extra_repr(self) -> str:return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'def flops(self, N):# calculate flops for 1 window with token length of Nflops = 0# qkv = self.qkv(x)flops += N * self.dim * 3 * self.dim# attn = (q @ k.transpose(-2, -1))flops += self.num_heads * N * (self.dim // self.num_heads) * N#  x = (attn @ v)flops += self.num_heads * N * N * (self.dim // self.num_heads)# x = self.proj(x)flops += N * self.dim * self.dimreturn flopsclass SwinTransformerBlock(nn.Module):r""" Swin Transformer Block.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resulotion.num_heads (int): Number of attention heads.window_size (int): Window size.shift_size (int): Shift size for SW-MSA.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.num_heads = num_headsself.window_size = window_sizeself.shift_size = shift_sizeself.mlp_ratio = mlp_ratioif min(self.input_resolution) <= self.window_size:# if window size is larger than input resolution, we don't partition windowsself.shift_size = 0self.window_size = min(self.input_resolution)assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)if self.shift_size > 0:# calculate attention mask for SW-MSAH, W = self.input_resolutionimg_mask = torch.zeros((1, H, W, 1))  # 1 H W 1h_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))w_slices = (slice(0, -self.window_size),slice(-self.window_size, -self.shift_size),slice(-self.shift_size, None))cnt = 0for h in h_slices:for w in w_slices:img_mask[:, h, w, :] = cntcnt += 1mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1mask_windows = mask_windows.view(-1, self.window_size * self.window_size)attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))else:attn_mask = Noneself.register_buffer("attn_mask", attn_mask)def forward(self, x):H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = xx = self.norm1(x)x = x.view(B, H, W, C)# cyclic shiftif self.shift_size > 0:shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))else:shifted_x = x# partition windowsx_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, Cx_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C# W-MSA/SW-MSAattn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C# merge windowsattn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C# reverse cyclic shiftif self.shift_size > 0:x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))else:x = shifted_xx = x.view(B, H * W, C)# FFNx = shortcut + self.drop_path(x)x = x + self.drop_path(self.mlp(self.norm2(x)))return xdef extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"def flops(self):flops = 0H, W = self.input_resolution# norm1flops += self.dim * H * W# W-MSA/SW-MSAnW = H * W / self.window_size / self.window_sizeflops += nW * self.attn.flops(self.window_size * self.window_size)# mlpflops += 2 * H * W * self.dim * self.dim * self.mlp_ratio# norm2flops += self.dim * H * Wreturn flopsclass PatchMerging(nn.Module):r""" Patch Merging Layer.Args:input_resolution (tuple[int]): Resolution of input feature.dim (int): Number of input channels.norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):super().__init__()self.input_resolution = input_resolutionself.dim = dimself.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)self.norm = norm_layer(4 * dim)def forward(self, x):"""x: B, H*W, C"""H, W = self.input_resolutionB, L, C = x.shapeassert L == H * W, "input feature has wrong size"assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."x = x.view(B, H, W, C)x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 Cx1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 Cx2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 Cx3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 Cx = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*Cx = x.view(B, -1, 4 * C)  # B H/2*W/2 4*Cx = self.norm(x)x = self.reduction(x)return xdef extra_repr(self) -> str:return f"input_resolution={self.input_resolution}, dim={self.dim}"def flops(self):H, W = self.input_resolutionflops = H * W * self.dimflops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dimreturn flopsclass PatchExpand(nn.Module):def __init__(self, input_resolution, dim, dim_scale=2, norm_layer=nn.LayerNorm):super().__init__()self.input_resolution = input_resolutionself.dim = dimself.expand = nn.Linear(dim, 2*dim, bias=False) if dim_scale==2 else nn.Identity()self.norm = norm_layer(dim // dim_scale)def forward(self, x):"""x: B, H*W, C"""H, W = self.input_resolutionx = self.expand(x)B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//4)x = x.view(B,-1,C//4)x= self.norm(x)return xclass FinalPatchExpand_X4(nn.Module):def __init__(self, input_resolution, dim, dim_scale=4, norm_layer=nn.LayerNorm):super().__init__()self.input_resolution = input_resolutionself.dim = dimself.dim_scale = dim_scaleself.expand = nn.Linear(dim, 16*dim, bias=False)self.output_dim = dim self.norm = norm_layer(self.output_dim)def forward(self, x):"""x: B, H*W, C"""H, W = self.input_resolutionx = self.expand(x)B, L, C = x.shapeassert L == H * W, "input feature has wrong size"x = x.view(B, H, W, C)x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//(self.dim_scale**2))x = x.view(B,-1,self.output_dim)x= self.norm(x)return xclass BasicLayer(nn.Module):""" A basic Swin Transformer layer for one stage.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resolution.depth (int): Number of blocks.num_heads (int): Number of attention heads.window_size (int): Local window size.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNormdownsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: Noneuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False."""def __init__(self, dim, input_resolution, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.depth = depthself.use_checkpoint = use_checkpoint# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim, input_resolution=input_resolution,num_heads=num_heads, window_size=window_size,shift_size=0 if (i % 2 == 0) else window_size // 2,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop, attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])# patch merging layerif downsample is not None:self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)else:self.downsample = Nonedef forward(self, x):for blk in self.blocks:if self.use_checkpoint:x = checkpoint.checkpoint(blk, x)else:x = blk(x)if self.downsample is not None:x = self.downsample(x)return xdef extra_repr(self) -> str:return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"def flops(self):flops = 0for blk in self.blocks:flops += blk.flops()if self.downsample is not None:flops += self.downsample.flops()return flopsclass BasicLayer_up(nn.Module):""" A basic Swin Transformer layer for one stage.Args:dim (int): Number of input channels.input_resolution (tuple[int]): Input resolution.depth (int): Number of blocks.num_heads (int): Number of attention heads.window_size (int): Local window size.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.drop (float, optional): Dropout rate. Default: 0.0attn_drop (float, optional): Attention dropout rate. Default: 0.0drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNormupsample (nn.Module | None, optional): upsample layer at the end of the layer. Default: Noneuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False."""def __init__(self, dim, input_resolution, depth, num_heads, window_size,mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., norm_layer=nn.LayerNorm, upsample=None, use_checkpoint=False):super().__init__()self.dim = dimself.input_resolution = input_resolutionself.depth = depthself.use_checkpoint = use_checkpoint# build blocksself.blocks = nn.ModuleList([SwinTransformerBlock(dim=dim, input_resolution=input_resolution,num_heads=num_heads, window_size=window_size,shift_size=0 if (i % 2 == 0) else window_size // 2,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop, attn_drop=attn_drop,drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,norm_layer=norm_layer)for i in range(depth)])# patch merging layerif upsample is not None:self.upsample = PatchExpand(input_resolution, dim=dim, dim_scale=2, norm_layer=norm_layer)else:self.upsample = Nonedef forward(self, x):for blk in self.blocks:if self.use_checkpoint:x = checkpoint.checkpoint(blk, x)else:x = blk(x)if self.upsample is not None:x = self.upsample(x)return xclass PatchEmbed(nn.Module):r""" Image to Patch EmbeddingArgs:img_size (int): Image size.  Default: 224.patch_size (int): Patch token size. Default: 4.in_chans (int): Number of input image channels. Default: 3.embed_dim (int): Number of linear projection output channels. Default: 96.norm_layer (nn.Module, optional): Normalization layer. Default: None"""def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):super().__init__()img_size = to_2tuple(img_size)patch_size = to_2tuple(patch_size)patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]self.img_size = img_sizeself.patch_size = patch_sizeself.patches_resolution = patches_resolutionself.num_patches = patches_resolution[0] * patches_resolution[1]self.in_chans = in_chansself.embed_dim = embed_dimself.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = Nonedef forward(self, x):B, C, H, W = x.shape# FIXME look at relaxing size constraintsassert H == self.img_size[0] and W == self.img_size[1], \f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw Cif self.norm is not None:x = self.norm(x)return xdef flops(self):Ho, Wo = self.patches_resolutionflops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])if self.norm is not None:flops += Ho * Wo * self.embed_dimreturn flopsclass SwinTransformerSys(nn.Module):r""" Swin TransformerA PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -https://arxiv.org/pdf/2103.14030Args:img_size (int | tuple(int)): Input image size. Default 224patch_size (int | tuple(int)): Patch size. Default: 4in_chans (int): Number of input image channels. Default: 3num_classes (int): Number of classes for classification head. Default: 1000embed_dim (int): Patch embedding dimension. Default: 96depths (tuple(int)): Depth of each Swin Transformer layer.num_heads (tuple(int)): Number of attention heads in different layers.window_size (int): Window size. Default: 7mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Trueqk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: Nonedrop_rate (float): Dropout rate. Default: 0attn_drop_rate (float): Attention dropout rate. Default: 0drop_path_rate (float): Stochastic depth rate. Default: 0.1norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.ape (bool): If True, add absolute position embedding to the patch embedding. Default: Falsepatch_norm (bool): If True, add normalization after patch embedding. Default: Trueuse_checkpoint (bool): Whether to use checkpointing to save memory. Default: False"""def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,embed_dim=96, depths=[2, 2, 2, 2], depths_decoder=[1, 2, 2, 2], num_heads=[3, 6, 12, 24],window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,norm_layer=nn.LayerNorm, ape=False, patch_norm=True,use_checkpoint=False, final_upsample="expand_first", **kwargs):super().__init__()print("SwinTransformerSys expand initial----depths:{};depths_decoder:{};drop_path_rate:{};num_classes:{}".format(depths,depths_decoder,drop_path_rate,num_classes))self.num_classes = num_classesself.num_layers = len(depths)self.embed_dim = embed_dimself.ape = apeself.patch_norm = patch_normself.num_features = int(embed_dim * 2 ** (self.num_layers - 1))self.num_features_up = int(embed_dim * 2)self.mlp_ratio = mlp_ratioself.final_upsample = final_upsample# split image into non-overlapping patchesself.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)num_patches = self.patch_embed.num_patchespatches_resolution = self.patch_embed.patches_resolutionself.patches_resolution = patches_resolution# absolute position embeddingif self.ape:self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))trunc_normal_(self.absolute_pos_embed, std=.02)self.pos_drop = nn.Dropout(p=drop_rate)# stochastic depthdpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule# build encoder and bottleneck layersself.layers = nn.ModuleList()for i_layer in range(self.num_layers):layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),input_resolution=(patches_resolution[0] // (2 ** i_layer),patches_resolution[1] // (2 ** i_layer)),depth=depths[i_layer],num_heads=num_heads[i_layer],window_size=window_size,mlp_ratio=self.mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],norm_layer=norm_layer,downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,use_checkpoint=use_checkpoint)self.layers.append(layer)# build decoder layersself.layers_up = nn.ModuleList()self.concat_back_dim = nn.ModuleList()for i_layer in range(self.num_layers):concat_linear = nn.Linear(2*int(embed_dim*2**(self.num_layers-1-i_layer)),int(embed_dim*2**(self.num_layers-1-i_layer))) if i_layer > 0 else nn.Identity()if i_layer ==0 :layer_up = PatchExpand(input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))), dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)), dim_scale=2, norm_layer=norm_layer)else:layer_up = BasicLayer_up(dim=int(embed_dim * 2 ** (self.num_layers-1-i_layer)),input_resolution=(patches_resolution[0] // (2 ** (self.num_layers-1-i_layer)),patches_resolution[1] // (2 ** (self.num_layers-1-i_layer))),depth=depths[(self.num_layers-1-i_layer)],num_heads=num_heads[(self.num_layers-1-i_layer)],window_size=window_size,mlp_ratio=self.mlp_ratio,qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate,drop_path=dpr[sum(depths[:(self.num_layers-1-i_layer)]):sum(depths[:(self.num_layers-1-i_layer) + 1])],norm_layer=norm_layer,upsample=PatchExpand if (i_layer < self.num_layers - 1) else None,use_checkpoint=use_checkpoint)self.layers_up.append(layer_up)self.concat_back_dim.append(concat_linear)self.norm = norm_layer(self.num_features)self.norm_up= norm_layer(self.embed_dim)if self.final_upsample == "expand_first":print("---final upsample expand_first---")self.up = FinalPatchExpand_X4(input_resolution=(img_size//patch_size,img_size//patch_size),dim_scale=4,dim=embed_dim)self.output = nn.Conv2d(in_channels=embed_dim,out_channels=self.num_classes,kernel_size=1,bias=False)self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)@torch.jit.ignoredef no_weight_decay(self):return {'absolute_pos_embed'}@torch.jit.ignoredef no_weight_decay_keywords(self):return {'relative_position_bias_table'}#Encoder and Bottleneckdef forward_features(self, x):x = self.patch_embed(x)if self.ape:x = x + self.absolute_pos_embedx = self.pos_drop(x)x_downsample = []for layer in self.layers:x_downsample.append(x)x = layer(x)x = self.norm(x)  # B L Creturn x, x_downsample#Dencoder and Skip connectiondef forward_up_features(self, x, x_downsample):for inx, layer_up in enumerate(self.layers_up):if inx == 0:x = layer_up(x)else:x = torch.cat([x,x_downsample[3-inx]],-1)x = self.concat_back_dim[inx](x)x = layer_up(x)x = self.norm_up(x)  # B L Creturn xdef up_x4(self, x):H, W = self.patches_resolutionB, L, C = x.shapeassert L == H*W, "input features has wrong size"if self.final_upsample=="expand_first":x = self.up(x)x = x.view(B,4*H,4*W,-1)x = x.permute(0,3,1,2) #B,C,H,Wx = self.output(x)return xdef forward(self, x):x, x_downsample = self.forward_features(x)x = self.forward_up_features(x,x_downsample)x = self.up_x4(x)return xdef flops(self):flops = 0flops += self.patch_embed.flops()for i, layer in enumerate(self.layers):flops += layer.flops()flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)flops += self.num_features * self.num_classesreturn flops

实验

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/48313.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

并查集及其简单应用

文章目录 一.并查集二.并查集的实现三.并查集的基本应用 一.并查集 并查集的逻辑结构:由多颗不相连通的多叉树构成的森林(一个这样的多叉树就是森林的一个连通分量) 并查集的元素(树节点)用0~9的整数表示,并查集可以表示如下: 并查集的物理存储结构:并查集一般采用顺序结构实…

Qt与电脑管家4

折线图&#xff1a; #ifndef LINE_CHART_H #define LINE_CHART_H#include <QWidget> #include <QPainter> #include "circle.h" class line_chart : public QWidget {Q_OBJECT public:explicit line_chart(QWidget *parent nullptr); protected:void pa…

手机直播源码开发,协议讨论篇(三):RTMP实时消息传输协议

实时消息传输协议RTMP简介 RTMP又称实时消息传输协议&#xff0c;是一种实时通信协议。在当今数字化时代&#xff0c;手机直播源码平台为全球用户进行服务&#xff0c;如何才能增加用户&#xff0c;提升用户黏性&#xff1f;就需要让一对一直播平台能够为用户提供优质的体验。…

【私有GPT】CHATGLM-6B部署教程

【私有GPT】CHATGLM-6B部署教程 CHATGLM-6B是什么&#xff1f; ChatGLM-6B是清华大学知识工程和数据挖掘小组&#xff08;Knowledge Engineering Group (KEG) & Data Mining at Tsinghua University&#xff09;发布的一个开源的对话机器人。根据官方介绍&#xff0c;这是…

打开软件提示mfc100u.dll缺失是什么意思?要怎么处理?

当你打开某个软件或者运行游戏&#xff0c;系统提示mfc100u.dll丢失&#xff0c;此时这个软件或者游戏根本无法运行。其实&#xff0c;mfc100u.dll是动态库文件&#xff0c;它是VS2010编译的软件所产生的&#xff0c;如果电脑运行程序时提示缺少mfc100u.dll文件&#xff0c;程序…

【Linux】网络层协议:IP

我们必须接受批评&#xff0c;因为它可以帮助我们走出自恋的幻象&#xff0c;不至于长久在道德和智识上自我陶醉&#xff0c;在自恋中走向毁灭&#xff0c;事实上我们远比自己想象的更伪善和幽暗。 文章目录 一、IP和TCP之间的关系&#xff08;提供策略 和 提供能力&#xff09…

中英双语对话大语言模型:ChatGLM-6B

介绍 ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型&#xff0c;基于 General Language Model (GLM) 架构&#xff0c;具有 62 亿参数。结合模型量化技术&#xff0c;用户可以在消费级的显卡上进行本地部署&#xff08;INT4 量化级别下最低只需 6GB 显存&#xff09;。…

罗勇军 →《算法竞赛·快冲300题》每日一题:“超级骑士” ← DFS

【题目来源】http://oj.ecustacm.cn/problem.php?id1810http://oj.ecustacm.cn/viewnews.php?id1023https://www.acwing.com/problem/content/3887/【题目描述】 现在在一个无限大的平面上&#xff0c;给你一个超级骑士。 超级骑士有N种走法&#xff0c;请问这个超级骑士能否…

【Liunx】冯诺伊曼体系结构

冯诺伊曼体系结构 我们常见的计算机&#xff0c;如笔记本。我们不常见的计算机&#xff0c;如服务器&#xff0c;大部分都遵守冯诺伊曼体系。 到目前为止&#xff0c;我们所认识的计算机&#xff0c;都是由一个个硬件所组成的。 输入单元&#xff1a;键盘&#xff0c;鼠标&am…

情报与GPT技术大幅降低鱼叉攻击成本

邮件鱼叉攻击&#xff08;spear phishing attack&#xff09;是一种高度定制化的网络诈骗手段&#xff0c;攻击者通常假装是受害人所熟知的公司或组织发送电子邮件&#xff0c;以骗取受害人的个人信息或企业机密。 以往邮件鱼叉攻击需要花费较多的时间去采集情报、深入了解受…

Java【HTTP】什么是 Cookie 和 Session? 如何理解这两种机制的区别和作用?

文章目录 前言一、Cookie1, 什么是 Cookie2, Cookie 从哪里来3, Cookie 到哪里去4, Cookie 有什么用 二、Session1, 什么是 Session2, 理解 Session 三、Cookie 和 Session 的区别总结 前言 各位读者好, 我是小陈, 这是我的个人主页, 希望我的专栏能够帮助到你: &#x1f4d5; …

2023国赛数学建模A题B题C题D题资料思路汇总 高教社杯

本次比赛我们将会全程更新思路模型及代码&#xff0c;大家查看文末名片获取 之前国赛相关的资料和助攻可以查看 2022数学建模国赛C题思路分析_2022年数学建模c题思路_UST数模社_的博客-CSDN博客 2022国赛数学建模A题B题C题D题资料思路汇总 高教社杯_2022国赛a题题目_UST数模…

[保研/考研机试] KY212 二叉树遍历 华中科技大学复试上机题 C++实现

题目链接&#xff1a; 二叉树遍历_牛客题霸_牛客网二叉树的前序、中序、后序遍历的定义&#xff1a; 前序遍历&#xff1a;对任一子树&#xff0c;先访问根&#xff0c;然后遍历其左子树&#xff0c;最。题目来自【牛客题霸】https://www.nowcoder.com/share/jump/43719512169…

Apipost数据模型功能详解

在API设计和开发过程中&#xff0c;存在许多瓶颈&#xff0c;其中一个主要问题是在遇到相似数据结构的API时会产生重复性较多的工作&#xff1a;在每个API中都编写相同的数据&#xff0c;这不仅浪费时间和精力&#xff0c;还容易出错并降低API的可维护性。 为了解决这个问题&a…

注册中心/配置管理 —— SpringCloud Consul

Consul 概述 Consul 是一个可以提供服务发现&#xff0c;健康检查&#xff0c;多数据中心&#xff0c;key/Value 存储的分布式服务框架&#xff0c;用于实现分布式系统的发现与配置。Cousul 使用 Go 语言实现&#xff0c;因此天然具有可移植性&#xff0c;安装包仅包含一个可执…

MySql014——分组的GROUP BY子句和排序ORDER BYSELECT子句顺序

前提&#xff1a;使用《MySql006——检索数据&#xff1a;基础select语句》中创建的products表 一、GROUP BY子句基础用法 SELECT vend_id, COUNT(*) AS num_prods FROMstudy.products GROUP BY vend_id;上面的SELECT语句指定了两个列&#xff0c;vend_id包含产品供应商的ID&…

构建系统自动化-autoreconf

autoreconf简介 autoreconf是一个GNU Autotools工具集中的一个命令&#xff0c;用于自动重新生成构建系统的配置脚本和相关文件。 Autotools是一组用于自动化构建系统的工具&#xff0c;包括Autoconf、Automake和Libtool。它们通常用于跨平台的软件项目&#xff0c;以便在不同…

【数据结构与算法】迪杰斯特拉算法

迪杰斯特拉算法 介绍 迪杰斯特拉&#xff08;Dijkstra&#xff09;算法是典型最短路径算法&#xff0c;用于计算一个节点到其他节点的最短路径。它的主要特点是以中心向外层层扩展&#xff08;广度优先搜索思想&#xff09;&#xff0c;直到扩展到终点为止。 算法过程 设置…

离谱的Bug

离谱的 Bug Bug 情况发现 Bug修改 Bug其他感受历史 Bug火星Spirit号Mars Global Surveyor任务 Bug 情况 有一次&#xff0c;我在开发一个网页应用程序时&#xff0c;遇到了一个令人目瞪口呆的Bug。这个Bug出现在一个特定的页面上&#xff0c;当用户点击某个按钮时&#xff0c;…

Redis 十大数据类型

Redis数据类型都有哪些&#xff1f; Redis支持丰富的数据类型&#xff0c;那么具体在Redis7中都有哪些数据类型呢&#xff1f;请看下图&#xff1a; 官网介绍&#xff1a;https://redis.io/docs/data-types/。 其中&#xff0c;String、Hash、List、Set、Sorted Set等类型是大…