Pytorch入门实战 P10-使用pytorch实现车牌识别

目录

前言

一、MyDataset文件

二、完整代码:

三、结果展示:

四、添加accuracy值


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

本周的学习内容是,使用pytorch实现车牌识别。

前言

        之前的案例里面,我们大多是使用的是datasets.ImageFolder函数,直接导入已经分类好的数据集形成Dataset,然后使用DataLoader加载Dataset,但是如果对无法分类的数据集,我们应该如何导入呢。

        这篇文章主要就是介绍通过自定义的一个MyDataset加载车牌数据集并完成车牌识别。

一、MyDataset文件

数据文件是这样的,没有进行分类的。

# 加载数据文件
class MyDataset(data.Dataset):def __init__(self, all_labels, data_paths_str, transform):self.img_labels = all_labels  # 获取标签信息self.img_dir = data_paths_str  # 图像目录路径self.transform = transform   # 目标转换函数def __len__(self):return len(self.img_labels)   # 返回数据集的长度,即标签的数量def __getitem__(self, index):image = Image.open(self.img_dir[index]).convert('RGB')   # 打开指定索引的图像文件,并将其转换为RGB模式label = self.img_labels[index]  # 获取图像对应的标签if self.transform:image = self.transform(image)   # 如果设置了转换函数,则对图像进行转换(如,裁剪、缩放、归一化等)return image, label  # 返回图像和标签

二、完整代码:

import pathlibimport matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from torch import nn
from torch.utils import data
from torchvision import transforms
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib as mpl
mpl.use('Agg')  # 在服务器上运行的时候,打开注释device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)data_dir = './data'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classNames = [str(path).split('/')[1].split('_')[1].split('.')[0] for path in data_paths]
# print(classNames)  # '沪G1CE81', '云G86LR6', '鄂U71R9F', '津G467JR'....data_paths_str = [str(path) for path in data_paths]# 数据可视化
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(14,5))
plt.suptitle('data show', fontsize=15)
for i in range(18):plt.subplot(3, 6, i+1)# 显示图片images = plt.imread(data_paths_str[i])plt.imshow(images)
plt.show()# 3、标签数字化
char_enum = ["京","沪","津","渝","冀","晋","蒙","辽","吉","黑","苏","浙","皖","闽","赣","鲁","豫","鄂","湘","粤","桂","琼","川","贵","云","藏","陕","甘","青","宁","新","军","使"]number = [str(i) for i in range(0, 10)]  # 0-9 的数字
alphabet = [chr(i) for i in range(65, 91)]  # A到Z的字母
char_set = char_enum + number + alphabet
char_set_len = len(char_set)
label_name_len = len(classNames[0])# 将字符串数字化
def text2vec(text):vector = np.zeros([label_name_len, char_set_len])for i, c in enumerate(text):idx = char_set.index(c)vector[i][idx] = 1.0return vectorall_labels = [text2vec(i) for i in classNames]# 加载数据文件
class MyDataset(data.Dataset):def __init__(self, all_labels, data_paths_str, transform):self.img_labels = all_labels  # 获取标签信息self.img_dir = data_paths_str  # 图像目录路径self.transform = transform   # 目标转换函数def __len__(self):return len(self.img_labels)def __getitem__(self, index):image = Image.open(self.img_dir[index]).convert('RGB')label = self.img_labels[index]  # 获取图像对应的标签if self.transform:image = self.transform(image)return image, label  # 返回图像和标签total_datadir = './data/'
train_transforms = transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])total_data = MyDataset(all_labels, data_paths_str, train_transforms)
# 划分数据
train_size = int(0.8*len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data,[train_size, test_size])
print(train_size, test_size)  # 10940 2735# 数据加载
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=True)for X, y in test_loader:print('Shape of X [N,C,H,W]:',X.shape)   # ([16, 3, 224, 224])print('Shape of y:', y.shape, y.dtype)    # torch.Size([16, 7, 69]) torch.float64break# 搭建网络模型
class Network_bn(nn.Module):def __init__(self):super(Network_bn, self).__init__()"""nn.Conv2d()函数:第一个参数(in_channels)是输入的channel数量第二个参数(out_channels)是输出的channel数量第三个参数(kernel_size)是卷积核大小第四个参数(stride)是步长,默认为1第五个参数(padding)是填充大小,默认为0"""self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(12)self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)self.bn2 = nn.BatchNorm2d(12)self.pool = nn.MaxPool2d(2, 2)self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn4 = nn.BatchNorm2d(24)self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)self.bn5 = nn.BatchNorm2d(24)self.fc1 = nn.Linear(24 * 50 * 50, label_name_len * char_set_len)self.reshape = Reshape([label_name_len, char_set_len])def forward(self, x):x = F.relu(self.bn1(self.conv1(x)))x = F.relu(self.bn2(self.conv2(x)))x = self.pool(x)x = F.relu(self.bn4(self.conv4(x)))x = F.relu(self.bn5(self.conv5(x)))x = self.pool(x)x = x.view(-1, 24 * 50 * 50)x = self.fc1(x)# 最终reshapex = self.reshape(x)return xclass Reshape(nn.Module):def __init__(self,shape):super(Reshape, self).__init__()self.shape = shapedef forward(self, x):return x.view(x.size(0), *self.shape)model = Network_bn().to(device)
print(model)# 优化器与损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.0001)
loss_model = nn.CrossEntropyLoss()def test(model, test_loader, loss_model):size = len(test_loader.dataset)num_batches = len(test_loader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in test_loader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_model(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchesprint(f'Avg loss: {test_loss:>8f}\n')return correct, test_lossdef train(model,train_loader, loss_model, optimizer):model = model.to(device)model.train()for i, (images, labels) in enumerate(train_loader, 0):   # 0 是标起始位置的值images = Variable(images.to(device))labels = Variable(labels.to(device))optimizer.zero_grad()outputs = model(images)loss = loss_model(outputs, labels)loss.backward()optimizer.step()if i % 100 == 0:print('[%5d] loss: %.3f' % (i, loss))# 模型的训练
test_acc_list = []
test_loss_list = []
epochs = 30
for t in range(epochs):print(f"Epoch {t+1}\n-----------------------")train(model,train_loader, loss_model,optimizer)test_acc,test_loss = test(model, test_loader, loss_model)test_acc_list.append(test_acc)test_loss_list.append(test_loss)print('Done!!!')# 结果分析
x = [i for i in range(1,31)]
plt.plot(x, test_loss_list, label="Loss", alpha = 0.8)
plt.xlabel('Epoch')
plt.ylabel('Loss')plt.legend()
plt.show()
plt.savefig("/data/jupyter/deep_demo/p10_car_number/resultImg.jpg")  # 保存图片在服务器的位置
plt.show()

三、结果展示:

总结:从刚开始损失为0.077 到,训练30轮后,损失到了0.026。

四、添加accuracy值

需求:对在上面的代码中,对loss进行了统计更新,请补充acc统计更新部分,即获取每一次测试的ACC值。

添加accuracy的运行过程:

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/13168.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国网698.45报文解析工具

本文分享一个698.45协议的报文解析工具,此报文解析工具功能强大,可以解析多种国网数据协议。 下载链接: https://pan.baidu.com/s/1ngbBG-yL8ucRWLDflqzEnQ 提取码: y1de 主要界面如下: 本工具内置698.45数据协议, 即可调用word…

win编写bat脚本启动java服务

新建txt,编写,前台启动,出现cmd黑窗口 echo off start java -jar zhoao1.jar start java -jar zhoao2.jar pause完成后,重命名.bat 1、后台启动,不出现cmd黑窗口,app是窗口名称 echo off start "名…

美团小程序mtgsig1.2逆向

声明 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!wx a15018601872 本文章未…

VMware虚拟机没有网,无法设置网络为桥接状态

今天需要使用Ubuntu18但现有虚拟机是Ubuntu20,由于硬盘空间不够大,所以删除了原来的虚拟机并重新搭建Ubuntu18的环境,然后发现虚拟机没有网络,而我之前的虚拟机这一切都是正常的。 在网络设置里勾选的是桥接模式但无法联网&#x…

Cow Exhibition G的来龙去脉

[USACO03FALL] Cow Exhibition G - 洛谷 曲折经过 爆搜 一开始没什么好的想法&#xff0c;就针对每头奶牛去or不去进行了爆搜。 #include <cstdio> #include <algorithm> using namespace std;#define maxn 405 int iq[maxn], eq[maxn]; int ans; int n;void d…

留学资讯 | 2024英国学生签证申请需要满足哪些条件?

英国移民局于2020年9月10日发布了《移民规则变更声明: HC 707》&#xff0c;对学生签证制度进行了全面改革。该法案于2020年10月5日正式生效。根据此法案&#xff0c;新的学生签证——The Student and Child Student Routes学生和儿童学生路线&#xff0c;将替代原先的Tier 4学…

短视频赛道有哪些:成都鼎茂宏升文化传媒公司

短视频赛道有哪些&#xff1a;探索多元化的内容领域 随着科技的飞速发展和人们生活节奏的加快&#xff0c;短视频已成为现代人生活中不可或缺的一部分。它以其简短、直观、易于分享的特点&#xff0c;迅速占领了各个年龄层和社会群体的心智。然而&#xff0c;短视频的赛道并非…

层次式体系结构概述

1.软件体系结构 软件体系结构可定义为&#xff1a;软件体系结构为软件系统提供了结构、行为和属性的高级抽象&#xff0c;由构成系统的元素描述、这些元素的相互作用、指导元素集成的模式以及这些模式的约束组成。软件体系结构不仅指定了系统的组织结构和拓扑结构&#xff0c;并…

小程序框架是智能融媒体平台构建的最佳线路

过去5年&#xff0c;媒体行业一直都在进行着信息化建设向融媒体平台建设的转变。一些融媒体的建设演变总结如下&#xff1a; 新闻终端的端侧内容矩阵建设&#xff0c;如App新闻端&#xff0c;社交平台上的官方媒体等新闻本地生活双旗舰客户端&#xff0c;兼顾主流媒体核心宣传…

TopOn 正式聚合Kwai 旗下程序化广告平台——Kwai Network

**我们非常高兴的宣布&#xff0c;TopOn SDK 近日已正式聚合Kwai Network。**作为Kwai 旗下的程序化广告平台&#xff0c;Kwai Network 通过优质的变现能力及产品能力&#xff0c;为广大开发者提供高效及时的服务。 TopOn 聚合平台与Kwai Network 正式完成接入后&#xff0c;开…

实战+代码!Selenium + Phantom JS爬取天天基金数据

功能&#xff1a; 通过程序实现从基金列表页&#xff0c;获取指定页数内所有基金的近一周收益率以及每支基金的详情页链接。再进入每支基金的详情页获取其余的基金信息&#xff0c;将所有获取到的基金详细信息按近6月收益率倒序排列写入一个Excel表格。 思路&#xff1a; 1.…

vm 虚拟机 Debian12 开启 root、ssh 登录功能

前言&#xff0c;安装的时候语言就选中文就好了。选择中文&#xff0c;在安装的时候就可以选择国内 163 的源。 开启 ssh 功能 先提权&#xff0c;用 root 账户 su安装 ssh 安装 ssh-server apt install openssh-server启动 ssh systemctl start ssh查看 ssh 状态 systemctl st…

u3d的ab文件注意事项

//----------------LoadAllAB.cs--------------------- using System.Collections;using UnityEngine;namespace System.IO{public class LoadAllAB : MonoBehaviour{ //读取本地string path "Assets/Actors/lznh/ab/animation/t_bl/";// Use this for initiali…

Keil手动安装编译器V5版本

V5编译器下载&#xff1a;免积分下载 新版的keil不会自动帮你安装V5版本的编译器&#xff0c;但是很多教程很多比赛所用单片机都是V5的编译器&#xff0c;所以用来开以前的或者开源的很多东西编译直接一大堆报错。 吐槽说完了接下来教你怎么解决 打开installer&#xff08;在…

搞大事!法国邀请芬兰公司建量子工厂

法国当地时间5月13日&#xff0c;法国总统马克龙宣布启动2024年度“选择法国”&#xff08;Choose France&#xff09;商业峰会。今年峰会召开前&#xff0c;法国赢得了创纪录的150亿欧元外国投资承诺&#xff0c;覆盖从人工智能到制药和能源等领域。 而涉及到量子领域最重磅的…

el-select下拉框 添加 el-checkbox 多选框,支持全选、取消全选

el-select下拉框 添加 el-checkbox 多选框&#xff0c;支持全选、取消全选 前言一、实现思路二、实现代码1.模板代码2. css 样式3.js 代码 DEMO 演示总结 前言 实现效果预览 提示&#xff1a;本内容基于element-ui 组件实现&#xff0c;如果使用其他组件库、可参考下面实现方…

关于GitHub仓库建立及提交问题

文章目录 前言GitHub仓库创建token令牌的获取GitHub克隆到本地GitHub上传文件 前言 为了整一个GitHub仓库然后上传文件&#xff0c;笔者看了不下100篇博客&#xff0c;20段教程&#xff0c;最后在两位大佬的帮助下&#xff0c;才整明白了&#x1f62d; 先提前说一嘴从 2021年8月…

网络爬虫安全:90后小伙,用软件非法搬运他人原创视频被判刑

目录 违法视频搬运软件是网络爬虫 如何发现偷盗视频的爬虫&#xff1f; 拦截违法网络爬虫 央视《今日说法》栏目近日报道了一名程序员开发非法视频搬运软件获利超700多万&#xff0c;最终获刑的案例。 国内某知名短视频平台报警称&#xff0c;有人在网络上售卖一款视频搬运…

刘邦的创业团队是沛县人,朱元璋的则是凤阳;要创业,一个县人才就够了

当人们回顾刘邦和朱元璋的创业经历时&#xff0c;总是会感慨他们起于微末&#xff0c;都创下了偌大王朝&#xff0c;成就无上荣誉。 尤其是我们查阅史书时&#xff0c;发现这二人的崛起班底都是各自的家乡人&#xff0c;例如刘邦的班底就是沛县人&#xff0c;朱元璋的班底是凤…

分布式搜索-elaticsearch基础 概念

什么是elaticsearch: 倒排索引&#xff1a;就是将要查询的内容分成一个个词条&#xff0c;在将词条文档id存入&#xff0c;词条是唯一的。 文档词条总结: mysql和Elasticsearch概念对比: 架构: 基本概念总结: