百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 热门文章 > 正文

opencv手写数字识别:SVM和KNearest

bigegpt 2024-08-19 12:05 7 浏览

先看结果:

svm模型的识别结果:红色为识别错误的

KNearest分类模型的识别结果:红色为错误的


处理流程:

1 数据加载:我们从digits.png里加载一些训练样本。

2 倾斜矫正

3 提取梯度方向直方图hog特征

4 将梯度直方图转换到Hellinger metric

5 使用KNearest分类并测试

6 使用SVM分类并测试


涉及到的知识点有:

1 倾斜矫正:图像距的计算,仿射变换

2 梯度方向直方图是什么?

3 Hellinger距离,hellinger 矩阵

4 KNearest分类模型详细

5 SVM分类模型详细

这节内容非常多,非常扎实,要专注一阵子了。

处理的数据是这样的:



我们先看一下倾斜矫正的内容:

倾斜矫正

倾斜矫正其实最终要求取一个变换矩阵M, 然后用这个矩阵M就可以实现对图像的倾斜矫正。

那么图像这个M如何求取呢?这里就用到了图像的距。

https://www.cnblogs.com/ronny/p/3985810.html

把像素坐标看成是二维的随机变量,那么介于0-1之间图像的灰度值就可以表示一个二维的灰度概率密度函数了。

这样,

而opencv中,

所以,我们就可以用:

63 # 抗扭斜,倾斜纠正
 64 def deskew(img):
 65     m = cv.moments(img)
 66     if abs(m['mu02']) < 1e-2:
 67         return img.copy()
 68     skew = m['mu11']/m['mu02']
 69     M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
 70     img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR)
 71     return img

mu11/mu02来表示图像的斜切系数,因为图像斜切了,所以原本图像的中心点就移动位置了,所以我们需要将图像的中心点再移动回去,

因此我们就得到了图像的刚体变换矩阵M:

69     M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])

得到这个M,就可以对图片进行倾斜矫正了,如第70行代码,就是对图片用M矩阵进行变换的过程。

关于第二步,我们看看梯度直方图是什么?

梯度直方图全称:图像梯度的直方图。

这里面有两个概念:图像梯度和直方图。

我先看图像的梯度:

把img(i,j)中i,j看称自变量,把i,j对应的img值看成是因变量,

这样图像的梯度是指:[Dimg/Di,Dimg/Dj]

每个像素都可以求到两个方向上的梯度值。

这样我们把这两个方向上的梯度值转化为极坐标中,就表示为[R,theta]

到此,我们求到了图像的梯度。

再看直方图:

对于一幅图像来说,我们先求一下图像的灰度直方图。

灰度直方图:

我们查看图像每个位置上的灰度值,

统计一下灰度值为0的像素个数计为N0,

统计一下灰度值为1的像素个数计为N1,

统计一下灰度值为2的像素个数计为N2,

。。。

统计一下灰度值为254的像素个数计为N254,

统计一下灰度值为255的像素个数计为N255,

这样统计下来,我们一共得到256个值,分别是:N0,N1,N2,N3,…,N255.

这256个值组成一个向量,就是图像的绘制直方图。

为什么叫直方图这么一个名字呢?

因为这256个值通常我们都用直方图来显示。

注意到:N0+N1+N2+…+N255 = N

N 为图片的像素个数。

我们一共统计了256个值,也称为256个bin.

那么我们可以只统计128个bin 么?也是可以的。

N0+N1 构成一个bin

N2+N3构成一个bin

这样,我们就得到了128个bin.

一般情况下, 我们需要对灰度直方图归一化一下,

归一化就是N0/N,N1/N,N2/N,…N256/N,这就是图像归一化直方图。

这里归一化直方图里的每个值N0/N就是为0的像素占比了

这就是直方图:统计个数,或者统计比例。

我们刚才是计算的图像的灰度直方图。

那么现在我们来求图像的梯度直方图。

灰度直方图是统计不同灰度值下的像素个数占比。

梯度直方图就是统计不同梯度值下的像素个数占比。

梯度值我们刚才计算了,极坐标下,我们用半径和角度来表示。

我们会按照角度来区分出18个bin,

然后将落在每个bin内的像素的梯度的半径R值相加。

这样,我们就得到了18个值,组成长度为18的向量。

我们再归一下一下,就得到了梯度的直方图。

为了提取到更加细节的特征,我们将图像分成小区域,在每个小区域上提取归一化的梯度直方图,

再将每个小区域上的梯度直方图拼接起来,形成整个图片的灰度直方图。

然后再看看Hellinger 矩阵是什么?

本例代码中,将得到梯度直方图转化为Hellinger Matrix.

