KdTree.java 7.38 KB
import java.io.BufferedReader;
import java.io.FileReader;

import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.Queue;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;

public class KdTree {

	private static final boolean vertical = true;
	private static final boolean horizontal = false;
	private Node root;
	private int size;

	private static class Node {

		private Point2D p; // the point
		private RectHV rect; // the axis-aligned rectangle corresponding to this
								// node
		private Node lb; // the left/bottom subtree
		private Node rt; // the right/top subtree

		public Node(Point2D p, RectHV rect) {
			this.p = p;
			this.rect = rect;
		}

	}

	// construct an empty set of points
	public KdTree() {
		root = null;
	}

	// is the set empty?
	public boolean isEmpty() {
		return size() == 0;
	}

	// number of points in the set
	public int size() {
		return size;
	}

	// add the point to the set (if it is not already in the set)
	public void insert(Point2D p) {
		if (p == null)
			throw new NullPointerException();

		root = insert(root, p, new RectHV(0, 0, 1, 1), vertical);
	}

	// helper function to recursively insert to the tree
	private Node insert(Node x, Point2D p, RectHV rect, boolean orientation) {

		// if x is null we've reached the end and can add a new node
		if (x == null) {
			this.size++;
			return new Node(p, rect);

		}

		// if the node's point equals the point passed in
		// then return that node to avoid duplicates
		if (x.p.equals(p)) {
			return x;
		}

		// determine if a node belongs to the left or right branch of the tree
		// based off it's orientation. The root node is vertical and the
		// orientation
		// alternates between that and horizontal
		if (orientation == vertical) {

			// if the current node is vertical then the node it branches from
			// will be horizontal
			// so the x values are compared to determine which side to add the
			// new node to
			double cmp = p.x() - x.p.x();

			if (cmp < 0) {
				x.lb = insert(x.lb, p, new RectHV(x.rect.xmin(), x.rect.ymin(), x.p.x(), x.rect.ymax()), horizontal);
			} else {
				x.rt = insert(x.rt, p, new RectHV(x.p.x(), x.rect.ymin(), x.rect.xmax(), x.rect.ymax()), horizontal);
			}
		} else {

			// same as above except the current node is horizontal so the
			// branches will be vertical
			// the y values are compared to determine which side to add the new
			// node to
			double cmp = p.y() - x.p.y();

			if (cmp < 0) {
				x.lb = insert(x.lb, p, new RectHV(x.rect.xmin(), x.rect.ymin(), x.rect.xmax(), x.p.y()), vertical);
			} else {
				x.rt = insert(x.rt, p, new RectHV(x.rect.xmin(), x.p.y(), x.rect.xmax(), x.rect.ymax()), vertical);
			}

		}

		return x;
	}

	// does the set contain point p?
	public boolean contains(Point2D p) {

		if (p == null)
			throw new NullPointerException();

		return get(p);
	}

	// helper function to get a specific point p
	private boolean get(Point2D p) {
		return get(root, p, vertical);
	}

	// helper function to recursively find the node in the tree
	private boolean get(Node x, Point2D p, boolean orientation) {

		// the point doesn't exist in the tree
		if (x == null)
			return false;

		// the point does exist in the tree
		if (x.p.equals(p)) {
			return true;
		}

		// compare points based on the orientation and either their x or y
		// coordinate
		// and returns the next node in the tree
		double cmp;
		if (orientation == vertical) {
			cmp = p.x() - x.p.x();

		} else {
			cmp = p.y() - x.p.y();
		}

		if (cmp < 0) {
			return get(x.lb, p, !orientation);
		} else {
			return get(x.rt, p, !orientation);
		}

	}

	// draw all points to standard draw
	public void draw() {
		draw(root, vertical);

	}

	// draws red lines for vertical line segments
	// draws blue lines for horizontal line segments
	private void draw(Node x, boolean orientation) {

		if (orientation == vertical) {
			StdDraw.setPenColor(StdDraw.RED);
			StdDraw.line(x.p.x(), x.rect.ymin(), x.p.x(), x.rect.ymax());
		} else {
			StdDraw.setPenColor(StdDraw.BLUE);
			StdDraw.line(x.rect.xmin(), x.p.y(), x.rect.xmax(), x.p.y());
		}

		if (x.lb != null) {
			draw(x.lb, !orientation);
		}

		if (x.rt != null) {
			draw(x.rt, !orientation);
		}

		// draw point last to be on top of line
		StdDraw.setPenColor(StdDraw.BLACK);
		x.p.draw();
	}

