OCR文本识别模型CRNN

CRNN网络结构

论文地址:https://arxiv.org/pdf/1507.05717
参考:https://blog.csdn.net/xiaosongshine/article/details/112198145

git:https://github.com/shuyeah2356/crnn.pytorch
CRNN文本识别实现端到端的不定长文本识别。
CRNN网络把包含三部分:卷积层(CNN)、循环层(RNN)和转录层(CTC loss)
在这里插入图片描述

1、卷积层::通过深层卷积操作对输入图像做特征提取,得到特征图;
2、循环层: 循环层使用双向LSTM(BLSTM)对特征序列进行预测,对序列中的每一个特征向量进行学习,并输出预测标签(真实值)的分布;
3、转录层:转录层使用CTC loss ,把循环层获取的一系列标签分布转换成最终的标签序列。

对于输入图片:
输入图像为灰度图(单通道);
高度为32,经过卷积处理后高度变为1;
输入图片宽度为100,输入图片大小为(100,32,1)
CNN输出尺寸为(512, 1, 40),卷积操作输出512个特征图,每一个特征图高度为1,宽度为26。

在代码中有判断图片的高度能够被16整除。

assert imgH % 16 == 0

1、CNN

卷积层用来提取文图像的特征,堆叠使用卷积层和最大池化层,特别的,最后两个最大池化层在宽度和高度上的步长是不相等的池化的窗口尺寸是(w,h):(1,2),因为待识别的文本图片多数是高较小而宽较长,使用1×2的池化窗口尽量不丢失在宽度方向的信息。
卷积操作的具体实现代码:

# 输入图片大小为(160,32,1)
assert imgH % 16 == 0, 'imgH has to be a multiple of 16 图片高度必须为16的倍数'# 一共有7次卷积操作ks = [3, 3, 3, 3, 3, 3, 2]  # 卷积层卷积尺寸3表示3x3,2表示2x2ps = [1, 1, 1, 1, 1, 1, 0]  # padding大小ss = [1, 1, 1, 1, 1, 1, 1]  # stride大小nm = [64, 128, 256, 256, 512, 512, 512]  # 卷积核个数,卷积操作输出特征层的通道数cnn = nn.Sequential()def convRelu(i, batchNormalization=False):  # 创建卷积层nIn = nc if i == 0 else nm[i - 1]  # 确定输入channel维度,如果是第一层网络,输入通道数为图片通道数,输入特征层的通道数为上一个特征层的输出通道数nOut = nm[i]  # 确定输出channel维度cnn.add_module('conv{0}'.format(i),nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))  # 添加卷积层# BN层if batchNormalization:cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))# Relu激活层if leakyRelu:cnn.add_module('relu{0}'.format(i),nn.LeakyReLU(0.2, inplace=True))else:cnn.add_module('relu{0}'.format(i), nn.ReLU(True))# 卷积核大小为3×3,s=1,p=1,输出通道数为64,特征层大小为100×32×64convRelu(0)# 经过2×2Maxpooling,宽高减半,特征层大小变为50×16×64cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))# 卷积核大小为3×3,s=1,p=1,输出通道数为128,特征层大小为50×16×128convRelu(1)# 经过2×2Maxpooling,宽高减半,特征层大小变为25×8×128cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))# 卷积核大小为3×3,s=1,p=1,输出通道数为256,特征层大小为25×8×256,卷积后面接BatchNormalizationconvRelu(2, True)# 卷积核大小为3×3,s=1,p=1,输出通道数为256,特征层大小为25×8×256convRelu(3)# 经过MaxPooling,卷积核大小为2×2,在h上stride=2,p=0,s=2,h=(8+0-2)//2+1=4,w上的stride=1,p=1,s=1,w=(25+2-2)//1+1=26通道数不变,26×4×256cnn.add_module('pooling{0}'.format(2),nn.MaxPool2d((2, 2), (2, 1), (0, 1)))    # 参数 (h, w)# 卷积核大小为3×3,s=1,p=1,输出通道数为512,特征层大小为50×16×512,卷积后面接BatchNormalizationconvRelu(4, True)# 卷积核大小为3×3,s=1,p=1,输出通道数为512,特征层大小为26×4×512convRelu(5)# 经过MaxPooling,卷积核大小为2×2,在h上stride=2,p=0,s=2,h=(4+0-2)//2+1=2,w上的stride=1,p=1,s=1,w=(26+2-2)//1+1=27通道数不变,27×2×512cnn.add_module('pooling{0}'.format(3),nn.MaxPool2d((2, 2), (2, 1), (0, 1)))# 卷积核大小为2×2,s=1,p=0,输出通道数为512,特征层大小为26×1×512convRelu(6, True)