所用代码为:

148         # transform to Hellinger kernel
149         eps = 1e-7
150         hist /= hist.sum() + eps
151         hist = np.sqrt(hist)
152         hist /= norm(hist) + eps

我们看看Hellinger 距离的定义:

参考:https://www.cnblogs.com/wangxiaocvpr/p/5523294.html?ivk_sa=1024320u

这个Hellinger distance也就是海林格距离,是用来评价两个概率分布的相似程度的。越相似,距离就越小。

这里将梯度直方图Hellinger化,相当于求取了和0向量的海林格距离。

对了,这个海林格距离叫巴氏距离,他们是一回事。

参考:https://www.bilibili.com/video/av243968802?ivk_sa=1024320u

然后再看KNearest和SVM

假设待分类的图片特征为hist=[h0,h1,h2,…,h63]

KNearest是一种分类算法

KNearest要对这个特征进行分类

分类的方法:将hist与我们样本中所有hist计算一下欧氏距离,然后看一下最小的K个欧式距离所对应的样本的类别,那个类别对应的样本多,待分类图片就属于哪一类。

KNearest的代码如下:

180     print('training KNearest...')
181     import pdb
182     pdb.set_trace()
183     model = KNearest(k=4)
184     model.train(samples_train, labels_train)
185     vis = evaluate_model(model, digits_test, samples_test, labels_test)
186     cv.imshow('KNearest test', vis)

185行是对KNearest模型准确率的计算。

SVM也是一种分类算法,

SVM要对这个特征进行分类

分类的方法:y = Wxhist + b

这里的W的维度和hist的维度是一样的,都是64维的,是需要事先根据样本计算出来。

bias的维度是一,是需要事先根据样本计算出来。

根据样本计算出来的过程,实际上就是模型的训练过程。

代码如下:

188     print('training SVM...')
189     model = SVM(C=2.67, gamma=5.383)
190     model.train(samples_train, labels_train)
191     vis = evaluate_model(model, digits_test, samples_test, labels_test)
192     cv.imshow('SVM test', vis)
193     print('saving SVM as "digits_svm.dat"...')
194     model.save('digits_svm.dat')

如代码第190行,就是训练svm模型,计算出W和b的过程。

191行是对svm模型准确率的计算。

