欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132602155
OpenFold Multimer 在训练过程的数据加载时,需要将 MSA 与 Template 信息转换成 Feature,再进行训练,这样速度较慢。通过修改数据集类 OpenFoldSingleMultimerDataset
的 __getitem__
方法,可以加速训练过程。
1. 准备训练数据
在训练过程中,需要读取 mmcif_cache.json
文件,数据结构如下:
{"4ewn": {"release_date": "2012-12-05","chain_ids": ["D"],"seqs": ["MLAKRI..."],"no_chains": 1,"resolution": 1.9},"5m9r": {"release_date": "2017-02-22","chain_ids": ["A","B"],"seqs": ["MQDNS...","MQDNS..."],"no_chains": 2,"resolution": 1.44},
#...
}
当前的训练数据格式,例如 train_200_mini.csv
,如下:
pdb_id,chain_id,resolution,release_date,seq,len,chain_type,filepath
7m5z,"A,B",3.06,2021-10-06,"LEDVV...,QNKLE...","263,264","protein,protein",[pdb_path]/structures/m5/pdb7m5z.ent.gz
7k05,"A,B",1.85,2021-10-06,"MSFPP...,MSFPP...","200,200","protein,protein",[pdb_path]/structures/k0/pdb7k05.ent.gz
# ...
同时需要将 feature 的路径,也加入到训练文件 mmcif_cache.json
中,进而,通过预读文件,进行特征抽取,即:
[your folder]/multimer_train/features
使用特征文件夹中,已经预处理之后的特征 features.pkl
,进行训练即可:
# 单个文件夹内容
chain_id_map.json
features.pkl
sequences.fasta
训练文件的转换命令,如下:
python openfold_scripts/main_mmcif_cache_transfer.py -i data/train_200_mini.csv -f [your folder]/multimer_train/features -o mydata/openfold/mmcif_cache_mini.json
源码如下:
#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/8/31
"""
import argparse
import json
import os
import sys
from pathlib import Pathimport pandas as pd
from tqdm import tqdmp = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:sys.path.append(p)class MmcifCacheTransfer(object):"""训练 CSV 转换成 OpenFold 的 mmcif_cache.json 格式"""def __init__(self):pass@staticmethoddef process(input_path, feature_dir, output_path):print(f"[Info] 输入文件: {input_path}")print(f"[Info] 特征文件夹: {feature_dir}")print(f"[Info] 输出文件: {output_path}")assert os.path.isfile(input_path)df = pd.read_csv(input_path)print(f"[Info] 输入样本: {len(df)}")mmcif_cache_dict = dict()# pdb_id,chain_id,resolution,release_date,seq,len,chain_type,filepathfor _, row in tqdm(df.iterrows(), "[Info] pdb"):pdb_id = row["pdb_id"]release_date = row["release_date"]chain_ids = row["chain_id"].split(",")seqs = row["seq"].split(",")no_chains = len(chain_ids)resolution = float(row["resolution"])feature_folder = os.path.join(feature_dir, pdb_id[1:3], f"pdb{pdb_id}_{''.join(chain_ids)}")pdb_dict = {"release_date": str(release_date),"chain_ids": chain_ids,"seqs": seqs,"no_chains": no_chains,"resolution": resolution,"feature_folder": feature_folder}mmcif_cache_dict[pdb_id] = pdb_dictwith open(output_path, "w") as fp:fp.write(json.dumps(mmcif_cache_dict, indent=4))print(f"[Info] 全部处理完成: {output_path}")def main():parser = argparse.ArgumentParser()parser.add_argument("-i","--input-path",help="the input file path.",type=Path,required=True,)parser.add_argument("-f","--feature-dir",help="the preprocess feature dir.",type=Path,required=True)parser.add_argument("-o","--output-path",help="the output file path.",type=Path,required=True)args = parser.parse_args()input_path = str(args.input_path)feature_dir = str(args.feature_dir)output_path = str(args.output_path)assert os.path.isfile(input_path)# from root_dir import ROOT_DIR, DATA_DIR# input_path = os.path.join(ROOT_DIR, "data", "train_200_mini.csv")# output_path = os.path.join(DATA_DIR, "openfold", "mmcif_cache_mini.json")mct = MmcifCacheTransfer()mct.process(input_path, feature_dir, output_path)if __name__ == '__main__':main()
2. 加载训练数据
OpenFold Multimer 的特征读取逻辑,在 openfold/data/data_modules.py#OpenFoldSingleMultimerDataset()
中,即:
if self.mode == 'train' or self.mode == 'eval':path = os.path.join(self.data_dir, f"{mmcif_id}")ext = Nonefor e in self.supported_exts:if os.path.exists(path + e):ext = ebreakif ext is None:raise ValueError("Invalid file type")# TODO: Add pdb and core exts to data_pipeline for multimerpath += extif ext == ".cif":data = self._parse_mmcif(path, mmcif_id, self.alignment_dir, alignment_index)else:raise ValueError("Extension branch missing")
else:path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")data = self.data_pipeline.process_fasta(fasta_path=path,alignment_dir=self.alignment_dir)
修改成直接加载 Feature 的形式,即:
if self.mode == 'train' or self.mode == 'eval':# 训练或评估时,使用预处理的特征feat_folder = self.mmcif_data_cache[mmcif_id]['feature_folder']feat_path = os.path.join(feat_folder, "features.pkl")# logger.info(f"[Info] feat_path: {feat_path}")data = {}with open(feat_path, "rb") as f:feat_dict = pickle.load(f)data.update(feat_dict)# logger.info(f"[Info] data: {data.keys()}")
else:path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")data = self.data_pipeline.process_fasta(fasta_path=path,alignment_dir=self.alignment_dir)
同时,还需要修改训练数据总数:
def __len__(self):# 数据部分都由 mmcif_data_cache 提供# return len(self._chain_ids)return len(self.mmcif_data_cache.keys)
3. 配置模型训练
模型训练的参数,如下:
python3 train_openfold.py \--train_data_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \--train_alignment_dir mydata/alignment_dir/ \--train_mmcif_data_cache_path [your folder]/multimer_train/openfold_cache/mmcif_cache_mini.json \--template_mmcif_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \--output_dir mydata/output_dir/ \--max_template_date "2021-10-10" \--config_preset "model_1_multimer_v3" \--template_release_dates_cache_path mmcif_cache.json \--precision bf16 \--gpus 1 \--replace_sampler_ddp=True \--seed 42 \--deepspeed_config_path deepspeed_config.json \--checkpoint_every_epoch \--obsolete_pdbs_file_path [your folder]/af2-data-v230/pdb_mmcif/obsolete.dat
模型训练占用显存较多,V100 目前无法支持,调低 crop_size 与 num_workers,降低资源占用,配置位于 openfold/config.py
中,即:
# crop_size
elif "multimer" in name:c.update(multimer_config_update.copy_and_resolve_references())c.data.train.crop_size = 64 # TODO: 用于测试# num_workers
"data_module": {"use_small_bfd": False,"data_loaders": {"batch_size": 1,# "num_workers": 16,"num_workers": 2, # TODO: 用于测试"pin_memory": True,},
},
其中,crop_size = 64 占用显存约是 5141MiB
训练日志,如下:
Epoch 0: 0%| | 0/199 [00:00<?, ?it/s]INFO:openfold/data/data_modules.py:mmcif_id is: 7poc, idx: 148 and has 4 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7u49, idx: 97 and has 3 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7z7h, idx: 114 and has 6 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7nup, idx: 111 and has 4 chains
cum_loss: tensor([84.1698], device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>) losses: {'distogram': tensor(4.1562, device='cuda:0', dtype=torch.float64), 'experimentally_resolved': tensor(0.6914, device='cuda:0'), 'fape': tensor(1.6598, device='cuda:0', dtype=torch.float64), 'plddt_loss': tensor(3.9062, device='cuda:0', dtype=torch.float64), 'masked_msa': tensor(3.0938, device='cuda:0'), 'supervised_chi': tensor(0.7941, device='cuda:0', dtype=torch.float64), 'violation': tensor(3.6495, device='cuda:0'), 'tm': tensor(4.1562, device='cuda:0', dtype=torch.float64), 'chain_center_of_mass': tensor([1.3754], device='cuda:0', dtype=torch.float64), 'unscaled_loss': tensor([10.5212], device='cuda:0', dtype=torch.float64), 'loss': tensor([84.1698], device='cuda:0', dtype=torch.float64)}
Epoch 0: 1%|▉ | 1/199 [02:55<9:38:06, 175.18s/it, loss=84.2, v_num=]