竞赛 题目:基于深度学习的中文汉字识别 - 深度学习 卷积神经网络 机器视觉 OCR

文章目录

  • 0 简介
  • 1 数据集合
  • 2 网络构建
  • 3 模型训练
  • 4 模型性能评估
  • 5 文字预测
  • 6 最后

0 简介

🔥 优质竞赛项目系列,今天要分享的是

基于深度学习的中文汉字识别

该项目较为新颖,适合作为竞赛课题方向,学长非常推荐!

🧿 更多资料, 项目分享:

https://gitee.com/dancheng-senior/postgraduate

1 数据集合

学长手有3755个汉字(一级字库)的印刷体图像数据集,我们可以利用它们进行接下来的3755个汉字的识别系统的搭建。

用深度学习做文字识别,用的网络当然是CNN,那具体使用哪个经典网络?VGG?RESNET?还是其他?我想了下,越深的网络训练得到的模型应该会更好,但是想到训练的难度以及以后线上部署时预测的速度,我觉得首先建立一个比较浅的网络(基于LeNet的改进)做基本的文字识别,然后再根据项目需求,再尝试其他的网络结构。这次任务所使用的深度学习框架是强大的Tensorflow。

2 网络构建

第一步当然是搭建网络和计算图

其实文字识别就是一个多分类任务,比如这个3755文字识别就是3755个类别的分类任务。我们定义的网络非常简单,基本就是LeNet的改进版,值得注意的是我们加入了batch
normalization。另外我们的损失函数选择sparse_softmax_cross_entropy_with_logits,优化器选择了Adam,学习率设为0.1

#network: conv2d->max_pool2d->conv2d->max_pool2d->conv2d->max_pool2d->conv2d->conv2d->max_pool2d->fully_connected->fully_connecteddef build_graph(top_k):keep_prob = tf.placeholder(dtype=tf.float32, shape=[], name='keep_prob')images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1], name='image_batch')labels = tf.placeholder(dtype=tf.int64, shape=[None], name='label_batch')is_training = tf.placeholder(dtype=tf.bool, shape=[], name='train_flag')with tf.device('/gpu:5'):#给slim.conv2d和slim.fully_connected准备了默认参数:batch_normwith slim.arg_scope([slim.conv2d, slim.fully_connected],normalizer_fn=slim.batch_norm,normalizer_params={'is_training': is_training}):conv3_1 = slim.conv2d(images, 64, [3, 3], 1, padding='SAME', scope='conv3_1')max_pool_1 = slim.max_pool2d(conv3_1, [2, 2], [2, 2], padding='SAME', scope='pool1')conv3_2 = slim.conv2d(max_pool_1, 128, [3, 3], padding='SAME', scope='conv3_2')max_pool_2 = slim.max_pool2d(conv3_2, [2, 2], [2, 2], padding='SAME', scope='pool2')conv3_3 = slim.conv2d(max_pool_2, 256, [3, 3], padding='SAME', scope='conv3_3')max_pool_3 = slim.max_pool2d(conv3_3, [2, 2], [2, 2], padding='SAME', scope='pool3')conv3_4 = slim.conv2d(max_pool_3, 512, [3, 3], padding='SAME', scope='conv3_4')conv3_5 = slim.conv2d(conv3_4, 512, [3, 3], padding='SAME', scope='conv3_5')max_pool_4 = slim.max_pool2d(conv3_5, [2, 2], [2, 2], padding='SAME', scope='pool4')flatten = slim.flatten(max_pool_4)fc1 = slim.fully_connected(slim.dropout(flatten, keep_prob), 1024,activation_fn=tf.nn.relu, scope='fc1')logits = slim.fully_connected(slim.dropout(fc1, keep_prob), FLAGS.charset_size, activation_fn=None,scope='fc2')# 因为我们没有做热编码,所以使用sparse_softmax_cross_entropy_with_logitsloss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), labels), tf.float32))update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)if update_ops:updates = tf.group(*update_ops)loss = control_flow_ops.with_dependencies([updates], loss)global_step = tf.get_variable("step", [], initializer=tf.constant_initializer(0.0), trainable=False)optimizer = tf.train.AdamOptimizer(learning_rate=0.1)train_op = slim.learning.create_train_op(loss, optimizer, global_step=global_step)probabilities = tf.nn.softmax(logits)# 绘制loss accuracy曲线tf.summary.scalar('loss', loss)tf.summary.scalar('accuracy', accuracy)merged_summary_op = tf.summary.merge_all()# 返回top k 个预测结果及其概率;返回top K accuracypredicted_val_top_k, predicted_index_top_k = tf.nn.top_k(probabilities, k=top_k)accuracy_in_top_k = tf.reduce_mean(tf.cast(tf.nn.in_top_k(probabilities, labels, top_k), tf.float32))return {'images': images,'labels': labels,'keep_prob': keep_prob,'top_k': top_k,'global_step': global_step,'train_op': train_op,'loss': loss,'is_training': is_training,'accuracy': accuracy,'accuracy_top_k': accuracy_in_top_k,'merged_summary_op': merged_summary_op,'predicted_distribution': probabilities,'predicted_index_top_k': predicted_index_top_k,'predicted_val_top_k': predicted_val_top_k}

