博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Python实现KNN算法及手写程序识别
阅读量:4520 次
发布时间:2019-06-08

本文共 4819 字,大约阅读时间需要 16 分钟。

1.Python实现KNN算法

输入:inX:与现有数据集(1xN)进行比较的向量

   dataSet:已知向量的大小m数据集(NxM)
   个标签:数据集标签(1xM矢量)
   k:用于比较的邻居数(应为奇数)
输出:最受欢迎的类标签(归类问题)

1 # -*- coding: utf-8 -*-  2 """  3 Created on Sun Apr 16 23:01:54 2017  4   5 @author: SimonsZhao  6 """ 10 kNN: k Nearest Neighbors 12 Input:      inX: vector to compare to existing dataset (1xN) 13             dataSet: size m data set of known vectors (NxM) 14             labels: data set labels (1xM vector) 15             k: number of neighbors to use for comparison (should be an odd number) 17 Output:     the most popular class label 20 ''' 21 from numpy import * 22 import operator 23 from os import listdir 24  25 def classify0(inX, dataSet, labels, k): 26     dataSetSize = dataSet.shape[0] 27     diffMat = tile(inX, (dataSetSize,1)) - dataSet 28     sqDiffMat = diffMat**2 29     sqDistances = sqDiffMat.sum(axis=1) 30     distances = sqDistances**0.5 31     sortedDistIndicies = distances.argsort()      32     classCount={}           33     for i in range(k): 34         voteIlabel = labels[sortedDistIndicies[i]] 35         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 36     sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) 37     return sortedClassCount[0][0] 38  39 def createDataSet(): 40     group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]) 41     labels = ['A','A','B','B'] 42     return group, labels 43  44 def file2matrix(filename): 45     fr = open(filename) 46     numberOfLines = len(fr.readlines())         #get the number of lines in the file 47     returnMat = zeros((numberOfLines,3))        #prepare matrix to return 48     classLabelVector = []                       #prepare labels return    49     fr = open(filename) 50     index = 0 51     for line in fr.readlines(): 52         line = line.strip() 53         listFromLine = line.split('\t') 54         returnMat[index,:] = listFromLine[0:3] 55         classLabelVector.append(int(listFromLine[-1])) 56         index += 1 57     return returnMat,classLabelVector 58      59 def autoNorm(dataSet): 60     minVals = dataSet.min(0) 61     maxVals = dataSet.max(0) 62     ranges = maxVals - minVals 63     normDataSet = zeros(shape(dataSet)) 64     m = dataSet.shape[0] 65     normDataSet = dataSet - tile(minVals, (m,1)) 66     normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide 67     return normDataSet, ranges, minVals 68     69 def datingClassTest(): 70     hoRatio = 0.50      #hold out 10% 71     datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file 72     normMat, ranges, minVals = autoNorm(datingDataMat) 73     m = normMat.shape[0] 74     numTestVecs = int(m*hoRatio) 75     errorCount = 0.0 76     for i in range(numTestVecs): 77         classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) 78         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) 79         if (classifierResult != datingLabels[i]): errorCount += 1.0 80     print "the total error rate is: %f" % (errorCount/float(numTestVecs)) 81     print errorCount 82      83 def img2vector(filename): 84     returnVect = zeros((1,1024)) 85     fr = open(filename) 86     for i in range(32): 87         lineStr = fr.readline() 88         for j in range(32): 89             returnVect[0,32*i+j] = int(lineStr[j]) 90     return returnVect 91  92 def handwritingClassTest(): 93     hwLabels = [] 94     trainingFileList = listdir('trainingDigits')           #load the training set 95     m = len(trainingFileList) 96     trainingMat = zeros((m,1024)) 97     for i in range(m): 98         fileNameStr = trainingFileList[i] 99         fileStr = fileNameStr.split('.')[0]     #take off .txt100         classNumStr = int(fileStr.split('_')[0])101         hwLabels.append(classNumStr)102         trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)103     testFileList = listdir('testDigits')        #iterate through the test set104     errorCount = 0.0105     mTest = len(testFileList)106     for i in range(mTest):107         fileNameStr = testFileList[i]108         fileStr = fileNameStr.split('.')[0]     #take off .txt109         classNumStr = int(fileStr.split('_')[0])110         vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)111         classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)112         print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)113         if (classifierResult != classNumStr): errorCount += 1.0114     print "\nthe total number of errors is: %d" % errorCount115     print "\nthe total error rate is: %f" % (errorCount/float(mTest))

2.数据集(测试集合训练集)

 

3.KNN测试结果

 

转载于:https://www.cnblogs.com/jackchen-Net/p/6800275.html

你可能感兴趣的文章
じ守望者┱ o
查看>>
底层驱动框架1
查看>>
jquery formcheck.js
查看>>
51nod 1251 Fox序列的数量 (容斥)
查看>>
centos7安装Lnmp(Linux+Nginx+MySql+Php+phpMyAdmin+Apache)
查看>>
iOS内存警告浅析
查看>>
Python入门---[第二篇]基础语法
查看>>
Swift---Swift5基本语法
查看>>
分析Ajax请求并抓取今日头条街拍美图
查看>>
[bzoj1452][JSOI2009]Count(树状数组)
查看>>
C/C++(指针数组)
查看>>
数据库的三大范式
查看>>
结对第二次—文献摘要热词统计及进阶需求
查看>>
swift 集合类型
查看>>
由Reference展开的学习
查看>>
第四届CCF软件能力认证
查看>>
字符数组和字符指针的差别
查看>>
简单的横向ListView实现(version 4.0)
查看>>
【转】jbdc程序启动报错:ORA-12505;PL/SQL却可以登录的解决方法
查看>>
Java Spring学习笔记03.@component
查看>>