我们看下全部的代码:

  1 #!/usr/bin/env python
  2 
  3 '''
  4 SVM and KNearest digit recognition.
  5 
  6 Sample loads a dataset of handwritten digits from 'digits.png'.
  7 Then it trains a SVM and KNearest classifiers on it and evaluates
  8 their accuracy.
  9 
 10 Following preprocessing is applied to the dataset:
 11  - Moment-based image deskew (see deskew())
 12  - Digit images are split into 4 10x10 cells and 16-bin
 13    histogram of oriented gradients is computed for each
 14    cell
 15  - Transform histograms to space with Hellinger metric (see [1] (RootSIFT))
 16 
 17 
 18 [1] R. Arandjelovic, A. Zisserman
 19     "Three things everyone should know to improve object retrieval"
 20     http://www.robots.ox.ac.uk/~vgg/publications/2012/Arandjelovic12/arandjelovic12.pdf
 21 
 22 Usage:
 23    digits.py
 24 '''
 25 
 26 
 27 # Python 2/3 compatibility
 28 from __future__ import print_function
 29 
 30 import numpy as np
 31 import cv2 as cv
 32 
 33 # built-in modules
 34 from multiprocessing.pool import ThreadPool
 35 
 36 from numpy.linalg import norm
 37 
 38 # local modules
 39 from common import clock, mosaic
 40 
 41 
 42 
 43 SZ = 20 # size of each digit is SZ x SZ
 44 CLASS_N = 10
 45 DIGITS_FN = 'digits.png'
 46 
 47 def split2d(img, cell_size, flatten=True):
 48     h, w = img.shape[:2]
 49     sx, sy = cell_size
 50     cells = [np.hsplit(row, w//sx) for row in np.vsplit(img, h//sy)]
 51     cells = np.array(cells)
 52     if flatten:
 53         cells = cells.reshape(-1, sy, sx)
 54     return cells
 55 
 56 def load_digits(fn):
 57     fn = cv.samples.findFile(fn)
 58     print('loading "%s" ...' % fn)
 59     digits_img = cv.imread(fn, cv.IMREAD_GRAYSCALE)
 60     digits = split2d(digits_img, (SZ, SZ))
 61     labels = np.repeat(np.arange(CLASS_N), len(digits)/CLASS_N)
 62     return digits, labels
 63 # 抗扭斜,倾斜纠正
 64 def deskew(img):
 65     m = cv.moments(img)
 66     if abs(m['mu02']) < 1e-2:
 67         return img.copy()
 68     skew = m['mu11']/m['mu02']
 69     M = np.float32([[1, skew, -0.5*SZ*skew], [0, 1, 0]])
 70     img = cv.warpAffine(img, M, (SZ, SZ), flags=cv.WARP_INVERSE_MAP | cv.INTER_LINEAR)
 71     return img
 72 
 73 
 74 class KNearest(object):
 75     def __init__(self, k = 3):
 76         self.k = k
 77         self.model = cv.ml.KNearest_create()
 78 
 79     def train(self, samples, responses):
 80         self.model.train(samples, cv.ml.ROW_SAMPLE, responses)
 81 
 82     def predict(self, samples):
 83         _retval, results, _neigh_resp, _dists = self.model.findNearest(samples, self.k)
 84         return results.ravel()
 85 
 86     def load(self, fn):
 87         self.model = cv.ml.KNearest_load(fn)
 88 
 89     def save(self, fn):
 90         self.model.save(fn)
 91 
 92 class SVM(object):
 93     def __init__(self, C = 1, gamma = 0.5):
 94         self.model = cv.ml.SVM_create()
 95         self.model.setGamma(gamma)
 96         self.model.setC(C)
 97         self.model.setKernel(cv.ml.SVM_RBF)
 98         self.model.setType(cv.ml.SVM_C_SVC)
 99 
100     def train(self, samples, responses):
101         self.model.train(samples, cv.ml.ROW_SAMPLE, responses)
102 
103     def predict(self, samples):
104         return self.model.predict(samples)[1].ravel()
105 
106     def load(self, fn):
107         self.model = cv.ml.SVM_load(fn)
108 
109     def save(self, fn):
110         self.model.save(fn)
111 
112 def evaluate_model(model, digits, samples, labels):
113     resp = model.predict(samples)
114     err = (labels != resp).mean()
115     print('error: %.2f %%' % (err*100))
116 
117     confusion = np.zeros((10, 10), np.int32)
118     for i, j in zip(labels, resp):
119         confusion[i, int(j)] += 1
120     print('confusion matrix:')
121     print(confusion)
122     print()
123 
124     vis = []
125     for img, flag in zip(digits, resp == labels):
126         img = cv.cvtColor(img, cv.COLOR_GRAY2BGR)
127         if not flag:
128             img[...,:2] = 0
129         vis.append(img)
130     return mosaic(25, vis)
131 
132 def preprocess_simple(digits):
133     return np.float32(digits).reshape(-1, SZ*SZ) / 255.0
134 
135 def preprocess_hog(digits):
136     samples = []
137     for img in digits:
138         gx = cv.Sobel(img, cv.CV_32F, 1, 0)
139         gy = cv.Sobel(img, cv.CV_32F, 0, 1)
140         mag, ang = cv.cartToPolar(gx, gy)
141         bin_n = 16
142         bin = np.int32(bin_n*ang/(2*np.pi))
143         bin_cells = bin[:10,:10], bin[10:,:10], bin[:10,10:], bin[10:,10:]
144         mag_cells = mag[:10,:10], mag[10:,:10], mag[:10,10:], mag[10:,10:]
145         hists = [np.bincount(b.ravel(), m.ravel(), bin_n) for b, m in zip(bin_cells, mag_cells)]
146         hist = np.hstack(hists)
147 
148         # transform to Hellinger kernel
149         eps = 1e-7
150         hist /= hist.sum() + eps
151         hist = np.sqrt(hist)
152         hist /= norm(hist) + eps
153 
154         samples.append(hist)
155     return np.float32(samples)
156 
157 
158 if __name__ == '__main__':
159     print(__doc__)
160 
161     digits, labels = load_digits(DIGITS_FN)
162 
163     print('preprocessing...')
164     # shuffle digits
165     rand = np.random.RandomState(321)
166     shuffle = rand.permutation(len(digits))
167     digits, labels = digits[shuffle], labels[shuffle]
168 
169     #倾斜矫正?
170     digits2 = list(map(deskew, digits))
171     samples = preprocess_hog(digits2)
172 
173     train_n = int(0.9*len(samples))
174     cv.imshow('test set', mosaic(25, digits[train_n:]))
175     digits_train, digits_test = np.split(digits2, [train_n])
176     samples_train, samples_test = np.split(samples, [train_n])
177     labels_train, labels_test = np.split(labels, [train_n])
178 
179 
180     print('training KNearest...')
181     import pdb
182     pdb.set_trace()
183     model = KNearest(k=4)
184     model.train(samples_train, labels_train)
185     vis = evaluate_model(model, digits_test, samples_test, labels_test)
186     cv.imshow('KNearest test', vis)
187 
188     print('training SVM...')
189     model = SVM(C=2.67, gamma=5.383)
190     model.train(samples_train, labels_train)
191     vis = evaluate_model(model, digits_test, samples_test, labels_test)
192     cv.imshow('SVM test', vis)
193     print('saving SVM as "digits_svm.dat"...')
194     model.save('digits_svm.dat')
195 
196     cv.waitKey(0)

相关推荐

悠悠万事,吃饭为大(悠悠万事吃饭为大,什么意思)

新媒体编辑:杜岷赵蕾初审:程秀娟审核:汤小俊审签:周星...

高铁扒门事件升级版!婚宴上‘冲喜’老人团:我们抢的是社会资源

凌晨两点改方案时,突然收到婚庆团队发来的视频——胶东某酒店宴会厅,三个穿大红棉袄的中年妇女跟敢死队似的往前冲,眼瞅着就要扑到新娘的高额钻石项链上。要不是门口小伙及时阻拦,这婚礼造型团队熬了三个月的方案...

微服务架构实战:商家管理后台与sso设计,SSO客户端设计

SSO客户端设计下面通过模块merchant-security对SSO客户端安全认证部分的实现进行封装,以便各个接入SSO的客户端应用进行引用。安全认证的项目管理配置SSO客户端安全认证的项目管理使...

还在为 Spring Boot 配置类加载机制困惑?一文为你彻底解惑

在当今微服务架构盛行、项目复杂度不断攀升的开发环境下,SpringBoot作为Java后端开发的主流框架,无疑是我们手中的得力武器。然而,当我们在享受其自动配置带来的便捷时,是否曾被配置类加载...

Seata源码—6.Seata AT模式的数据源代理二

大纲1.Seata的Resource资源接口源码2.Seata数据源连接池代理的实现源码3.Client向Server发起注册RM的源码4.Client向Server注册RM时的交互源码5.数据源连接...

30分钟了解K8S(30分钟了解微积分)

微服务演进方向o面向分布式设计(Distribution):容器、微服务、API驱动的开发;o面向配置设计(Configuration):一个镜像,多个环境配置;o面向韧性设计(Resista...

SpringBoot条件化配置(@Conditional)全面解析与实战指南

一、条件化配置基础概念1.1什么是条件化配置条件化配置是Spring框架提供的一种基于特定条件来决定是否注册Bean或加载配置的机制。在SpringBoot中,这一机制通过@Conditional...

一招解决所有依赖冲突(克服依赖)

背景介绍最近遇到了这样一个问题,我们有一个jar包common-tool,作为基础工具包,被各个项目在引用。突然某一天发现日志很多报错。一看是NoSuchMethodError,意思是Dis...

你读过Mybatis的源码?说说它用到了几种设计模式

学习设计模式时,很多人都有类似的困扰——明明概念背得滚瓜烂熟,一到写代码就完全想不起来怎么用。就像学了一堆游泳技巧,却从没下过水实践,很难真正掌握。其实理解一个知识点,就像看立体模型,单角度观察总...

golang对接阿里云私有Bucket上传图片、授权访问图片

1、为什么要设置私有bucket公共读写:互联网上任何用户都可以对该Bucket内的文件进行访问,并且向该Bucket写入数据。这有可能造成您数据的外泄以及费用激增,若被人恶意写入违法信息还可...

spring中的资源的加载(spring加载原理)

最近在网上看到有人问@ContextConfiguration("classpath:/bean.xml")中除了classpath这种还有其他的写法么,看他的意思是想从本地文件...

Android资源使用(android资源文件)

Android资源管理机制在Android的开发中,需要使用到各式各样的资源,这些资源往往是一些静态资源,比如位图,颜色,布局定义,用户界面使用到的字符串,动画等。这些资源统统放在项目的res/独立子...

如何深度理解mybatis?(如何深度理解康乐服务质量管理的5个维度)

深度自定义mybatis回顾mybatis的操作的核心步骤编写核心类SqlSessionFacotryBuild进行解析配置文件深度分析解析SqlSessionFacotryBuild干的核心工作编写...

@Autowired与@Resource原理知识点详解

springIOCAOP的不多做赘述了,说下IOC:SpringIOC解决的是对象管理和对象依赖的问题,IOC容器可以理解为一个对象工厂,我们都把该对象交给工厂,工厂管理这些对象的创建以及依赖关系...

java的redis连接工具篇(java redis client)

在Java里,有不少用于连接Redis的工具,下面为你介绍一些主流的工具及其特点:JedisJedis是Redis官方推荐的Java连接工具,它提供了全面的Redis命令支持,且...