pytorch+CRNN实现

最近接触了一个仪表盘识别的项目,简单调研以后发现可以用CRNN来做。但是手边缺少仪表盘数据集,就先用ICDAR2013试了一下。
在这里插入图片描述
结果遇到了一系列坑。为了不使读者和自己在以后的日子继续遭罪。我把正确的代码发到下面了。
1)超参数请不要调整!!!!CRNN前期训练极其离谱,需要良好的调参,loss才会慢慢下降。
在这里插入图片描述
我给出了一个训练曲线,可以看到确实贼几把怪,七拐八拐的。

2)千万不要用百度开源的那个ctc!!!

网络代码:

#crnn.py
import torch.nn as nn
import torch.nn.functional as Fclass BidirectionalLSTM(nn.Module):# Inputs hidden units Outdef __init__(self, nIn, nHidden, nOut):super(BidirectionalLSTM, self).__init__()self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)self.embedding = nn.Linear(nHidden * 2, nOut)def forward(self, input):recurrent, _ = self.rnn(input)T, b, h = recurrent.size()t_rec = recurrent.view(T * b, h)output = self.embedding(t_rec)  # [T * b, nOut]output = output.view(T, b, -1)return outputclass CRNN(nn.Module):#                   32    1   37     256def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):super(CRNN, self).__init__()assert imgH % 16 == 0, 'imgH has to be a multiple of 16'ks = [3, 3, 3, 3, 3, 3, 2]ps = [1, 1, 1, 1, 1, 1, 0]ss = [1, 1, 1, 1, 1, 1, 1]nm = [64, 128, 256, 256, 512, 512, 512]cnn = nn.Sequential()def convRelu(i, batchNormalization=False):nIn = nc if i == 0 else nm[i - 1]nOut = nm[i]cnn.add_module('conv{0}'.format(i),nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))if batchNormalization:cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))if leakyRelu:cnn.add_module('relu{0}'.format(i),nn.LeakyReLU(0.2, inplace=True))else:cnn.add_module('relu{0}'.format(i), nn.ReLU(True))convRelu(0)cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))  # 64x16x64convRelu(1)cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))  # 128x8x32convRelu(2, True)convRelu(3)cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 256x4x16convRelu(4, True)convRelu(5)cnn.add_module('pooling{0}'.format(3),nn.MaxPool2d((2, 2), (2, 1), (0, 1)))  # 512x2x16convRelu(6, True)  # 512x1x16self.cnn = cnnself.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh),BidirectionalLSTM(nh, nh, nclass))def forward(self, input):# conv features#print('---forward propagation---')conv = self.cnn(input)b, c, h, w = conv.size()assert h == 1, "the height of conv must be 1"conv = conv.squeeze(2) # b *512 * widthconv = conv.permute(2, 0, 1)  # [w, b, c]output = F.log_softmax(self.rnn(conv), dim=2)return output

训练:

#train.py
import os
import torch
import cv2
import numpy as np
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequenceimport crnn
import time
import re
import matplotlib.pyplot as plt
dic={" ":0,"a":1,"b":2,"c":3,"d":4,"e":5,"f":6,"g":7,"h":8,"i":9,"j":10,"k":11,"l":12,"m":13,"n":14,"o":15,"p":16,"q":17,"r":18,"s":19,"t":20,"u":21,"v":22,"w":23,"x":24,"y":25,"z":26,"A":27,"B":28,"C":29,"D":30,"E":31,"F":32,"G":33,"H":34,"I":35,"J":36,"K":37,"L":38,"M":39,"N":40,"O":41,"P":42,"Q":43,"R":44,"S":45,"T":46,"U":47,"V":48,"W":49,"X":50,"Y":51,"Z":52}STR=" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
n_class=53
label_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task1_GT"
image_sources=r"E:\machine_learning\instrument\icdar_2013\Challenge2_Test_Task12_Images"
use_gpu = True
learning_rate = 0.0001
max_epoch = 100
batch_size = 20
# 调整图像大小和归一化操作
class resizeAndNormalize():def __init__(self, size, interpolation=cv2.INTER_LINEAR):# 注意对于opencv,size的格式是(w,h)self.size = sizeself.interpolation = interpolation# ToTensor属于类  """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.self.toTensor = transforms.ToTensor()def __call__(self, image):# (x,y) 对于opencv来说,图像宽对应x轴,高对应y轴image = cv2.resize(image, self.size, interpolation=self.interpolation)# 转为tensor的数据结构image = self.toTensor(image)# 对图像进行归一化操作#image = image.sub_(0.5).div_(0.5)return imagedef load_data(label_folder,image_folder,label_suffix_name=".txt",image_suffix_name=".jpg"):image_file,label_file,num_file=[],[],[]for parent_folder, _, file_names in os.walk(label_folder):# 遍历当前子文件夹中的所有文件for file_name in file_names:# 只处理图片文件# if file_name.endswith(('jpg', 'jpeg', 'png', 'gif')):#提取jpg、jpeg等格式的文件到指定目录if file_name.endswith((label_suffix_name)):  # 提取json格式的文件到指定目录# 构造源文件路径和目标文件路径a,b=file_name.split("gt_")c,d=b.split(label_suffix_name)image_name=image_folder + "\\" + c + image_suffix_nameif os.path.exists(image_name):label_name = label_folder + "\\" + file_nametxt=open(label_name,'rb')txtl=txt.readlines()for line in range(len(txtl)):image_file.append(image_name)label_file.append(label_name)num_file.append(line)return image_file,label_file,num_filedef zl2lable(zl):label_list=[]for str in zl:label_list.append(dic[str])return label_listclass NewDataSet(Dataset):def __init__(self, label_source,image_source,train=True):super(NewDataSet, self).__init__()self.image_file,self.label_file,self.num_file= load_data(label_source,image_source)def __len__(self):return len(self.image_file)def __getitem__(self, index):txt = open(self.label_file[index], 'rb')img=cv2.imread(self.image_file[index],cv2.IMREAD_GRAYSCALE)wordL = txt.readlines()word=str(wordL[self.num_file[index]])pl = re.findall(r'\d+',word)zl = re.findall(r"[a-zA-Z]+", word)[1]  #1#img tensorx1, y1, x2, y2 = pl[:4]img= img[int(y1):int(y2),int(x1):int(x2), ](height, width)=img.shape# 由于crnn网络输入图像的高为32,故需要resize原始图像的heightsize_height = 32# ratio = 32 / float(height)size_width =100transform = resizeAndNormalize((size_width, size_height))# 图像预处理imageTensor = transform(img)#label tensorl = zl2lable(zl)labelTensor = torch.IntTensor(l)return imageTensor,labelTensorclass CRNNDataSet(Dataset):def __init__(self, imageRoot, labelRoot):self.image_root = imageRootself.image_dict = self.readfile(labelRoot)self.image_name = [fileName for fileName, _ in self.image_dict.items()]def __getitem__(self, index):image_path = os.path.join(self.image_root, self.image_name[index])keys = self.image_dict.get(self.image_name[index])label = [int(x) for x in keys]image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)# if image is None:#     return None,None(height, width) = image.shape# 由于crnn网络输入图像的高为32,故需要resize原始图像的heightsize_height = 32ratio = 32 / float(height)size_width = int(ratio * width)transform = resizeAndNormalize((size_width, size_height))# 图像预处理image = transform(image)# 标签格式转换为IntTensorlabel = torch.IntTensor(label)return image, labeldef __len__(self):return len(self.image_name)def readfile(self, fileName):res = []with open(fileName, 'r') as f:lines = f.readlines()for line in lines:res.append(line.strip())dic = {}total = 0for line in res:part = line.split(' ')# 由于会存在训练过程中取图像的时候图像不存在导致异常,所以在初始化的时候就判断图像是否存在if not os.path.exists(os.path.join(self.image_root, part[0])):print(os.path.join(self.image_root, part[0]))total += 1else:dic[part[0]] = part[1:]print(total)return dictrainData =NewDataSet(label_sources,image_sources)trainLoader = DataLoader(dataset=trainData, batch_size=1, shuffle=True, num_workers=0)# valData = CRNNDataSet(imageRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\images\\",
#                       labelRoot="D:\BaiduNetdiskDownload\Synthetic_Chinese_String_Dataset\lables\data_t.txt")
#
# valLoader = DataLoader(dataset=valData, batch_size=1, shuffle=True, num_workers=1)#
# def decode(preds):
#     pred = []
#     for i in range(len(preds)):
#         if preds[i] != 5989 and ((i == 5989) or (i != 5989 and preds[i] != preds[i - 1])):
#             pred.append(int(preds[i]))
#     return pred
#
#
def toSTR(l):str_l=[]if isinstance(l, int):l=[l]for i in range(len(l)):str_l.append(STR[l[i]])return str_l
def toRES(l):new_l=[]new_str=' 'for i in range(len(l)):if(l[i]==' '):new_str = ' 'continueelif new_str!=l[i]:new_l.append(l[i])new_str=l[i]return new_ldef val(model=torch.load("pytorch-crnn.pth")):# 将模式切换为验证评估模式loss_func = torch.nn.CTCLoss(blank=0, reduction='mean')model.eval()test_n=10for i, (data, label) in enumerate(trainLoader):if(i>test_n):break;output = model(data.cuda())pred_label=output.max(2)[1]input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))target_lengths = torch.IntTensor([label.size(1)] * int(label.size(0)))# forward(self, log_probs, targets, input_lengths, target_lengths)#log_probs = output.log_softmax(2).requires_grad_()targets = label.cuda()loss = loss_func(output.cpu(), targets.cpu(), input_lengths, target_lengths)pred_l=np.array(pred_label.cpu().squeeze()).tolist()label_l=np.array(targets.cpu().squeeze()).tolist()print(i,":",loss,"pred:",toRES(toSTR(pred_l)),"label_l",toSTR(label_l))def train():model = crnn.CRNN(32, 1, n_class, 256)if torch.cuda.is_available() and use_gpu:model.cuda()loss_func = torch.nn.CTCLoss(blank=0,reduction='mean')optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,betas=(0.9, 0.999))lossTotal = 0.0k = 0printInterval = 100start_time = time.time()loss_list=[]total_list=[]for epoch in range(max_epoch):n=0data_list = []label_list = []label_len=[]for i, (data, label) in enumerate(trainLoader):#data_list.append(data)label_list.append(label)label_len.append(label.size(1))n=n+1if n%batch_size!=0:continuek=k+1data=torch.cat(data_list, dim=0)data_list.clear()label = torch.cat(label_list, dim=1).squeeze(0)label_list.clear()target_lengths=torch.tensor(np.array(label_len))label_len.clear()# 开启训练模式model.train()if torch.cuda.is_available and use_gpu:data = data.cuda()loss_func = loss_func.cuda()label = label.cuda()output = model(data)log_probs = output# example 建议使用这样,貌似直接把output送进去loss fun也没发现什么问题#log_probs = output.log_softmax(2).requires_grad_()targets = label.cuda()input_lengths = torch.IntTensor([output.size(0)] * int(output.size(1)))# forward(self, log_probs, targets, input_lengths, target_lengths)#targets =torch.zeros(targets.shape)loss = loss_func(log_probs.cpu(), targets, input_lengths, target_lengths)/batch_sizelossTotal += float(loss)print("epoch:",epoch,"num:",i,"loss:",float(loss))loss_list.append(float(loss))if k % printInterval == 0:print("[%d/%d] [%d/%d] loss:%f" % (epoch, max_epoch, i + 1, len(trainLoader), lossTotal / printInterval))total_list.append( lossTotal / printInterval)lossTotal = 0.0torch.save(model, 'pytorch-crnn.pth')optimizer.zero_grad()loss.backward()optimizer.step()plt.figure()plt.plot(loss_list)plt.savefig("loss.jpg")plt.clf()plt.figure()plt.plot(total_list)plt.savefig("total.jpg")end_time = time.time()print("takes {}s".format((end_time - start_time)))return modelif __name__ == '__main__':train()