3 模型训练

训练之前我们应设计好数据怎么样才能高效地喂给网络训练。

首先,我们先创建数据流图,这个数据流图由一些流水线的阶段组成,阶段间用队列连接在一起。第一阶段将生成文件名,我们读取这些文件名并且把他们排到文件名队列中。第二阶段从文件中读取数据(使用Reader),产生样本,而且把样本放在一个样本队列中。根据你的设置,实际上也可以拷贝第二阶段的样本,使得他们相互独立,这样就可以从多个文件中并行读取。在第二阶段的最后是一个排队操作,就是入队到队列中去,在下一阶段出队。因为我们是要开始运行这些入队操作的线程,所以我们的训练循环会使得样本队列中的样本不断地出队。

在这里插入图片描述
入队操作都在主线程中进行,Session中可以多个线程一起运行。 在数据输入的应用场景中,入队操作是从硬盘中读取输入,放到内存当中,速度较慢。
使用QueueRunner可以创建一系列新的线程进行入队操作,让主线程继续使用数据。如果在训练神经网络的场景中,就是训练网络和读取数据是异步的,主线程在训练网络,另一个线程在将数据从硬盘读入内存。

# batch的生成
def input_pipeline(self, batch_size, num_epochs=None, aug=False):# numpy array 转 tensorimages_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)# 将image_list ,label_list做一个slice处理input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)labels = input_queue[1]images_content = tf.read_file(input_queue[0])images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)if aug:images = self.data_augmentation(images)new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)images = tf.image.resize_images(images, new_size)image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,min_after_dequeue=10000)# print 'image_batch', image_batch.get_shape()return image_batch, label_batch

训练时数据读取的模式如上面所述,那训练代码则根据该架构设计如下:

def train():print('Begin training')# 填好数据读取的路径train_feeder = DataIterator(data_dir='./dataset/train/')test_feeder = DataIterator(data_dir='./dataset/test/')model_name = 'chinese-rec-model'with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)) as sess:# batch data 获取train_images, train_labels = train_feeder.input_pipeline(batch_size=FLAGS.batch_size, aug=True)test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size)graph = build_graph(top_k=1)  # 训练时top k = 1saver = tf.train.Saver()sess.run(tf.global_variables_initializer())# 设置多线程协调器coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/val')start_step = 0# 可以从某个step下的模型继续训练if FLAGS.restore:ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)if ckpt:saver.restore(sess, ckpt)print("restore from the checkpoint {0}".format(ckpt))start_step += int(ckpt.split('-')[-1])logger.info(':::Training Start:::')try:i = 0while not coord.should_stop():i += 1start_time = time.time()train_images_batch, train_labels_batch = sess.run([train_images, train_labels])feed_dict = {graph['images']: train_images_batch,graph['labels']: train_labels_batch,graph['keep_prob']: 0.8,graph['is_training']: True}_, loss_val, train_summary, step = sess.run([graph['train_op'], graph['loss'], graph['merged_summary_op'], graph['global_step']],feed_dict=feed_dict)train_writer.add_summary(train_summary, step)end_time = time.time()logger.info("the step {0} takes {1} loss {2}".format(step, end_time - start_time, loss_val))if step > FLAGS.max_steps:breakif step % FLAGS.eval_steps == 1:test_images_batch, test_labels_batch = sess.run([test_images, test_labels])feed_dict = {graph['images']: test_images_batch,graph['labels']: test_labels_batch,graph['keep_prob']: 1.0,graph['is_training']: False}accuracy_test, test_summary = sess.run([graph['accuracy'], graph['merged_summary_op']],feed_dict=feed_dict)if step > 300:test_writer.add_summary(test_summary, step)logger.info('===============Eval a batch=======================')logger.info('the step {0} test accuracy: {1}'.format(step, accuracy_test))logger.info('===============Eval a batch=======================')if step % FLAGS.save_steps == 1:logger.info('Save the ckpt of {0}'.format(step))saver.save(sess, os.path.join(FLAGS.checkpoint_dir, model_name),global_step=graph['global_step'])except tf.errors.OutOfRangeError:logger.info('==================Train Finished================')saver.save(sess, os.path.join(FLAGS.checkpoint_dir, model_name), global_step=graph['global_step'])finally:# 达到最大训练迭代数的时候清理关闭线程coord.request_stop()coord.join(threads)

