package mlproject.hmm;

import java.util.Arrays;
import java.util.HashMap;

import java.util.List;

import mlproject.io.InputSequence;

import mlproject.io.TrainingSequence;

import mlproject.phylo.EvolutionaryModel;

public class EvidenceEmissionModel extends EmissionModel {
    /* Input sequence components
     * 0-specie sequence
     * 1-alignment
     * 2-exonerate
     * 3-rnaweasel
     */
    
    private int pseudoCount = 1;
    private int emissionModelType = 1;
    
    private double[][] stateLogProbs;
    private double[][] exonerateLogProbs;
    private double[][] rnaweaselLogProbs;
    private int nStates; //number of states
    private StateModel stateModel;
    
    EvolutionaryModel emodelIntergenic;
    EvolutionaryModel emodelIntronic;
    List<EvolutionaryModel> emodelExonic;
    
    EvolutionaryModel[] emodelStateMapping; //which emodel to use for each state
    
    public EvidenceEmissionModel() {
    }

    public double emissionLogProb(int state, int pos, InputSequence input) {
        int nucleotide = (Integer)input.getComponent(0).getX(pos);
        if(emissionModelType == 1){
            return  (nucleotide==4)?0.0:stateLogProbs[state][nucleotide]+ //nucleotide 4 = N
                emodelStateMapping[state].getColumnLogProbability((String)input.getComponent(1).getX(pos),state >= stateModel.getFirstMinusState())+
                exonerateLogProbs[state][(Integer)input.getComponent(2).getX(pos)]+
                rnaweaselLogProbs[state][(Integer)input.getComponent(3).getX(pos)];
        } else  if(emissionModelType == 2){
            return  (nucleotide==4)?0.0:stateLogProbs[state][nucleotide]+ //nucleotide 4 = N
                emodelStateMapping[state].getColumnLogProbability((String)input.getComponent(1).getX(pos),state >= stateModel.getFirstMinusState())+
                //exonerateLogProbs[state][(Integer)input.getComponent(2).getX(pos)]+
                rnaweaselLogProbs[state][(Integer)input.getComponent(3).getX(pos)];
        } else  if(emissionModelType == 3){
            return  (nucleotide==4)?0.0:stateLogProbs[state][nucleotide]+ //nucleotide 4 = N
                //emodelStateMapping[state].getColumnLogProbability((String)input.getComponent(1).getX(pos),state >= stateModel.getFirstMinusState())+
                exonerateLogProbs[state][(Integer)input.getComponent(2).getX(pos)]+
                rnaweaselLogProbs[state][(Integer)input.getComponent(3).getX(pos)];
        } else  if(emissionModelType == 4){
            return  (nucleotide==4)?0.0:stateLogProbs[state][nucleotide]+ //nucleotide 4 = N
                emodelStateMapping[state].getColumnLogProbability((String)input.getComponent(1).getX(pos),state >= stateModel.getFirstMinusState())+
                exonerateLogProbs[state][(Integer)input.getComponent(2).getX(pos)];
                //rnaweaselLogProbs[state][(Integer)input.getComponent(3).getX(pos)];
        } else  if(emissionModelType == 5){
            return  (nucleotide==4)?0.0:stateLogProbs[state][nucleotide]+ //nucleotide 4 = N
                //emodelStateMapping[state].getColumnLogProbability((String)input.getComponent(1).getX(pos),state >= stateModel.getFirstMinusState())+
                exonerateLogProbs[state][(Integer)input.getComponent(2).getX(pos)];
                //rnaweaselLogProbs[state][(Integer)input.getComponent(3).getX(pos)];
        } else  if(emissionModelType == 6){
            return  (nucleotide==4)?0.0:stateLogProbs[state][nucleotide]+ //nucleotide 4 = N
                emodelStateMapping[state].getColumnLogProbability((String)input.getComponent(1).getX(pos),state >= stateModel.getFirstMinusState());
                //exonerateLogProbs[state][(Integer)input.getComponent(2).getX(pos)]+
                //rnaweaselLogProbs[state][(Integer)input.getComponent(3).getX(pos)];
        } else  if(emissionModelType == 7){
            return  (nucleotide==4)?0.0:stateLogProbs[state][nucleotide]+ //nucleotide 4 = N
                //emodelStateMapping[state].getColumnLogProbability((String)input.getComponent(1).getX(pos),state >= stateModel.getFirstMinusState())+
                //exonerateLogProbs[state][(Integer)input.getComponent(2).getX(pos)]+
                rnaweaselLogProbs[state][(Integer)input.getComponent(3).getX(pos)];
        } else {
            return  (nucleotide==4)?0.0:stateLogProbs[state][nucleotide]; //nucleotide 4 = N
                //emodelStateMapping[state].getColumnLogProbability((String)input.getComponent(1).getX(pos),state >= stateModel.getFirstMinusState())+
                //exonerateLogProbs[state][(Integer)input.getComponent(2).getX(pos)]+
                //rnaweaselLogProbs[state][(Integer)input.getComponent(3).getX(pos)];
        }
        
    }


