bug概述:【踩坑记录📝】Removed shared tensor while saving.
简单来说,这个bug的危害是trainer.save()无法正确存储权重。这篇博文的作者也给出了两种处理方法,但要么要改transformers版本,要么要包裹Trainer类,太麻烦了。
我使用transformers版本4.38.0,deepspeed zero 0/1/2复现了这篇博文所述bug,研究得知bug的具体原因后,给出以下解决思路:
第一步:获得正确的state_dict
trainer调用_save方法时,默认入参state_dict=None,然后_save方法通过self.model.state_dict()获得state_dict,导致报错。原因是此时self.model是通过accelerator加载的,state_dict需要用self.accelerator.get_state_dict(self.model)的方式获得。
解决思路是提前在调用_save之前提前通过self.accelerator.get_state_dict(self.model)把state_dict取出,然后直接带state_dict入参来调用_save方法,实现权重存储。
我在这里借鉴了QWen-VL仓库finetune.py的写法,参考代码如下:
def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str, bias="none"):"""Collects the state dict and dump to disk."""# check if zero3 mode enabledif trainer.args.hf_deepspeed_config.config['zero_optimization']['stage'] == 3:state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()else:state_dict = trainer.accelerator.get_state_dict(trainer.model)if trainer.args.should_save and trainer.args.local_rank == 0:trainer._save(output_dir, state_dict=state_dict)
第二步:顺利把state_dict存在文件里
trainer在初始化时,args.save_safetensors要设置为false,来让_save方法内部调用model.save_pretrained时不存model.safetensors文件,避免这处调用报错。
第三步:顺利加载权重文件
以上两步做好后,可以保证正确的state_dict在pytorch_model.bin里存下来,而不至于丢失。
再次使用这份权重时,为了从pytorch_model.bin而不是model.safetensors加载权重,可以在from_pretrained的时候要设置use_safetensors=False;也可以把加载目录的model.safetensors直接删掉。
做好以上几步就可以完美解决这个bug.