首页 > 网络 > 云计算 > 正文
python实现knn算法
2017-04-28       个评论      
收藏    我要投稿
python实现knn算法:python如何实现knn算法呢?希望下面的文章对大家有所帮助。
import numpy as np
import operator

def createDataSet():
    group = np.array([[1.0,1.1],[1.0,1.0],[0.0,0.0],[0.0,0.1]])
    labels = ['A','A','B','B']
    return group,labels

#分类算法:inX待分类的点
def classify0(inX,dataSet,labels,k):
    dataSetSize = dataSet.shape[0] #取出行数,为了方便下一步让待分类的点扩充为矩阵
    diffMat = np.tile(inX,(dataSetSize,1)) - dataSet #把点inX复制成dataSetSize行,1列的矩阵
    sqDiffMat = diffMat ** 2
    sqDistances = sqDiffMat.sum(axis = 1)#按行相加(x1-x2)**2 + (y1-y2)**2,因为数据已经是一维的了
    distance = sqDistances ** 0.5
    sortedDistance = distance.argsort()
    classCount = {}
    for i in range(k):
        voteLabel = labels[sortedDistance[i]]
        classCount[voteLabel] = classCount.get(voteLabel,0) + 1 #默认值为0,取出每个类别的数量
    # 把{"类别":"次数"}变成[('类别','次数')]的格式,然后取次数字段,按降序排列
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0] #取出次数最多的类别 [('类别','次数')]  '类别'

#把文件的数据转换成矩阵的格式
def file2matrix(filename):
    fr = open(filename)
    arrayOfLines = fr.readline()#按行读取
    numOfLines = len(arrayOfLines)#得出总行数
    returnMat = np.zeros((numOfLines,3))#定义一个空的矩阵,numOfLines行,3列
    classLabelVector = []
    index = 0
    for line in arrayOfLines:
        line = line.strip()#去掉换行符/n,空格
        listFromLine = line.split('\t')#每行数据按照\t进行分割
        returnMat[index,:] = listFromLine[0:3]# :代表索引取到末尾,把数据填充到returnMat这个空矩阵
        classLabelVector.append(int(listFromLine[-1]))#因为文本的数据是String类型,所以需要转换
        index += 1 #一条记录加1
    return returnMat, classLabelVector

#对数据进行归一化处理
def autoNorm(dataSet):
    minVals = dataSet.min(0)#表示不同行相比较得出最小,得到的是一行数据
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    m = dataSet.shape[0]
    normDataSet = np.zeros(np.shape(dataSet))#创建一个和dataSet一样的0矩阵
    normDataSet = dataSet - np.tile(minVals,(m,1)) #把最小值的那一行复制成m行,列不变的矩阵,再被dataSet相减
    normDataSet = normDataSet / np.tile(ranges,(m,1))#再除以最大值减去最小值的值
    return normDataSet,ranges,minVals

if __name__ == '__main__':
    group, labels = createDataSet()
    result = classify0([3,0.2],group,labels,3)
    print (result)
点击复制链接 与好友分享!回本站首页
上一篇:scala解析json日志
下一篇:最后一页
相关文章
图文推荐
文章
推荐
点击排行

关于我们 | 联系我们 | 广告服务 | 投资合作 | 版权申明 | 在线帮助 | 网站地图 | 作品发布 | Vip技术培训
版权所有: 红黑联盟--致力于做实用的IT技术学习网站