1. 首页
  2. Python

统计学习方法3:KNN分类,KD树最近邻搜索Python实现

代码借鉴自此链接

适用问题:多类分类,回归(回归暂时没有使用)

模型特点:不具有显示的学习过程,利用训练数据对特征向量空间进行划分。

基本要素:1.K值的选择、2.距离度量(一般用欧式距离)、3.分类决策规则(如多数表决)

K值     近似误差   估计误差     缺点                  特点 小         小         大      容易发生过拟合         模型复杂 大         大         小      不相似也对预测起作用    模型简单

在应用中,K值一般取一个较小的数值。

通常采用交叉验证法来选取最优的k值

附上完整代码(可运行):

#author:胤 #time:2019/3/15  19:10  '''代码来自:https://blog.csdn.net/tudaodiaozhale/article/details/77327003''' # --*-- coding:utf-8 --*-- import numpy as np   class Node:  # 结点     def __init__(self, data, lchild=None, rchild=None):         self.data = data         self.lchild = lchild         self.rchild = rchild   class KdTree:  # kd树     def __init__(self):         self.kdTree = None      def create(self, dataSet, depth):  # 创建kd树,返回根结点  参数1:数据集   参数2:树的深度(同时决定:排序依据的维度)         if (len(dataSet) > 0):             m, n = np.shape(dataSet)  # 求出样本行,列             midIndex = int(m / 2)  # 中间数的索引位置             axis = depth % n  # 判断以哪个轴划分数据             sortedDataSet = self.sort(dataSet, axis)  # 进行排序             node = Node(sortedDataSet[midIndex])  # 将节点数据域设置为中位数,具体参考下书本             # 将两边的数切割成两个集合,然后递归调用create方法             leftDataSet = sortedDataSet[: midIndex]  # 将中位数的左边创建2改副本             rightDataSet = sortedDataSet[midIndex + 1:]             print(leftDataSet)             print(rightDataSet)             node.lchild = self.create(leftDataSet, depth + 1)  # 将中位数左边样本传入来递归创建树             node.rchild = self.create(rightDataSet, depth + 1)             return node         else:             return None      def sort(self, dataSet, axis):  # 采用冒泡排序,利用aixs作为轴进行划分         sortDataSet = dataSet[:]  # 由于不能破坏原样本,此处建立一个副本         m, n = np.shape(sortDataSet)         for i in range(m):             for j in range(0, m - i - 1):                 if (sortDataSet[j][axis] > sortDataSet[j + 1][axis]):  #把数值大的一个一个向下沉                     temp = sortDataSet[j]                     sortDataSet[j] = sortDataSet[j + 1]                     sortDataSet[j + 1] = temp         print(sortDataSet)         return sortDataSet      def preOrder(self, node):  # 前序遍历         if node != None:             print("tttt->%s" % node.data)             self.preOrder(node.lchild)             self.preOrder(node.rchild)      def search(self, tree, x):  # 搜索         self.nearestPoint = None  # 保存最近的点         self.nearestValue = 0  # 保存最近的值          def travel(node, depth=0):  # 递归搜索             if node != None:  # 递归终止条件                 n = len(x)      # 特征数                 axis = depth % n  # 计算轴                 if x[axis] < node.data[axis]:  # 如果数据小于结点,则往左结点找                     travel(node.lchild, depth + 1)                 else:                     travel(node.rchild, depth + 1)                  # 以下是递归完毕后,往父结点方向回朔,对应算法3.3(3)                 distNodeAndX = self.dist(x, node.data)  # 目标和节点的距离判断                 if (self.nearestPoint == None):  # 确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)                     self.nearestPoint = node.data                     self.nearestValue = distNodeAndX                 elif (self.nearestValue > distNodeAndX):                     self.nearestPoint = node.data                     self.nearestValue = distNodeAndX                  print(node.data,'t', depth,'t',self.nearestValue,distNodeAndX,'t',node.data[axis], x[axis])                 if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  # 确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)                     if x[axis] < node.data[axis]:                         travel(node.rchild, depth + 1)                     else:                         travel(node.lchild, depth + 1)          print('位置t树深度t当前最近距离t当前点与目标点距离t')         travel(tree)        #tree的属性与根结点是一样的,所以可以带入定义方法参数时,里面的node属性         return self.nearestPoint      def dist(self, x1, x2):  # 欧式距离的计算         return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5   if __name__ == '__main__':     dataSet = [[2, 3],                [5, 4],                [9, 6],                [4, 7],                [8, 1],                [7, 2]]     x = [3, 6]     kdtree = KdTree()     tree = kdtree.create(dataSet, 0)   #创建KD树     print('打印先根遍历的输出顺序')     kdtree.preOrder(tree)              #先序遍历测试     print("给定的数据的最近邻",kdtree.search(tree, x))      #在树中搜索给定的 X的最近邻

原文始发于:统计学习方法3:KNN分类,KD树最近邻搜索Python实现

主题测试文章,只做测试使用。发布者:熱鬧獨處,转转请注明出处:http://www.cxybcw.com/11969.html

联系我们

13687733322

在线咨询:点击这里给我发消息

邮件:1877088071@qq.com

工作时间:周一至周五,9:30-18:30,节假日休息

QR code