基于卷积神经网络的交通标志识别(pytorch,opencv,yolov5)

文章目录

  • 数据集介绍:
  • resnet18模型代码
  • 加载数据集(Dataset与Dataloader)
  • 模型训练
  • 训练准确率及损失函数:
  • resnet18交通标志分类源码
  • yolov5检测与识别(交通标志)

本文共包含两部分,
第一部分是用resnet18对交通标志分类,仅仅只是交通标志分类
文末附有yolov5和resnet18结合的源码,yolov5复制检测交通标志位置,然后使用resnet18对交通标志进行分类。

数据集介绍:

本文使用的数据集共有6000多张,共包含58个类别。部分数据集如下:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

resnet18模型代码

使用pytorch自带的resnet18模型,代码如下:

from torchvision import models
import torch.nn as nn#加载resnet18模型
net=models.resnet18(weights=None)
#因为分类个数为58,所以需要修改模型最后一层全连接层
net.fc=nn.Linear(in_features=512, out_features=58, bias=True)
# print(net)

加载数据集(Dataset与Dataloader)

from torch.utils.data import Dataset,DataLoader
import numpy as np
import cv2
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from PIL import Image
import os
from torchvision import transforms
import torch
import randoma=[]
class Mydata(Dataset):def __init__(self,lines,train=True):super(Mydata, self).__init__()self.lines=linesrandom.shuffle(self.lines)self.train=traindef __len__(self):return len(self.lines)def __getitem__(self, index):txts=self.lines[index].strip().split(';')src_path='pic/'+txts[0]w=int(txts[1])h=int(txts[2])x1=int(txts[3])y1=int(txts[4])x2=int(txts[5])y2=int(txts[6])new_x1=random.randint(0,x1)new_y1=random.randint(0,y1)new_x2=random.randint(x2,w-1)new_y2=random.randint(y2,h-1)lab=int(txts[7])# if lab in a:#     pass# else:a.append(lab)## a.sort()# print(len(a))# print(a)img = Image.open(src_path)img=np.array(img)[...,:3]img=img[new_y1:new_y2,new_x1:new_x2]#数据增强if self.train:img=self.get_random_data(img)else:img = cv2.resize(img, (128, 128))# cv2.imshow('img',img[...,::-1])# cv2.waitKey(0)#归一化img=(img/255.0).astype('float32')img=np.transpose(img,(2,0,1))img=torch.from_numpy(img)return img,labdef get_random_data(self,img):seq = iaa.Sequential([# iaa.Flipud(0.5),  # flip up and down (vertical)# iaa.Fliplr(0.5),  # flip left and right (horizontal)iaa.Multiply((0.8, 1.2)),  # change brightness, doesn't affect BBs(bounding boxes)iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值iaa.Crop(percent=(0, 0.2)),iaa.Affine(translate_px={"x": (0,15), "y": (0,15)},  # 平移scale=(0.8, 1.2),  # 尺度变换rotate=(-20, 20),mode='constant',cval=(125)),iaa.Resize(128)])img= seq(image=img)return img
if __name__ == '__main__':lines=open('data.txt','r').readlines()my=Mydata(lines=lines,train=True)myloader=DataLoader(dataset=my,batch_size=3,shuffle=False)for i,j in myloader:print(i.shape,j.shape)

模型训练

经过60个epoch训练后,模型准确率基本上达到百分百

