1.Python实现KNN算法
输入:inX:与现有数据集(1xN)进行比较的向量
dataSet:已知向量的大小m数据集(NxM) 个标签:数据集标签(1xM矢量) k:用于比较的邻居数(应为奇数)输出:最受欢迎的类标签(归类问题)1 # -*- coding: utf-8 -*- 2 """ 3 Created on Sun Apr 16 23:01:54 2017 4 5 @author: SimonsZhao 6 """ 10 kNN: k Nearest Neighbors 12 Input: inX: vector to compare to existing dataset (1xN) 13 dataSet: size m data set of known vectors (NxM) 14 labels: data set labels (1xM vector) 15 k: number of neighbors to use for comparison (should be an odd number) 17 Output: the most popular class label 20 ''' 21 from numpy import * 22 import operator 23 from os import listdir 24 25 def classify0(inX, dataSet, labels, k): 26 dataSetSize = dataSet.shape[0] 27 diffMat = tile(inX, (dataSetSize,1)) - dataSet 28 sqDiffMat = diffMat**2 29 sqDistances = sqDiffMat.sum(axis=1) 30 distances = sqDistances**0.5 31 sortedDistIndicies = distances.argsort() 32 classCount={} 33 for i in range(k): 34 voteIlabel = labels[sortedDistIndicies[i]] 35 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 36 sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 37 return sortedClassCount[0][0] 38 39 def createDataSet(): 40 group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) 41 labels = ['A','A','B','B'] 42 return group, labels 43 44 def file2matrix(filename): 45 fr = open(filename) 46 numberOfLines = len(fr.readlines()) #get the number of lines in the file 47 returnMat = zeros((numberOfLines,3)) #prepare matrix to return 48 classLabelVector = [] #prepare labels return 49 fr = open(filename) 50 index = 0 51 for line in fr.readlines(): 52 line = line.strip() 53 listFromLine = line.split('\t') 54 returnMat[index,:] = listFromLine[0:3] 55 classLabelVector.append(int(listFromLine[-1])) 56 index += 1 57 return returnMat,classLabelVector 58 59 def autoNorm(dataSet): 60 minVals = dataSet.min(0) 61 maxVals = dataSet.max(0) 62 ranges = maxVals - minVals 63 normDataSet = zeros(shape(dataSet)) 64 m = dataSet.shape[0] 65 normDataSet = dataSet - tile(minVals, (m,1)) 66 normDataSet = normDataSet/tile(ranges, (m,1)) #element wise divide 67 return normDataSet, ranges, minVals 68 69 def datingClassTest(): 70 hoRatio = 0.50 #hold out 10% 71 datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file 72 normMat, ranges, minVals = autoNorm(datingDataMat) 73 m = normMat.shape[0] 74 numTestVecs = int(m*hoRatio) 75 errorCount = 0.0 76 for i in range(numTestVecs): 77 classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) 78 print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) 79 if (classifierResult != datingLabels[i]): errorCount += 1.0 80 print "the total error rate is: %f" % (errorCount/float(numTestVecs)) 81 print errorCount 82 83 def img2vector(filename): 84 returnVect = zeros((1,1024)) 85 fr = open(filename) 86 for i in range(32): 87 lineStr = fr.readline() 88 for j in range(32): 89 returnVect[0,32*i+j] = int(lineStr[j]) 90 return returnVect 91 92 def handwritingClassTest(): 93 hwLabels = [] 94 trainingFileList = listdir('trainingDigits') #load the training set 95 m = len(trainingFileList) 96 trainingMat = zeros((m,1024)) 97 for i in range(m): 98 fileNameStr = trainingFileList[i] 99 fileStr = fileNameStr.split('.')[0] #take off .txt100 classNumStr = int(fileStr.split('_')[0])101 hwLabels.append(classNumStr)102 trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)103 testFileList = listdir('testDigits') #iterate through the test set104 errorCount = 0.0105 mTest = len(testFileList)106 for i in range(mTest):107 fileNameStr = testFileList[i]108 fileStr = fileNameStr.split('.')[0] #take off .txt109 classNumStr = int(fileStr.split('_')[0])110 vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)111 classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)112 print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)113 if (classifierResult != classNumStr): errorCount += 1.0114 print "\nthe total number of errors is: %d" % errorCount115 print "\nthe total error rate is: %f" % (errorCount/float(mTest))
2.数据集(测试集合训练集)
3.KNN测试结果