主要是记录自己的内容,具体原理可能不够详细,代码可供参考。
原理参考 书《数据结构与算法分析:JAVA语言描述》P356 和 自定向下的伸展树的原理
package chapter12; import chapter04.MyCustomException; public class SplayTree<T extends Comparable<? super T>> { private Node<T> root; // 根节点 private final Node<T> nullNode; // 用于表示空节点,不用不直接用null是因为这样在进行展开splay时更加简单 private final Node<T> header = new Node<>(null); // 在splay展开时使用 private static class Node<T> { T element; Node<T> left; Node<T> right; Node(T element) { this.element = element; this.left = null; this.right = null; } Node(T element, Node<T> left, Node<T> right) { this.element = element; this.left = left; this.right = right; } } // ************************************************************************************************** public SplayTree() { nullNode = new Node<>(null); nullNode.left = nullNode.right = nullNode; root = nullNode; } public void makeEmpty() { root = nullNode; } public boolean isEmpty() { return root == nullNode; } public void insert(T element) { if (root == nullNode) { root = new Node<>(element, nullNode, nullNode); } else { root = splay(element, root); if (element.compareTo(root.element) < 0) { Node<T> tmp = new Node<>(element, root.left, root); root.left = nullNode; root = tmp; } else if (element.compareTo(root.element) > 0) { Node<T> tmp = new Node<>(element, root, root.right); root.right = nullNode; root = tmp; } } } public void remove(T element) { if (!contains(element)) { // contains()中调用了splay()函数,会将对应的节点转移到root节点下 return; } Node<T> newRoot; if (root.left == nullNode) { newRoot = root.right; } else { newRoot = root.left; newRoot = splay(element, newRoot); newRoot.right = root.right; } root = newRoot; } public T findMin() { if (isEmpty()) { throw new MyCustomException(); } Node<T> tmp = root; while (tmp.left != nullNode) { tmp = tmp.left; } root = splay(tmp.element, root); return root.element; } public T findMax() { if (isEmpty()) { throw new MyCustomException(); } Node<T> tmp = root; while (tmp.right != nullNode) { tmp = tmp.right; } root = splay(tmp.element, root); return root.element; } public boolean contains(T element) { if (isEmpty()) { return false; } root = splay(element, root); return element.compareTo(root.element) == 0; } public void printTree() { if (isEmpty()) System.out.println("Empty tree"); else printTree(root); } // *************************************************************************************************** // 本程序优化了自定向下展开的过程,使之更加简洁。 private Node<T> splay(T element, Node<T> t) { Node<T> leftTreeMax, rightTreeMin; header.left = header.right = nullNode; leftTreeMax = rightTreeMin = header; nullNode.element = element; while (true) { if (element.compareTo(t.element) < 0) { if (element.compareTo(t.left.element) < 0) { // 左一字型旋转 t = rotateWithLeftChild(t); } if (t.left == nullNode) { break; } // 单旋转和之字形旋转都可以只用以下三步解决 rightTreeMin.left = t; rightTreeMin = t; t = t.left; } else if (element.compareTo(t.element) > 0) { if (element.compareTo(t.right.element) > 0) { // 右一字型旋转 t = rotateWithRightChild(t); } if (t.right == nullNode) { break; } // 单旋转和之字形旋转都可以只用以下三步解决 leftTreeMax.right = t; leftTreeMax = t; t = t.right; } else { break; } } // 自顶向下展开的最后整理 leftTreeMax.right = t.left; rightTreeMin.left = t.right; t.left = header.right; t.right = header.left; return t; } private Node<T> rotateWithLeftChild(Node<T> t) { Node<T> tmp = t.left; t.left = tmp.right; tmp.right = t; return tmp; } private Node<T> rotateWithRightChild(Node<T> t) { Node<T> tmp = t.right; t.right = tmp.left; tmp.left = t; return tmp; } private void printTree(Node<T> node) { if (node.left != nullNode) { printTree(node.left); } System.out.println(node.element); if (node.right != nullNode) { printTree(node.right); } } // ***************************************************************************************************** public static void main(String[] args) { SplayTree<String> splayTree = new SplayTree<>(); splayTree.insert("GG"); splayTree.insert("TT"); splayTree.insert("OO"); splayTree.insert("BB"); splayTree.insert("NN"); splayTree.insert("LL"); splayTree.insert("GG"); splayTree.insert("UU"); splayTree.printTree(); System.out.println(); splayTree.remove("GG"); splayTree.remove("TT"); splayTree.remove("OO"); splayTree.remove("BB"); splayTree.remove("NN"); splayTree.remove("LL"); splayTree.printTree(); System.out.println(); splayTree.remove("UU"); splayTree.insert("AA"); System.out.println(splayTree.root.element); } }