package org.deft.repository.ast;

import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;

import org.deft.repository.ast.transform.ITreeNodeFilter;

/**
 * The TreeWalker allows to iterate over a tree of TreeNodes. It can be used
 * like a normal iterator using method nextNode(). Additionally it is possible
 * to use more specific navigation commands such as get the first child or get
 * the parent. While this can also be done directly on the tree, the
 * TreeWalker has one big advantage: it can filter nodes. If a filter is applied
 * all filtered nodes are ignored during traversal.
 * 
 * @author abartho
 */
public class TreeWalker implements Iterator<TreeNode> {
	
	private TreeNode root;
	private ITreeNodeFilter filter;
	private TreeNode current;
	private boolean rootIsLeaf = false;
	
	private Map<TreeNode, List<TreeNode>> childrenCache 
			= new HashMap<TreeNode, List<TreeNode>>();
	
	private class AcceptAllFilter implements ITreeNodeFilter {
		public boolean accept(TreeNode node) {
			return true;
		}
	}
	
	public TreeWalker(TreeNode root) {
		this(root, null);
	}
	
	public TreeWalker(TreeNode root, ITreeNodeFilter filter) {
		if (root == null) {
			throw new IllegalArgumentException("root must not be null");
		} 
		if (root instanceof TreeNodeRoot) {
			this.root = root.getFirstChild(); 
		} else {
			this.root = root;
		}
		this.current = root;
		if (filter == null) {
			this.filter = new AcceptAllFilter();
		} else {
			this.filter = filter;
		}
		rootIsLeaf = (children(root).isEmpty());
	}
	



	public TreeNode getCurrentNode() {
		return current;
	}
	
	public void setCurrentNode(TreeNode currentNode) {
		this.current = currentNode;
	}

	public ITreeNodeFilter getFilter() {
		return filter;
	}

	public TreeNode getRoot() {
		return root;
	}

	public TreeNode firstChild() {
		TreeNode tn = findFirstChild();
		if (tn != null) {
			current = tn;
		}
		return tn;
	}
	
	private TreeNode findFirstChild() {
		List<TreeNode> children = children();
		if (children.isEmpty()) {
			return null;
		} else {
			return children.get(0);
		}

	}
	
	public TreeNode lastChild() {
		TreeNode tn = findLastChild();
		if (tn != null) {
			current = tn;
		}
		return tn;
	}
	
	private TreeNode findLastChild() {
		List<TreeNode> children = filterList(current.getChildren());
		if (children.isEmpty()) {
			return null;
		} else {
			int posInList = posInList(children);
			return children.get(posInList);
		}
	}
	
	public TreeNode nextNode() {
		TreeNode tn = findNextNode(current);
		if (tn != null) {
			current = tn;
		}
		return tn;
	}

	private TreeNode findNextNode(TreeNode start) {
		//if the tree consists only of a root node
		//(i.e. root has no children), then don't try further
		//and return null
		if (rootIsLeaf) {
			return null;
		}
		
		//depth first search, so first look for a child
		TreeNode firstChild = findFirstChild();	
		if (firstChild != null) { 	//if there was a first child, return it
			return firstChild;
		} else {					//if there was no first child
			//try the next sibling
			TreeNode nextSibling = findNextSibling(start);
			if (nextSibling != null) {	//if there was a next sibling, return it
				return nextSibling;
			} else {					//if there was no next sibling
				//we have completed a subtree, go up until there is a next sibling
				//or we reach the root
				TreeNode parent = start;
				TreeNode tn = null;
				boolean isLastInList = false;
				boolean reachedRoot = false;
				do {
					tn = parent;
					parent = findParentNode(tn);
					List<TreeNode> list = filterList(parent.getChildren());
					isLastInList = isLastInList(tn, list);
					reachedRoot = parent == root;
					
				} while (isLastInList && !reachedRoot);

				if (!isLastInList) {	//if we found an unvisited sibling, return it
					TreeNode ret = findNextSibling(tn);
					return ret;
				} else {			//if we reached the root, everything has been found
					return null;
				}
			}
		}
	}

	
	public TreeNode nextSibling() {
		TreeNode tn = findNextSibling(current);
		if (tn != null) {
			current = tn;
		}
		return tn;
	}
	
