时间嵌入
1 傅里叶时间嵌入
class GaussianFourierProjection(nn.Module):"""Gaussian Fourier embeddings for noise levels."""def __init__(self, embedding_size=256, scale=1.0):super().__init__()self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)def forward(self, x):x_proj = x[:, None] * self.W[None, :] * 2 * np.pireturn torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
2,SinusoidalPosEmb
class SinusoidalPosEmb(nn.Module):def __init__(self, dim):super().__init__()self.dim = dimdef forward(self, x):device = x.devicehalf_dim = self.dim // 2emb = math.log(10000) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, device=device) * -emb)emb = x[:, None] * emb[None, :]emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb