package mlproject.hmm;

import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

import java.util.Arrays;

import java.util.List;

import mlproject.io.InputSequence;
import mlproject.io.TrainingSequence;


import mlproject.util.Util;

public class EvidenceHMM implements Serializable{
    private double transitionPseudoCount = 1;
    private double transitionCoeficient = 1;
    
    private StateModel stateModel;
    
    boolean[][] transitionMatrix;
    
    String[] stateNames;
    
    int[][] backTransitions; // for each state a list of transitions that end in this state
    double[][] backTransitionLogProbs;
    
    
    private EmissionModel emissionModel;
    
    
    public EvidenceHMM(StateModel stateModel, EmissionModel emModel) {
        this.stateModel = stateModel;
        emissionModel = emModel;
        
        transitionMatrix = stateModel.getTransitionMatrix();
        
        backTransitions = new int[transitionMatrix.length][];
        backTransitionLogProbs = new double[transitionMatrix.length][];
        
        /* calculate back transitions from transition matrix*/
        for(int i=0;i<transitionMatrix.length;i++){ 
            int count = 0;
            for(int j=0;j<transitionMatrix.length;j++){
                if(transitionMatrix[j][i]) count++;
            }
            backTransitions[i] = new int[count];
            backTransitionLogProbs[i] = new double[count];
            
            count = 0;
            for(int j=0;j<transitionMatrix.length;j++){
                if(transitionMatrix[j][i]){ 
                    backTransitions[i][count] = j;
                    count++;
                }
            }
        }
    }
    
