基于卷积神经网络的交通标志识别(pytorch,opencv,yolov5)

文章目录

  • 数据集介绍:
  • resnet18模型代码
  • 加载数据集(Dataset与Dataloader)
  • 模型训练
  • 训练准确率及损失函数:
  • resnet18交通标志分类源码
  • yolov5检测与识别(交通标志)

本文共包含两部分,
第一部分是用resnet18对交通标志分类,仅仅只是交通标志分类
文末附有yolov5和resnet18结合的源码,yolov5复制检测交通标志位置,然后使用resnet18对交通标志进行分类。

数据集介绍:

本文使用的数据集共有6000多张,共包含58个类别。部分数据集如下:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

resnet18模型代码

使用pytorch自带的resnet18模型,代码如下:

from torchvision import models
import torch.nn as nn#加载resnet18模型
net=models.resnet18(weights=None)
#因为分类个数为58,所以需要修改模型最后一层全连接层
net.fc=nn.Linear(in_features=512, out_features=58, bias=True)
# print(net)

加载数据集(Dataset与Dataloader)

from torch.utils.data import Dataset,DataLoader
import numpy as np
import cv2
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from PIL import Image
import os
from torchvision import transforms
import torch
import randoma=[]
class Mydata(Dataset):def __init__(self,lines,train=True):super(Mydata, self).__init__()self.lines=linesrandom.shuffle(self.lines)self.train=traindef __len__(self):return len(self.lines)def __getitem__(self, index):txts=self.lines[index].strip().split(';')src_path='pic/'+txts[0]w=int(txts[1])h=int(txts[2])x1=int(txts[3])y1=int(txts[4])x2=int(txts[5])y2=int(txts[6])new_x1=random.randint(0,x1)new_y1=random.randint(0,y1)new_x2=random.randint(x2,w-1)new_y2=random.randint(y2,h-1)lab=int(txts[7])# if lab in a:#     pass# else:a.append(lab)## a.sort()# print(len(a))# print(a)img = Image.open(src_path)img=np.array(img)[...,:3]img=img[new_y1:new_y2,new_x1:new_x2]#数据增强if self.train:img=self.get_random_data(img)else:img = cv2.resize(img, (128, 128))# cv2.imshow('img',img[...,::-1])# cv2.waitKey(0)#归一化img=(img/255.0).astype('float32')img=np.transpose(img,(2,0,1))img=torch.from_numpy(img)return img,labdef get_random_data(self,img):seq = iaa.Sequential([# iaa.Flipud(0.5),  # flip up and down (vertical)# iaa.Fliplr(0.5),  # flip left and right (horizontal)iaa.Multiply((0.8, 1.2)),  # change brightness, doesn't affect BBs(bounding boxes)iaa.GaussianBlur(sigma=(0, 1.0)),  # 标准差为0到3之间的值iaa.Crop(percent=(0, 0.2)),iaa.Affine(translate_px={"x": (0,15), "y": (0,15)},  # 平移scale=(0.8, 1.2),  # 尺度变换rotate=(-20, 20),mode='constant',cval=(125)),iaa.Resize(128)])img= seq(image=img)return img
if __name__ == '__main__':lines=open('data.txt','r').readlines()my=Mydata(lines=lines,train=True)myloader=DataLoader(dataset=my,batch_size=3,shuffle=False)for i,j in myloader:print(i.shape,j.shape)

模型训练

经过60个epoch训练后,模型准确率基本上达到百分百