	private TreeNode findNextSibling(TreeNode start) {
		TreeNode parent = findParentNode(start);
		List<TreeNode> children = children(parent);
		if (children.isEmpty() || isLastInList(start, children)) {
			return null;
		} else {
			int posInList = posInList(start, children);
			return children.get(posInList + 1); 
		}
	}

	public TreeNode parentNode() {
		TreeNode tn = findParentNode(current);
		if (tn != null) {
			current = tn;
		}
		return tn;
	}
	
	private TreeNode findParentNode(TreeNode start) {
//		if (start instanceof TokenNode && ((TokenNode)start).getToken().getText().equals("}")) {
//			int x = 1;
//		}
		if (start == root) {
			return null;
		}

		TreeNode parent = start.getParent();
		//check all ancestors until we find one that the filter accepts
		//or we reach the root
		while (!filter.accept(parent) && parent != root) {
			parent = parent.getParent();
		}
		return parent;
	}
	
	public List<TreeNode> children() {
		return children(current);
	}
	
	private List<TreeNode> children(TreeNode node) {
		if (node == null) {
			return new LinkedList<TreeNode>();
		}
		//this method is called very often with the same node
		//if there are nodes to be filtered the necessary
		//computation time sums up. Therefore cache
		//the children for each node
		if (childrenCache.containsKey(node)) {
			return childrenCache.get(node);
		}
		
		List<TreeNode> unfilteredChildren = node.getChildren();
		List<TreeNode> filteredChildren = new LinkedList<TreeNode>(unfilteredChildren);
		List<Integer> filteredPositions = new LinkedList<Integer>();
		int pos = 0;
		//get the indexes of all filtered nodes 
		for (TreeNode tn : unfilteredChildren) {
			if (!filter.accept(tn)) {
				filteredPositions.add(pos);
			}
			pos++;
		}
		//replace all filtered nodes by their descendant nodes (recursively)
		for (int i = filteredPositions.size() - 1; i >= 0; i--) {
			//reusing pos variable
			pos = filteredPositions.get(i);
			List<TreeNode> c = children(unfilteredChildren.get(pos));
			filteredChildren.remove(pos);
			filteredChildren.addAll(pos, c);
		}
		childrenCache.put(node, filteredChildren);
		return filteredChildren;
	}


//	public TreeNode previousNode() {
//
//	}

//	public TreeNode previousSibling() {
//
//	}


	
	private List<TreeNode> filterList(List<TreeNode> origList) {
		List<TreeNode> filteredList = new LinkedList<TreeNode>();
		for (TreeNode tn : origList) {
			if (filter.accept(tn)) {
				filteredList.add(tn);
			}
		}
		return filteredList;
	}
	
	
	private int posInList(List<TreeNode> list) {
		return posInList(current, list);
	}
	
	private int posInList(TreeNode node, List<TreeNode> list) {
		int pos = -1;
		for (TreeNode n : list) {
			pos++;
			if (n == node) {
				return pos;
			}
		}
		return -1;
	}

	
	private boolean isLastInList(TreeNode node, List<TreeNode> list) {
		return list.size() > 0 && list.get(list.size() - 1) == node;
	}

	//stores the latest TreeNode that the iterator has returned
	private TreeNode iteratorOutput;
	
	
	public boolean hasNext() {
		TreeNode next = findNextNode(current);
		return next != null;
		
		
		//usually current is one TreeNode ahead iteratorOutput
		//but if all nodes have been processed, iteratorOutput
		//takes the old value of current and current stays the same
		//return (current != iteratorOutput);
	}

	public TreeNode next() {
		if (!hasNext()) {
			throw new NoSuchElementException();
		}

		nextNode();
		return current;
	}

	public void remove() {
		throw new UnsupportedOperationException();
	}

}
