PyTorch | 保存和加载模型教程

点击上方“算法猿的成长”,选择“加为星标”

第一时间关注 AI 和 Python 知识

640?wx_fmt=jpeg

图片来自 Unsplash,作者: Jenny Caywood 

2019 年第 72 篇文章,总第 96 篇文章

总共 7000 字,建议收藏阅读

原题 | SAVING AND LOADING MODELS

作者 | Matthew Inkawhich

原文 | https://pytorch.org/tutorials/beginner/saving_loading_models.html

译者 | kbsc13("算法猿的成长"公众号作者)

声明 | 翻译是出于交流学习的目的,欢迎转载,但请保留本文出于,请勿用作商业或者非法用途

简介

本文主要介绍如何加载和保存 PyTorch 的模型。这里主要有三个核心函数:

  1. torch.save :把序列化的对象保存到硬盘。它利用了 Python 的 pickle 来实现序列化。模型、张量以及字典都可以用该函数进行保存;

  2. torch.load:采用 pickle 将反序列化的对象从存储中加载进来。

  3. torch.nn.Module.load_state_dict:采用一个反序列化的 state_dict加载一个模型的参数字典。

本文主要内容如下:

  • 什么是状态字典(state_dict)?

  • 预测时加载和保存模型

  • 加载和保存一个通用的检查点(Checkpoint)

  • 在同一个文件保存多个模型

  • 采用另一个模型的参数来预热模型(Warmstaring Model)

  • 不同设备下保存和加载模型

1. 什么是状态字典(state_dict)

PyTorch 中,一个模型(torch.nn.Module)的可学习参数(也就是权重和偏置值)是包含在模型参数(model.parameters())中的,一个状态字典就是一个简单的 Python 的字典,其键值对是每个网络层和其对应的参数张量。模型的状态字典只包含带有可学习参数的网络层(比如卷积层、全连接层等)和注册的缓存(batchnorm的 running_mean)。优化器对象(torch.optim)同样也是有一个状态字典,包含的优化器状态的信息以及使用的超参数。

由于状态字典也是 Python 的字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都很容易实现。

下面是一个简单的使用例子,例子来自:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

# Define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# Initialize model
model = TheModelClass()# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])

上述代码先是简单定义一个 5 层的 CNN,然后分别打印模型的参数和优化器参数。

输出结果:

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

2. 预测时加载和保存模型

加载/保存状态字典(推荐做法)

保存的代码:

torch.save(model.state_dict(), PATH)

加载的代码:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

当需要为预测保存一个模型的时候,只需要保存训练模型的可学习参数即可。采用 torch.save() 来保存模型的状态字典的做法可以更方便加载模型,这也是推荐这种做法的原因。

通常会用 .pt 或者 .pth 后缀来保存模型。

记住

  1. 在进行预测之前,必须调用 model.eval() 方法来将 dropout 和 batch normalization 层设置为验证模型。否则,只会生成前后不一致的预测结果。

  2. load_state_dict() 方法必须传入一个字典对象,而不是对象的保存路径,也就是说必须先反序列化字典对象,然后再调用该方法,也是例子中先采用 torch.load() ,而不是直接 model.load_state_dict(PATH)

加载/保存整个模型

保存

torch.save(model, PATH)

加载

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

保存和加载模型都是采用非常直观的语法并且都只需要几行代码即可实现。这种实现保存模型的做法将是采用 Python 的 pickle 模块来保存整个模型,这种做法的缺点就是序列化后的数据是属于特定的类和指定的字典结构,原因就是 pickle 并没有保存模型类别,而是保存一个包含该类的文件路径,因此,当在其他项目或者在 refactors 后采用都可能出现错误。

3. 加载和保存一个通用的检查点(Checkpoint)

保存的示例代码

torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)

加载的示例代码

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.eval()
# - or -
model.train()

当保存一个通用的检查点(checkpoint)时,无论是用于继续训练还是预测,都需要保存更多的信息,不仅仅是 state_dict ,比如说优化器的 state_dict 也是非常重要的,它包含了用于模型训练时需要更新的参数和缓存信息,还可以保存的信息包括 epoch,即中断训练的批次,最后一次的训练 loss,额外的 torch.nn.Embedding 层等等。

