kNN算法python实现和简单数字识别


 
kNN算法
算法优缺点:
优点:精度高、对异常值不敏感、无输入数据假定
缺点:时间复杂度和空间复杂度都很高
适用数据范围:数值型和标称型
算法的思路:
KNN算法(全称K最近邻算法),算法的思想很简单,简单的说就是物以类聚,也就是说我们从一堆已知的训练集中找出k个与目标最靠近的,然后看他们中最多的分类是哪个,就以这个为依据分类。 
 
函数解析:
库函数
 
tile()
如tile(A,n)就是将A重复n次
a = np.array([0, 1, 2])
np.tile(a, 2)
array([0, 1, 2, 0, 1, 2])
np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],[0, 1, 2, 0, 1, 2]])
np.tile(a, (2, 1, 2))
array([[[0, 1, 2, 0, 1, 2]],[[0, 1, 2, 0, 1, 2]]])
b = np.array([[1, 2], [3, 4]])
np.tile(b, 2)
array([[1, 2, 1, 2],[3, 4, 3, 4]])
np.tile(b, (2, 1))
array([[1, 2],[3, 4],[1, 2],[3, 4]])`
自己实现的函数
 
createDataSet()生成测试数组
kNNclassify(inputX, dataSet, labels, k)分类函数
 
inputX 输入的参数
dataSet 训练集
labels 训练集的标号
k 最近邻的数目
复制代码
 1 #coding=utf-8
 2 from numpy import *
 3 import operator
 4 
 5 def createDataSet():
 6     group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
 7     labels = ['A','A','B','B']
 8     return group,labels
 9 #inputX表示输入向量(也就是我们要判断它属于哪一类的)
10 #dataSet表示训练样本
11 #label表示训练样本的标签
12 #k是最近邻的参数,选最近k个
13 def kNNclassify(inputX, dataSet, labels, k):
14     dataSetSize = dataSet.shape[0]#计算有几个训练数据
15     #开始计算欧几里得距离
16     diffMat = tile(inputX, (dataSetSize,1)) - dataSet
17     
18     sqDiffMat = diffMat ** 2
19     sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
20     distances = sqDistances ** 0.5
21     #欧几里得距离计算完毕
22     sortedDistance = distances.argsort()
23     classCount = {}
24     for i in xrange(k):
25         voteLabel = labels[sortedDistance[i]]
26         classCount[voteLabel] = classCount.get(voteLabel,0) + 1
27     res = max(classCount)
28     return res
29 
30 def main():
31     group,labels = createDataSet()
32     t = kNNclassify([0,0],group,labels,3)
33     print t
34     
35 if __name__=='__main__':
36     main()
37             
复制代码
 
 
 
kNN应用实例
手写识别系统的实现
数据集:
两个数据集:training和test。分类的标号在文件名中。像素32*32的。数据大概这个样子:
 
 
方法:
kNN的使用,不过这个距离算起来比较复杂(1024个特征),主要是要处理如何读取数据这个问题的,比较方面直接调用就可以了。
速度:
速度还是比较慢的,这里数据集是:training 2000+,test 900+(i5的CPU)
 
k=3的时候要32s+
复制代码
 1 #coding=utf-8
 2 from numpy import *
 3 import operator
 4 import os
 5 import time
 6 
 7 def createDataSet():
 8     group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
 9     labels = ['A','A','B','B']
10     return group,labels
11 #inputX表示输入向量(也就是我们要判断它属于哪一类的)
12 #dataSet表示训练样本
13 #label表示训练样本的标签
14 #k是最近邻的参数,选最近k个
15 def kNNclassify(inputX, dataSet, labels, k):
16     dataSetSize = dataSet.shape[0]#计算有几个训练数据
17     #开始计算欧几里得距离
18     diffMat = tile(inputX, (dataSetSize,1)) - dataSet
19     #diffMat = inputX.repeat(dataSetSize, aixs=1) - dataSet
20     sqDiffMat = diffMat ** 2
21     sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
22     distances = sqDistances ** 0.5
23     #欧几里得距离计算完毕
24     sortedDistance = distances.argsort()
25     classCount = {}
26     for i in xrange(k):
27         voteLabel = labels[sortedDistance[i]]
28         classCount[voteLabel] = classCount.get(voteLabel,0) + 1
29     res = max(classCount)
30     return res
31 
32 def img2vec(filename):
33     returnVec = zeros((1,1024))
34     fr = open(filename)
35     for i in range(32):
36         lineStr = fr.readline()
37         for j in range(32):
38             returnVec[0,32*i+j] = int(lineStr[j])
39     return returnVec
40     
41 def handwritingClassTest(trainingFloder,testFloder,K):
42     hwLabels = []
43     trainingFileList = os.listdir(trainingFloder)
44     m = len(trainingFileList)
45     trainingMat = zeros((m,1024))
46     for i in range(m):
47         fileName = trainingFileList[i]
48         fileStr = fileName.split('.')[0]
49         classNumStr = int(fileStr.split('_')[0])
50         hwLabels.append(classNumStr)
51         trainingMat[i,:] = img2vec(trainingFloder+'/'+fileName)
52     testFileList = os.listdir(testFloder)
53     errorCount = 0.0
54     mTest = len(testFileList)
55     for i in range(mTest):
56         fileName = testFileList[i]
57         fileStr = fileName.split('.')[0]
58         classNumStr = int(fileStr.split('_')[0])
59         vectorUnderTest = img2vec(testFloder+'/'+fileName)
60         classifierResult = kNNclassify(vectorUnderTest, trainingMat, hwLabels, K)
61         #print classifierResult,' ',classNumStr
62         if classifierResult != classNumStr:
63             errorCount +=1
64     print 'tatal error ',errorCount
65     print 'error rate',errorCount/mTest
66         
67 def main():
68     t1 = time.clock()
69     handwritingClassTest('trainingDigits','testDigits',3)
70     t2 = time.clock()
71     print 'execute ',t2-t1
72 if __name__=='__main__':
73     main()
74             
 
 

评论关闭