DFSMN
SAN-M
python实现
import torch
import torch. nn as nn
import torch. nn. functional as Fclass PositionalEncoding ( nn. Module) : def __init__ ( self, d_model, dropout= 0.1 , max_len= 5000 ) : super ( PositionalEncoding, self) . __init__( ) self. dropout = nn. Dropout( p= dropout) position = torch. arange( max_len) . unsqueeze( 1 ) div_term = torch. exp( torch. arange( 0 , d_model, 2 ) * - ( torch. log( torch. tensor( 10000.0 ) ) / d_model) ) pe = torch. zeros( max_len, 1 , d_model) pe[ : , 0 , 0 : : 2 ] = torch. sin( position * div_term) pe[ : , 0 , 1 : : 2 ] = torch. cos( position * div_term) self. register_buffer( 'pe' , pe) def forward ( self, x) : x = x + self. pe[ : x. size( 0 ) ] return self. dropout( x) class SelfAttention ( nn. Module) : def __init__ ( self, in_features, out_features, dropout= 0.1 ) : super ( SelfAttention, self) . __init__( ) self. in_features = in_featuresself. out_features = out_featuresself. w_qs = nn. Linear( in_features, out_features, bias= False ) self. w_ks = nn. Linear( in_features, out_features, bias= False ) self. w_vs = nn. Linear( in_features, out_features, bias= False ) self. fc_out = nn. Linear( out_features, out_features, bias= False ) self. dropout = nn. Dropout( dropout) self. softmax = nn. Softmax( dim= - 1 ) def forward ( self, q, k, v, mask= None ) : n_heads = self. w_qs. weight. size( 0 ) d_k = self. w_qs. weight. size( 1 ) // n_headsq = self. w_qs( q) . view( q. size( 0 ) , q. size( 1 ) , n_heads, d_k) k = self. w_ks( k) . view( k. size( 0 ) , k. size( 1 ) , n_heads, d_k) v = self. w_vs( v) . view( v. size( 0 ) , v. size( 1 ) , n_heads, d_k) scores = torch. matmul( q. transpose( 1 , 2 ) , k. transpose( 1 , 3 ) ) / d_k ** 0.5 if mask is not None : scores = scores. masked_fill( mask == 0 , - 1e9 ) attn = self. softmax( scores) output = torch. matmul( attn, v) . transpose( 1 , 2 ) . contiguous( ) output = output. view( output. size( 0 ) , output. size( 1 ) , - 1 ) output = self. fc_out( output) return output, attnclass SANMEncoderLayer ( nn. Module) : def __init__ ( self, size, self_attn, feed_forward, dropout= 0.1 ) : super ( SANMEncoderLayer, self) . __init__( ) self. self_attn = self_attnself. feed_forward = feed_forwardself. norm1 = nn. LayerNorm( size) self. norm2 = nn. LayerNorm( size) self. dropout = nn. Dropout( dropout) def forward ( self, x, mask) : residual = xx = self. norm1( x) x, _ = self. self_attn( x, x, x, mask) x = F. relu( x) x = self. dropout( x) x = residual + xx = self. norm2( x) residual = xx = self. feed_forward( x) x = self. dropout( x) x = residual + xreturn xclass SANMEncoder ( nn. Module) : def __init__ ( self, input_dim, num_layers, size, num_heads, ff_size, dropout= 0.1 ) : super ( SANMEncoder, self) . __init__( ) self. embedding = PositionalEncoding( size) self. layers = nn. ModuleList( [ SANMEncoderLayer( size, SelfAttention( size, size) , nn. Linear( size, ff_size) , dropout) for _ in range ( num_layers) ] ) def forward ( self, x, mask) : x = self. embedding( x) for layer in self. layers: x = layer( x, mask) return x