package mlproject.io;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;

import java.io.IOException;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;

import mlproject.hmm.EvidenceStateModel;
import mlproject.hmm.StateModel;

import mlproject.util.Util;

public class GTFStateInputReaderEvidence137 {
    private class GTFElement implements Comparable{
        public int start,end;
        public String strand,gene;

        public int compareTo(Object t) {
            if(start < ((GTFElement)t).start) return -1;
            else return 1;
        }
    }
    
    private HashMap<String,ArrayList<GTFElement>> cdsMap = new HashMap<String,ArrayList<GTFElement>>();
    private InputReader inputReader;
    private StateModel stateModel;
    
    private int[] exonerateCodingStats = new int[4];
    private int[] exonerateIntronStats = new int[4];
    private int[] weaselStats = new int[4];
    
    private int plusCount;
    private int minusCount;
    
    private static int totalIntronLength;
    private static int intronCount;
    private static int totalExonLength;
    private static int exonCount;
    
    
    
    public GTFStateInputReaderEvidence137(InputReader inputReader, String gtfFile, StateModel stateModel) throws FileNotFoundException, IOException {
        this.inputReader = inputReader;
        this.stateModel = stateModel;
        parseGTF(gtfFile);
    }
    
    public TrainingSequence readNextTrainingSequence() throws IOException {
        InputSequence in = inputReader.readNextSequence();
        if(in == null) return null;
        incrementStats(in);
        int[] states = readStates(in.getComponent(0));
        TrainingSequence result = new TrainingSequence(in,states);
        if(states == null) {
            result.setValid(false);
        } else {
            for(int i=1;i<states.length;i++){
                if(!stateModel.getTransitionMatrix()[states[i-1]][states[i]]){
                    result.setValid(false);
                    System.out.println("Invalid sequence:"+in.getName()+" transition:"+states[i-1]+"-"+states[i]);
                } else {
                    result.setValid(true);
                }
            }
        }
        return result;
    }
    
    private void parseGTF(String gtfFile) throws FileNotFoundException, IOException {
        BufferedReader br = new BufferedReader(new FileReader(gtfFile));
        String line;
        while((line=br.readLine())!=null){
            if(line.equals("")) continue;
            String[] fields = line.split("\t");
            if(!fields[2].equals("CDS")) continue;
            if(cdsMap.get(fields[0]) == null) cdsMap.put(fields[0], new ArrayList<GTFElement>());
            GTFElement el = new GTFElement();
            el.gene = fields[0];
            el.start = Integer.parseInt(fields[3]);
            el.end = Integer.parseInt(fields[4]);
            el.strand = fields[6];
            cdsMap.get(fields[0]).add(el);
        }
        br.close();
        
        for(ArrayList<GTFElement> list:cdsMap.values()){
            Collections.sort(list);
        }
    }
    
    

