PositionWiseFeedForward 类的代码
class PositionWiseFeedForward(nn.Module):def __init__(self, num_hidden, num_ff):super(PositionWiseFeedForward, self).__init__()self.W_in = nn.Linear(num_hidden, num_ff, bias=True)self.W_out = nn.Linear(num_ff, num_hidden, bias=True)self.act = torch.nn.GELU()def forward(self, h_V):h = self.act(self.W_in(h_V))h = self.W_out(h)return h
PositionWiseFeedForward
在每个序列位置独立地对每个 token 的表示进行两次线性变换,中间通过一个非线性激活函数,先扩展维度后再还原。它不会引入序列间的信息交互,只会对每个位置的 token 进行单独的处理,因此称为“逐位置前馈网络”。
在 Transformer 网络中,PositionWiseFeedForward
是每层 Transformer 结构的标准组成部分,用于提升模型的表达能力,能够捕捉序列中每个位置的更复杂特征。
DecLayer 类的代码
class DecLayer(nn.Module):def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=