    /* read model from a file */
    public static EvidenceHMM read(String location) throws IOException, ClassNotFoundException {
        ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream(location)));
        EvidenceHMM result =  (EvidenceHMM)in.readObject();
        in.close();
        return result;
    }
    
    /* write model to a file */
    public void write(String location) throws IOException, ClassNotFoundException {
        ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(location));
        out.writeObject(this);
        out.close();
    }
    
    public void train(List<TrainingSequence> trainingSequences){
        trainTransitions(trainingSequences);
        emissionModel.train(stateModel,trainingSequences);
    }

    private void trainTransitions(List<TrainingSequence> trainingSequences){
        double[][] counts = new double[transitionMatrix.length][transitionMatrix.length];
        
        for(int i=0;i<counts.length;i++){
            for(int j=0;j<counts.length;j++){
                if(transitionMatrix[i][j]){
                    counts[i][j] += transitionPseudoCount; //add pseudocount
                }
            }
        }
        
        for(TrainingSequence ti:trainingSequences){
            for(int i=0;i<ti.length()-1;i++){
                counts[ti.getY(i)][ti.getY(i+1)]++; //count transitions
            }
        }
        
        /* fix joined transision intron edges to prevent intron jumping, was too lazy to put it into state model */
        
        double fix = 0;
        double fix2 = 0;
        double fix3 = 0;
        double fix4 = 0;
        double fix5 = 0;
        double fix6 = 0;
        double fix7 = 0;
        double fix8 = 0;
        for(int i=0;i<6;i++){ // do for each +/- intron state
            fix += counts[17+i*16][17+i*16];
            fix2 += counts[17+i*16][18+i*16];
            fix3 += counts[131+i*16][131+i*16];
            fix4 += counts[131+i*16][132+i*16];
            
            fix5 += counts[16+i*16][17+i*16];
            fix6 += counts[16+i*16][18+i*16];
            fix7 += counts[130+i*16][131+i*16];
            fix8 += counts[130+i*16][132+i*16];
        }
        for(int i=0;i<6;i++){
            counts[17+i*16][17+i*16] = fix;;
            counts[17+i*16][18+i*16] = fix2;
            counts[131+i*16][131+i*16] = fix3;
            counts[131+i*16][132+i*16] = fix4;
            counts[16+i*16][17+i*16] = fix5;
            counts[16+i*16][18+i*16] = fix6;
            counts[130+i*16][131+i*16] = fix7;
            counts[130+i*16][132+i*16] = fix8;
        }
        
        
        
        
        
        /* convert counts to log probs */
            
        for(int i=0;i<counts.length;i++){
            double rowCount = 0;
            for(int j=0;j<counts.length;j++){
                rowCount += counts[i][j];
            }
            for(int j=0;j<counts.length;j++){
                if(transitionMatrix[i][j]){
                    counts[i][j] = Math.log(counts[i][j]/rowCount);
                }
            }
        }
        
        /* copy values to backTransitionLogProbs */
        
        for(int i=0;i<counts.length;i++){
            int count = 0;
            for(int j=0;j<counts.length;j++){
                if(transitionMatrix[j][i]){
                    backTransitionLogProbs[i][count] = counts[j][i]*transitionCoeficient;
                    count++;
                }
            }
        }
        
                                                         
        
        
    }
    
    
    /* Viterbi */ 
    public int[] infereBestPath(InputSequence input){
        /* main variables */
        int length = input.length();
        int nStates = backTransitions.length;
        double[][] logProbs = new double[length][nStates]; // calculation space
        int[][] solution = new int[length][nStates]; //solution path; [0][] is unused but better than  always decrement
        
        /* tmp variables */
        double[] emissions = new double[nStates];
        double tmpLogProb = Double.NEGATIVE_INFINITY;
        double bestLogProb = Double.NEGATIVE_INFINITY;
        int bestState = -1;
        
        /* calculate best path */
        
        // start in intergenic or plus start codon / minus stop
        Arrays.fill(logProbs[0],Double.NEGATIVE_INFINITY);
        for(int i=0;i<stateModel.getIntergenicStates().length;i++){
            logProbs[0][stateModel.getIntergenicStates()[i]] = emissionModel.emissionLogProb(stateModel.getIntergenicStates()[i],0,input);    
        }
        logProbs[0][stateModel.getPlusStartCodonStates()[0]] = emissionModel.emissionLogProb(stateModel.getPlusStartCodonStates()[0],0,input);
        logProbs[0][stateModel.getMinusStopCodonStates()[0]] = emissionModel.emissionLogProb(stateModel.getMinusStopCodonStates()[0],0,input);
        
        
        
        for(int pos=1;pos<length;pos++){
            for(int j=0;j<nStates;j++){
                emissions[j] = emissionModel.emissionLogProb(j,pos,input);
            }
            
            for(int j=0;j<nStates;j++){
                bestLogProb = Double.NEGATIVE_INFINITY;
                bestState = -1;
                
                /* find best previous state */
                for(int k=0;k<backTransitions[j].length;k++){
                    tmpLogProb = logProbs[pos-1][backTransitions[j][k]] + backTransitionLogProbs[j][k] + emissions[j];
                    if(tmpLogProb >= bestLogProb){
                        bestLogProb = tmpLogProb;
                        bestState = backTransitions[j][k];
                    }
                }
                logProbs[pos][j] = bestLogProb;
                solution[pos][j] = bestState;
            }
        }
        /* get the solution path*/
        int[] res = new int[length];
        int previous;
        
        bestLogProb = Double.NEGATIVE_INFINITY;
        bestState = -1;
        for(int i=0;i<nStates;i++){
            if(!Util.containsElement(stateModel.getIntergenicStates(),i) && i != stateModel.getPlusStopCodonStates()[2] &&
                i != stateModel.getMinusStartCodonStates()[2]) continue;
            if(logProbs[length-1][i] >= bestLogProb){
                bestLogProb = logProbs[length-1][i];
                bestState = i;
            }
        }
        
        
        res[length-1] = bestState;
        previous = solution[length-1][bestState];
        
        
        // follow the path until start state
        for(int i=length-2;i>=0;i--){
            res[i] = previous;
            previous = solution[i][previous];
        }
        
        return res;
    }
    

    public void setTransitionPseudoCount(double pseudoCount) {
        this.transitionPseudoCount = pseudoCount;
    }

    public double getTransitionPseudoCount() {
        return transitionPseudoCount;
    }

    public void setTransitionCoeficient(double transitionCoeficient) {
        this.transitionCoeficient = transitionCoeficient;
    }

    public double getTransitionCoeficient() {
        return transitionCoeficient;
    }
}
