代码借鉴自此链接
适用问题:多类分类,回归(回归暂时没有使用)
模型特点:不具有显示的学习过程,利用训练数据对特征向量空间进行划分。
基本要素:1.K值的选择、2.距离度量(一般用欧式距离)、3.分类决策规则(如多数表决)
K值 近似误差 估计误差 缺点 特点 小 小 大 容易发生过拟合 模型复杂 大 大 小 不相似也对预测起作用 模型简单
在应用中,K值一般取一个较小的数值。
通常采用交叉验证法来选取最优的k值
附上完整代码(可运行):
#author:胤 #time:2019/3/15 19:10 '''代码来自:https://blog.csdn.net/tudaodiaozhale/article/details/77327003''' # --*-- coding:utf-8 --*-- import numpy as np class Node: # 结点 def __init__(self, data, lchild=None, rchild=None): self.data = data self.lchild = lchild self.rchild = rchild class KdTree: # kd树 def __init__(self): self.kdTree = None def create(self, dataSet, depth): # 创建kd树,返回根结点 参数1:数据集 参数2:树的深度(同时决定:排序依据的维度) if (len(dataSet) > 0): m, n = np.shape(dataSet) # 求出样本行,列 midIndex = int(m / 2) # 中间数的索引位置 axis = depth % n # 判断以哪个轴划分数据 sortedDataSet = self.sort(dataSet, axis) # 进行排序 node = Node(sortedDataSet[midIndex]) # 将节点数据域设置为中位数,具体参考下书本 # 将两边的数切割成两个集合,然后递归调用create方法 leftDataSet = sortedDataSet[: midIndex] # 将中位数的左边创建2改副本 rightDataSet = sortedDataSet[midIndex + 1:] print(leftDataSet) print(rightDataSet) node.lchild = self.create(leftDataSet, depth + 1) # 将中位数左边样本传入来递归创建树 node.rchild = self.create(rightDataSet, depth + 1) return node else: return None def sort(self, dataSet, axis): # 采用冒泡排序,利用aixs作为轴进行划分 sortDataSet = dataSet[:] # 由于不能破坏原样本,此处建立一个副本 m, n = np.shape(sortDataSet) for i in range(m): for j in range(0, m - i - 1): if (sortDataSet[j][axis] > sortDataSet[j + 1][axis]): #把数值大的一个一个向下沉 temp = sortDataSet[j] sortDataSet[j] = sortDataSet[j + 1] sortDataSet[j + 1] = temp print(sortDataSet) return sortDataSet def preOrder(self, node): # 前序遍历 if node != None: print("tttt->%s" % node.data) self.preOrder(node.lchild) self.preOrder(node.rchild) def search(self, tree, x): # 搜索 self.nearestPoint = None # 保存最近的点 self.nearestValue = 0 # 保存最近的值 def travel(node, depth=0): # 递归搜索 if node != None: # 递归终止条件 n = len(x) # 特征数 axis = depth % n # 计算轴 if x[axis] < node.data[axis]: # 如果数据小于结点,则往左结点找 travel(node.lchild, depth + 1) else: travel(node.rchild, depth + 1) # 以下是递归完毕后,往父结点方向回朔,对应算法3.3(3) distNodeAndX = self.dist(x, node.data) # 目标和节点的距离判断 if (self.nearestPoint == None): # 确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a) self.nearestPoint = node.data self.nearestValue = distNodeAndX elif (self.nearestValue > distNodeAndX): self.nearestPoint = node.data self.nearestValue = distNodeAndX print(node.data,'t', depth,'t',self.nearestValue,distNodeAndX,'t',node.data[axis], x[axis]) if (abs(x[axis] - node.data[axis]) <= self.nearestValue): # 确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b) if x[axis] < node.data[axis]: travel(node.rchild, depth + 1) else: travel(node.lchild, depth + 1) print('位置t树深度t当前最近距离t当前点与目标点距离t') travel(tree) #tree的属性与根结点是一样的,所以可以带入定义方法参数时,里面的node属性 return self.nearestPoint def dist(self, x1, x2): # 欧式距离的计算 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5 if __name__ == '__main__': dataSet = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]] x = [3, 6] kdtree = KdTree() tree = kdtree.create(dataSet, 0) #创建KD树 print('打印先根遍历的输出顺序') kdtree.preOrder(tree) #先序遍历测试 print("给定的数据的最近邻",kdtree.search(tree, x)) #在树中搜索给定的 X的最近邻
原文始发于:统计学习方法3:KNN分类,KD树最近邻搜索Python实现
主题测试文章,只做测试使用。发布者:熱鬧獨處,转转请注明出处:http://www.cxybcw.com/11969.html