package mlproject.phylo;

import java.util.*;


public class PhylogeneticTree {
    private ArrayList<PhylogeneticTreeNode> tree; //nodes of the tree
    private int[] felsensteinOrder, l, r, p;
    private double[] d;
    private ArrayList<Integer> leaves; // indexes of leaves
    private HashMap<String, Integer> columnOrder; //indexes of species


    public int getNumberOfSpecies() {
        return leaves.size();
    }

    public int getSpecieIndex(String specieName) {
        Integer res = columnOrder.get(specieName);
        if (res == null)
            return -1;
        else
            return res;
    }

    private static ArrayList<PhylogeneticTreeNode> parseNewickString(String newickString) {
        //simple top-down parser
        LinkedList<PhylogeneticTreeNode> result = new LinkedList<PhylogeneticTreeNode>();
        LinkedList<PhylogeneticTreeNode> queue = new LinkedList<PhylogeneticTreeNode>();
        int nodeId = 0;
        StringBuffer buffer = new StringBuffer();
        PhylogeneticTreeNode node;
        For:for (int ix = 0; ix < newickString.length(); ix++) {
            char c = newickString.charAt(ix);
            switch (c) {
            case '('://push
                node = new PhylogeneticTreeNode();
                node.i = nodeId;
                node.n = buffer.toString();
                if (!queue.isEmpty()) {
                    node.p = queue.getLast().i;
                    if (queue.getLast().l == -1)
                        queue.getLast().l = nodeId;
                    else if (queue.getLast().r == -1)
                        queue.getLast().r = nodeId;
                    else
                        throw new RuntimeException("bad newick input");
                }
                result.add(node);
                queue.add(node);
                nodeId++;
                buffer = new StringBuffer();
                break;
            case ':'://push
                if (buffer.length() == 0)
                    break;
                node = new PhylogeneticTreeNode();
                node.i = nodeId;
                node.n = buffer.toString();
                if (!queue.isEmpty()) {
                    node.p = queue.getLast().i;
                    if (queue.getLast().l == -1)
                        queue.getLast().l = nodeId;
                    else if (queue.getLast().r == -1)
                        queue.getLast().r = nodeId;
                    else
                        throw new RuntimeException("bad newick input");
                }
                result.add(node);
                queue.add(node);
                nodeId++;
                buffer = new StringBuffer();
                break;
            case ',':
                node = queue.removeLast();
                node.d = Double.parseDouble(buffer.toString());
                buffer = new StringBuffer();
                break;
            case ')': // pop
                node = queue.removeLast();
                node.d = Double.parseDouble(buffer.toString());
                buffer = new StringBuffer();
                break;
            case ';': //pop
                node = queue.removeLast();
                if(buffer.length()>0)
                    node.d = Double.parseDouble(buffer.toString());
                if(!queue.isEmpty() )
                    throw new RuntimeException("bad newick input");
                break For;
            default:
                buffer.append(c);
            }
        }
        return new ArrayList<PhylogeneticTreeNode>(result);
    }

    public PhylogeneticTree(String newickString) { //parse tree
        /* initiate data */
        tree = parseNewickString(newickString.trim());
        leaves = new ArrayList<Integer>();
        columnOrder = new HashMap<String, Integer>();
        felsensteinOrder = new int[tree.size()];

        
        for (int i = 0; i < tree.size(); i++) {
            /* find leaves */
            PhylogeneticTreeNode node = tree.get(i);
            if (node.l * node.r < 0)
                throw new RuntimeException("Non binary node:" + node);
            if (node.l == -1) {
                leaves.add(i);
                columnOrder.put(node.n, leaves.size() - 1);
            }
        }
        /* calculate the order of felsenstein computation */
        felsensteinOrder = calculateFelsensteinOrder();
        l = new int[tree.size()];
        r = new int[tree.size()];
        p = new int[tree.size()];
        d = new double[tree.size()];
        for (int i = 0; i < tree.size(); i++) {
            l[i] = tree.get(i).l;
            r[i] = tree.get(i).r;
            p[i] = tree.get(i).p;
            d[i] = tree.get(i).d;
        }
    }

    private int[] calculateFelsensteinOrder() {
        /* breadth-first search */
        LinkedList<Integer> queue = new LinkedList<Integer>();
        int[] res = new int[tree.size()]; /* ids of nodes in order to be caluclated*/
        int ix = tree.size() - 1;
        queue.add(0);
        while (queue.size() != 0) {
            PhylogeneticTreeNode node = tree.get(queue.removeFirst());
            if (node.l == -1)
                continue;
            queue.add(node.l);
            res[--ix] = node.l;
            queue.add(node.r);
            res[--ix] = node.r;
        }
        return res;

    }

    public ArrayList<PhylogeneticTreeNode> getTree() {
        return tree;
    }

    public int[] getFelsensteinOrder() {
        return felsensteinOrder;
    }


    public int[] getLeaves() {
        int[] result = new int[leaves.size()];
        for (int i = 0; i < leaves.size(); i++)
            result[i] = leaves.get(i);
        return result;
    }

    public void setFelsensteinOrder(int[] felsensteinOrder) {
        this.felsensteinOrder = felsensteinOrder;
    }

    public int[] getL() {
        return l;
    }

    public int[] getR() {
        return r;
    }

    public int[] getP() {
        return p;
    }

    public double[] getD() {
        return d;
    }
}
