pytorch搭建squeezenet网络的整套工程,及其转tensorrt进行cuda加速

本来,前辈们用caffe搭建了一个squeezenet的工程,用起来也还行,但考虑到caffe的停更后续转trt应用在工程上时可能会有版本的问题所以搭建了一个pytorch版本的。
以下的环境搭建不再细说,主要就是pyorch,其余的需要什么pip install什么。

网络搭建

squeezenet的网络结构及其具体的参数如下:
在这里插入图片描述
后续对着这张表进行查看每层的输出时偶然发现这张表有问题,一张224×224的图片经过7×7步长为2的卷积层时输出应该是109×109才对,而不是这个111×111。所以此处我猜测要不是卷积核的参数有问题,要不就是这个输出结果有问题。我对了下下面的结果,发现都是从这个111×111的结果得出来的,这个结果没问题;但是我又对了下原有caffe版本的第一个卷积层用的就是这个7×7/2的参数,卷积核也没问题。这就有点矛盾了…这张表出自作者原论文,论文也是发表在顶会上,按道理应该不会有错才对。才疏学浅,希望大家有知道咋回事的能告诉我一声,这里我就还是用这个卷积核的参数了。
在这里插入图片描述
squeezenet有以上三个版本,我对了下发现前辈用的是中间这个带有简单残差的结构,为了进行对比这里也就用这个结构进行搭建了。
如下为网络结构的代码:

import torch
import torch.nn as nnclass Fire(nn.Module):def __init__(self, in_channel, squzee_channel, out_channel):super().__init__()self.squeeze = nn.Sequential(nn.Conv2d(in_channel, squzee_channel, 1),nn.ReLU(inplace=True))self.expand_1x1 = nn.Sequential(nn.Conv2d(squzee_channel, out_channel, 1), nn.ReLU(inplace=True))self.expand_3x3 = nn.Sequential(nn.Conv2d(squzee_channel, out_channel, 3, padding=1),nn.ReLU(inplace=True))def forward(self, x):x = self.squeeze(x)x = torch.cat([self.expand_1x1(x),self.expand_3x3(x)], 1)return xclass SqueezeNet_caffe(nn.Module):"""mobile net with simple bypass"""def __init__(self, class_num=5):super().__init__()self.stem = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=96, kernel_size=7, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2, ceil_mode=True))self.fire2 = Fire(96, 16, 64)self.fire3 = Fire(128, 16, 64)self.fire4 = Fire(128, 32, 128)self.fire5 = Fire(256, 32, 128)self.fire6 = Fire(256, 48, 192)self.fire7 = Fire(384, 48, 192)self.fire8 = Fire(384, 64, 256)self.fire9 = Fire(512, 64, 256)self.maxpool = nn.MaxPool2d(3, 2, ceil_mode=True)self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Conv2d(512, class_num, kernel_size=1),   nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1, 1))  )def forward(self, x):x = self.stem(x)f2 = self.fire2(x)f3 = self.fire3(f2) + f2f4 = self.fire4(f3)f4 = self.maxpool(f4)f5 = self.fire5(f4) + f4f6 = self.fire6(f5)f7 = self.fire7(f6) + f6f8 = self.fire8(f7)f8 = self.maxpool(f8)f9 = self.fire9(f8) + f8x = self.classifier(f9)x = x.view(x.size(0), -1)return xdef squeezenet_caffe(class_num=5):return SqueezeNet_caffe(class_num=class_num)

然后其余的整个工程代码就是pytorch搭建dataset、dataloader,每轮的前向、计算loss、反向传播等都是一个差不多的套路,就不在这里码出来了,直接放上链接,大家有需要可以直接下载(里面也集成了其他的分类网络)。

数据处理

dataset我用的是torchvision.datasets.ImageFolder,所以用目录名称作为数据集的label,目录结构如下:
在这里插入图片描述
将每一类的图片都放在对应的目录中,验证集以及测试集的数据集也是按照这样的格式。