测试结果如下:
在这里插入图片描述
最后给一些参考文献:
https://www.cnblogs.com/azheng333/p/7449515.html
https://blog.csdn.net/wzw12315/article/details/106643182

另外给出数据集和我训练好的模型:
链接:https://pan.baidu.com/s/1-jTA22bLKv2ut_1EJ1WMKA?pwd=jvk8
提取码:jvk8

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/2668.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android oom_adj 详细解读

源码基于:Android R 0. 前言 在博文《oom_adj 内存水位算法剖析》一文中详细的分析了lmkd 中针对 oom_adj 内存水位的计算、使用方法,在博文《oom_adj 更新原理(1)》、《oom_adj 更新原理(2)》中对Android 系统中 oom_adj 的更新原理进行了详细的剖析。…

Centos 7 安装 Oracle 11G

Oracle 11G 安装教程 准备环境 p13390677_112040_Linux-x86-64_1of7.zipp13390677_112040_Linux-x86-64_2of7.zipCentos 7- rhel7-英文版的系统–不想换语言的执行(LANGen_US)– 传输 文件到服务器上 创建用户和组 [rootlocalhost ~]# groupadd oracle [rootlocalhost ~]…

Windows11 C盘瘦身

1.符号链接 将大文件夹移动到其他盘,创建成符号链接 2.修改Android Studio路径设置 1.SDK路径 2.Gradle路径 3.模拟器路径 设置环境变量 ANDROID_SDK_HOME

