Pytorch框架权重文件转onnx格式
代码案例
import torch
import torchvision.models as modelsmodel = models.resnet50()
model.load_state_dict(torch.load("./model/pytorch-resnet50.pth"))model.eval()
example_input = torch.randn(32, 3, 224, 224) # 根据模型输入要求的形状创建示例输入张量onnx_path = "model.onnx"
torch.onnx.export(model, example_input, onnx_path, opset_version=11)print("模型已成功转换为 ONNX 格式并保存为", onnx_path)
解析
导入所需的库:
import torch
import onnxruntime
加载 PyTorch 模型:
model = YourModel()
model.load_state_dict(torch.load('model.pth'))
model.eval()
将模型导出为 ONNX 格式:
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, 'model.onnx', opset_version=11)
dummy_input 是示例输入,用于确定模型的输入尺寸。这里使用的是基本的图像格式
- 1是一张图片
- 3是三层RGB
- 224*224代表照片的尺寸
验证 ONNX 模型:
onnx_model = onnx.load('model.onnx')
onnx.checker.check_model(onnx_model)
pth 权重文件与 onnx 权重文件的不同
PyTorch (.pth) 权重文件:
- PyTorch 模型权重文件主要用于在 PyTorch 框架中继续训练和推理模型。
- 保留了模型的完整定义,包括网络结构、层参数等信息,方便在 PyTorch 中进行微调和二次训练。
ONNX (.onnx) 权重文件:
- ONNX 格式主要用于模型部署和跨平台运行,如 C++、Java、JavaScript 等环境中。
- ONNX 文件是一种标准的机器学习模型交换格式,可以被不同框架使用。适合在嵌入式设备、移动设备等资源受限的环境中部署。