from mymodel import net
from myDataset import Mydata
import random
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch
from tqdm import tqdm
import matplotlib.pylab as pltbatch_size=32
Epoch=60
lr=0.001lines=open('data.txt','r').readlines()
random.shuffle(lines)
val_lines=random.sample(lines,int(len(lines)*0.1))
train_lines=list(set(lines)-set(val_lines))train_data=Mydata(lines=train_lines)
val_data=Mydata(lines=val_lines,train=False)
train_loader=DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
val_loader=DataLoader(dataset=val_data,batch_size=batch_size,shuffle=False)num_train   = len(train_lines)
epoch_step  = num_train // batch_size
BCE_loss     = nn.CrossEntropyLoss()
optimizer  = optim.Adam(net.parameters(), lr=lr, betas=(0.5, 0.999))
lr_scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
#获取学习率函数
def get_lr(optimizer):for param_group in optimizer.param_groups:return param_group['lr']
#计算准确率函数
def metric_func(pred,lab):_,index=torch.max(pred,dim=-1)acc=torch.where(index==lab,1.,0.).mean()return acc
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net=net.to(device)
#设置损失函数
loss_fun     = nn.CrossEntropyLoss()if __name__ == '__main__':T_acc=[]V_acc=[]T_loss=[]V_loss=[]# 设置迭代次数200次epoch_step = num_train // batch_sizefor epoch in range(1, Epoch + 1):net.train()total_loss = 0loss_sum = 0.0train_acc_sum=0.0with tqdm(total=epoch_step, desc=f'Epoch {epoch}/{Epoch}', postfix=dict, mininterval=0.3) as pbar:for step, (features, labels) in enumerate(train_loader, 1):features = features.to(device)labels = labels.to(device)batch_size = labels.size()[0]optimizer.zero_grad()predictions = net(features)loss = loss_fun(predictions, labels)loss.backward()optimizer.step()total_loss += losstrain_acc = metric_func(predictions, labels)train_acc_sum+=train_accpbar.set_postfix(**{'loss': total_loss.item() / (step),"acc":train_acc_sum.item()/(step),'lr': get_lr(optimizer)})pbar.update(1)T_acc.append(train_acc_sum.item()/(step))T_loss.append(total_loss.item() / (step))# 验证net.eval()val_acc_sum = 0val_loss_sum=0for val_step, (features, labels) in enumerate(val_loader, 1):with torch.no_grad():features = features.to(device)labels = labels.to(device)predictions = net(features)val_metric = metric_func(predictions, labels)loss=loss_fun(predictions,labels)val_acc_sum += val_metric.item()val_loss_sum+=loss.item()print('val_acc=%.4f' % (val_acc_sum / val_step))V_acc.append(round(val_acc_sum / val_step,2))V_loss.append(val_loss_sum/val_step)# 保存模型if (epoch) % 2 == 0:torch.save(net.state_dict(), 'logs/Epoch%d-Loss%.4f_.pth' % (epoch, total_loss / (epoch_step + 1)))lr_scheduler.step()plt.figure()plt.plot(T_acc,'r')plt.plot(V_acc,'b')plt.title('Training and validation Acc')plt.xlabel("Epochs")plt.ylabel("Acc")plt.legend(["Train_acc", "Val_acc"])# plt.show()plt.savefig("ACC.png")plt.figure()plt.plot(T_loss, 'r')plt.plot(V_loss, 'b')plt.title('Training and validation loss')plt.xlabel("Epochs")plt.ylabel("loss")plt.legend(["Train_loss", "Val_loss"])plt.savefig("LOSS.png")plt.show()

训练准确率及损失函数:

准确率:

在这里插入图片描述
损失函数:
在这里插入图片描述

resnet18交通标志分类源码

(包含训练,预测代码,准确率,损失函结果图像,数据集等):
下载地址:

yolov5检测与识别(交通标志)

前面是使用resnet18网络对交通标志分类,只是单单的分类,无法从一张完整的全局图像中检测交通标志位置。对此,首先使用yolov5从全局图像中检测交通标志的位置,只是检测没有分类,然后再使用前面训练好的resnet18模型对交通标志分类。其效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/15103.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++学习笔记(21)——继承

目录 1. 继承的概念及定义1.1 继承的概念1.2 继承定义1.2.1 定义格式1.2.2 继承关系和访问限定符1.2.3 继承基类成员访问方式的变化 继承的概念总结: 2. 基类和派生类对象赋值转换3.继承中的作用域4.派生类的默认成员函数知识点:派生类中6个默认成员函数…

win11 wsl ubuntu24.04

win11 wsl ubuntu24.04 一:开启Hyper-V二:安装wsl三:安装ubuntu24.04三:桥接模式,固定IP四:U盘使用五:wsl 从c盘迁移到其它盘参考资料 一:开启Hyper-V win11家庭版开启hyper-v 桌面…

Pytorch-01 框架简介

智能框架概述 人工智能框架是一种软件工具,用于帮助开发人员构建和训练人工智能模型。这些框架提供了各种功能,如定义神经网络结构、优化算法、自动求导等,使得开发人员可以更轻松地实现各种人工智能任务。通过使用人工智能框架,…

虚拟机使用的是此版本 VMware Workstation 不支持的硬件版本。 模块“Upgrade”启动失败。 未能启动虚拟机。

问题: 虚拟机使用的是此版本 VMware Workstation 不支持的硬件版本。 模块“Upgrade”启动失败。 未能启动虚拟机。 分析: 该虚拟机环境之前使用的VMware版本与你所使用的VMware版本不一致。大概率你使用的是刚从别人电脑里拷过来的虚拟机环境。 解决&…

游戏后台开发技术全面解析

在这个数字时代,游戏产业已经成为全球最受欢迎的娱乐方式之一。从简单的手机游戏到复杂的大型多人在线角色扮演游戏(MMORPG),游戏的世界正变得越来越丰富和多样化。而这一切的背后,都离不开强大的游戏后台技术支持。在…