from mymodel import net
from myDataset import Mydata
import random
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm import tqdm
import matplotlib.pylab as pltbatch_size=32
Epoch=60
lr=0.001lines=open('data.txt','r').readlines()
random.shuffle(lines)
val_lines=random.sample(lines,int(len(lines)*0.1))
train_lines=list(set(lines)-set(val_lines))train_data=Mydata(lines=train_lines)
val_data=Mydata(lines=val_lines,train=False)
train_loader=DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
val_loader=DataLoader(dataset=val_data,batch_size=batch_size,shuffle=False)num_train   = len(train_lines)
epoch_step  = num_train // batch_size
BCE_loss     = nn.CrossEntropyLoss()
optimizer  = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999))
lr_scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
#获取学习率函数
def get_lr(optimizer):for param_group in optimizer.param_groups:return param_group['lr']
#计算准确率函数
def metric_func(pred,lab):_,index=torch.max(pred,dim=-1)acc=torch.where(index==lab,1.,0.).mean()return acc
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net=net.to(device)
#设置损失函数
loss_fun     = nn.CrossEntropyLoss()if __name__ == '__main__':T_acc=[]V_acc=[]T_loss=[]V_loss=[]# 设置迭代次数200次epoch_step = num_train // batch_sizefor epoch in range(1, Epoch + 1):net.train()total_loss = 0loss_sum = 0.0train_acc_sum=0.0with tqdm(total=epoch_step, desc=f'Epoch {epoch}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:for step, (features, labels) in enumerate(train_loader, 1):features = features.to(device)labels = labels.to(device)batch_size = labels.size()[0]optimizer.zero_grad()predictions = net(features)loss = loss_fun(predictions, labels)loss.backward()optimizer.step()total_loss += losstrain_acc = metric_func(predictions, labels)train_acc_sum+=train_accpbar.set_postfix(**{'loss': total_loss.item() / (step),"acc":train_acc_sum.item()/(step),'lr': get_lr(optimizer)})pbar.update(1)T_acc.append(train_acc_sum.item()/(step))T_loss.append(total_loss.item() / (step))# 验证net.eval()val_acc_sum = 0val_loss_sum=0for val_step, (features, labels) in enumerate(val_loader, 1):with torch.no_grad():features = features.to(device)labels = labels.to(device)predictions = net(features)val_metric = metric_func(predictions, labels)loss=loss_fun(predictions,labels)val_acc_sum += val_metric.item()val_loss_sum+=loss.item()print('val_acc=%.4f' % (val_acc_sum / val_step))V_acc.append(round(val_acc_sum / val_step,2))V_loss.append(val_loss_sum/val_step)# 保存模型if (epoch) % 2 == 0:torch.save(net.state_dict(), 'logs/Epoch%d-Loss%.4f_.pth' % (epoch, total_loss / (epoch_step + 1)))lr_scheduler.step()plt.figure()plt.plot(T_acc,'r')plt.plot(V_acc,'b')plt.title('Training and validation Acc')plt.xlabel("Epochs")plt.ylabel("Acc")plt.legend(["Train_acc", "Val_acc"])# plt.show()plt.savefig("ACC.png")plt.figure()plt.plot(T_loss, 'r')plt.plot(V_loss, 'b')plt.title('Training and validation loss')plt.xlabel("Epochs")plt.ylabel("loss")plt.legend(["Train_loss", "Val_loss"])plt.savefig("LOSS.png")plt.show()

训练准确率及损失函数:

准确率:

在这里插入图片描述
损失函数:
在这里插入图片描述

resnet18交通标志分类源码

(包含训练,预测代码,准确率,损失函结果图像,数据集等):
下载地址:

yolov5检测与识别(交通标志)

前面是使用resnet18网络对交通标志分类,只是单单的分类,无法从一张完整的全局图像中检测交通标志位置。对此,首先使用yolov5从全局图像中检测交通标志的位置,只是检测没有分类,然后再使用前面训练好的resnet18模型对交通标志分类。其效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/15103.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

回溯算法06(总结+leetcode332,51,37)

参考资料: https://programmercarl.com/%E5%9B%9E%E6%BA%AF%E6%80%BB%E7%BB%93.html 力扣这三题暂时不在本篇笔记中贴代码了,有兴趣的可参考332.重新安排形成、N皇后、解数独 总结: 画树形图分析题目 用途:回溯算法是用 递归实现…

C++学习笔记(21)——继承

目录 1. 继承的概念及定义1.1 继承的概念1.2 继承定义1.2.1 定义格式1.2.2 继承关系和访问限定符1.2.3 继承基类成员访问方式的变化 继承的概念总结: 2. 基类和派生类对象赋值转换3.继承中的作用域4.派生类的默认成员函数知识点:派生类中6个默认成员函数…

win11 wsl ubuntu24.04

win11 wsl ubuntu24.04 一:开启Hyper-V二:安装wsl三:安装ubuntu24.04三:桥接模式,固定IP四:U盘使用五:wsl 从c盘迁移到其它盘参考资料 一:开启Hyper-V win11家庭版开启hyper-v 桌面…

Pytorch-01 框架简介

智能框架概述 人工智能框架是一种软件工具,用于帮助开发人员构建和训练人工智能模型。这些框架提供了各种功能,如定义神经网络结构、优化算法、自动求导等,使得开发人员可以更轻松地实现各种人工智能任务。通过使用人工智能框架,…

LangChain - Tool Calling 工具调用

文章目录 介绍组件1、ChatModel.bind_tools(...)2、AIMessage.tool_calls3、create_tool_calling_agent() 三、LangGraphwith_structured_output 四、结论 本文翻译整理自:Tool Calling with LangChain https://blog.langchain.dev/tool-calling-with-langchain/ TL…

汽车液态锂电池过充时,有怎样的表现,或者对电池有怎样的危害?

标签: 汽车液态锂电池过充的表现与危害; 电池过充; 汽车液态锂电池过充的表现与危害 液态锂电池在过充状态下,会出现一系列不良表现,并且对电池本身以及使用安全造成严重危害。以下是详细的分析: 1. 过充的表现 电压升高:在过充过程中,电池电压会超过其设计的最大电…