运行命令

训练命令:

python train.py -net squeezenet_caffe -gpu -b 64 -t_data 训练集路径 -v_data 验证集路径 -imgsz 100

-net后面跟着是网络类型,都集成了如下的分类网络:
在这里插入图片描述
如果有n卡则-gpu使用gpu训练,-b是batch size,-imgsz是数据的input尺寸即resize的尺寸。
测试命令:

python test.py -net squeezenet_caffe -weights 训练好的模型路径 -gpu -b 64 -data 测试集路径 -imgsz 100

出现问题

一开始进行训练一切正常,到后面却忽然画风突变:
在这里插入图片描述
loss忽然大幅度上升,acc也同一时刻大幅度下降,然后数值不变呈斜率为0的一条直线。估计是梯度爆炸了(也是到这一步我先从网络结构找原因,对本文的第一张表一层一层对参数和结果才发现表中的问题),网络结构对完没问题,于是打印每个batch的梯度,顺便使用clip进行剪枝限定其最大阈值。

optimizer.zero_grad()
outputs = net(images)
loss = loss_function(outputs, labels)
loss.backward()grad_max = 0
grad_min = 10
for p in net.parameters():# 打印每个梯度的模,发现打印太多了一直刷屏所以改为下面的print最大最小值# print(p.grad.norm())gvalue = p.grad.norm()if gvalue > grad_max:grad_max = gvalueif gvalue < grad_min:grad_min = gvalue
print("grad_max:")
print(grad_max)
print("grad_min:")
print(grad_min)
# 将梯度的模clip到小于10的范围
torch.nn.utils.clip_grad_norm(p,10)optimizer.step()

按道理来说应该会有所改善,但结果是,训练几轮之后依旧出现这个问题。
但是,果然梯度在曲线异常的时候数值也是异常的:
在这里插入图片描述
刚开始正常学习的时候梯度值基本上都在e-1数量级的,曲线异常阶段梯度值都如图所示无限接近0,难怪不学习。
我们此时看一下tensorboard,我将梯度的最大最小值write进去,方便追踪:
在这里插入图片描述

可以发现在突变处梯度值忽然爆炸激增,猜测原因很可能是学习率太大了,动量振动幅度太大了跳出去跳不回来了。查看设置的学习率超参发现初始值果然太大了(0.1),于是改为0.01。再次运行后发现查看其tensorboard:
在这里插入图片描述
这回是正常的了。
但其实我放大查看了梯度爆炸点的梯度值:
在这里插入图片描述

发现其最大值没超过10,所以我上面的clip没起到作用,我如果将阈值改成2,结果如下:
在这里插入图片描述

发现起到了作用,但曲线没那么平滑,可能改成1或者再小一些效果会更好。但我觉得还是直接改学习率一劳永逸比较简单。

Pytorch模型转TensorRT模型

在训练了神经网络之后,TensorRT可以对网络进行压缩、优化以及运行时部署,并且没有框架的开销。TensorRT通过combines
layers,kernel优化选择,以及根据指定的精度执行归一化和转换成最优的matrix math方法,改善网络的延迟、吞吐量以及效率。
总之,通俗来说,就是训练的模型转trt后可以在n卡上高效推理,对于实际工程应用更加有优势。

首先将pth转onnx:

# pth->onnx->trtexec
# (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
import torchvision
import torch,os
from models.squeezenet_caffe import squeezenet_caffebatch_size = 1    # just a random numbercurrent_dir=os.path.dirname(os.path.abspath(__file__)) # 获取当前路径
device = 'cuda' if torch.cuda.is_available() else 'cpu'model = squeezenet_caffe().cuda()model_path='/data/cch/pytorch-cifar100-master/checkpoint/squeezenet_caffe/Monday_04_September_2023_11h_48m_33s/squeezenet_caffe-297-best.pth'  # cloth
state_dict = torch.load(model_path, map_location=device)
print(1)
# mew_state_dict = OrderedDict()
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in state_dict.items() if (k in model_dict and 'fc' not in k)}
model_dict.update(pretrained_dict)
print(2)
model.load_state_dict(model_dict)
model.eval()
print(3)
# output = model(data)# Input to the model
x = torch.randn(batch_size, 3, 100, 100, requires_grad=True)
x = x.cuda()
torch_out = model(x)# Export the model
torch.onnx.export(model,               # model being runx,                         # model input (or a tuple for multiple inputs)"/data/cch/pytorch-cifar100-master/checkpoint/squeezenet_caffe/Monday_04_September_2023_11h_48m_33s/squeezenet_caffe-297-best.onnx",   # where to save the model (can be a file or file-like object)export_params=True,        # store the trained parameter weights inside the model fileopset_version=10,          # the ONNX version to export the model todo_constant_folding=True,  # whether to execute constant folding for optimizationinput_names = ['input'],   # the model's input namesoutput_names = ['output'], # the model's output namesdynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes'output' : {0 : 'batch_size'}})

只需要修改一下输入输出的路径和输入的size即可。
然后是onnx转trt,这里需要自己先安装搭建好tensorrt的环境(环境搭建可能会有点复杂需要编译,有时间单独出一个详细的搭建过程),然后在tensorrt工程下的bin目录下运行命令:

./trtexec --onnx=/data/.../best.onnx --saveEngine=/data.../best.trt --workspace=6000

TensorRT可以提供workspace作为每层网络执行时的临时存储空间,该空间是共享的以减少显存占用(单位是M)。具体的原理可以参考这篇。

前向推理

代码如下:

# 动态推理
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
import numpy as np
import torchvision.transforms as transforms
from PIL import Imagedef load_engine(engine_path):# TRT_LOGGER = trt.Logger(trt.Logger.WARNING)  # INFOTRT_LOGGER = trt.Logger(trt.Logger.ERROR)print('---')print(trt.Runtime(TRT_LOGGER))print('---')with open(engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:return runtime.deserialize_cuda_engine(f.read())# 2. 读取数据,数据处理为可以和网络结构输入对应起来的的shape,数据可增加预处理
def get_test_transform():return transforms.Compose([transforms.Resize([100, 100]),transforms.ToTensor(),# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),transforms.Normalize(mean=[0.4796262, 0.4549252, 0.43396652], std=[0.27888104, 0.28492442, 0.27168077])])image = Image.open('/data/.../dog.jpg') 
image = get_test_transform()(image)
image = image.unsqueeze_(0) # -> NCHW, 1,3,224,224
print("input img mean {} and std {}".format(image.mean(), image.std()))
image =  np.array(image)path = '/data/.../squeezenet_caffe-297-best.trt'
# 1. 建立模型,构建上下文管理器
engine = load_engine(path)
print(engine)
context = engine.create_execution_context()
context.active_optimization_profile = 0# 3.分配内存空间,并进行数据cpu到gpu的拷贝
# 动态尺寸,每次都要set一下模型输入的shape,0代表的就是输入,输出根据具体的网络结构而定,可以是0,1,2,3...其中的某个头。
context.set_binding_shape(0, image.shape)
d_input = cuda.mem_alloc(image.nbytes)  # 分配输入的内存。
output_shape = context.get_binding_shape(1)
buffer = np.empty(output_shape, dtype=np.float32)
d_output = cuda.mem_alloc(buffer.nbytes)  # 分配输出内存。
cuda.memcpy_htod(d_input, image)
bindings = [d_input, d_output]# 4.进行推理,并将结果从gpu拷贝到cpu。
context.execute_v2(bindings)  # 可异步和同步
cuda.memcpy_dtoh(buffer, d_output)
output = buffer.reshape(output_shape)
y_pred_binary = np.argmax(output, axis=1)
print(y_pred_binary[0])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/73753.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Gradle下载库速度过慢解决办法

最近搞了个Gradle的项目&#xff0c;项目下载依赖库太慢了&#xff0c;于是… Gradle下载库速度过慢的问题可能由多种原因导致&#xff0c;以下是一些可能的解决方案&#xff1a; 1、使用国内镜像站点&#xff1a; 你可以改变Gradle的配置&#xff0c;使用国内的镜像站点来下…

go开发之个人微信的开发

简要描述&#xff1a; 检测好友状态 请求URL&#xff1a; http://域名地址/checkZombie 请求方式&#xff1a; POST 请求头Headers&#xff1a; Content-Type&#xff1a;application/jsonAuthorization&#xff1a;login接口返回 参数&#xff1a; 参数名必选类型说明…

SpringCloudAlibaba Gateway(一)简单集成

SpringCloudAlibaba Gateway(一)简单集成 随着服务模块的增加&#xff0c;一定会产生多个接口地址&#xff0c;那么客户端调用多个接口只能使用多个地址&#xff0c;维护多个地址是很不方便的&#xff0c;这个时候就需要统一服务地址。同时也可以进行统一认证鉴权的需求。那么服…

go语言基础操作---七

socket简单介绍—套接字编程 什么是Socket Socket&#xff0c;英文含义是【插座、插孔】&#xff0c;一般称之为套接字&#xff0c;用于描述IP地址和端口。可以实现不同程序间的数据通信。 Socket起源于Unix&#xff0c;而Unix基本哲学之一就是“一切皆文件”&#xff0c;都可…

[移动通讯]【Carrier Aggregation in LTE】【 Log analysis-2】

前言&#xff1a; 接 [移动通讯]【Carrier Aggregation in LTE】【 Theory Log analysis-1】 这里面 主要讲解一下日志分析 目录&#xff1a; 总体流程 UE Capbaility Information MeasurementReport RRC Connection Reconfiguration RRCConnectionReconfiguration…

开源药店商城系统源码比较:哪个适合你的药品电商业务

在构建药品电商业务时&#xff0c;选择适合的药店商城系统源码是至关重要的决策之一。开源药店商城系统源码提供了快速入门的机会&#xff0c;但在选择之前&#xff0c;您需要仔细考虑您的需求、技术要求和可扩展性。本文将比较几个流行的开源药店商城系统源码&#xff0c;以帮…

LSTM基础

LSTM 视频讲得非常好 https://www.bilibili.com/video/BV1644y1W7sD/?spm_id_from333.788&vd_source3b42b36e44d271f58e90f86679d77db7门的概念 过去&#xff0c;不过去&#xff0c;过去一部分 点乘&#xff0c;0 concatenation&#xff0c;pointwise LSTM RNN 上一…

C/C++之链表的建立

个人主页&#xff1a;点我进入主页 专栏分类&#xff1a;C语言初阶 C语言程序设计————KTV C语言小游戏 C语言进阶 C语言刷题 欢迎大家点赞&#xff0c;评论&#xff0c;收藏。 一起努力&#xff0c;一起奔赴大厂。 目录 1.头插 1.1简介 1.2代码实现头插 …

系统报错“由于找不到msvcp140.dll无法继续执行代码”的处理方法

我在使用电脑时&#xff0c;突然发现了一个错误提示&#xff1a;“无法启动程序&#xff0c;因为找不到msvcp140.dll文件”。这让我非常困惑&#xff0c;因为我确定这个文件应该存在于我的电脑上。但是电脑依然报错“由于找不到msvcp140.dll无法继续执行代码”&#xff0c;这个…

vue仿企微文档给页面加水印(水印内容可自定义,超简单)

1.在src下得到utils里新建一个文件watermark.js /** 水印添加方法 */let setWatermark (str1, str2) > {let id 1.23452384164.123412415if (document.getElementById(id) ! null) {document.body.removeChild(document.getElementById(id))}let can document.createE…

如何使用PyTorch训练LLM

推荐&#xff1a;使用 NSDT场景编辑器 快速搭建3D应用场景 像LangChain这样的库促进了上述端到端AI应用程序的实现。我们的教程介绍 LangChain for Data Engineering & Data Applications 概述了您可以使用 Langchain 做什么&#xff0c;包括 LangChain 解决的问题&#xf…

Visual Stadio使用技巧

C语言调试技巧 Debug 和 Release 的介绍 Debug&#xff1a;通常称为调试版本&#xff0c;它包含调试信息&#xff0c;并且不作任何优化&#xff0c;便于程序员调试&#xff08;可调试&#xff09;。 Release&#xff1a;通常称为发布版本&#xff0c;它往往时进行了各种优化&a…

SpringMVC_执行流程

四、SpringMVC执行流程 1.SpringMVC 常用组件 DispatcherServlet&#xff1a;前端控制器&#xff0c;用于对请求和响应进行统一处理HandlerMapping&#xff1a;处理器映射器&#xff0c;根据 url/method可以去找到具体的 Handler(Controller)Handler:具体处理器&#xff08;程…

【Springcloud】Actuator服务监控

【Springcloud】Actuator服务监控 【一】基本介绍【二】如何使用【三】端点分类【四】整合Admin-Ui【五】客户端配置【六】集成Nacos【七】登录认证【八】实时日志【九】动态日志【十】自定义通知 【一】基本介绍 &#xff08;1&#xff09;什么是服务监控 监视当前系统应用状…

【NLP的python库(02/4) 】:Spacy

一、说明 借助 Spacy&#xff0c;一个复杂的 NLP 库&#xff0c;可以使用用于各种 NLP 任务的不同训练模型。从标记化到词性标记再到实体识别&#xff0c;Spacy 还生成了精心设计的 Python 数据结构和强大的可视化效果。最重要的是&#xff0c;可以加载和微调不同的语言模型以适…

HTTP代理协议原理分析

HTTP代理协议是一种常见的网络协议&#xff0c;它可以在网络中传递HTTP协议的请求和响应。本文将介绍HTTP代理协议的分析和原理&#xff0c;包括HTTP代理的工作流程、HTTP代理的请求和响应格式、HTTP代理的优缺点等方面。 一、HTTP代理的工作流程 HTTP代理的工作流程如下&#…

通过idea实现springboot集成mybatys

概述 使用springboot 集成 mybatys后&#xff0c;通过http请求接口&#xff0c;使得通过http请求可以直接直接操作数据库&#xff1b; 完成后端功能框架&#xff1b;前端是准备上小程序&#xff0c;调用https的请求接口用。简单实现后端框架&#xff1b; 详细 springboot 集…

Elasticsearch,Logstash和Kibana安装部署(ELK Stack)

前言 当今数字化时代&#xff0c;信息的快速增长使得各类组织和企业面临着海量数据的处理和分析挑战。在这样的背景下&#xff0c;ELK Stack&#xff08;Elasticsearch、Logstash 和 Kibana&#xff09;作为一套强大的开源工具组合&#xff0c;成为了解决数据管理、搜索和可视…

linux并发服务器 —— 多线程并发(六)

线程概述 同一个程序中的所有线程均会独立执行相同程序&#xff0c;且共享同一份全局内存区域&#xff1b; 进程是CPU分配资源的最小单位&#xff0c;线程是操作系统调度执行的最小单位&#xff1b; Linux环境下&#xff0c;线程的本质就是进程&#xff1b; ps -Lf pid&…

简单了解ARP协议

目录 一、什么是ARP协议&#xff1f; 二、为什么需要ARP协议&#xff1f; 三、ARP报文格式 四、广播域是什么&#xff1f; 五、ARP缓存表是什么&#xff1f; 六、ARP的类型 6.1 ARP代理 6.2 免费ARP 七、不同网络设备收到ARP广播报文的处理规则 八、ARP工作机制原理 …