执行以下指令进行模型训练。因为我使用的是TITAN
X,所以感觉训练时间不长,大概1个小时可以训练完毕。训练过程的loss和accuracy变换曲线如下图所示

然后执行指令,设置最大迭代步数为16002,每100步进行一次验证,每500步存储一次模型。

python Chinese_OCR.py --mode=train --max_steps=16002 --eval_steps=100 --save_steps=500

在这里插入图片描述

4 模型性能评估

我们的需要对模模型进行评估,我们需要计算模型的top 1 和top 5的准确率。

执行指令

python Chinese_OCR.py --mode=validation

在这里插入图片描述

def validation():print('Begin validation')test_feeder = DataIterator(data_dir='./dataset/test/')final_predict_val = []final_predict_index = []groundtruth = []with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,allow_soft_placement=True)) as sess:test_images, test_labels = test_feeder.input_pipeline(batch_size=FLAGS.batch_size, num_epochs=1)graph = build_graph(top_k=5)saver = tf.train.Saver()sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())  # initialize test_feeder's inside statecoord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)if ckpt:saver.restore(sess, ckpt)print("restore from the checkpoint {0}".format(ckpt))logger.info(':::Start validation:::')try:i = 0acc_top_1, acc_top_k = 0.0, 0.0while not coord.should_stop():i += 1start_time = time.time()test_images_batch, test_labels_batch = sess.run([test_images, test_labels])feed_dict = {graph['images']: test_images_batch,graph['labels']: test_labels_batch,graph['keep_prob']: 1.0,graph['is_training']: False}batch_labels, probs, indices, acc_1, acc_k = sess.run([graph['labels'],graph['predicted_val_top_k'],graph['predicted_index_top_k'],graph['accuracy'],graph['accuracy_top_k']], feed_dict=feed_dict)final_predict_val += probs.tolist()final_predict_index += indices.tolist()groundtruth += batch_labels.tolist()acc_top_1 += acc_1acc_top_k += acc_kend_time = time.time()logger.info("the batch {0} takes {1} seconds, accuracy = {2}(top_1) {3}(top_k)".format(i, end_time - start_time, acc_1, acc_k))except tf.errors.OutOfRangeError:logger.info('==================Validation Finished================')acc_top_1 = acc_top_1 * FLAGS.batch_size / test_feeder.sizeacc_top_k = acc_top_k * FLAGS.batch_size / test_feeder.sizelogger.info('top 1 accuracy {0} top k accuracy {1}'.format(acc_top_1, acc_top_k))finally:coord.request_stop()coord.join(threads)return {'prob': final_predict_val, 'indices': final_predict_index, 'groundtruth': groundtruth}

5 文字预测

刚刚做的那一步只是使用了我们生成的数据集作为测试集来检验模型性能,这种检验是不大准确的,因为我们日常需要识别的文字样本不会像是自己合成的文字那样的稳定和规则。那我们尝试使用该模型对一些实际场景的文字进行识别,真正考察模型的泛化能力。

首先先编写好预测的代码

def inference(name_list):print('inference')image_set=[]# 对每张图进行尺寸标准化和归一化for image in name_list:temp_image = Image.open(image).convert('L')temp_image = temp_image.resize((FLAGS.image_size, FLAGS.image_size), Image.ANTIALIAS)temp_image = np.asarray(temp_image) / 255.0temp_image = temp_image.reshape([-1, 64, 64, 1])image_set.append(temp_image)# allow_soft_placement 如果你指定的设备不存在,允许TF自动分配设备with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options,allow_soft_placement=True)) as sess:logger.info('========start inference============')# images = tf.placeholder(dtype=tf.float32, shape=[None, 64, 64, 1])# Pass a shadow label 0. This label will not affect the computation graph.graph = build_graph(top_k=3)saver = tf.train.Saver()# 自动获取最后一次保存的模型ckpt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)if ckpt:       saver.restore(sess, ckpt)val_list=[]idx_list=[]# 预测每一张图for item in image_set:temp_image = itempredict_val, predict_index = sess.run([graph['predicted_val_top_k'], graph['predicted_index_top_k']],feed_dict={graph['images']: temp_image,graph['keep_prob']: 1.0,graph['is_training']: False})val_list.append(predict_val)idx_list.append(predict_index)#return predict_val, predict_indexreturn val_list,idx_list