基于单片机的盲人导航智能拐杖老人防丢防摔倒发短息定位

功能介绍 以STM32单片机作为主控系统; OLED液晶当前实时距离,安全距离,当前经纬度信息;超声波检测小于设置的安全距离,蜂鸣器报警提示:低于安全距离!超声波检测当前障碍物距离,GPS进…

python发送邮件yagmail库

yagmail库发送邮件简洁,代码量少 import yagmaildef send_yagmail(sender, send_password, addressee, hostsmtp.qq.com, port465):yag yagmail.SMTP(sender, send_password, host, port)img_url https://img2.baidu.com/it/u483398814,2966849709&fm253&…

基于单片机的智能空调系统的设计与实现

功能介绍 以51单片机作为主控系统;LCD1602液晶显示当前水温,定时提醒,水量变化DS18B20检测当前水体温度;水位传感器检测当前水位;继电器驱动加热片进行水温加热;定时提醒喝水,蜂鸣器报警&#x…

LeetCode面试题02.07.链表相交

面试题02.07.链表相交 两种解题思路 面试题02.07.链表相交一、双指针二、哈希集合 一、双指针 这道题简单来说,就是求两个链表交点节点的指针 这里注意:交点不是数值相等,而是指针相等 为了方便举例,假设节点元素数值相等&…

用Python采用Modbus-Tcp的方式读取485电子水尺数据

README.TXT 2023/6/15 V1.0 实现了单个点位数据通信、数据解析、数据存储 2023/6/17 V2.0 实现了多个点位数据通信、数据解析、数据存储 2023/6/19 V2.1 完善log存储,仅保留近3天的log记录,避免不必要的存储;限制log大小,2MB。架…

数字原生时代,奥哲如何让企业都成为“原住民”?

22年前,美国教育学家马克‧普伦斯基(Marc Prensky)出版了《数字原生与数字移民》(Digital Natives, Digital Immigrants)一书,首次提出了“数字原住民”和“数字移民”两大概念,用来定义跨时代的…