对应的网络结构:
在这里插入图片描述

这里卷积操作后的特征图大小为26×1×512

2、RNN

对于卷积操作输出的结果经过处理之后才能输入到RNN中。
卷积操作输出的特征图高度一定为1,代码中也做约束

assert h == 1

将h,w维度合并,合并后的维度变为输入到RNN中的时间不长(time_step),每一个序列的长度为原始特征层的通道数512
输入到LSTM中的特征图大小是多少(面试被问到的问题):
每次输入到LSTM中的特征,时间不长的数量为原始特征的h×w(26),每次输入一个序列,序列长度为原始特征层的通道数512

def forward(self, input):# conv featuresconv = self.cnn(input)b, c, h, w = conv.size()# batch_size,512,1,26assert h == 1, "the height of conv must be 1"# 将宽高维度合并,特征层大小为(batch_size, 512, 26)conv = conv.squeeze(2)# 维度顺序调整(26, batch_size, 512),w作为时间步长(作为LSTM中的一个时间不长time_step)conv = conv.permute(2, 0, 1)  # rnn featuresoutput = self.rnn(conv)# print(output.size())return output

序列是按照列从左到右生成的,每一列包含512为特征,输入到LSTM中的第i个特征是特征图第i列像素的连接。
在这里插入图片描述
对于卷积操作、Maxpooling层和BatchNormalization卷积操作具有平移不变性。每一个从左到右的序列对应原始图像中的一个矩形区域,且顺序是对应的。特征序列中的每一个向量对一个原图中的一个感受野。
![在这里插入图片描述](https://img-blog.csdnimg.cn/direct/03b310b41f4a46e2b4d56423f7cf7c95.png =600x500在这里插入图片描述

将特征序列按照time_step输入到RNN中,RNN隐藏层的神经元数量为256,使用多层的双向LSTM。
输入时间不长的数量为26,经过RNN得到26个特征向量,输出特征向量的类别分类结果。
每一个特征向量对应的是原图中的一个小的矩形区域。RNN来判断这个矩形区域属于哪个字符。
根据输入的特征向量得到所有字符的softmax概率分布,这是一个长度为待识别字符类别总数量的向量,RNN的输出向量作为转录层CTC的输入。
LSTM实现的代码:

class BidirectionalLSTM(nn.Module):def __init__(self, nIn, nHidden, nOut):super(BidirectionalLSTM, self).__init__()self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)     # nIn:输入神经元个数# *2因为使用双向LSTM,将双向的隐藏层单元拼接在一起,两层个256单元的双向LSTMself.embedding = nn.Linear(nHidden * 2, nOut)def forward(self, input):# 经过RNN输出feature map特征结果recurrent, _ = self.rnn(input)T, b, h = recurrent.size()  # T:时间步长,b:batch size,h:hiden unitt_rec = recurrent.view(T * b, h)# 第一次LSTM得到特征层[26×256,256],view成[26, 256, 256]# 第二次LSTM得到特征层[26×256,num_class],view成[26, 256, num_class]output = self.embedding(t_rec)  # [T * b, nOut]output = output.view(T, b, -1)return output
# nh为隐藏层神经节点数,nclass为所有识别字符的类别总数
self.rnn = nn.Sequential(BidirectionalLSTM(512, nh, nh), # 输入的时间步长为512BidirectionalLSTM(nh, nh, nclass))

第一次LSTM得到特征层[26×256,256],view成[26, 256, 256]
第二次LSTM得到特征层[26×256,num_class],view成[26, 256, num_class]

3、转录层CTC(Connectionist Temporal Classification)

转录层将RNN对每个特征向量做的预测转换成标签序列的过程。对于不定长序列的对齐问题。
RNN进行序列分类时,可能会出现一个字被识别多次,需要去除冗余机制。
处理的方法(引入blank机制)这一过程称为解码过程:

  1. 在重复的字符之间增加一个空格‘-’,
  2. 删除连续重复的字符,
  3. 再去掉路径中左右的‘-’字符
    编码过程是由神经网络来实现的。
    文本标签可以有多个不同的字符组合路径得到。

CTC loss如何计算:
在训练阶段根据这些概率分布向量和对应的文本标签计算损失函数。
根据能得到对应标签的所有路径的分数之和类计算损失函数。
每条路径的概率为每一个时间步中对应字符的分数的乘积。CTC损失函数定义为概率的负最大似然函数,为了计算方便对函数取对数。

p(l|y)= ∑ π : B ( π ) p ( π ∣ y ) \sum\limits_{π:B(π)}p(π|y) πB(π)p(πy)

预测过程如何实现
先使用标准的CNN网络提取文本特征;
利用BLSTM将特征向量进行融合,已提取字符序列的上下文特征,得到每列特征的概率分布;
最后通过CTC进行预测得到文本序列。

在训练阶段CRNN将特征图像统一缩放到w×32,而在测试阶段对于输入的图片拉伸会导致识别率降低。CRNN保持输入图像尺寸比例,但是图像的高度h必须统一为32,卷积特征图的尺寸动态决定了LSTM的时序长度(时间步长)。


感谢:
https://blog.csdn.net/xiaosongshine/article/details/112198145
https://github.com/meijieru/crnn.pytorch
https://www.bilibili.com/video/BV1Wy4y1473z?p=2&vd_source=91cfed371d5491e2973d221d250b54ae

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/8232.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

两个手机在一起ip地址一样吗?两个手机是不是两个ip地址

在数字时代的浩瀚海洋中,手机已经成为我们生活中不可或缺的一部分。随着移动互联网的飞速发展,IP地址成为了连接手机与互联网的桥梁。那么,两个手机在一起IP地址一样吗?两个手机是不是两个IP地址?本文将带您一探究竟&a…

微火全域外卖系统是什么?为什么市场占有率这么高?

近日,全域外卖领域又出现了新变动,一个名为微火的品牌凭借着其全域外卖系统,在短短几个月的时间里,就占领了大部分市场。截止发稿日期前,微火全域外卖系统的市场占有率已经超过48%。 据了解,所谓的全域外卖…

微信小程序之搜索框样式(带源码)

一、效果图&#xff1a; 点击搜索框&#xff0c;“请输入搜索内容消失”&#xff0c;可输入关键字 二、代码&#xff1a; 2.1、WXML代码&#xff1a; <!--搜索框部分--><view class"search"><view class"search-btn">&#x1f50d;&l…

数据库复习2

试述SQL的特点 有两个关系 S(A,B,C, D)和 T(C,D,E,F)&#xff0c;写出与下列查询等价的 SQL 表达式: 用SQL语句建立第2章习题6中的4个表&#xff1b;针对建立的4个表用SQL完成第2章习题6中的查询 针对习题4中的4个表试用SQL完成以下各项操作 (1)找出所有供应商的姓名和所在城市…

[图解]SysML和EA建模住宅安全系统-02

1 00:00:00,900 --> 00:00:02,690 这个就是一个块定义图了 2 00:00:03,790 --> 00:00:04,780 简称BDD 3 00:00:05,610 --> 00:00:08,070 实际上就是UML里面的类图 4 00:00:08,080 --> 00:00:09,950 和组件图的一个结合体 5 00:00:13,150 --> 00:00:14,690 我…

WDW-10B微机控制电子万能试验机技术方案

一&#xff0e;设备外观照片&#xff1a; 项目简介&#xff1a; 微机控制电子式万能试验机是专门针对高等院校、各种金属、非金属科研厂家及国家级质检单位而设计的高端微机控制电子式万能试验机、计算机系统通过全数字控制器&#xff0c;经调速系统控制伺服电机转动&#xff…

MT3033 新的表达式

代码&#xff1a; #include <bits/stdc.h> using namespace std; bool is_op(char c) {return c & || c |; } int priority(char op) { // 运算优先级。如果有-*/等别的运算符&#xff0c;则这个函数很有必要if (op & || op |){return 1;}return -1; } voi…

数据链路层(详细版)【01】

数据链路层是在物理层和网络层之间的协议&#xff0c;提供相邻节点的可靠数据传输 一、从体系结构来看数据链路层 数据链路层是为上下两层提供服务或者上下两层向他传送数据&#xff08;服务【垂直】&#xff09;&#xff1b;与其对等层之间用帧进行通信&#xff08;协议【水平…

2024年51cto下载的视频怎么导出

如果你喜欢在51cto上观看各种专业技术视频&#xff0c;那么你可能想将喜欢的视频保存到本地设备中&#xff0c;以便随时随地观看。今天&#xff0c;我们就来探讨一下如何在2024年将51cto下载的视频导出到你的设备中 下载51cto的工具我已经打包好了&#xff0c;有需要的自己下载…

重学java 31.API 2.StringBuilder

总有一天&#xff0c;我不再畏惧任何人的离开 —— 24.5.8 StringBuilder的介绍 1.概述 一个可变的字符序列,此类提供了一个与StringBuffer兼容的一套API&#xff0c;但是不保证同步&#xff08;线程不安全&#xff0c;效率高&#xff…

Qt 6.7 正式发布!

本文翻译自&#xff1a;Qt 6.7 Released! 原文作者&#xff1a;Qt Group研发总监Volker Hilsheimer 在最新发布的Qt 6.7版本中&#xff0c;我们大大小小作出了许多改善&#xff0c;以便您在构建现代应用程序和用户体验时能够享受更多乐趣。 部分新增功能已推出了技术预览版&a…

scikit-learn多因子线性回归预测房价

1.首先是单因子线性回归预测房价 import numpy as np import pandas as pd from matplotlib import pyplot as plt from sklearn.linear_model import LinearRegression from sklearn.metrics import mean_squared_error, r2_score# 1.读取csa房屋数据 path D:/pythonDATA/us…

地道俄语口语,柯桥俄语培训哪家好

1、по-моему 依我看&#xff1b;在我看来 例&#xff1a; По-моему, сегодня будет дождь. 依我看, 今天要下雨。 Сделай по-моему. 按我的办法干吧 2、кажется 似乎是&#xff1b;看起来 例&#xff1a; Парень, …

mvc区域、Html.RenderAction、Html.RenderPartial、 模板、section

根据上图 Html.RenderPartial 与 Html.RenderAction 区别 RenderAction 会把对应的视图结果渲染 RenderPartial 会把html视图直接渲染 模板

The 2021 Sichuan Provincial Collegiate Programming Contest

The 2021 Sichuan Provincial Collegiate Programming Contest The 2021 Sichuan Provincial Collegiate Programming Contest A. Chuanpai 题意&#xff1a;给出总值k&#xff0c;将k分成xyk&#xff0c;x和y均小于7&#xff0c;最多分成多少组。 思路&#xff1a;暴力跑一…

【工具】Office/WPS 插件|AI 赋能自动化生成 PPT 插件测评 —— 必优科技 ChatPPT

本文参加百度的有奖征文活动&#xff0c;更主要的也是借此机会去体验一下 AI 生成 PPT 的产品的现状&#xff0c;因此本文是设身处地从用户的角度去体验、使用这个产品&#xff0c;并反馈最真实的建议和意见&#xff0c;除了明确该产品的优点之外&#xff0c;也发现了不少缺陷和…

实战Java虚拟机-基础篇

JVM的组成 一、自动垃圾回收 1.Java的内存管理 Java中为了简化对象的释放&#xff0c;引入了自动的垃圾回收&#xff08;Garbage Collection简称GC&#xff09;机制。通过垃圾回收器来对不再使用的对象完成自动的回收&#xff0c;垃圾回收器主要负责对堆上的内存进行回收。其…

vue项目基于WebRTC实现一对一音视频通话

效果 前端代码 <template><div class"flex items-center flex-col text-center p-12 h-screen"><div class"relative h-full mb-4 fBox"><video id"localVideo"></video><video id"remoteVideo">…

【MQTT】服务端、客户端工具使用记录

目录 一、服务端 1.1 下载 1.2 相关命令 &#xff08;1&#xff09;启动 &#xff08;2&#xff09;关闭 &#xff08;3&#xff09;修改用户名和密码 1.3 后台管理 &#xff08;1&#xff09;MQTT配置 &#xff08;2&#xff09;集群概览 &#xff08;3&#xff09;…

如何使用SkyWalking收集分析分布式系统的追踪数据

Apache SkyWalking 是一个开源的观测性工具&#xff0c;用于收集、分析和展示分布式系统的追踪数据。SkyWalking 支持多种语言的追踪&#xff0c;包括但不限于 Java、.NET、Node.js 等。以下是使用 SkyWalking 工具实现数据采集的详细步骤&#xff1a; 1. 下载和安装 SkyWalkin…