这里需要说明一下,我会把我要识别的文字图像存入一个叫做tmp的文件夹内,里面的图像按照顺序依次编号,我们识别时就从该目录下读取所有图片仅内存进行逐一识别。

# 获待预测图像文件夹内的图像名字
def get_file_list(path):list_name=[]files = os.listdir(path)files.sort()for file in files:file_path = os.path.join(path, file)list_name.append(file_path)return list_name

那我们使用训练好的模型进行汉字预测,观察效果。首先我从一篇论文pdf上用截图工具截取了一段文字,然后使用文字切割算法把文字段落切割为单字,如下图,因为有少量文字切割失败,所以丢弃了一些单字。

从一篇文章中用截图工具截取文字段落。

在这里插入图片描述
切割出来的单字,黑底白字。

在这里插入图片描述

最后将所有的识别文字按顺序组合成段落,可以看出,汉字识别完全正确,说明我们的基于深度学习的OCR系统还是相当给力!

在这里插入图片描述

至此,支持3755个汉字识别的OCR系统已经搭建完毕,经过测试,效果还是很不错。这是一个没有经过太多优化的模型,在模型评估上top
1的正确率达到了99.9%,这是一个相当优秀的效果了,所以说在一些比较理想的环境下的文字识别的效果还是比较给力,但是对于复杂场景的或是一些干扰比较大的文字图像,识别起来的效果可能不会太理想,这就需要针对特定场景做进一步优化。

6 最后

🧿 更多资料, 项目分享:

https://gitee.com/dancheng-senior/postgraduate

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/153406.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是媒体见证?媒体宣传有哪些好处?

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 一,什么是媒体见证? 媒体见证是指企业举办活动,发布会,邀请媒体现场采访的一种宣传方式,媒体到场后,对其进行记录…

lenovo联想笔记本ThinkPad P1 Gen5/X1 Extreme Gen5原装出厂Windows11预装OEM系统

链接:https://pan.baidu.com/s/13E97Nwc-0-N7ffPjEeeeOw?pwdep4l 提取码:ep41 原装出厂系统自带所有驱动、出厂主题壁纸、Office办公软件、联想电脑管家等预装程序 所需要工具:32G或以上的U盘 文件格式:ISO 文件大小&#xff…

Java实现俄罗斯方块游戏

俄罗斯方块游戏本身的逻辑: 俄罗斯方块游戏的逻辑是比较简单的。它就类似于堆砌房子一样,各种各样的方地形状是不同的。但是,俄罗斯方块游戏的界面被等均的分为若干行和若干列,因此方块的本质就是占用了多少个单元。 首先来考虑…

解决 Python requests 库中 SSL 错误转换为 Timeouts 问题

解决 Python requests 库中 SSL 错误转换为 Timeouts 问题:理解和处理 SSL 错误的关键 在使用Python的requests库进行HTTPS请求时,可能会遇到SSL错误,这些错误包括但不限于证书不匹配、SSL层出现问题等。如果在requests库中设置verifyFalse&…

《向量数据库指南》——Range Search 使用方法和参数检查

Range Search 使用方法 如需使用 Range Search,只需要修改搜索请求中的搜索参数。接下来我会讲一下的详细使用指南,在指南的最后还提供了 Python 示例代码。 开始前 请确保已安装并运行 Milvus Cloud。请确保已创建 1 个 Collection,并为该 Collection 创建索引。 Ra…

【LeetCode:2216. 美化数组的最少删除数 | 贪心】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

掌握源码,轻松搭建:一站式建站系统源码 附完整搭建步骤与教程

随着互联网的快速发展,网站已成为人们生活中不可或缺的一部分。然而,对于许多初学者或中小企业来说,搭建一个完整的网站系统并非易事。这涉及到前端和后端的开发、数据库管理等多个环节。为了解决这一痛点,我们推出了一站式建站系…

