1.
import torch. nn as nn
import torch
from einops import rearrange, repeat
from einops. layers. torch import Rearrange
import torch. nn. functional as Fclass PreNorm ( nn. Module) : def __init__ ( self, dim, fn) : super ( ) . __init__( ) self. norm = nn. LayerNorm( dim) self. fn = fndef forward ( self, x, ** kwargs) : return self. fn( self. norm( x) , ** kwargs) class FeedForward ( nn. Module) : def __init__ ( self, dim, hidden_dim, dropout= 0. ) : super ( ) . __init__( ) self. net = nn. Sequential( nn. Linear( dim, hidden_dim) , nn. GELU( ) , nn. Dropout( dropout) , nn. Linear( hidden_dim, dim) , nn. Dropout( dropout) ) def forward ( self, x) : return self. net( x) class PPM ( nn. Module) : def __init__ ( self, pooling_sizes= ( 1 , 3 , 5 ) ) : super ( ) . __init__( ) self. layer = nn. ModuleList( [ nn. AdaptiveAvgPool2d( output_size= ( size, size) ) for size in pooling_sizes] ) def forward ( self, feat) : b, c, h, w = feat. shapeoutput = [ layer( feat) . view( b, c, - 1 ) for layer in self. layer] output = torch. cat( output, dim= - 1 ) return output
class ESA_layer ( nn. Module) : def __init__ ( self, dim, heads= 8 , dim_head= 64 , dropout= 0. ) : super ( ) . __init__( ) inner_dim = dim_head * headsproject_out = not ( heads == 1 and dim_head == dim) self. heads = headsself. scale = dim_head ** - 0.5 self. attend = nn. Softmax( dim= - 1 ) self. to_qkv = nn. Conv2d( dim, inner_dim * 3 , kernel_size= 1 , stride= 1 , padding= 0 , bias= False ) self. ppm = PPM( pooling_sizes= ( 1 , 3 , 5 ) ) self. to_out = nn. Sequential( nn. Linear( inner_dim, dim) , nn. Dropout( dropout) ) if project_out else nn. Identity( ) def forward ( self, x) : b, c, h, w = x. shapeq, k, v = self. to_qkv( x) . chunk( 3 , dim= 1 ) q = rearrange( q, 'b (head d) h w -> b head (h w) d' , head= self. heads) k, v = self. ppm( k) , self. ppm( v) k = rearrange( k, 'b (head d) n -> b head n d' , head= self. heads) v = rearrange( v, 'b (head d) n -> b head n d' , head= self. heads) dots = torch. matmul( q, k. transpose( - 1 , - 2 ) ) * self. scale attn = self. attend( dots) out = torch. matmul( attn, v) out = rearrange( out, 'b head n d -> b n (head d)' ) return self. to_out( out) class ESA_blcok ( nn. Module) : def __init__ ( self, dim, heads= 8 , dim_head= 64 , mlp_dim= 512 , dropout= 0. ) : super ( ) . __init__( ) self. ESAlayer = ESA_layer( dim, heads= heads, dim_head= dim_head, dropout= dropout) self. ff = PreNorm( dim, FeedForward( dim, mlp_dim, dropout= dropout) ) def forward ( self, x) : b, c, h, w = x. shapeout = rearrange( x, 'b c h w -> b (h w) c' ) out = self. ESAlayer( x) + outout = self. ff( out) + outout = rearrange( out, 'b (h w) c -> b c h w' , h= h) return out+ xdef MaskAveragePooling ( x, mask) : mask = torch. sigmoid( mask) b, c, h, w = x. shapeeps = 0.0005 x_mask = x * maskh, w = x. shape[ 2 ] , x. shape[ 3 ] area = F. avg_pool2d( mask, ( h, w) ) * h * w + epsx_feat = F. avg_pool2d( x_mask, ( h, w) ) * h * w / areax_feat = x_feat. view( b, c, - 1 ) return x_feat
class LCA_layer ( nn. Module) : def __init__ ( self, dim, heads= 8 , dim_head= 64 , dropout= 0. ) : super ( ) . __init__( ) inner_dim = dim_head * headsproject_out = not ( heads == 1 and dim_head == dim) self. heads = headsself. scale = dim_head ** - 0.5 self. attend = nn. Softmax( dim= - 1 ) self. to_qkv = nn. Conv2d( dim, inner_dim * 3 , kernel_size= 1 , stride= 1 , padding= 0 , bias= False ) self. to_out = nn. Sequential( nn. Linear( inner_dim, dim) , nn. Dropout( dropout) ) if project_out else nn. Identity( ) def forward ( self, x, mask) : b, c, h, w = x. shapeq, k, v = self. to_qkv( x) . chunk( 3 , dim= 1 ) q = rearrange( q, 'b (head d) h w -> b head (h w) d' , head= self. heads) k, v = MaskAveragePooling( k, mask) , MaskAveragePooling( v, mask) k = rearrange( k, 'b (head d) n -> b head n d' , head= self. heads) v = rearrange( v, 'b (head d) n -> b head n d' , head= self. heads) dots = torch. matmul( q, k. transpose( - 1 , - 2 ) ) * self. scale attn = self. attend( dots) out = torch. matmul( attn, v) out = rearrange( out, 'b head n d -> b n (head d)' ) return self. to_out( out) class LCA_blcok ( nn. Module) : def __init__ ( self, dim, heads= 8 , dim_head= 64 , mlp_dim= 512 , dropout= 0. ) : super ( ) . __init__( ) self. LCAlayer = LCA_layer( dim, heads= heads, dim_head= dim_head, dropout= dropout) self. ff = PreNorm( dim, FeedForward( dim, mlp_dim, dropout= dropout) ) def forward ( self, x, mask) : b, c, h, w = x. shapeout = rearrange( x, 'b c h w -> b (h w) c' ) out = self. LCAlayer( x, mask) + outout = self. ff( out) + outout = rearrange( out, 'b (h w) c -> b c h w' , h= h) return out
if __name__ == '__main__' : x = torch. rand( ( 4 , 3 , 320 , 320 ) ) mask = torch. rand( 4 , 1 , 320 , 320 ) lca = LCA_blcok( dim= 3 ) esa = ESA_blcok( dim= 3 ) print ( lca( x, mask) . shape) print ( esa( x) . shape)