不定长图片验证码训练

基于LSTM和CTCLoss训练不定长图片验证码
Github项目地址:https://github.com/JansonJo/captcha_ocr.git

# coding=utf-8
"""
将三通道的图片转为灰度图进行训练
"""
import itertools
import os
import re
import random
import string
from collections import Counter
from os.path import join
import yaml
import cv2
import numpy as np
import tensorflow as tf
from keras import backend as K
from keras.callbacks import ModelCheckpoint, EarlyStopping, Callback
from keras.layers import Input, Dense, Activation, Dropout, BatchNormalization, Reshape, Lambda
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.merge import add, concatenate
from keras.layers.recurrent import GRU
from keras.models import Model, load_modelf = open('./config/config_demo.yaml', 'r', encoding='utf-8')
cfg = f.read()
cfg_dict = yaml.load(cfg)config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# config.gpu_options.per_process_gpu_memory_fraction = cfg_dict['System']['GpuMemoryFraction']
session = tf.Session(config=config)
K.set_session(session)# System config
TRAIN_SET_PTAH = cfg_dict['System']['TrainSetPath']
VALID_SET_PATH = cfg_dict['System']['TestSetPath']
TEST_SET_PATH = cfg_dict['System']['TestSetPath']
MAX_TEXT_LEN = cfg_dict['System']['MaxTextLenth']
IMG_W = cfg_dict['System']['IMG_W']
IMG_H = cfg_dict['System']['IMG_H']
MODEL_NAME = cfg_dict['System']['ModelName']
LABEL_REGEX = cfg_dict['System']['LabelRegex']
ALPHABET = cfg_dict['System']['Alphabet']# NeuralNet config
RNN_SIZE = cfg_dict['NeuralNet']['RNNSize']
DROPOUT = cfg_dict['NeuralNet']['Dropout']# TrainParam config
MONITOR = cfg_dict['TrainParam']['EarlyStoping']['monitor']
PATIENCE = cfg_dict['TrainParam']['EarlyStoping']['patience']
MODE = cfg_dict['TrainParam']['EarlyStoping']['mode']
BASELINE = cfg_dict['TrainParam']['EarlyStoping']['baseline']
EPOCHS = cfg_dict['TrainParam']['Epochs']
BATCH_SIZE = cfg_dict['TrainParam']['BatchSize']
TEST_BATCH_SIZE = cfg_dict['TrainParam']['TestBatchSize']
TEST_SET_NUM = cfg_dict['TrainParam']['TestSetNum']def get_counter(dirpath):letters = ''lens = []for root, dirs, files in os.walk(dirpath):for filename in files:m = re.search(LABEL_REGEX, filename, re.M | re.I)description = m.group(1)lens.append(len(description))letters += descriptionprint('Max plate length in "%s":' % dirpath, max(Counter(lens).keys()))return Counter(letters)c_val = get_counter(VALID_SET_PATH)
c_train = get_counter(TRAIN_SET_PTAH)
letters_train = set(c_train.keys())
letters_val = set(c_val.keys())
print('letters_train: %s' % ''.join(sorted(letters_train)))
print('letters_val: %s' % ''.join(sorted(letters_val)))
if letters_train == letters_val:print('Letters in train and val do match')
else:raise Exception('Letters in train and val don\'t match')
# print(len(letters_train), len(letters_val), len(letters_val | letters_train))
# letters = sorted(list(letters_train))
# letters = '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
letters = ALPHABET
if len(letters) == 0:letters = string.digits + string.ascii_uppercase + string.ascii_lowercase
class_num = len(letters) + 1   # plus 1 for blank
print('Alphabet Letters:', ''.join(letters))# Input data generatordef labels_to_text(labels):# return ''.join(list(map(lambda x: letters[int(x)], labels)))return ''.join([letters[int(x)] if int(x) != len(letters) else '' for x in labels])def text_to_labels(text):# return list(map(lambda x: letters.index(x), text))return [letters.find(x) if letters.find(x) > -1 else len(letters) for x in text]def is_valid_str(s):for ch in s:if not ch in letters:return Falsereturn Trueclass TextImageGenerator:def __init__(self,dirpath,tag,img_w, img_h,batch_size,downsample_factor,max_text_len=MAX_TEXT_LEN):self.img_h = img_hself.img_w = img_wself.batch_size = batch_sizeself.max_text_len = max_text_lenself.downsample_factor = downsample_factorimg_dirpath = dirpathself.samples = []for filename in os.listdir(img_dirpath):name, ext = os.path.splitext(filename)if ext in ['.png', '.jpg']:img_filepath = join(img_dirpath, filename)m = re.search(LABEL_REGEX, filename, re.M | re.I)description = m.group(1)if len(description) < MAX_TEXT_LEN:description = description + '_' * (MAX_TEXT_LEN - len(description))# if is_valid_str(description):#     self.samples.append([img_filepath, description])self.samples.append([img_filepath, description])self.n = len(self.samples)self.indexes = list(range(self.n))self.cur_index = 0# build data:self.imgs = np.zeros((self.n, self.img_h, self.img_w))self.texts = []for i, (img_filepath, text) in enumerate(self.samples):img = cv2.imread(img_filepath)img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)     # cv2默认是BGR模式img = cv2.resize(img, (self.img_w, self.img_h))img = img.astype(np.float32)img /= 255# width and height are backwards from typical Keras convention# because width is the time dimension when it gets fed into the RNNself.imgs[i, :, :] = imgself.texts.append(text)@staticmethoddef get_output_size():return len(letters) + 1def next_sample(self):self.cur_index += 1if self.cur_index >= self.n:self.cur_index = 0random.shuffle(self.indexes)return self.imgs[self.indexes[self.cur_index]], self.texts[self.indexes[self.cur_index]]def next_batch(self):while True:# width and height are backwards from typical Keras convention# because width is the time dimension when it gets fed into the RNNif K.image_data_format() == 'channels_first':X_data = np.ones([self.batch_size, 1, self.img_w, self.img_h])else:X_data = np.ones([self.batch_size, self.img_w, self.img_h, 1])Y_data = np.ones([self.batch_size, self.max_text_len])input_length = np.ones((self.batch_size, 1)) * (self.img_w // self.downsample_factor - 2)label_length = np.zeros((self.batch_size, 1))source_str = []for i in range(self.batch_size):img, text = self.next_sample()img = img.Tif K.image_data_format() == 'channels_first':img = np.expand_dims(img, 0)else:img = np.expand_dims(img, -1)X_data[i] = imgY_data[i] = text_to_labels(text)source_str.append(text)text = text.replace("_", "")  # important steplabel_length[i] = len(text)inputs = {'the_input': X_data,'the_labels': Y_data,'input_length': input_length,'label_length': label_length,# 'source_str': source_str}outputs = {'ctc': np.zeros([self.batch_size])}yield (inputs, outputs)tiger = TextImageGenerator(VALID_SET_PATH, 'val', IMG_W, IMG_H, 8, 4)for inp, out in tiger.next_batch():print('Text generator output (data which will be fed into the neutral network):')print('1) the_input (image)')if K.image_data_format() == 'channels_first':img = inp['the_input'][0, 0, :, :]else:img = inp['the_input'][0, :, :, 0]# plt.imshow(img.T, cmap='gray')# plt.show()print('2) the_labels (plate number): %s is encoded as %s' %(labels_to_text(inp['the_labels'][0]), list(map(int, inp['the_labels'][0]))))print('3) input_length (width of image that is fed to the loss function): %d == %d / 4 - 2' %(inp['input_length'][0], tiger.img_w))print('4) label_length (length of plate number): %d' % inp['label_length'][0])break# # Loss and train functions, network architecture
def ctc_lambda_func(args):y_pred, labels, input_length, label_length = args# the 2 is critical here since the first couple outputs of the RNN# tend to be garbage:y_pred = y_pred[:, 2:, :]return K.ctc_batch_cost(labels, y_pred, input_length, label_length)downsample_factor = 4def train(img_w=IMG_W, img_h=IMG_H, dropout=DROPOUT, batch_size=BATCH_SIZE, rnn_size=RNN_SIZE):# Input Parameters# Network parametersconv_filters = 16kernel_size = (3, 3)pool_size = 2time_dense_size = 32if K.image_data_format() == 'channels_first':input_shape = (1, img_w, img_h)else:input_shape = (img_w, img_h, 1)global downsample_factordownsample_factor = pool_size ** 2tiger_train = TextImageGenerator(TRAIN_SET_PTAH, 'train', img_w, img_h, batch_size, downsample_factor)tiger_val = TextImageGenerator(VALID_SET_PATH, 'val', img_w, img_h, batch_size, downsample_factor)act = 'relu'input_data = Input(name='the_input', shape=input_shape, dtype='float32')inner = Conv2D(conv_filters, kernel_size, padding='same',activation=None, kernel_initializer='he_normal',name='conv1')(input_data)inner = BatchNormalization()(inner)  # add BNinner = Activation(act)(inner)inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max1')(inner)inner = Conv2D(conv_filters, kernel_size, padding='same',activation=None, kernel_initializer='he_normal',name='conv2')(inner)inner = BatchNormalization()(inner)  # add BNinner = Activation(act)(inner)inner = MaxPooling2D(pool_size=(pool_size, pool_size), name='max2')(inner)conv_to_rnn_dims = (img_w // (pool_size ** 2), (img_h // (pool_size ** 2)) * conv_filters)inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)# cuts down input size going into RNN:inner = Dense(time_dense_size, activation=None, name='dense1')(inner)inner = BatchNormalization()(inner)  # add BNinner = Activation(act)(inner)if dropout:inner = Dropout(dropout)(inner)  # 防止过拟合# Two layers of bidirecitonal GRUs# GRU seems to work as well, if not better than LSTM:gru_1 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru1')(inner)gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru1_b')(inner)gru1_merged = add([gru_1, gru_1b])gru_2 = GRU(rnn_size, return_sequences=True, kernel_initializer='he_normal', name='gru2')(gru1_merged)gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True, kernel_initializer='he_normal', name='gru2_b')(gru1_merged)inner = concatenate([gru_2, gru_2b])if dropout:inner = Dropout(dropout)(inner)  # 防止过拟合# transforms RNN output to character activations:inner = Dense(tiger_train.get_output_size(), kernel_initializer='he_normal',name='dense2')(inner)y_pred = Activation('softmax', name='softmax')(inner)base_model = Model(inputs=input_data, outputs=y_pred)base_model.summary()labels = Input(name='the_labels', shape=[tiger_train.max_text_len], dtype='float32')input_length = Input(name='input_length', shape=[1], dtype='int64')label_length = Input(name='label_length', shape=[1], dtype='int64')# Keras doesn't currently support loss funcs with extra parameters# so CTC loss is implemented in a lambda layerloss_out = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([y_pred, labels, input_length, label_length])model = Model(inputs=[input_data, labels, input_length, label_length], outputs=loss_out)# the loss calc occurs elsewhere, so use a dummy lambda func for the lossmodel.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer='adadelta')# if not load:# captures output of softmax so we can decode the output during visualization# test_func = K.function([input_data], [y_pred])earlystoping = EarlyStopping(monitor=MONITOR, patience=PATIENCE, verbose=1, mode=MODE, baseline=BASELINE)train_model_path = './tmp/train_' + MODEL_NAMEcheckpointer = ModelCheckpoint(filepath=train_model_path,verbose=1,save_best_only=True)if os.path.exists(train_model_path):model.load_weights(train_model_path)print('load model weights:%s' % train_model_path)evaluator = Evaluate(model)model.fit_generator(generator=tiger_train.next_batch(),steps_per_epoch=tiger_train.n,epochs=EPOCHS,initial_epoch=1,validation_data=tiger_val.next_batch(),validation_steps=tiger_val.n,callbacks=[checkpointer, earlystoping, evaluator])base_model.save('./model/' + MODEL_NAME)print('----train end----')# For a real OCR application, this should be beam search with a dictionary
# and language model.  For this example, best path is sufficient.
def decode_batch(out):ret = []for j in range(out.shape[0]):out_best = list(np.argmax(out[j, 2:], 1))out_best = [k for k, g in itertools.groupby(out_best)]outstr = ''for c in out_best:if c < len(letters):outstr += letters[c]ret.append(outstr)return retclass Evaluate(Callback):def __init__(self, model):self.accs = []self.model = modeldef on_epoch_end(self, epoch, logs=None):acc = evaluate(self.model)self.accs.append(acc)# Test on validation images
def evaluate(model):global downsample_factortiger_test = TextImageGenerator(TEST_SET_PATH, 'test', IMG_W, IMG_H, TEST_BATCH_SIZE, downsample_factor)net_inp = model.get_layer(name='the_input').inputnet_out = model.get_layer(name='softmax').outputpredict_model = Model(inputs=net_inp, outputs=net_out)equalsIgnoreCaseNum = 0.00equalsNum = 0.00totalNum = 0.00for inp_value, _ in tiger_test.next_batch():batch_size = inp_value['the_input'].shape[0]X_data = inp_value['the_input']# net_out_value = sess.run(net_out, feed_dict={net_inp: X_data})net_out_value = predict_model.predict(X_data)pred_texts = decode_batch(net_out_value)labels = inp_value['the_labels']texts = []for label in labels:text = labels_to_text(label)texts.append(text)for i in range(batch_size):# print('Predict: %s ---> Label: %s' % (pred_texts[i], texts[i]))totalNum += 1if pred_texts[i] == texts[i]:equalsNum += 1if pred_texts[i].lower() == texts[i].lower():equalsIgnoreCaseNum += 1else:print('Predict: %s ---> Label: %s' % (pred_texts[i], texts[i]))if totalNum >= TEST_SET_NUM:breakprint('---Result---')print('Test num: %d, accuracy: %.5f, ignoreCase accuracy: %.5f' % (totalNum, equalsNum / totalNum, equalsIgnoreCaseNum / totalNum))return equalsIgnoreCaseNum / totalNumif __name__ == '__main__':train()

