AlphaFold3的AtomAttentionEncoder
类用于处理基于原子的表示学习任务。
源代码:
class AtomAttentionEncoderOutput(NamedTuple):"""Structured output class for AtomAttentionEncoder."""token_single: torch.Tensor # (bs, n_tokens, c_token)atom_single_skip_repr: torch.Tensor # (bs, n_atoms, c_atom)atom_single_skip_proj: torch.Tensor # (bs, n_atoms, c_atom)atom_pair_skip_repr: torch.Tensor # (bs, n_atoms // n_queries, n_queries, n_keys, c_atompair)class AtomAttentionEncoder(nn.Module):def __init__(self,c_token: int,c_atom: int = 128,c_atompair: int = 16,c_trunk_pair: int = 16,no_blocks: int = 3,no_heads: int = 4,dropout=0.0,n_queries: int = 32,n_keys: int = 128,trunk_conditioning: bool = False,clear_cache_between_blocks: bool = False):"""Initialize the AtomAttentionEncoder module.Args:c_token:The number of channels for the token representation.c_atom:The number of channels for the atom representation. Defaults to 128.c_atompair:The number of channels for the pair representation. Defaults to 16.c_trunk_pair:The number of channels for the trunk pair representation. Defaults to 16.no_blocks:Number of blocks in AtomTransformer. Defaults to 3.no_heads:Number of parallel attention heads. Note that c_atom will be split across no_heads(i.e. each head will have dimension c_atom // no_heads).dropout:Dropout probability on attn_output_weights. Default: 0.0 (no dropout).n_queries:The size of the atom window. Defaults to 32.n_keys:Number of atoms each atom attends to in local sequence space. Defaults to 128.trunk_conditioning:Whether to condition the atom single and atom-pair representation on the trunk.Defaults to False.clear_cache_between_blocks:Whether to clear CUDA's GPU memory cache between blocks of thestack. Slows down each block but can reduce fragmentation"""super().__init__()self.no_blocks = no_blocksself.c_token = c_tokenself.c_atom = c_atomself.c_atompair = c_atompairself.c_trunk_pair = c_trunk_pairself.no_heads = no_headsself.dropout = dropoutself.n_queries = n_queriesself.n_keys = n_keysself.trunk_conditioning = trunk_conditioningself.clear_cache_between_blocks = clear_cache_between_blocks# Embedding per-atom metadata, concat(ref_pos, ref_charge, ref_mask, ref_element, ref_atom_name_chars)self.linear_atom_embedding = LinearNoBias(3 + 1 + 1 + 4 + 4, c_atom) # 128, * 64# Embedding offsets between atom reference positionsself.linear_atom_offsets = LinearNoBias(3, c_atompair)self.linear_atom_distances = LinearNoBias(1, c_atompair)# Embedding the valid maskself.linear_mask = LinearNoBias(1, c_atompair)if trunk_conditioning:self.proj_trunk_single = nn.Sequential(LayerNorm(c_token),LinearNoBias(c_token, c_atom))self.proj_trunk_pair = nn.Sequential(LayerNorm(c_trunk_pair),LinearNoBias(c_trunk_pair, c_atompair))self.linear_noisy_pos = LinearNoBias(3, c_atom)# Adding the single conditioning to the pair representationself.linear_single_to_pair_row = LinearNoBias(c_atom, c_atompair, init='relu')self.linear_single_to_pair_col = LinearNoBias(c_atom, c_atompair, init='relu')# Small MLP on the pair activationsself.pair_mlp = nn