	// all points that are inside the rectangle
	public Iterable<Point2D> range(RectHV rect) {
		Queue<Point2D> queue = new Queue<>();
		range(root, rect, queue);

		return queue;
	}

	// recurse through the tree to find intersecting rectangles of the
	// nodes in the tree while the node is not null.
	private void range(Node x, RectHV rect, Queue<Point2D> queue) {

		if (x != null) {

			if (!x.rect.intersects(rect)) {
				return;
			}

			if (rect.contains(x.p)) {
				queue.enqueue(x.p);
			}

			range(x.lb, rect, queue);
			range(x.rt, rect, queue);
		}

	}

	// a nearest neighbor in the set to point p; null if the set is empty
	public Point2D nearest(Point2D p) {
		if (p == null)
			throw new NullPointerException();

		return nearest(root, p, root.p, vertical);
	}

	
	// garbage please redo
	private Point2D nearest(Node x, Point2D p, Point2D min, boolean orientation) {

		if (x == null)
			return min;
		if (orientation == vertical) {
			if (p.x() < x.p.x()) {
				min = nearest(x.rt, p, min, horizontal);

				if (x.lb != null && min.distanceSquaredTo(p) > x.lb.rect.distanceSquaredTo(p)) {
					min = nearest(x.lb, p, min, horizontal);

				}
			} else {
				min = nearest(x.lb, p, min, horizontal);

				if (x.rt != null && min.distanceSquaredTo(p) > x.rt.rect.distanceSquaredTo(p)) {
					min = nearest(x.rt, p, min, horizontal);
				}

			}

		} else {
			if (p.y() < x.p.y()) {
				min = nearest(x.lb, p, min, vertical);

				if (x.lb != null && min.distanceSquaredTo(p) > x.lb.rect.distanceSquaredTo(p)) {
					min = nearest(x.lb, p, min,  vertical);

				}
			} else {
				min = nearest(x.lb, p, min, vertical);

				if (x.rt != null && min.distanceSquaredTo(p) > x.rt.rect.distanceSquaredTo(p)) {
					min = nearest(x.rt, p, min, vertical);
				}
			}
		}

		return min;
	}

	// unit testing of the methods (optional)
	public static void main(String[] args) throws Exception {

		KdTree kdtree = new KdTree();

		/*
		 * System.out.println(kdtree.size());
		 * System.out.println(kdtree.isEmpty()); kdtree.insert(new Point2D(0.2,
		 * 0.4)); kdtree.insert(new Point2D(0.9, 0.6)); kdtree.insert(new
		 * Point2D(0.024, 0.34)); kdtree.insert(new Point2D(0.1, 0.6));
		 * kdtree.insert(new Point2D(0.6, 0.2)); kdtree.insert(new Point2D(0.7,
		 * 0.1)); kdtree.insert(new Point2D(0.6, 0.2)); kdtree.insert(new
		 * Point2D(0.7, 0.1)); kdtree.insert(new Point2D(0.5, 0.5));
		 * 
		 * System.out.println(kdtree.isEmpty());
		 * System.out.println(kdtree.contains(new Point2D(0.97, 0.34)));
		 * System.out.println(kdtree.contains(new Point2D(0.5, 0.5)));
		 * 
		 * Iterable<Point2D> iterable = kdtree.range(new RectHV(0,0,1,1));
		 * 
		 * for(Point2D point : iterable){ System.out.println(point.toString());
		 * }
		 * 
		 * kdtree.draw();
		 */

		BufferedReader reader = null;
		try {
			reader = new BufferedReader(new FileReader(args[0]));
		} catch (Exception e) {
			System.out.println("File not found");
		}

		String line;
		while ((line = reader.readLine()) != null) {
			String[] splitLine = line.trim().split("\\s+");

			double a = Double.parseDouble(splitLine[0]);
			double b = Double.parseDouble(splitLine[1]);
			Point2D p = new Point2D(a, b);
			kdtree.insert(p);
		}

		

	}
}