转载于:https://www.cnblogs.com/CoolJayson/p/10602040.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/449444.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[云框架]KONG API Gateway v1.5 -框架说明、快速部署、插件开发

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到教程。 当前版本采用KONGv0.12.3 当我们决定对应用进行微服务改造时&#xff0c;应用客户端如何与微服务交互的问题也随之而来&#xff0c;毕竟…

真格量化-主力跟买策略

#!/usr/bin/env python # coding:utf-8 from PoboAPI import * import datetime import time import numpy as np import pandas as pd #日线级别 #开始时间,用于初始化一些参数 def OnStart(context):print("I\m starting...")#设定一个全局变量品种,本策略交易50E…

顶级投资者的21条箴言(组图)

每天你都会听见五花八门的投资建议&#xff0c;告诉你应该买入还是卖出。如果这让你感到无所适从&#xff0c;不妨静下心来&#xff0c;听听历史上最成功的投资者的建议。 我们搜集了21位顶尖大牛的投资箴言&#xff0c;以飨读者。 1、George Soros&#xff1a;好的投资总是无…

python 游戏 —— 汉诺塔(Hanoita)

一、汉诺塔问题 1. 问题来源 问题源于印度的一个古老传说&#xff0c;大梵天创造世界的时候做了三根金刚石柱子&#xff0c;在一根柱子上从下往上按照大小顺序摞着64片黄金圆盘。大梵天命令婆罗门把圆盘从下面开始按大小顺序重新摆放在另一根柱子上。并且规定&#xff0c;在小圆…

