kNN识别手写图像
示例 :使用k-近邻算法的手写识别系统 (1) 收集数据:提供文本文件。 (2) 准备数据:编写函数classify0(), 将图像格式转换为分类器使用的list格式。 (3) 分析数据:检查数据,确保它符合要求。 (4) 训练算法:此步驟不适用于k-近邻算法。 (5) 测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。 (6) 使用算法:本例没有完成此步骤,若你感兴趣可以构建完整的应用程序,从图像中提取数字,并完成数字识别,美国的邮件分拣系统就是一个实际运行的类似系统。
提示
注:由于原本数据集已经在0和1之间,所以不需要转化数字特征值。数据集 (opens new window)
# 代码
from numpy import *
from os import listdir
import operator
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
# 距离计算
'''
tile(A,rep)
功能:重复A的各个维度
参数类型:
A: Array类的都可以
rep:A沿着各个维度重复的次数
'''
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
# numpy中的 axis=0表示列,向下,axis=1表示行,向右
# 在平时使用的sun默认的是axis=0就是普通的相加,当加入axis=1以后就是将一个矩阵的每一行向量相加
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
# argsort函数返回的是数组值从小到大的索引值
sortedDistIndicies = distances.argsort()
classCount = {}
# 选择距离最小的k个点
for i in range(k):
votellabel = labels[sortedDistIndicies[i]]
classCount[votellabel] = classCount.get(votellabel, 0) + 1
# 排序
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
# 将图像矩阵转化为1x1024的向量
returnVect = zeros((1, 1024))
fr = open(filename)
# 循环读出文件的前32行
for i in range(32):
lineStr = fr.readline()
# 将每行的头32个字符值存储在Numpy数组中
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
# 返回数组
return returnVect
def handwritingClassTest():
hwLabels = []
# 获取训练数据集下目录的所有文件名的列表
trainingFileList = listdir('trainingDigits')
# 得到文件数量
m = len(trainingFileList)
# 创建m行1024列的训练矩阵
trainingMat = zeros((m, 1024))
for i in range(m):
# 从文件名解析分类数字
# 解析出0_10.txt
fileNameStr = trainingFileList[i]
# 获得0_10
fileStr = fileNameStr.split('.')[0]
# 获得0
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
# 获得测试数据集下目录的所有文件名的列表
testFileList = listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with:%d,the real answer is:%d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr):
errorCount += 1.0
print("\nthe total number of errors is:%d" % errorCount)
print("\nthe total error rate is :%f" % (errorCount / float(mTest)))
testVector = img2vector('testDigits/0_13.txt')
# X[:, m:n],即取二维数组中的第m到n-1列的所有数据
print("测试输出:\n", testVector[0, 0:31])
handwritingClassTest()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# 运行结果
the classifier came back with:1,the real answer is:1
the classifier came back with:1,the real answer is:1
.....
the classifier came back with:5,the real answer is:5
the classifier came back with:5,the real answer is:5
the classifier came back with:6,the real answer is:6
the classifier came back with:6,the real answer is:6
....
the classifier came back with:9,the real answer is:9
the classifier came back with:9,the real answer is:9
the classifier came back with:9,the real answer is:9
the total number of errors is:11
# 错误率为1.2%
the total error rate is :0.011628
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
编辑 (opens new window)
上次更新: 2022/12/31, 16:52:27
- 01
- SpringCache基本配置类05-16
- 03
- Rpamis-security-原理解析12-13