QueryEncoding
类用于在输入张量 x
上添加一种查询序列的特殊编码。这里的查询编码将第一个序列标记为查询序列,并将其与其他序列区分开。以下是代码中的细节和每一步的作用。
源码:
class QueryEncoding(nn.Module):def __init__(self, d_model):super(QueryEncoding, self).__init__()self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)def forward(self, x):B, N, L, K = x.shapeidx = torch.ones((B, N, L), device=x.device).long()idx[:,0,:] = 0 # first sequence is the queryx = x + self.pe(idx)return x
代码解读:
class QueryEncoding(nn.Module):def __init__(self, d_model):super(QueryEncoding, self).__init__()self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)def forward(self, x):B, N, L, K = x.shapeidx = torch.ones((B, N, L), device=x.device).long()i