【MySQL精通之路】MySQL-环境变量

本节列出了MySQL直接或间接使用的环境变量。 其中大部分也可以在本手册的其他地方找到。 命令行上的选项优先于选项文件和环境变量中指定的值,选项文件中的值优先于环境变量中的值。 在许多情况下,最好使用配置文件而不是环境变量来修改MySQL的行为。…

虚拟机使用的是此版本 VMware Workstation 不支持的硬件版本。 模块“Upgrade”启动失败。 未能启动虚拟机。

问题: 虚拟机使用的是此版本 VMware Workstation 不支持的硬件版本。 模块“Upgrade”启动失败。 未能启动虚拟机。 分析: 该虚拟机环境之前使用的VMware版本与你所使用的VMware版本不一致。大概率你使用的是刚从别人电脑里拷过来的虚拟机环境。 解决&…

开发需要知道的敏捷开发理念

敏捷宣言和原则 敏捷软件开发宣言 敏捷软件开发宣言(Agile Manifesto)是敏捷开发方法的核心指导原则,由17位软件开发专家在2001年共同起草。该宣言强调了在软件开发过程中对某些价值观的优先级: 个体和互动高于流程和工具&#…

游戏后台开发技术全面解析

在这个数字时代,游戏产业已经成为全球最受欢迎的娱乐方式之一。从简单的手机游戏到复杂的大型多人在线角色扮演游戏(MMORPG),游戏的世界正变得越来越丰富和多样化。而这一切的背后,都离不开强大的游戏后台技术支持。在…

项目日记(3) boost搜索引擎

目录 1. 准备工作 2. 搜索初始化 3. 搜索部分 4. 对content部分处理 5. 服务器编写 前言: 上次在项目日记(2)写了index索引, 这次就可以进行search搜索了. 不多说快看. 先点个一键三联. 蟹蟹!!! 1. 准备工作 后面需要倒排索引的结构体, 先准备好. words是后面一个文档里面…

三丰云云服务器测评

三丰云是一家知名的云计算服务提供商,提供免费虚拟主机和免费云服务器等多种云计算服务。这些服务深受广大用户的喜爱,因为它们可以帮助用户轻松地搭建网站、应用程序等,同时无需购买昂贵的硬件设备。 对于初学者来说,使用三丰云…

Java重写

方法重写的意义 在java中,子类可以继承父类中的方法,而不需要重新编写相同的方法,但是有时子类并不想原封不动的继承父类方法,需要做一定的修改,这时候就需要使用方法重写 方法重写的定义 在继承的前提下 子类可以根据…

UI面试手册

UI面试手册 薪资:6~9k 高级:8~15k 岗位职责: 负责公司品牌形象及产品相关海报、宣传画等创意设计工作;负责公司日常宣传、营销广告、策划设计制作、产品设计和包装设计等工作;配合国内外广告投放物料设计,按进度要求…

Python使用连接池操作MySQL

测试环境说明:Python版本是 3.8.10 ,DBUtils版本是3.1.0 ,pymysql版本是1.0.3 首先安装指定版本的连接池库DBUtils 、还有pymysql pip install DBUtils3.1.0 pip install pymysql1.0.3创建文件 sqlConfig.py # sqlConfig.pyimport pymysql…

Math类

类 Math 包含执行基本数值运算的方法,例如基本指数、对数、平方根和三角函数。下面是我写代码时用到的一些字段和方法,归纳如下。 字段 修饰符和类型 Field描述static final double Edouble 值比任何其他值都更接近e, 自然对数的底…

YOLOv10论文解读:实时端到端的目标检测模型

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

618购物节快递量激增,EasyCVR视频智能分析助力快递网点智能升级

随着网络618购物节的到来,物流仓储与快递行业也迎来业务量暴增的情况。驿站网点和快递门店作为物流体系的重要组成部分,其安全性和运营效率日益受到关注。为了提升这些场所的安全防范能力和服务水平,实施视频智能监控方案显得尤为重要。 一、…

蓝桥杯嵌入式国赛笔记(2):拓展板按键程序设计

目录 1、前言 2、电路原理 3、代码编写 3.1 读取Btn电压 3.2 检索按键 3.3 main文件编写 3.3.1 进行变量定义 3.3.2 AD_Key函数 3.3.3 LCD函数 3.3.4 main函数 3.3.5 完整代码 4、测试 5、总结 1、前言 本文进行拓展板按键程序设计,拓展板的按键是通…

人生苦短,我学python之数据类型(下)

个人主页:星纭-CSDN博客 系列文章专栏:Python 踏上取经路,比抵达灵山更重要!一起努力一起进步! 目录 一.集合 1.1子集与超集 1.2交集,并集,补集,差集 1.intersection(英文&a…