Base62x比Base64的编码速度更快吗?

现在几乎所有企事业单位、政府机构、军工系统等的IT生产系统都会用到Base64编码&#xff0c;从RSA安全密钥到管理信息系统登录入口回跳&#xff0c;目前越来越多的IT系统研发者开始使用 Base62x 替换 Base64. -Base62x 提供了一种无符号输出的Base64的编码方案&#xff0c;在许…

对Docker常用命令的整理

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到教程。 查看docker版本信息、 #docker version #docker -v #docker info image镜像操作命令 #docker search image_name //检索image #docker p…

再说千遍万遍,都不如这四句话管用,不服不行!

一、健康是最大的利益    人有时候&#xff0c;真不知要谋求什么&#xff1f;往往把最值得维护和珍贵的东西忽视了&#xff0c;却不知拣了芝麻丢了西瓜。   现在好多人都在透支健康&#xff0c;燃烧生命&#xff0c;经常借口工作忙、应酬多&#xff0c;不注意生活方式&…

error: failed to push some refs to 'https://gitee.com/xxx/xxx'

一开始以为是本地版本和线上的差异 果断先直接pull 之后 还是不对,哎 不瞎搞了 搜... 获得消息: git pull --rebase origin master 原来如此:是缺失了文件 转载于:https://www.cnblogs.com/G921123/p/10605956.html

