默认的,我们最常用的resume方式:
if args.resume:checkpoint = torch.load(resume_path, map_location='cpu')model_without_ddp.load_state_dict(checkpoint['model'])print("Resume checkpoint %s" % resume_path)if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'evaluate') and args.evaluate):optimizer.load_state_dict(checkpoint['optimizer'])args.start_epoch = checkpoint['epoch'] + 1if 'scaler' in checkpoint:loss_scaler.load_state_dict(checkpoint['scaler'])print("With optim & sched!")del checkpoint
在resume模型的时候,可能会遇到某些层是没有的,或者你改变了某些层的维度,从而导致model_state_dict()错误,所以此时的解决办法为:忽略这些层,不加载它们
情况一:
如果是你的model出现了某些新的维度,但是resume model中并没有
直接使用 strict 参数置为 False 即可:
model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
情况二:
此时你的model的某些层的参数发生变化了,但是resume的model还是原来的model,那么就忽略这些维度不匹配的层,只加载维度相同的层:
1)分布式保存的ckpt进行resume:
if args.resume:checkpoint = torch.load(args.resume, map_location='cpu')model_state_dict = model.state_dict()# 过滤掉尺寸不匹配的参数(为了训练不同rep的图像生成模型)filtered_state_dict = {}for k, v in checkpoint['model'].items():model_key = 'module.' + k# 【在分布式训练或使用 DataParallel 时,模型的状态字典中的参数名称通常会带有 module. 前缀】if model_key in model_state_dict and model_state_dict[model_key].shape == v.shape:filtered_state_dict[model_key[7:]] = velse:print(f"Skipping parameter {k} due to size mismatch: checkpoint shape {v.shape} vs model shape {model_state_dict[model_key].shape}")# 加载过滤后的状态字典model_without_ddp.load_state_dict(filtered_state_dict, strict=False)del checkpoint
2)正常保存的ckpt进行resume(一般使用这个就可以了):
if args.resume:checkpoint = torch.load(args.resume, map_location='cpu')model_state_dict = model_without_ddp.state_dict()print(model_state_dict)# 过滤掉尺寸不匹配的参数(为了训练不同rep的图像生成模型)filtered_state_dict = {}for k, v in checkpoint['model'].items():if k in model_state_dict and model_state_dict[k].shape == v.shape:filtered_state_dict[k] = velse:print(f"Skipping parameter {k} due to size mismatch: checkpoint shape {v.shape} vs model shape {model_state_dict[k].shape}")# 加载过滤后的状态字典model_without_ddp.load_state_dict(filtered_state_dict, strict=False)del checkpoint
上面没有进行优化器的resume,是因为对维度不匹配的情况,再resume 优化器很麻烦,感觉意义也不大