sortablejs拖拽后新增和删除行时顺序错乱

问题描述:如下图所示,使用sortablejs拖拽后,在序号2后新增行会出现新增行跑到第一行的错误顺序。 解决:在进行拖拽后,对表格数据进行清空重新赋值。

一种可度量的测试体系-精准测试

行业现状 软件行业长期存在一个痛点,即测试效果无法度量。通常依赖于测试人员的能力和经验,测试结果往往不可控,极端情况下同一个业务功能,即使是同一个人员在不同的时间段,测试场景和过程也可能不一致,从而…

抖音电商双11官方数据最全汇总!

11月13日,抖音电商数据发布“抖音商城双11好物节”数据报告,展现双11期间平台全域经营情况及大众消费趋势。 报告显示,10月20日至11月11日,抖音电商里的直播间累计直播时长达到5827万小时,挂购物车的短视频播放了1697亿…

第十一篇 基于JSP 技术的网上购书系统——产品类别管理、评论/留言管理、注册用户管理、新闻管理功能实现(网上商城、仿淘宝、当当、亚马逊)

目录 1.产品类别管理 1.1功能说明 1.2界面设计 1.3处理流程 1.4数据来源和算法 1.4.1数据来源 1.4.2 查询条件 1.4.3相关sql实例 2. 评论/留言管理 2.1功能说明 2.2 界面设计 2.3处理流程 2.4数据来源和算法 2.4.1数据来源 2.4.2 查询条件 2.4.3相关sql实例…

vue3 使用simplebar【滚动条】

1.下载simplebar-vue npm install simplebar-vue --save2.引入注册 import simplebar from "simplebar-vue"; import simplebar-vue/dist/simplebar.min.css import simplebar-vue/dist/simplebar-vue.jsvue2的版本基础上 【引入注册】 import simplebar from &qu…

IDEA 搭建 SpringCloud 项目【超详细步骤】

文章目录 一、前言二、项目搭建1. 数据库准备2. 创建父工程3. 创建注册中心4. 服务注册5. 编写业务代码6. 服务拉取 一、前言 所谓微服务,就是要把整个业务模块拆分成多个各司其职的小模块,做到单一职责原则,不会重复开发相同的业务代码&…

数据预处理pandas pd.json_normalize占用内存过大优化

问题描述 从ES下载数据,数据格式为json,然后由pandas进行解析,json中的嵌套字段会进行展开作为列名(由于维度初期无法预测,所以根据数据有啥列就使用啥列,这是最方便的点),变成表格,方面了后续…

电脑开不了机怎么办?三招帮你成功解决!

电脑是我们日常工作和生活的重要工具,但有时候它们也会出现开机问题。当电脑无法启动时,可能会让人感到焦虑,电脑开不了机怎么办?不必担心,通常有多种方法可以解决这些问题。本文将介绍三种常见的方法,以帮…

【广州华锐互动】VR虚拟现实技术助力太空探险:穿越时空,探索宇宙奥秘

随着科技的不断发展,虚拟现实(VR)技术已经逐渐走进我们的生活。在教育领域,VR技术的应用也日益广泛,为学生提供了更加生动、直观的学习体验。本文将以利用VR开展太空探险学习为主题,探讨如何将这一先进技术…

提升办公效率,畅享多功能办公笔记软件Notion for Mac

在现代办公环境中,高效的笔记软件对于提高工作效率至关重要。而Notion for Mac作为一款全能的办公笔记软件,将成为你事业成功的得力助手。 Notion for Mac以其多功能和灵活性而脱颖而出。无论你是需要记录会议笔记、管理项目任务、制定流程指南&#xf…

基于springboot实现冬奥会科普平台系统【项目源码+论文说明】计算机毕业设计

基于SpringBoot实现冬奥会科普平台系统演示 摘要 随着信息技术和网络技术的飞速发展,人类已进入全新信息化时代,传统管理技术已无法高效,便捷地管理信息。为了迎合时代需求,优化管理效率,各种各样的管理平台应运而生&…

图像的傅里叶变换

目录 ​编辑 傅里叶基础 傅里叶基础numpy实现 逆傅里叶numpy实现 频域的高通滤波 傅里叶OpenCV实现 傅里叶OpenCV逆变换实现 频域的低通滤波 傅里叶变换有什么应用场景 傅里叶变换matlab实现 傅里叶基础 法国数学家吉恩巴普提斯特约瑟夫傅里叶被世人铭记的最大的贡献…