DFSMN
SAN-M
python实现
import torch
import torch.nn as nn
import torch.nn.functional as Fclass PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)position = torch.arange(max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * -(torch.log(torch.tensor(10000.0)) / d_model))pe = torch.zeros(max_len, 1, d_model)pe[:, 0, 0::2] = torch.sin(position * div_term)pe[:, 0, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(0)]return self.dropout(x)class SelfAttention(nn.Module):def __init__(self, in_features, out_features, dropout=0.1):super(SelfAttention, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.w_qs = nn.Linear(in_features, out_features, bias=False)self.w_ks = nn.Linear(in_features, out_features, bias=False)self.w_vs = nn.Linear(in_features, out_features, bias=False)self.fc_out = nn.Linear(out_features, out_features, bias=False)self.dropout = nn.Dropout(dropout)self.softmax = nn.Softmax(dim=-1)def forward(self, q, k, v, mask=None):n_heads = self.w_qs.weight.size(0)d_k = self.w_qs.weight.size(1) // n_headsq = self.w_qs(q).view(q.size(0), q.size(1), n_heads, d_k)k = self.w_ks(k).view(k.size(0), k.size(1), n_heads, d_k)v = self.w_vs(v).view(v.size(0), v.size(1), n_heads, d_k)scores = torch.matmul(q.transpose(1, 2), k.transpose(1, 3)) / d_k ** 0.5if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)attn = self.softmax(scores)output = torch.matmul(attn, v).transpose(1, 2).contiguous()output = output.view(output.size(0), output.size(1), -1)output = self.fc_out(output)return output, attnclass SANMEncoderLayer(nn.Module):def __init__(self, size, self_attn, feed_forward, dropout=0.1):super(SANMEncoderLayer, self).__init__()self.self_attn = self_attnself.feed_forward = feed_forwardself.norm1 = nn.LayerNorm(size)self.norm2 = nn.LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, mask):residual = xx = self.norm1(x)x, _ = self.self_attn(x, x, x, mask)x = F.relu(x)x = self.dropout(x)x = residual + xx = self.norm2(x)residual = xx = self.feed_forward(x)x = self.dropout(x)x = residual + xreturn xclass SANMEncoder(nn.Module):def __init__(self, input_dim, num_layers, size, num_heads, ff_size, dropout=0.1):super(SANMEncoder, self).__init__()self.embedding = PositionalEncoding(size)self.layers = nn.ModuleList([SANMEncoderLayer(size, SelfAttention(size, size), nn.Linear(size, ff_size), dropout)for _ in range(num_layers)])def forward(self, x, mask):x = self.embedding(x)for layer in self.layers:x = layer(x, mask)return x