    private int[] readStates(InputSequence<Integer> fasta) throws IOException {
        int A=0;
        int C=1;
        int G=2;
        int T=3;
        
        ArrayList<GTFElement> list = cdsMap.get(fasta.getName());
        if(list == null)
            throw new RuntimeException("Missing gtf:"+fasta.getName());
        int[] states = new int[fasta.length()];
        
        int exonCount = 0;
        String strand = list.get(0).strand;
        for(int i=0;i<list.size();i++){
            /*  first assign +1 to exonic states exluding start and stop codon and assign state numbers to intronic states*/
            if(i>0){ // there exists an intron
                int length = list.get(i).start - list.get(i-1).end - 1;
                if(length < 14){
                    System.out.println("Invalid sequence:"+fasta.getName()+" intron length:"+length);
                    return null;
                }
                if(strand.equals("+")){
                    int phase = exonCount % 3;
                    int baseState = -1;
                    if(phase == 0){
                        baseState = 12;
                    } else if(phase==1){
                        if(fasta.getX(list.get(i-1).end-1) == T){
                            baseState = 28; 
                        } else {
                            baseState = 44;
                        }
                    } else {
                        
                        if(fasta.getX(list.get(i-1).end-2) == T){
                            if(fasta.getX(list.get(i-1).end-1) == A){
                                baseState = 60;
                            } else {
                                baseState = 76;
                            }
                        
                        } else {
                            baseState = 92;
                        }
                    }
                    for(int pos=0;pos<6;pos++){
                        states[list.get(i-1).end+pos] = baseState+pos;
                    }
                    for(int pos=0;pos<9;pos++){
                        states[list.get(i).start-10+pos] = baseState+7+pos;
                    }
                    Arrays.fill(states,list.get(i-1).end+6,list.get(i).start-10,baseState+6);
                    
                } else {
                    int phase = (3 - exonCount % 3) % 3;
                    int baseState = -1;
                    if(phase == 0){
                        baseState = 122;
                    } else if(phase==1){
                        if(fasta.getX(list.get(i-1).end-2) == C || fasta.getX(list.get(i-1).end-2) == T){
                            if(fasta.getX(list.get(i-1).end-1) == T){
                                baseState = 138;
                            } else {
                                baseState = 154;
                            }
                        } else {
                            baseState = 170;
                        }
                    } else {
                        if(fasta.getX(list.get(i-1).end-2) == C || fasta.getX(list.get(i-1).end-2) == T){
                            baseState = 186;
                        } else {
                            baseState = 202;
                        }
                    }
                    for(int pos=0;pos<9;pos++){
                        states[list.get(i-1).end+pos] = baseState+pos;
                    }
                    for(int pos=0;pos<6;pos++){
                        states[list.get(i).start-7+pos] = baseState+10+pos;
                    }
                    Arrays.fill(states,list.get(i-1).end+9,list.get(i).start-7,baseState+9);
                }
            }
            if(i==0 && strand.equals("+")){
                Arrays.fill(states,list.get(i).start+2,list.get(i).end,1);
            } else if(i == list.size()-1 && strand.equals("-")){
                Arrays.fill(states,list.get(i).start-1,list.get(i).end-3,1);
            } else {
                Arrays.fill(states,list.get(i).start-1,list.get(i).end,1);
            }
            exonCount += list.get(i).end - list.get(i).start + 1;
        }
        /* assign states to codons */
        int[] codon = new int[3];
        int[] positions = new int[3];
        int codonPosition = 0;
        for(int i=0;i<states.length;i++){
            if(states[i] == 1){
                codon[codonPosition] = fasta.getX(i);
                positions[codonPosition] = i;
                codonPosition = (codonPosition + 1) % 3;
                if(codonPosition == 0){
                    if(strand.equals("+")){
                        if(codon[0] == T){
                            if(codon[1]==A){
                                if(codon[2]==G || codon[2]==A){
                                    System.out.println("Invalid sequence:"+fasta.getName()+" codon TAA/TAG in CDS pos:" + i);
                                    return null;
                                }
                                states[positions[0]] = 4;
                                states[positions[1]] = 6;
                                states[positions[2]] = 9;
                            } else {
                                states[positions[0]] = 4;
                                states[positions[1]] = 7;
                                states[positions[2]] = 10;
                            }
                        } else {
                            states[positions[0]] = 5;
                            states[positions[1]] = 8;
                            states[positions[2]] = 11;
                        }
                    } else {
                        if(codon[0] == C || codon[0] == T){
                            if(codon[1]==T){
                                if(codon[2]==A){
                                    System.out.println("Invalid sequence:"+fasta.getName()+" codon TAA/TAG in CDS pos:"+i);
                                    return null;
                                }
                                states[positions[0]] = 114;
                                states[positions[1]] = 116;
                                states[positions[2]] = 119;
                            } else {
                                states[positions[0]] = 114;
                                states[positions[1]] = 117;
                                states[positions[2]] = 120;
                            }
                        } else {
                            states[positions[0]] = 115;
                            states[positions[1]] = 118;
                            states[positions[2]] = 121;
                        }
                    }
                }
            }
        }
        
        /* assign states to start and stop codons */
        for(int i=0;i<list.size();i++){
            GTFElement el = list.get(i);
            if(i == 0){
                if(el.strand.equals("+")){
                    if(fasta.getX(el.start-1) != A || fasta.getX(el.start) != T || (fasta.getX(el.start+1) != G && fasta.getX(el.start+1) != A)){
                        System.out.println("Invalid sequence:"+fasta.getName()+" invalid plus start codon ATG/A" );
                        return null;
                    }
                    
                    
                    states[el.start-1] = 1;
                    states[el.start] = 2;
                    states[el.start+1] = 3;
                } else {
                    if((fasta.getX(el.start-4) != T && fasta.getX(el.start-4) != C)|| fasta.getX(el.start-3) != T || fasta.getX(el.start-2) != A){
                        System.out.println("Invalid sequence:"+fasta.getName()+" invalid minus stop TAA/TAG" );
                        return null;
                    }
                    states[el.start-4] = 111;
                    states[el.start-3] = 112;
                    states[el.start-2] = 113;
                }
            }
            
            if(i == list.size()-1){
                if(el.strand.equals("+")){
                    if(fasta.getX(el.end) != T || fasta.getX(el.end+1) != A || (fasta.getX(el.end+2) != G && fasta.getX(el.end+2) != A)){
                        System.out.println("Invalid sequence:"+fasta.getName()+" invalid plus stop codon TAA/TAG" );
                        
                        return null;
                    }
                    
                    states[el.end] = 108;
                    states[el.end+1] = 109;
                    states[el.end+2] = 110;
                } else {
                    if((fasta.getX(el.end-3) != C && fasta.getX(el.end-3) != T)|| fasta.getX(el.end-2) != A || fasta.getX(el.end-1) != T){
                        System.out.println("Invalid sequence:"+fasta.getName()+" invalid minus start codon ATG/A" );
                        return null;
                    }
                   
                    
                    states[el.end-3] = 218;
                    states[el.end-2] = 219;
                    states[el.end-1] = 220;
                }
            }
        }
        
        return states;
    }
    