    public void train(StateModel stateModel, List<TrainingSequence> trainingSequences) {
        this.stateModel = stateModel;
        this.nStates = stateModel.getNStates();
        stateLogProbs = new double[nStates][4];
        exonerateLogProbs = new double[nStates][5];
        rnaweaselLogProbs = new double[nStates][3];
        
        
        initializeEvolutionaryModels();
        /* add pseudocounts */
        for(int i=0;i<nStates;i++){
            Arrays.fill(stateLogProbs[i],pseudoCount);
            Arrays.fill(exonerateLogProbs[i],pseudoCount);
            Arrays.fill(rnaweaselLogProbs[i],pseudoCount);
        }
        
        /* count emissions */
        for(TrainingSequence ti:trainingSequences){
            InputSequence input = ti.getInputSequence();
            for(int i=0;i<ti.length()-1;i++){
                if((Integer)input.getComponent(0).getX(i) < 4)
                    stateLogProbs[ti.getY(i)][(Integer)input.getComponent(0).getX(i)]++;
                exonerateLogProbs[ti.getY(i)][(Integer)input.getComponent(2).getX(i)]++;
                rnaweaselLogProbs[ti.getY(i)][(Integer)input.getComponent(3).getX(i)]++;
            }
        }
        fixJoinedEmissions();
        fixCounts();
        
        for(int i=0;i<nStates;i++){
            double stateRowCount = 0;
            double exonerateRowCount = 0;
            double weaselRowCount = 0;
            /* get total counts */
            for(int j=0;j<4;j++){
                stateRowCount += stateLogProbs[i][j];
            }
            for(int j=0;j<3;j++){
                weaselRowCount += rnaweaselLogProbs[i][j];
            }
            for(int j=0;j<5;j++){
                exonerateRowCount += exonerateLogProbs[i][j];
            }
            /* calculate logprobs */
            for(int j=0;j<4;j++){
                if(stateLogProbs[i][j] == 0)
                    stateLogProbs[i][j] = Double.NEGATIVE_INFINITY;
                else
                    stateLogProbs[i][j] = Math.log(stateLogProbs[i][j]/stateRowCount);
            }
            for(int j=0;j<3;j++){
                rnaweaselLogProbs[i][j] = Math.log(rnaweaselLogProbs[i][j]/weaselRowCount);
            }
            for(int j=0;j<5;j++){
                exonerateLogProbs[i][j] = Math.log(exonerateLogProbs[i][j]/exonerateRowCount);
            }
        }
    }
    
