/*
 * Decompiled with CFR 0.152.
 */
package mlproject.hmm;

import java.io.BufferedInputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import mlproject.hmm.EmissionModel;
import mlproject.hmm.StateModel;
import mlproject.io.InputSequence;
import mlproject.io.TrainingSequence;
import mlproject.util.Util;

public class EvidenceHMM
implements Serializable {
    private double transitionPseudoCount = 1.0;
    private double transitionCoeficient = 1.0;
    private StateModel stateModel;
    boolean[][] transitionMatrix;
    String[] stateNames;
    int[][] backTransitions;
    double[][] backTransitionLogProbs;
    private EmissionModel emissionModel;

    public EvidenceHMM(StateModel stateModel, EmissionModel emModel) {
        this.stateModel = stateModel;
        this.emissionModel = emModel;
        this.transitionMatrix = stateModel.getTransitionMatrix();
        this.backTransitions = new int[this.transitionMatrix.length][];
        this.backTransitionLogProbs = new double[this.transitionMatrix.length][];
        int i = 0;
        while (i < this.transitionMatrix.length) {
            int count = 0;
            int j = 0;
            while (j < this.transitionMatrix.length) {
                if (this.transitionMatrix[j][i]) {
                    ++count;
                }
                ++j;
            }
            this.backTransitions[i] = new int[count];
            this.backTransitionLogProbs[i] = new double[count];
            count = 0;
            int j2 = 0;
            while (j2 < this.transitionMatrix.length) {
                if (this.transitionMatrix[j2][i]) {
                    this.backTransitions[i][count] = j2;
                    ++count;
                }
                ++j2;
            }
            ++i;
        }
    }

    public static EvidenceHMM read(String location) throws IOException, ClassNotFoundException {
        ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(new FileInputStream(location)));
        EvidenceHMM result = (EvidenceHMM)in.readObject();
        in.close();
        return result;
    }

    public void write(String location) throws IOException, ClassNotFoundException {
        ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(location));
        out.writeObject(this);
        out.close();
    }

    public void train(List<TrainingSequence> trainingSequences) {
        this.trainTransitions(trainingSequences);
        this.emissionModel.train(this.stateModel, trainingSequences);
    }

    private void trainTransitions(List<TrainingSequence> trainingSequences) {
        double[][] counts = new double[this.transitionMatrix.length][this.transitionMatrix.length];
        int i = 0;
        while (i < counts.length) {
            int j = 0;
            while (j < counts.length) {
                if (this.transitionMatrix[i][j]) {
                    double[] dArray = counts[i];
                    int n = j;
                    dArray[n] = dArray[n] + this.transitionPseudoCount;
                }
                ++j;
            }
            ++i;
        }
        for (TrainingSequence ti : trainingSequences) {
            int i2 = 0;
            while (i2 < ti.length() - 1) {
                double[] dArray = counts[ti.getY(i2)];
                int n = ti.getY(i2 + 1);
                dArray[n] = dArray[n] + 1.0;
                ++i2;
            }
        }
        double fix = 0.0;
        double fix2 = 0.0;
        double fix3 = 0.0;
        double fix4 = 0.0;
        double fix5 = 0.0;
        double fix6 = 0.0;
        double fix7 = 0.0;
        double fix8 = 0.0;
        int i3 = 0;
        while (i3 < 6) {
            fix += counts[17 + i3 * 16][17 + i3 * 16];
            fix2 += counts[17 + i3 * 16][18 + i3 * 16];
            fix3 += counts[131 + i3 * 16][131 + i3 * 16];
            fix4 += counts[131 + i3 * 16][132 + i3 * 16];
            fix5 += counts[16 + i3 * 16][17 + i3 * 16];
            fix6 += counts[16 + i3 * 16][18 + i3 * 16];
            fix7 += counts[130 + i3 * 16][131 + i3 * 16];
            fix8 += counts[130 + i3 * 16][132 + i3 * 16];
            ++i3;
        }
        int i4 = 0;
        while (i4 < 6) {
            counts[17 + i4 * 16][17 + i4 * 16] = fix;
            counts[17 + i4 * 16][18 + i4 * 16] = fix2;
            counts[131 + i4 * 16][131 + i4 * 16] = fix3;
            counts[131 + i4 * 16][132 + i4 * 16] = fix4;
            counts[16 + i4 * 16][17 + i4 * 16] = fix5;
            counts[16 + i4 * 16][18 + i4 * 16] = fix6;
            counts[130 + i4 * 16][131 + i4 * 16] = fix7;
            counts[130 + i4 * 16][132 + i4 * 16] = fix8;
            ++i4;
        }
        int i5 = 0;
        while (i5 < counts.length) {
            double rowCount = 0.0;
            int j = 0;
            while (j < counts.length) {
                rowCount += counts[i5][j];
                ++j;
            }
            int j2 = 0;
            while (j2 < counts.length) {
                if (this.transitionMatrix[i5][j2]) {
                    counts[i5][j2] = Math.log(counts[i5][j2] / rowCount);
                }
                ++j2;
            }
            ++i5;
        }
        int i6 = 0;
        while (i6 < counts.length) {
            int count = 0;
            int j = 0;
            while (j < counts.length) {
                if (this.transitionMatrix[j][i6]) {
                    this.backTransitionLogProbs[i6][count] = counts[j][i6] * this.transitionCoeficient;
                    ++count;
                }
                ++j;
            }
            ++i6;
        }
    }

    public int[] infereBestPath(InputSequence input) {
        int length = input.length();
        int nStates = this.backTransitions.length;
        double[][] logProbs = new double[length][nStates];
        int[][] solution = new int[length][nStates];
        double[] emissions = new double[nStates];
        double tmpLogProb = Double.NEGATIVE_INFINITY;
        double bestLogProb = Double.NEGATIVE_INFINITY;
        int bestState = -1;
        Arrays.fill(logProbs[0], Double.NEGATIVE_INFINITY);
        int i = 0;
        while (i < this.stateModel.getIntergenicStates().length) {
            logProbs[0][this.stateModel.getIntergenicStates()[i]] = this.emissionModel.emissionLogProb(this.stateModel.getIntergenicStates()[i], 0, input) + Math.log(0.3333333333333333);
            ++i;
        }
        logProbs[0][this.stateModel.getPlusStartCodonStates()[0]] = this.emissionModel.emissionLogProb(this.stateModel.getPlusStartCodonStates()[0], 0, input) + Math.log(0.3333333333333333);
        logProbs[0][this.stateModel.getMinusStopCodonStates()[0]] = this.emissionModel.emissionLogProb(this.stateModel.getMinusStopCodonStates()[0], 0, input) + Math.log(0.3333333333333333);
        int pos = 1;
        while (pos < length) {
            int j = 0;
            while (j < nStates) {
                emissions[j] = this.emissionModel.emissionLogProb(j, pos, input);
                ++j;
            }
            int j2 = 0;
            while (j2 < nStates) {
                bestLogProb = Double.NEGATIVE_INFINITY;
                bestState = -1;
                int k = 0;
                while (k < this.backTransitions[j2].length) {
                    tmpLogProb = logProbs[pos - 1][this.backTransitions[j2][k]] + this.backTransitionLogProbs[j2][k] + emissions[j2];
                    if (tmpLogProb >= bestLogProb) {
                        bestLogProb = tmpLogProb;
                        bestState = this.backTransitions[j2][k];
                    }
                    ++k;
                }
                logProbs[pos][j2] = bestLogProb;
                solution[pos][j2] = bestState;
                ++j2;
            }
            ++pos;
        }
        int[] res = new int[length];
        bestLogProb = Double.NEGATIVE_INFINITY;
        bestState = -1;
        int i2 = 0;
        while (i2 < nStates) {
            if ((Util.containsElement(this.stateModel.getIntergenicStates(), i2) || i2 == this.stateModel.getPlusStopCodonStates()[2] || i2 == this.stateModel.getMinusStartCodonStates()[2]) && logProbs[length - 1][i2] >= bestLogProb) {
                bestLogProb = logProbs[length - 1][i2];
                bestState = i2;
            }
            ++i2;
        }
        res[length - 1] = bestState;
        int previous = solution[length - 1][bestState];
        int i3 = length - 2;
        while (i3 >= 0) {
            res[i3] = previous;
            previous = solution[i3][previous];
            --i3;
        }
        return res;
    }

    public void setTransitionPseudoCount(double pseudoCount) {
        this.transitionPseudoCount = pseudoCount;
    }

    public double getTransitionPseudoCount() {
        return this.transitionPseudoCount;
    }

    public void setTransitionCoeficient(double transitionCoeficient) {
        this.transitionCoeficient = transitionCoeficient;
    }

    public double getTransitionCoeficient() {
        return this.transitionCoeficient;
    }
}