    private boolean isInIntronPosition(int pos, ArrayList<GTFElement> list){
        for(int i=1;i<list.size();i++){
            if(list.get(i-1).end < pos && pos < list.get(i).start) return true;
        }
        return false;
    }
    
    private boolean isInExonPosition(int pos, ArrayList<GTFElement> list){
        for(GTFElement el:list){
            if(el.start <= pos && pos <= el.end) return true;
        }
        return false;
    }
    
    private void incrementStats(InputSequence<Integer> input){
        InputSequence<Integer> exonerate = input.getComponent(2);
        InputSequence<Integer> weasel = input.getComponent(3);
        
        int TP = 0;
        int FP = 1;
        int TN = 2;
        int FN = 3;
        
        ArrayList<GTFElement> list = cdsMap.get(input.getName());
        
        for(GTFElement el:list){
            totalExonLength += el.end - el.start + 1;
            exonCount++;
        }
        
        for(int i=1;i<list.size();i++){
            totalIntronLength += list.get(i).end - list.get(i-1).start + 1; 
            intronCount++;
        }
        
        int strand = list.get(0).strand.equals("+")?0:1;
        if(strand == 0) plusCount++;
        else minusCount++;
        for(int pos=0;pos<input.length();pos++){
            int exonerateState = exonerate.getX(pos);
            int weaselState = weasel.getX(pos);
            if(exonerateState == 0 || exonerateState == 3 || exonerateState == 4){
                if(isInExonPosition(pos+1,list)) {
                    exonerateCodingStats[FN]++;
                } else exonerateCodingStats[TN]++;
                
            } else if(exonerateState == 1 || exonerateState == 2){
                if(exonerateState != strand+1){ //different strand
                    exonerateCodingStats[FP]++;
                } else {
                    if(isInExonPosition(pos+1,list)){
                        exonerateCodingStats[TP]++;
                    } else {
                        exonerateCodingStats[FP]++;
                    }
                }
            }
            
            if(exonerateState == 0 || exonerateState == 1 || exonerateState == 2){
                if(isInIntronPosition(pos+1,list)) {
                    exonerateIntronStats[FN]++;
                } else exonerateIntronStats[TN]++;
                
            } else if(exonerateState == 3 || exonerateState == 4){
                if(exonerateState != strand+3){ //different strand
                    exonerateIntronStats[FP]++;
                } else {
                    if(isInIntronPosition(pos+1,list)){
                        exonerateIntronStats[TP]++;
                    } else {
                        exonerateIntronStats[FP]++;
                    }
                }
            }
            
            if(weaselState == 0){
                if(isInIntronPosition(pos+1,list)) {
                    weaselStats[FN]++;
                } else weaselStats[TN]++;
                
            } else if(weaselState == 1 || weaselState == 2){
                if(weaselState != strand+1){ //different strand
                    weaselStats[FP]++;
                } else {
                    if(isInIntronPosition(pos+1,list)){
                        weaselStats[TP]++;
                    } else {
                        weaselStats[FP]++;
                    }
                }
            }
        }
    }
    
    public String summarize(){
        return 
        "---Avg lengths---\n"+
        "Exon:" + (double)totalExonLength/exonCount + " "+ totalExonLength+ "/" + exonCount + "\n" + 
        "Intron:" + (double)totalIntronLength/intronCount + " "+ totalIntronLength+ "/" + intronCount+  "\n" + 
        "---Strand counts---\n"+
        "Plus:" + plusCount + "\n" + 
        "Minus:" + minusCount + "\n" + 
        "---Input evidence stats---\n"+
        "Exonerate coding:" + Util.printStats(exonerateCodingStats) + 
        "Exonerate intron:" + Util.printStats(exonerateIntronStats) + 
        "RNAWeasel:" + Util.printStats(weaselStats); 
    }
}

