三角乘法(TriangleMultiplication)是作为一种更对称、更便宜的三角注意力(TriangleAttention)替代模块。
import jax
import haiku
import jax.numpy as jnpdef _layer_norm(axis=-1, name='layer_norm'):return common_modules.LayerNorm(axis=axis,create_scale=True,create_offset=True,eps=1e-5,use_fast_variance=True,scale_init=hk.initializers.Constant(1.),offset_init=hk.initializers.Constant(0.),param_axis=axis,name=name)class TriangleMultiplication(hk.Module):"""Triangle multiplication layer ("outgoing" or "incoming").Jumper et al. (2021) Suppl. Alg. 11 "TriangleMultiplicationOutgoing"Jumper et al. (2021) Suppl. Alg. 12 "TriangleMultiplicationIncoming""""def __init__(self, config, global_config, name='triangle_multiplication'):super().__init__(name=name)self.config = configself.global_config = global_configdef __call__(self, left_act, left_mask, is_training=True):"""Builds TriangleMultiplication module.Arguments:left_act: Pair activations, shape [N_res, N_res, c_z]left_mask: Pair mask, shape [N_res, N_res].is_training: Whether the module is in training mode.Returns:Outputs, same shape/type as left_act."""del is_trainingif self.config.fuse_projection_weights:return self._fused_triangle_multiplication(left_act, left_mask)else:return self._triangle_multiplication(left_act, left_mask)# @hk.transparent 是 Haiku 中的函数修饰器,用于标记函数为透明模式。# 透明模式用于在神经网络模块内共享参数。@hk.transparentdef _triangle_multiplication(self, left_act, left_mask):"""Implementation of TriangleMultiplication used in AF2 and AF-M<2.3."""c = self.configgc = self.global_configmask = left_mask[..., None]act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True,name='layer_norm_input')(left_act)input_act = actleft_projection = common_modules.Linear(c.num_intermediate_channel,name='left_projection')left_proj_act = mask * left_projection(act)right_projection = common_modules.Linear(c.num_intermediate_channel,name='right_projection')right_proj_act = mask * right_projection(act)left_gate_values = jax.nn.sigmoid(common_modules.Linear(c.num_intermediate_channel,bias_init=1.,initializer=utils.final_init(gc),name='left_gate')(act))right_gate_values = jax.nn.sigmoid(common_modules.Linear(c.num_intermediate_channel,bias_init=1.,initializer=utils.final_init(gc),name='right_gate')(act))left_proj_act *= left_gate_valuesright_proj_act *= right_gate_values# "Outgoing" edges equation: 'ikc,jkc->ijc'# "Incoming" edges equation: 'kjc,kic->ijc'# Note on the Suppl. Alg. 11 & 12 notation:# For the "outgoing" edges, a = left_proj_act and b = right_proj_act# For the "incoming" edges, it's swapped:# b = left_proj_act and a = right_proj_actact = jnp.einsum(c.equation, left_proj_act, right_proj_act)act = common_modules.LayerNorm(axis=[-1],create_scale=True,create_offset=True,name='center_layer_norm')(act)output_channel = int(input_act.shape[-1])act = common_modules.Linear(output_channel,initializer=utils.final_init(gc),name='output_projection')(act)gate_values = jax.nn.sigmoid(common_modules.Linear(output_channel,bias_init=1.,initializer=utils.final_init(gc),name='gating_linear')(input_act))act *= gate_valuesreturn act@hk.transparentdef _fused_triangle_multiplication(self, left_act, left_mask):"""TriangleMultiplication with fused projection weights."""mask = left_mask[..., None]c = self.configgc = self.global_configleft_act = _layer_norm(axis=-1, name='left_norm_input')(left_act)# Both left and right projections are fused into projection.projection = common_modules.Linear(2*c.num_intermediate_channel, name='projection')proj_act = mask * projection(left_act)# Both left + right gate are fused into gate_values.gate_values = common_modules.Linear(2 * c.num_intermediate_channel,name='gate',bias_init=1.,initializer=utils.final_init(gc))(left_act)proj_act *= jax.nn.sigmoid(gate_values)left_proj_act = proj_act[:, :, :c.num_intermediate_channel]right_proj_act = proj_act[:, :, c.num_intermediate_channel:]act = jnp.einsum(c.equation, left_proj_act, right_proj_act)act = _layer_norm(axis=-1, name='center_norm')(act)output_channel = int(left_act.shape[-1])act = common_modules.Linear(output_channel,initializer=utils.final_init(gc),name='output_projection')(act)gate_values = common_modules.Linear(output_channel,bias_init=1.,initializer=utils.final_init(gc),name='gating_linear')(left_act)act *= jax.nn.sigmoid(gate_values)return act