读入模型参数-且不占用显卡
see_memory_usage(‘message’)
# 4. 读入checkpiont参数
state_dict=torch.load('../train-output/'+ args.model_name_or_path.split('/')[-1] +'/unet/diffusion_pytorch_model.bin', map_location='cpu')
# state_dict=torch.load('../train-output/'+ args.model_name_or_path.split('/')[-1] +'/unet/diffusion_pytorch_model.bin')
如果不加map_location='cpu',读入的参数就会占用GPU
see_memory_usage(‘message’)
可以通过添加下面函数来检测GPU占用情况
import gc
import psutil
def see_memory_usage(message):# python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reportsgc.collect()# Print message except when distributed but not rank 0print(message)print(f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \CA {round(torch.cuda.memory_reserved() / (1024 * 1024 * 1024),2)} GB \Max_CA {round(torch.cuda.max_memory_reserved() / (1024 * 1024 * 1024))} GB ")vm_stats = psutil.virtual_memory()used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)print(f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%')# get the peak memory to report correct data, so reset the counter for the next callif hasattr(torch.cuda, 'reset_peak_memory_stats'):return torch.cuda.reset_peak_memory_stats()