上述保存代码就是介绍了如何保存这么多种信息,通过用一个字典来进行组织,然后继续调用 torch.save 方法,一般保存的文件后缀名是 .tar 。

加载代码也如上述代码所示,首先需要初始化模型和优化器,然后加载模型时分别调用 torch.load 加载对应的 state_dict 。然后通过不同的键来获取对应的数值。

加载完后,根据后续步骤,调用 model.eval() 用于预测,model.train() 用于恢复训练。

4. 在同一个文件保存多个模型

保存模型的示例代码

torch.save({'modelA_state_dict': modelA.state_dict(),'modelB_state_dict': modelB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),...}, PATH)

加载模型的示例代码

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

当我们希望保存的是一个包含多个网络模型 torch.nn.Modules 的时候,比如 GAN、一个序列化模型,或者多个模型融合,实现的方法其实和保存一个通用的检查点的做法是一样的,同样采用一个字典来保持模型的 state_dict 和对应优化器的 state_dict 。除此之外,还可以继续保存其他相同的信息。

加载模型的示例代码如上述所示,和加载一个通用的检查点也是一样的,同样需要先初始化对应的模型和优化器。同样,保存的模型文件通常是以 .tar 作为后缀名。

5. 采用另一个模型的参数来预热模型(Warmstaring Model)

保存模型的示例代码

torch.save(modelA.state_dict(), PATH)

加载模型的示例代码

modelB = TheModelBClass(*args, **kwargs)
modelB.load_state_dict(torch.load(PATH), strict=False)

在之前迁移学习教程中也介绍了可以通过预训练模型来微调,加快模型训练速度和提高模型的精度。

这种做法通常是加载预训练模型的部分网络参数作为模型的初始化参数,然后可以加快模型的收敛速度。

加载预训练模型的代码如上述所示,其中设置参数 strict=False 表示忽略不匹配的网络层参数,因为通常我们都不会完全采用和预训练模型完全一样的网络,通常输出层的参数就会不一样。

当然,如果希望加载参数名不一样的参数,可以通过修改加载的模型对应的参数名字,这样参数名字匹配了就可以成功加载。

6. 不同设备下保存和加载模型

在GPU上保存模型,在 CPU 上加载模型

保存模型的示例代码

torch.save(model.state_dict(), PATH)

加载模型的示例代码

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

在 CPU 上加载在 GPU 上训练的模型,必须在调用 torch.load() 的时候,设置参数 map_location ,指定采用的设备是 torch.device('cpu'),这个做法会将张量都重新映射到 CPU 上。

在GPU上保存模型,在 GPU 上加载模型

保存模型的示例代码

torch.save(model.state_dict(), PATH)

加载模型的示例代码

device = torch.device('cuda')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH)
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model                     

在 GPU 上训练和加载模型,调用 torch.load() 加载模型后,还需要采用 model.to(torch.device('cuda')),将模型调用到 GPU 上,并且后续输入的张量都需要确保是在 GPU 上使用的,即也需要采用 my_tensor.to(device)

在CPU上保存,在GPU上加载模型

保存模型的示例代码

torch.save(model.state_dict(), PATH)

加载模型的示例代码

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

这次是 CPU 上训练模型,但在 GPU 上加载模型使用,那么就需要通过参数 map_location 指定设备。然后继续记得调用 model.to(torch.device('cuda'))

保存 torch.nn.DataParallel 模型

保存模型的示例代码

torch.save(model.module.state_dict(), PATH)

torch.nn.DataParallel 是用于实现多 GPU 并行的操作,保存模型的时候,是采用 model.module.state_dict()

加载模型的代码也是一样的,采用 torch.load() ,并可以放到指定的 GPU 显卡上。


完整的代码:

https://github.com/pytorch/tutorials/blob/master/beginner_source/saving_loading_models.py

欢迎关注我的微信公众号--算法猿的成长,或者扫描下方的二维码,大家一起交流,学习和进步!

640?wx_fmt=png

如果觉得不错,在看、转发就是对小编的一个支持!