真格量化-历史波动率

#!/usr/bin/env python # coding:utf-8 from PoboAPI import * import datetime import time import numpy as np #日线级别 #开始时间,用于初始化一些参数 def OnStart(context):print("I\m starting...")#设定一个全局变量品种,本策略交易50ETF期权g.code = "…

DevOps团队结构类型汇总:总有一款适合你

前言 组织中任何DevOps工作的主要目标都是改进客户和业务的价值交付&#xff0c;而不是降低成本、提升自动化或者通过配置管理驱动一切&#xff1b;这意味着&#xff0c;为了实现有效的Dev和Ops协同&#xff0c;不同的组织可能需要不同的团队结构。 概述 具体哪种DevOps团队结构…

magic

转载于:https://www.cnblogs.com/P201821430028/p/10611080.html

真格量化-bs套利

#!/usr/bin/env python # coding:utf-8 from PoboAPI import * import datetime import time import numpy as np from copy import *#开始时间,用于初始化一些参数 def OnStart(context) :context.myacc = None#登录交易账号if context.accounts["回测期权"].Login…

人生历练必备的十个心态(图)

成功源自心态&#xff0c;如果为自己镶嵌上雄心、信心、决心、爱心、专心、诚心、耐心、恒心、虚心、静心这十颗心&#xff0c;不断打造自己的心态&#xff0c;你就一定会取得人生的成功! 第一个&#xff1a;雄心 你应该让自己试着从人生的地平线上跃起。 第二个&#xf…