Java重写

方法重写的意义 在java中,子类可以继承父类中的方法,而不需要重新编写相同的方法,但是有时子类并不想原封不动的继承父类方法,需要做一定的修改,这时候就需要使用方法重写 方法重写的定义 在继承的前提下 子类可以根据…

Python使用连接池操作MySQL

测试环境说明:Python版本是 3.8.10 ,DBUtils版本是3.1.0 ,pymysql版本是1.0.3 首先安装指定版本的连接池库DBUtils 、还有pymysql pip install DBUtils3.1.0 pip install pymysql1.0.3创建文件 sqlConfig.py # sqlConfig.pyimport pymysql…

YOLOv10论文解读:实时端到端的目标检测模型

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

618购物节快递量激增,EasyCVR视频智能分析助力快递网点智能升级

随着网络618购物节的到来,物流仓储与快递行业也迎来业务量暴增的情况。驿站网点和快递门店作为物流体系的重要组成部分,其安全性和运营效率日益受到关注。为了提升这些场所的安全防范能力和服务水平,实施视频智能监控方案显得尤为重要。 一、…

蓝桥杯嵌入式国赛笔记(2):拓展板按键程序设计

目录 1、前言 2、电路原理 3、代码编写 3.1 读取Btn电压 3.2 检索按键 3.3 main文件编写 3.3.1 进行变量定义 3.3.2 AD_Key函数 3.3.3 LCD函数 3.3.4 main函数 3.3.5 完整代码 4、测试 5、总结 1、前言 本文进行拓展板按键程序设计,拓展板的按键是通…

人生苦短,我学python之数据类型(下)

个人主页:星纭-CSDN博客 系列文章专栏:Python 踏上取经路,比抵达灵山更重要!一起努力一起进步! 目录 一.集合 1.1子集与超集 1.2交集,并集,补集,差集 1.intersection(英文&a…

webman使用summernote富文本编辑器

前言 Summernote富文本编辑器功能强大,可以直接从word直接复制内容过来而不破坏原有的文档格式,非常适合做商品详情等内容的编辑工具。本文将展示如何在php高性能框架webman中使用summernote编辑器。 下载 去Bootstrap 中文网、Summernote、jQuery官网…

【设计模式】JAVA Design Patterns——Converter(转换器模式)

🔍目的 转换器模式的目的是提供相应类型之间双向转换的通用方法,允许进行干净的实现,而类型之间无需相互了解。此外,Converter模式引入了双向集合映射,从而将样板代码减少到最少 🔍解释 真实世界例子 在真实…

低代码开发:拖拽式可视化构建工业物联网系统

什么是低代码? 低代码(Low Code)是一种可视化的软件开发方法,通过最少的手动编码可以更快地交付应用程序。低代码平台的图形用户界面和拖放功能可自动执行开发过程的各个方面,从而消除对传统计算机编程方法的依赖。 什么是低代码平台&#…

Pandas 创建层次化索引

1.创建多层次索引 1.1 隐式构造 最常见的方法是给DataFrame构造函数的index参数传递两个或更多的数组 # 导入pandasimport numpy as npimport pandas as pd​data np.random.randint(0,100,size(6,6))​# 行索引index [ ["1班","1班","1班&qu…

【全网最全】2024电工杯数学建模B题53页成品论文+完整matlab代码+完整python代码+数据预处理+可视化结果等(后续会更新)

您的点赞收藏是我继续更新的最大动力! 一定要点击如下的卡片链接,那是获取资料的入口! 【全网最全】2024电工杯数学建模B题53页成品论文完整matlab、py代码19建模过程代码数据等(后续会更新)「首先来看看目前已有的资…

微软新功能Recall引发隐私担忧,英国数据监管机构展开调查

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

[Spring Cloud] (9)XSS拦截器

文章目录 简述本文涉及代码已开源Fir Cloud 完整项目防XSS攻击必要性:作用: 整体效果后端增加拦截器开关配置pom中增加jsoup依赖添加JSON处理工具类添加xss拦截工具类防XSS-请求拦截器 前端 简述 本文涉及代码已开源 本文网关gateway,微服务…

Visual Studio Code插件

文章目录 工具类AIChinese (Simplified) (简体中文)cmake集Code RunnerGitLens — Git superchargedPath IntellisenseTodo TreeBookmarks (书签)markdownclangd 美化类Output Colorizer (输出窗口彩色)Doxygen Documentation Gen…

安装harbor出现问题: Running 1/1 ✘ Network harbor_harbor Error

安装harbor出现问题: [] Running 1/1 ✘ Network harbor_harbor Error 0.2s failed to create network harbor_harbor: Error response from daemon: Fa…