原理:
1.https://www.cnblogs.com/porco/p/4464414.html(里面代码不好修改,不建议直接利用)
2.https://www.cnblogs.com/zfyouxi/p/4795584.html
实例:可以用来求最短距离的点,例如:根据经纬度求最近点
java实现:
package main;import java.util.Collections;
import java.util.LinkedList;
import java.util.List;public class KDTreeMain {public static int KDTCount = 0; // 统计在kdt 搜索的时候,计算了和几个点的距离public static void main(String[] args) {/** n >> 2^xn 时, KDTCount才明显 < n* * =========================* n = 40000, xn = 10buld kdt time = 30760.0query kdt time = 1404.0best = 0.3296984744447501KDTCount = 4488query brute time = 3317.0best2 = 0.3296984744447501==========================n = 50000, xn = 10buld kdt time = 49664.0query kdt time = 558.0best = 0.3355435846472523KDTCount = 2056query brute time = 5557.0best2 = 0.3355435846472523==========================n = 50000. xn = 20buld kdt time = 63560.0query kdt time = 15136.0best = 0.8319764077450744KDTCount = 37500query brute time = 5791.0best2 = 0.8319764077450744**/int n = 50000; // 样本点个数int xn = 10; // 样本点维数int deep = 0; // 轴// 随机生成训练样本数据List<Point> pointList = new LinkedList<Point>();for (int i = 0; i < n; i++) {double[] d = new double[xn];for (int j = 0; j < d.length; j++) {d[j] = Math.random();}pointList.add(new Point(d));}// build treeSystem.out.println("beging insert...");double t1 = System.currentTimeMillis();KDTreeMain kdt = new KDTreeMain();Node root = new Node();kdt.insert(root, pointList, deep);double t2 = System.currentTimeMillis();System.out.println("buld kdt time = " + (t2 - t1));// show tree
// char[] path = new char[30];
// int pi = 0;
// showKDTree(root, path, pi);// 目标点double[] f = new double[xn];for (int j = 0; j < f.length; j++) {f[j] = Math.random();}Point p = new Point(f);// KDT搜索double t3 = System.currentTimeMillis();double best = Double.MAX_VALUE;best = query(root, p, best, deep);double t4 = System.currentTimeMillis();System.out.println("\nquery kdt time = " + (t4 - t3));System.out.println("best = " + best);System.out.println("KDTCount = " + KDTCount);// 暴力法double t5 = System.currentTimeMillis();int index = 0;double best2 = Double.MAX_VALUE;for (int i = 0; i < n; i++) {double dist = getDist(p, pointList.get(i));if (dist < best2) {best2 = dist;index = i;}}double t6 = System.currentTimeMillis();System.out.println("\nquery brute time = " + (t6 - t5));System.out.println("best2 = " + best2);// System.out.println("goal point = " + p.x[0] + " , " + p.x[1]);// System.out.println("neast point = " + pointList.get(index).x[0] + " , " + pointList.get(index).x[1]);}// build kdtreeprivate void insert(Node root, List<Point> pointList, int deep) {int mid = pointList.size() / 2;// 排序后拿到中位数Point.deep = deep;Collections.sort(pointList);// 类似快排的方法拿到中位数// getMedian(pointList, 0, pointList.size() - 1, mid, deep);// showList(pointList);// System.out.println("=========================");int pl = mid;int pr = mid;while(pl >= 0 && pointList.get(pl).x[deep] == pointList.get(mid).x[deep]) pl--;while(pr < pointList.size() && pointList.get(pr).x[deep] == pointList.get(mid).x[deep]) pr++;List<Point> pointListLeft = pointList.subList(0, pl + 1);List<Point> pointListMid = pointList.subList(pl + 1, pr);List<Point> pointListRight = pointList.subList(pr, pointList.size());root.pointList = pointListMid;if (pointListLeft.size() > 0) {root.l = new Node();insert(root.l, pointListLeft, (deep + 1) % pointList.get(0).x.length);}if (pointListRight.size() > 0) {root.r = new Node();insert(root.r, pointListRight, (deep + 1) % pointList.get(0).x.length);}}// search the nearest point to p in KDTreeprivate static double query(Node root, Point p, double best, int deep) {if (root == null) return Double.MAX_VALUE; double dist; if (root.l == null && root.r == null) { for (int i = 0; i < root.pointList.size(); i++) { KDTCount++; dist = getDist(root.pointList.get(i), p); best = dist < best ? dist : best; } return best; } // left or right if (p.x[deep] <= root.pointList.get(0).x[deep]) { best = query(root.l, p, best, (deep + 1) % p.x.length);} else { best = query(root.r, p, best, (deep + 1) % p.x.length);} // cur for (int i = 0; i < root.pointList.size(); i++) { KDTCount++; dist = getDist(root.pointList.get(i), p); best = dist < best ? dist : best; } // another side if (best >= Math.abs(p.x[deep] - root.pointList.get(0).x[deep])) { double distAnother = Double.MAX_VALUE; if (p.x[deep] <= root.pointList.get(0).x[deep]) { distAnother = query(root.r, p, best, (deep + 1) % p.x.length);} else { distAnother = query(root.l, p, best, (deep + 1) % p.x.length);} if (distAnother < best) { best = distAnother; } } return best; }// print kdtreeprivate static void showKDTree(Node root, char[] path, int pi) {if (root == null) return;System.out.print(pi + "# ");for (int i = 0; i < pi; i++) {System.out.print(path[i] + " ");}// midshowList(root.pointList);// leftpath[pi++] = 'L';showKDTree(root.l, path, pi);pi--;// rightpath[pi++] = 'R';showKDTree(root.r, path, pi);pi--;}// 欧式距离private static double getDist(Point p1, Point p2) {double sum = 0;for (int i = 0; i < p1.x.length; i++) {sum += (p1.x[i] - p2.x[i]) * (p1.x[i] - p2.x[i]);}if (sum == 0) return Double.MAX_VALUE;return Math.sqrt(sum);}// 类似快排的思想拿到中位数,O(n)时间复杂度private void getMedian(List<Point> pointList, int l, int r, int k, int deep) {if (l == r && k == 0) return; int pl = l; int pr = r; double[] tmp = pointList.get(l).x; while (pl < pr) { while (pl < pr && pointList.get(pr).x[deep] > tmp[deep]) pr--; if (pl >= pr) break; pointList.get(pl++).x = pointList.get(pr).x; while (pl < pr && pointList.get(pl).x[deep] < tmp[deep]) pl++; if (pl >= pr) break; pointList.get(pr--).x = pointList.get(pl).x;} pointList.get(pl).x = tmp; if(pl - l == k) return; if(pl - l > k) { getMedian(pointList, l, pl - 1, k, deep); } else { getMedian(pointList, pl + 1, r, k - (pl - l + 1), deep); } }// 打印一个点列表private static void showList(List<Point> pointList) {for (int i = 0; i < pointList.size(); i++) {for( int j = 0; j < pointList.get(i).x.length; j++) {System.out.print(pointList.get(i).x[j] + ",");}System.out.print(" / ");}System.out.println();}
}
// kdtree里的节点
class Node {List<Point> pointList = new LinkedList<Point>();Node l = null;Node r = null;
}
// 数据点
class Point implements Comparable<Point>{public static int deep = 0;double[] x;public Point(double[] d) {x = new double[d.length];for (int i = 0; i < d.length; i++) {x[i] = d[i];}}public int compareTo(Point o) {// return (int)(this.x[deep] == other.x[deep]); 出错,因为x的值在0~1之间,那么int都是0了Point other = (Point)o;if (this.x[deep] == other.x[deep]) return 0;if (this.x[deep] > other.x[deep]) return 1;return -1;}
}