1.训练并保存自己的模型
保存的模型格式为:XXX.pth
torch.save(model, "./weight/last.pth")if best_acc <(validation_acc / len_val):torch.save(model, "./weight/best.pth")
2.转化为ONNX格式
2.1环境安装(window10)
pip install onnx
pip install onnxruntime#验证安装配置是否成功
import torch
print('PyTorch 版本', torch.__version__)import onnx
print('ONNX 版本', onnx.__version__)import onnxruntime as ort
print('ONNX Runtime 版本', ort.__version__)
2.2.pth格式转ONNX格式
import torch
from torchvision import models# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)model = torch.load('best.pth')
model = model.eval().to(device)
x = torch.randn(1, 3, 256, 256).to(device) #这里要构造一个数据,保证和自己输入的图片大小一致3*256*256
output = model(x) #output.shape = torch.Size([1, 10]) 这是一个10分类问题#Pytorch模型转ONNX模型
x = torch.randn(1, 3, 256, 256).to(device)with torch.no_grad():torch.onnx.export(model, # 要转换的模型x, # 模型的任意一组输入'best.onnx', # 导出的 ONNX 文件名opset_version=11, # ONNX 算子集版本input_names=['input'], # 输入 Tensor 的名称(自己起名字)output_names=['output'] # 输出 Tensor 的名称(自己起名字)import onnx# 读取 ONNX 模型
onnx_model = onnx.load('resnet18_fruit30.onnx')# 检查模型格式是否正确
onnx.checker.check_model(onnx_model)
print('无报错,onnx模型载入成功')
这是project中就出现了“best.onnx”文件,表示转化ONNX格式成功!
3.可视化实时检测
3.1在PC电脑端查看
3.1.1环境安装(待补充)
pip install onnxruntime
需要提前保存一个类别ID和类别名称对应的文件
3.1.2 摄像头实时捕捉并分类
import onnxruntime
import torch
from torchvision import transforms
import torch.nn.functional as F
import pandas as pd
import numpy as np
from PIL import Image, ImageFont, ImageDraw
import matplotlib.pyplot as plt# 导入中文字体,指定字号
font = ImageFont.truetype('SimHei.ttf', 32)#载入ONNX模型,获取ONNX Runtime推力器
ort_session = onnxruntime.InferenceSession('best.onnx')#载入类别和ID对应字典
idx_to_labels = np.load('idx_to_labels.npy', allow_pickle=True).item()# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(256),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])