    /* map each state to one evol. model*/
    private void initializeEvolutionaryModels(){
        emodelStateMapping = new EvolutionaryModel[nStates];
        for(int i=0;i<stateModel.getIntergenicStates().length;i++){
            emodelStateMapping[stateModel.getIntergenicStates()[i]]=emodelIntergenic;
        }
        for(int i=0;i<stateModel.getPlusIntronicStates().length;i++){
            emodelStateMapping[stateModel.getPlusIntronicStates()[i]]=emodelIntronic;
        }
        for(int i=0;i<stateModel.getMinusIntronicStates().length;i++){
            emodelStateMapping[stateModel.getMinusIntronicStates()[i]]=emodelIntronic;
        }
        for(int i=0;i<stateModel.getPlusExon0States().length;i++){
            emodelStateMapping[stateModel.getPlusExon0States()[i]]=emodelExonic.get(0);
        }
        for(int i=0;i<stateModel.getMinusExon0States().length;i++){
            emodelStateMapping[stateModel.getMinusExon0States()[i]]=emodelExonic.get(0);
        }
        for(int i=0;i<stateModel.getPlusExon1States().length;i++){
            emodelStateMapping[stateModel.getPlusExon1States()[i]]=emodelExonic.get(1);
        }
        for(int i=0;i<stateModel.getMinusExon1States().length;i++){
            emodelStateMapping[stateModel.getMinusExon1States()[i]]=emodelExonic.get(1);
        }
        for(int i=0;i<stateModel.getPlusExon2States().length;i++){
            emodelStateMapping[stateModel.getPlusExon2States()[i]]=emodelExonic.get(2);
        }
        for(int i=0;i<stateModel.getMinusExon2States().length;i++){
            emodelStateMapping[stateModel.getMinusExon2States()[i]]=emodelExonic.get(2);
        }
    }
    
    /* fix counts of constrained states */
    private void fixCounts(){
        int[][] constraints = stateModel.getFixedStateNucleotideCounts();
        for(int i=0;i<nStates;i++){
            for(int j=0;j<4;j++){
                if(constraints[i][j] == StateModel.COUNTS_ZERO){
                    stateLogProbs[i][j] = 0;
                }
                if(constraints[i][j] == StateModel.COUNTS_ALL){
                    stateLogProbs[i][j] = 1;
                    for(int k=0;k<4;k++){
                        if(j == k) continue;
                        stateLogProbs[i][k] = 0;
                    }
                }
            }
        }
    }
    
    /* fix counts of states that should have same emissions */
    private void fixJoinedEmissions(){
        int[][] joined = stateModel.getJoinedEmissionStates();
        for(int i=0;i<joined.length;i++){
            double[] counts = new double[4];
            double[] exonerateCounts = new double[5];
            double[] weaselCounts = new double[3];
            /* count counts*/
            for(int j=0;j<joined[i].length;j++){
                for(int k=0;k<4;k++)
                    counts[k] += stateLogProbs[joined[i][j]][k];
                for(int k=0;k<5;k++)
                    exonerateCounts[k] += exonerateLogProbs[joined[i][j]][k];
                for(int k=0;k<3;k++)
                    weaselCounts[k] += rnaweaselLogProbs[joined[i][j]][k];
            }
            /*  set counted counts */
            for(int j=0;j<joined[i].length;j++){
                for(int k=0;k<4;k++)        
                    stateLogProbs[joined[i][j]][k] = counts[k];
                for(int k=0;k<5;k++)
                    exonerateLogProbs[joined[i][j]][k]  = exonerateCounts[k];
                for(int k=0;k<3;k++)
                    rnaweaselLogProbs[joined[i][j]][k] = weaselCounts[k];
            }
        }
    }

    public void setEmodelIntergenic(EvolutionaryModel emodelIntergenic) {
        this.emodelIntergenic = emodelIntergenic;
    }

    public EvolutionaryModel getEmodelIntergenic() {
        return emodelIntergenic;
    }

    public void setEmodelIntronic(EvolutionaryModel emodelIntronic) {
        this.emodelIntronic = emodelIntronic;
    }

    public EvolutionaryModel getEmodelIntronic() {
        return emodelIntronic;
    }

    public void setEmodelExonic(List<EvolutionaryModel> emodelExonic) {
        this.emodelExonic = emodelExonic;
    }

    public List<EvolutionaryModel> getEmodelExonic() {
        return emodelExonic;
    }

    public void setEmissionModelType(int emissionModel) {
        this.emissionModelType = emissionModel;
    }

    public int getEmissionModelType() {
        return emissionModelType;
    }
}
