AlphaFold3 protein_dataset模块 ProteinDataset
类 _get_masked_sequence
方法属于作用是为需要预测的残基生成掩码。该掩码以二进制张量形式呈现,其中 1 代表需要预测的部分,0 代表其他部分。此方法会依据多个参数来选定要掩码的残基,这些参数包含 mask_whole_chains
、mask_frac
、lower_limit
、upper_limit
、mask_sequential
以及 force_binding_sites_frac
等。
源代码:
def _get_masked_sequence(self,data,):"""Get the mask for the residues that need to be predicted.Depending on the parameters the residues are selected as follows:- if `mask_whole_chains` is `True`, the whole chain is masked- if `mask_frac` is given, the number of residues to mask is `mask_frac` times the length of the chain,- otherwise, the number of residues to mask is sampled uniformly from the range [`lower_limit`, `upper_limit`].If `mask_sequential` is `True`, the residues are masked based on the order in the sequence, otherwise aspherical mask is applied based on the coordinates.If `force_binding_sites_frac` > 0 and `mask_whole_chains` is `False`, in the fraction of cases where a chainfrom a polymer is sampled, the center of the masked region will be forced to be in a binding site.Parameters----------data : dictan entry generated by `ProteinDataset`Returns-------chain_M : torch.Tensora `(B, L)` shaped binary tensor where 1 denotes the part that needs to be predicted and0 is everything else"""if "cdr" in data and "cdr_id" in data:chain_M = torch.zeros_like(data["cdr"])if self.mask_all_cdrs:chain_M = data["cdr"] != CDR_REVERSE["-"]else:chain_M = data["cdr"] == data["cdr_id"]else:chain_M = torch.zeros_like(data["S"])chain_index = data["chain_id"]chain_bool = data["chain_encoding_all"] == chain_indexif self.mask_whole_chains:chain_M[chain_bool] = 1else:chains = torch.unique(data["chain_encoding_all"])chain_start = torch.where(chain_bool)[0][0]chain = data["X"][chain_bool]res_i = Noneinterface = []non_masked_interface = []if len(chains) > 1 and self.force_binding_sites_frac > 0:if random.uniform(0, 1) <= self.force_binding_sites_frac:X_copy = data["X"]i_indices = (chain_bool == 0).nonzero().flatten() # global