推荐阅读

  • 快速入门Pytorch(1)--安装、张量以及梯度

  • 快速入门PyTorch(2)--如何构建一个神经网络

  • 快速入门PyTorch(3)--训练一个图片分类器和多 GPUs 训练

  • PyTorch系列 | 快速入门迁移学习

  • PyTorch系列 | 如何加快你的模型训练速度呢?

  • PyTorch 系列 | 数据加载和预处理教程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/408540.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

都说dlib是人脸识别的神器,那到底能不能识破妖怪的伪装?

作者:盛光晓原文链接:https://blog.csdn.net/esa72ya/article/details/89189987众所周知,dlib是人脸识别的利器,被广泛应用于行为检测、安防工程、表情分析等,甚至还有学术界的前沿老师将这一技术用于上课点名&#xf…

国内有哪些不错的CV(计算机视觉)团队

点击上方“算法猿的成长”,选择“加为星标”第一时间关注 AI 和 Python 知识来源:知乎问题对于初入 CV 领域的同学,如果可以加入一个不错的团队,有好的导师带着,同时还有可以请教的师兄师姐,会加快入门 CV …

数据全裸时代,你的隐私有多容易获取?

大家好我是痴海,一位转型做增长的爬虫师,由于工作的缘故,对于身边许多信息都非常敏感。上个月朋友圈有很多人都在晒四六级成绩,有人欢喜有人忧愁,而我却感受到深深的恐惧。2018 年腾讯手机管家在一个报告中公布了一个数…

单元测试的一些基本概念

我们(程序员)多多少少都写过单元测试,有的可能几年前写的几行代码(比如我), 姑且也算写过吧,但是有些东西还是不是很清楚,比如什么是单元测试?怎么才算是好的单元测试&am…

深度学习领域有哪些瓶颈

来源:知乎问题深度学习是近年来人工智能热潮的原因,它的出现在很多方面都作出了突破,包括在图像、NLP以及语音等领域都有很多问题取得很大的突破,但它目前也存在一些问题和瓶颈需要解决。量子位https://www.zhihu.com/question/40…

cnn调优总结

关注&置顶“算法猿的成长”每日8:30,干货速递!转载自 Charlotte数据挖掘资料来自网上,略有删改针对CNN优化的总结Systematic evaluation of CNN advances on the ImageNet使用没有 batchnorm 的 ELU 非线性或者有 batchnorm 的 ReLU。用类…

Android笔记之自定义Editext

1、重写EdiText类,下面是一个逐条显示下划线的Editext import android.content.Context; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Paint; import android.util.AttributeSet; import android.view.Gravity; import android.w…

程序员到底要不要接外包?

? “沉默王二” ,你值得星标的公众号之前写过一篇文章,题目叫做《窝在二线城市很难受,要杀回一线城市吗》,里面提到程序员接外包这件事,于是很多小伙伴就私下问我:二哥,我也想接外包&#xff0…

Github项目|几行代码即可实现人脸检测、目标检测的开源计算机视觉库

关注&置顶“算法猿的成长”每日8:30,干货速递!2019 年第 73 篇文章,总第 97 篇文章今天介绍一个简单、易用的开源计算机视觉库,名字是 cvlib,其 Github 地址:https://github.com/arunponnusamy/cvlib官…

认识迅雷界面引擎

UI开发的新时代----认识迅雷界面引擎 第一部分:交互开发技术概述软件产品的交互开发一直以来都不是一件令人愉悦的事情。首先,由于每个人编写的第一个图形应用程序就已经使用了一些交互开发技术,而且由于IDE工具的强大,容易总结出交互开发就是…

最棒的Chrome插件去哪找?这里有一份榜单

上个月给大家介绍了重大更新后的 扩展迷 Extfans 网站,当时也说到:可以把它当成是一个 Chrome 商店的镜像版,可以无障碍下载安装 Chrome 扩展。不得不说,在不能正常使用 Chrome 商店的情况下,多亏了扩展迷 Extfans 这样…

【原创】推荐广告入门:DeepCTR-Torch,基于深度学习的CTR预测算法库

在计算广告和推荐系统中,CTR预估一直是一个核心问题。无论在工业界还是学术界都是一个热点研究问题,近年来也有若干相关的算法竞赛陆续举办。本文介绍一个使用PyTorch编写的深度学习的点击率预测算法库DeepCTR-Torch,具有简洁易用、模块化和可…