文章目录
- 1. 位置编码
- 1.2 EncoderLayer
- 1.3 Encoder
- 1.4 STNDT
1. 位置编码
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
import math
UNMASKED_LABEL = -100class PositionalEncoding(nn.Module):def __init__(self, trial_length, d_model, dropout):super().__init__()self.dropout = nn.Dropout(dropout)pe = torch.zeros(trial_length, d_model)position = torch.arange(0, trial_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)if d_model % 2 == 0:pe[:, 1::2] = torch.cos(position * div_term)else:pe[:, 1::2] = torch.cos(position * div_term[:-1])pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(0), :]return self.dropout(x)
1.2 EncoderLayer
model.py
核心编码层,加入了将空间注意力编码
class STNTransformerEncoderLayer(TransformerEncoderLayer):def __init__(self, d_model, d_model_s, num_heads=2, dim_feedforward=128, dropout=0.1, activation='relu'):super().__init__(d_model,nhead=num_heads,dim_feedforward=dim_feedforward,dropout=dropout,activation=activation)self.num_heads = num_headsself.num_input = d_modelself.d_model_s = d_model_s # d_model_s: 时间步数(例如 160), 用于空间自注意力self.spatial_self_attn = MultiheadAttention(embed_dim=d_model_s, num_heads=num_heads)self.spatial_norm1 = nn.LayerNorm(d_model_s)self.ts_norm1 = nn.LayerNorm(d_model)self.ts_norm2 = nn.LayerNorm(d_model)self.ts_linear1 = nn.Linear(d_model, dim_feedforward)self.ts_linear2 = nn.Linear(dim_feedforward, d_model)self.ts_dropout1 = nn.Dropout(dropout)self.ts_dropout2 = nn.Dropout(dropout)self.ts_dropout3 = nn.Dropout(dropout)def attend(self, src, context_mask=None, **kwargs):attn_res = self.self_attn(src, src, src, attn_mask=context_mask, **kwargs)return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))def spatial_attend(self, src, context_mask=None, **kwargs):r"""Attends over spatial dimensionArgs:src: spatiotemporal neural population inputcontext_mask: spatial context maskReturns:spatiotemporal neural population activity transformed by spatial attention"""attn_res = self.spatial_self_attn(src, src, src, attn_mask=context_mask, **kwargs)return (*attn_res, torch.tensor(0, device=src.device, dtype=torch.float))def forward(self, src, spatial_src, src_mask=None, spatial_src_mask=None, src_key_padding_mask=None):# temporalresidual = srcsrc = self.norm1(src)t_out, t_weights, _ = self.attend(src, context_mask=src_mask, key_padding_mask=src_key_padding_mask)src = residual + self.dropout1(t_out)residual = srcsrc = self.norm2(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))src = residual + self.dropout2(src2)# spatialspatial_src = self.spatial_norm1(spatial_src)spatial_out, spatial_weights, _ = self.spatial_attend(spatial_src,context_mask=spatial_src_mask, key_padding_mask=None)# spatio-temporal feature mixturets_residual = srcsrc = self.ts_norm1(src)ts_out = torch.bmm(spatial_weights, src.permute(1, 2, 0)).permute(2, 0, 1)ts_out = ts_residual + self.ts_dropout1(ts_out)ts_residual = ts_outts_out = self.ts_norm2(ts_out)ts_out = self.ts_linear2(self.ts_dropout2(self.activation(self.ts_linear1(ts_out))))ts_out = ts_residual + self.ts_dropout3(ts_out)return ts_out
1.3 Encoder
model.py
class STNTransformerEncoder(TransformerEncoder):def __init__(self, encoder_layer, num_layers, norm=None):super().__init__(encoder_layer, num_layers, norm)def forward(self, src, spatial_src, mask=None, spatial_mask=None):for i, mod in enumerate(self.layers):if i == 0:src = mod(src, spatial_src, src_mask=mask, spatial_src_mask=spatial_mask)else:src = mod(src, src.permute(2, 1, 0), src_mask=mask, spatial_src_mask=spatial_mask)if self.norm is not None:src = self.norm(src)return src
1.4 STNDT
model.py
class SpatioTemporalNDT(nn.Module):def __init__(self, trial_length, num_neurons, temperature=0.1, c_lambda=0.3, dropout=0.2, pos_drop=0.1, enc_layers=1, log_rates=True,enc_heads=2, enc_dff=128, enc_drop=0.1) -> None:super().__init__()self.src_mask = Noneself.num_input = num_neuronsself.num_spatial_input = trial_lengthself.embedder = nn.Identity()self.spatial_embedder = nn.Identity()self.scale = math.sqrt(num_neurons)self.spatial_scale = math.sqrt(trial_length)self.src_pos_encoder = PositionalEncoding(trial_length, num_neurons, pos_drop)self.spatial_pos_encoder = PositionalEncoding(num_neurons, trial_length, pos_drop)self.projector = nn.Identity()self.spatial_projector = nn.Identity()self.n_views = 2self.temperature = temperatureself.contrast_lambda = c_lambdaself.cel = nn.CrossEntropyLoss(reduction='none')self.mse = nn.MSELoss(reduction='mean')encoder_layer =STNTransformerEncoderLayer(d_model=self.num_input,d_model_s=self.num_spatial_input, num_heads=enc_heads,dim_feedforward=enc_dff,dropout=enc_drop)self.transformer_encoder = STNTransformerEncoder(encoder_layer, enc_layers, nn.LayerNorm(self.num_input))self.rate_dropout = nn.Dropout(dropout)self.src_decoder = nn.Linear(num_neurons, self.num_input)self.classifier = nn.PoissonNLLLoss(reduction='none', log_input=log_rates)def _get_mask(self, src, do_convert=True):if self.src_mask is not None:return self.src_masksize = src.size(0)context_forward = 13context_backward = 79mask = (torch.triu(torch.ones(size, size), diagonal=-context_forward) == 1).transpose(0, 1)back_mask = (torch.triu(torch.ones(size, size), diagonal=-context_backward) == 1)mask = mask & back_maskmask = mask.float()mask = binary_mask_to_attn_mask(mask)self.src_mask = maskreturn self.src_maskdef forward(self, src: torch.Tensor, mask_labels: torch.Tensor):src = src.float()spatial_src = src.permute(2,0,1)spatial_src = self.spatial_embedder(spatial_src) * self.spatial_scalespatial_src = self.spatial_pos_encoder(spatial_src)src = src.permute(1,0,2)src = self.embedder(src) * self.scalesrc = self.src_pos_encoder(src)src_mask = self._get_mask(src)spatial_src_mask = Noneencoder_output = self.transformer_encoder(src, spatial_src, src_mask, spatial_src_mask)encoder_output = self.rate_dropout(encoder_output)decoder_output = self.src_decoder(encoder_output)decoder_rates = decoder_output.permute(1, 0, 2)decoder_loss = self.classifier(decoder_rates, mask_labels)masked_decoder_loss = decoder_loss[mask_labels != UNMASKED_LABEL]masked_decoder_loss = masked_decoder_loss.mean()return masked_decoder_loss, decoder_ratesdef binary_mask_to_attn_mask(x):return x.float().masked_fill(x == 0, float('-inf')).masked_fill(x == 1, float(0.0))
下一篇: https://blog.csdn.net/weixin_46866349/article/details/139906391