【docker】常用docker命令,及一些坑

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到教程。 查看容器的root用户密码 docker logs <容器名orID> 2>&1 | grep ^User: | tail -n1因为docker容器启动时的root用户的密码…

kubernetes系列10—存储卷详解

kubernetes系列10—存储卷详解 1、认识存储卷 1.1 背景 默认情况下容器中的磁盘文件是非持久化的&#xff0c;容器中的磁盘的生命周期是短暂的&#xff0c;这就带来了一系列的问题&#xff1a;第一&#xff0c;当一个容器损坏之后&#xff0c;kubelet 会重启这个容器&#xff0…

真格量化-隐含波动率计算

#!/usr/bin/env python # coding:utf-8 from PoboAPI import * import datetime import time import numpy as np from copy import *#开始时间,用于初始化一些参数 def OnStart(context) :context.myacc = None#登录交易账号if context.accounts["回测期权"].Login…

Vue 后台管理

这里是结合vue和element快速成型的一个demo 里面展示了基本的后台管理界面的大体结构和element的基本操作 GitHub的地址&#xff1a;https://github.com/wwwming/adminDemo 转载于:https://www.cnblogs.com/wangming1002/p/10613014.html

生活窍门 这样用钱就会富足

当我终于从恶劣处境中解脱之后&#xff0c;我想买栋房子&#xff0c;然而父亲却丝毫不为我感到兴奋。他说&#xff1a;“在尽一项新的支付义务前&#xff0c;你应该多投资。”那个时候&#xff0c;许多人相信自己的房子是一种投资。我的父亲问我&#xff1a;“如果你买了一栋房…

如何在Kubernetes集群动态使用 NAS 持久卷

1. 介绍&#xff1a; 本文介绍的动态生成NAS存储卷的方案&#xff1a;在一个已有文件系统上&#xff0c;自动生成一个目录&#xff0c;这个目录定义为目标存储卷&#xff1b; 镜像地址&#xff1a;registry.cn-hangzhou.aliyuncs.com/acs/alicloud-nas-controller:v1.11.5.4-43…

Linux查看MySQL版本的四种方法

前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到教程。 1 在终端下执行 mysql -V 2 在help中查找 mysql --help |grep Distrib 3 在mysql 里查看 select version() 4 在mysql 里查看 status…