【数据结构】_1.集合与复杂度

目录 1. 集合框架 2. 时间复杂度 2.1 时间复杂度和空间复杂度 2.2 时间复杂度的概念 2.3 大O的渐进表示法 2.3.1 精确的时间复杂度表达式 2.3.2 大O渐进表示法的三条规则 2.3.3 时间复杂度的最好、平均与最坏情况 2.4 时间复杂度计算示例 3.空间复杂度 1. 集合框架 …

字节跳动后端面试,笔试部分

var code "7022f444-ded0-477c-9afe-26812ca8e7cb" 背景 笔者在刷B站的时候,看到了一个关于面试的实录,前半段是八股文,后半段是笔试部分,感觉笔试部分的题目还是挺有意思的,特此记录一下。 笔试部分 问…

【多线程例题】顺序打印abc线程

顺序打印-进阶版 方法一:三个线程竞争同一个锁,通过count判断是否打印 方法二:三个线程同时start,分别上锁,从a开始,打印后唤醒b 三个线程分别打印A,B,C 方法一:通过co…

JavaFX中MVC例子理解

JavaFX可以让你使用GUI组件创建桌面应用程序。一个GUI应用程序执行三个任务:接受用户的输入,处理输入,并显示输出。而一个GUI应用程序包含两个 类型的代码: 领域代码。处理特定领域的数据和遵循业务规范。交互代码。处理用户输入…

【Linux】多线程(上)

本文详细介绍了多线程的常见概念 生产者消费者模型将在多线程(下)继续讲解 欢迎大家指正 提起讨论进步啊 目录 多线程的理解 线程的优点 线程的缺点: 线程的用途 线程VS进程 用户级线程库 POSIX线程库 线程创建: 线程…

springboot整合jwt

JWT介绍 JWT是JSON Web Token的缩写,即JSON Web令牌,是一种自包含令牌。 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放标准。 JWT的声明一般被用来在身份提供者和服务提供者间传递被认证的用户身份信息,以便于从资源服务器获…

基于.net6的WPF程序使用SignalR进行通信

之前写的SignalR通信,是基于.net6api,BS和CS进行通信的。 .net6API使用SignalRvue3聊天WPF聊天_signalr wpf_故里2130的博客-CSDN博客 今天写一篇关于CS客户端的SignalR通信,后台服务使用.net6api 。其实和之前写的差不多,主要在…

Ubuntu22.04密码忘记怎么办 Ubuntu重置root密码方法

在Ubuntu 22.04 或其他更高版本上不小心忘记root或其他账户的密码怎么办? 首先uname -r查看当前系统正在使用的内核版本,记下来 前提:是你的本地电脑,有物理访问权限。其他如远程登录的不适用这套改密方法。 通过以下步骤&#…

写字楼/办公楼能源管理系统的具体应用 安科瑞 许敏

0 引言 随着社会的进步,我国经济的快速发展,企业的办公环境和方式发生了巨大的变化,专业的写字楼在各大城市遍布林立。写字楼的出现使得各地企业办公集中化、高效化,然而写字楼物业管理的同步发展对于企业服务来说更是一个很大的…

SciencePub学术 | 区块链类重点SCIEEI征稿中

SciencePub学术 刊源推荐: 区块链类重点SCIE&EI征稿中!信息如下,录满为止: 一、期刊概况: SCI-01 【期刊简介】IF:4.0-4.5,JCR2区,中科院3区; 【检索情况】SCIE&EI双检&…

Mysql+ETLCloud CDC+StarRocks实时数仓同步实战

一、业务需求及其痛点 大型企业需要对各种业务系统中的销售及营销数据进行实时同步分析,例如库存信息、对帐信号、会员信息、广告投放信息,生产进度信息等等,这些统计分析信息可以实时同步到StarRocks中进行分析和统计,StarRocks…