OpenCV 机器学习库提供了一系列 SVM 函数和类来实现 SVM 模型的训练和预测,方便用户实现自己的 SVM 模型,并应用于分类问题。本文主要介绍使用 openCV 实现手写算式识别的工作原理与实现过程。
目录
1 SVM 模型
1.1 SVM 模型介绍
1.2 SVM 模型原理
2 手写算式识别
2.1 字符识别
2.2 算式识别
1 SVM 模型
1.1 SVM 模型介绍
SVM 是支持向量机(Support Vector Machine)的英文缩写,是统计学习理论中一种重要的分类方法,其早期工作来自前苏联 Vladimir N. Vapnik 和 Alexander Y. Lerner 在1963年发表的研究。
1995年,Corinna Cortes 和 Vapnik 提出了软边距的非线性 SVM 并将其应用于手写字符识别问题,为 SVM 在其他领域的应用提供了参考。
SVM 的优点主要包括:
1)具有较好的可解释性。SVM 的决策函数和支持向量清晰,易于理解。
2)适用性广泛。SVM 能够应用于多种数据类型和领域,如文本分类、图像识别和生物信息学等。
3)鲁棒性强。SVM 对训练数据中的噪声和异常点具有较强的容错能力,能有效处理输入数据中的噪声。
4)适合高维数据。通过核函数,SVM 能够将低维空间的非线性问题映射到高维空间,进行线性划分,从而解决复杂的非线性问题。
5)可控制的过拟合。通过调整正则化参数和松弛变量,SVM 可以控制模型的复杂度,有效避免过拟合问题。
6)避免陷入局部最优解。使用结构风险最小化原则,使得 SVM 能够更好地避免陷入局部最优解,并具有较低的泛化误差。
1.2 SVM 模型原理
在二分类问题中,给定输入数据和学习目标: ,,若存在决策边界(decision boundary)
将样本按类别分开,则称该分类问题是线性可分的(Linear Separable)。
按照统计学习理论,分类器在经过学习新数据时会产生风险,风险的类型分为经验风险和结构风险:
式中 表示分类器,经验风险由损失函数定义,描述了分类器所给出的分类结果的准确程度;结构风险由分类器参数矩阵的范数定义,描述了分类器自身的复杂程度以及稳定程度。
复杂的分类器容易过拟合,因此是不稳定的。通过最小化经验风险和结构风险的线性组合以确定其模型参数:
式中 是正则化参数,当 时,该式被称为 正则化。
对于线性可分问题,SVM 经验风险为 0,SVM 模型简化为最小化结构风险,由于点到超平面的距离反比于 || ω ||,因此模型可解释为最大化样本到超平面的最小距离,
即最优超平面距离给定的每个样本尽可能远。
2 手写算式识别
2.1 字符识别
OpenCV 机器学习库提供了一系列 SVM 函数和类来实现 SVM 模型的训练和预测,可以很方便地实现用户自定义的分类模型。
使用 OpenCV 实现 SVM 模型的基本步骤如下:
(1)创建模型。使用 cv2.ml.SVM_create() 创建 SVM 模型,使用 setKernel() 指定核函数;
(2)初始化模型参数。使用 setC() 和 setGamma() 设置参数的初始值;
(3)模型训练。使用 train() 函数,以及向量化的样本和分类标签,训练模型;
(4)模型评估。使用 predict() 预测新样本,并统计正确率;
(5)模型保存。使用 save() 保存模型,文件格式为 *.dat 。
在手写算式的字符识别中,需要识别数字 0 ~ 9,以及 +,-,×,÷,(,)和 = 共 17 种字符。SVM 模型的输入样本是字符图像向量化的结果,处理步骤包括:
1)图像缩放。将字符图像统一成 28 × 28 大小;
2)颜色反转。使用 cv2.bitwise_not() 函数实现颜色反转,便于后续步骤;
3)去偏斜。使用 cv2.moments() 计算图像的矩,然后使用 cv2.warpAffine() 去偏斜;
4)向量化。将图像按照十字划分成 4 个区域,计算每个区域的方向梯度直方图,拼接成一个向量。
参考链接:OpenCV: OCR of Hand-written Data using SVM
2.2 算式识别
手写算式识别包括 3 个阶段:字符分割、图像预处理和字符识别。字符分割用于提取输入图像中的连续字符,图像预处理用于字符图像的特征化,字符识别用于图像与字符的对应。最后按照顺序拼接识别到的字符,就得到输出表达式。
#-*- Coding: utf-8 -*-import cv2
import numpy as np
import gradio as gr# 加载模型
model = cv2.ml.SVM_load('./svm_data.dat')
chars = '0123456789+-*/()='SZ = 28
bin_n = 16 # Number of binsdef resize(src_img, size):# 获取原图像的宽、高h, w = src_img.shapeif h >= size and w >= size:# 图像缩放dst_img = cv2.resize(src_img, (size, size), interpolation=cv2.INTER_CUBIC)elif h >= size:# 填充左右边缘dst_img = np.zeros(shape=(h, size), dtype=np.uint8)dst_img[:, (size-w)//2:(size-w)//2+w] = src_imgdst_img = cv2.resize(dst_img, (size, size), interpolation=cv2.INTER_CUBIC)elif w >= size:# 填充上下边缘dst_img = np.zeros(shape=(size, w), dtype=np.uint8)dst_img[(size-h)//2:(size-h)//2+h, :] = src_imgdst_img = cv2.resize(dst_img, (size, size), interpolation=cv2.INTER_CUBIC)else:# 填充四周dst_img = np.zeros(shape=(size, size), dtype=np.uint8)dst_img[(size-h)//2:(size-h)//2+h, (size-w)//2:(size-w)//2+w] = src_imgreturn dst_imgdef deskew(src_img):"""Deskew the image using its second order moments"""m = cv2.moments(src_img)if abs(m['mu02']) < 1e-2:return src_img.copy()skew = m['mu11']/m['mu02']M = np.float32([[1, -skew, 0.5*SZ*skew], [0, 1, 0]])dst_img = cv2.warpAffine(src_img, M, (SZ, SZ), cv2.WARP_INVERSE_MAP | cv2.INTER_LINEAR)return dst_imgdef hog(image):gx = cv2.Sobel(image, cv2.CV_32F, 1, 0)gy = cv2.Sobel(image, cv2.CV_32F, 0, 1)mag, ang = cv2.cartToPolar(gx, gy)bins = np.int32(bin_n*ang/(2*np.pi)) # quantizing binvalues in (0, ..., 16)bin_cells = bins[:14,:14], bins[14:,:14], bins[:14,14:], bins[14:,14:]mag_cells = mag[:14,:14], mag[14:,:14], mag[:14,14:], mag[14:,14:]hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]hist = np.hstack(hists) # hist is a 64bit vectorreturn histdef pre_process(src_img):"""图像预处理"""img_resize = resize(src_img, SZ)img_invert = cv2.bitwise_not(img_resize) # 颜色翻转img_deskew = deskew(img_invert)hist = hog(img_deskew)return histdef exprRecognize(src_img, filter_size):"""手写算式识别"""# 灰度图gray = cv2.cvtColor(src_img, cv2.COLOR_BGR2GRAY)# 二值化_, binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)binary_inv = cv2.bitwise_not(binary)# 中值滤波filter_size = int(filter_size[0][0]) if filter_size else 3binary_f = cv2.medianBlur(binary_inv, filter_size)# 查找字符区域contours, _ = cv2.findContours(binary_f, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)# 遍历所有区域,寻找最大宽度w_max = 0for cnt in contours:_, _, w, _ = cv2.boundingRect(cnt)if w > w_max:w_max = w# 遍历所有区域,拼接x坐标接近的区域char_dict = {}for cnt in contours:x, y, w, h = cv2.boundingRect(cnt)x_mid = x + w//2 # 计算中点位置if not char_dict.keys() or all(np.abs(z - x_mid) > w_max/1.5 for z in char_dict.keys()):char_dict[x_mid] = cntelse:for z in char_dict.keys():if np.abs(z - x_mid) <= w_max/1.5:char_dict[z] = np.concatenate((char_dict[z], cnt), axis=0) # 拼接两个区域# 按照中点坐标,对字符进行排序char_dict = dict(sorted(char_dict.items(), key=lambda item: item[0]))# 遍历所有区域,提取字符dst_img = []for _, cnt in char_dict.items():x, y, w, h = cv2.boundingRect(cnt)roi = binary[y:y+h, x:x+w]dst_img.append(roi)expr = ''for char in dst_img:hist = pre_process(char)hist = np.array(hist, dtype=np.float32)result = model.predict(hist.reshape(-1, 4*bin_n))[1]expr += chars[int(result[0])]return dst_img, expr, eval(expr.replace('=', ''))if __name__ == "__main__":demo = gr.Interface(fn=exprRecognize,inputs=[gr.Image(label="input image"), gr.Radio(['3x3', '5x5', '7x7'], value='3x3')],outputs=[gr.Gallery(label="charset", columns=[3], object_fit="contain", height="auto"),gr.Text(label="expression"),gr.Text(label="result")],live=True)demo.launch()