/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver;

import calhoun.analysis.crf.CRFInference;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.solver.check.FeatureCalculator;
import calhoun.analysis.crf.solver.check.TransitionInfo;
import calhoun.util.ColtUtil;
import calhoun.util.DenseBooleanMatrix2D;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class Viterbi
implements CRFInference {
    Log log = LogFactory.getLog(Viterbi.class);
    boolean debug = false;
    private double[] bestScore;
    private int[] backPointers;
    private boolean allPaths;

    public boolean isAllPaths() {
        return this.allPaths;
    }

    public void setAllPaths(boolean allPaths) {
        this.allPaths = allPaths;
    }

    @Override
    public CRFInference.InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {
        int numStates = fm.getNumStates();
        int len = seq.length();
        DenseBooleanMatrix2D transitions = fm.getLegalTransitions();
        if (this.allPaths || transitions == null) {
            transitions = new DenseBooleanMatrix2D(numStates, numStates);
            transitions.assign(true);
        }
        TransitionInfo t = new TransitionInfo(fm, false);
        FeatureCalculator calc = new FeatureCalculator(fm, lambda, t);
        this.bestScore = new double[numStates * len];
        this.backPointers = new int[numStates * len];
        for (int pos = 0; pos < len; ++pos) {
            int posIndex = pos * numStates;
            for (int state = 0; state < numStates; ++state) {
                int index = posIndex + state;
                double nodeVal = calc.calcNodeValue(seq, pos, state);
                if (pos == 0) {
                    this.bestScore[index] = nodeVal;
                    continue;
                }
                double max = Double.NEGATIVE_INFINITY;
                int prevState = -1;
                for (int k = 0; k < numStates; ++k) {
                    if (!transitions.getQuick(k, state)) continue;
                    double previous = this.bestScore[posIndex - numStates + k];
                    double edge = calc.calcEdgeValue(seq, pos, k, state);
                    double current = previous + edge + nodeVal;
                    if (this.debug) {
                        this.log.debug((Object)String.format("Pos: %d Trans: %d-%d %.2f (Prev: %.2f + Edge: %.2f + Node: %.2f)", pos, k, state, current, previous, edge, nodeVal));
                    }
                    if (!(current > max)) continue;
                    max = current;
                    prevState = k;
                }
                this.bestScore[index] = max;
                this.backPointers[index] = prevState;
            }
        }
        int[] ret = new int[len];
        ret[len - 1] = ColtUtil.maxInColumn(this.bestScore, numStates, len - 1);
        for (int i = len - 1; i > 0; --i) {
            ret[i - 1] = this.backPointers[numStates * i + ret[i]];
        }
        CRFInference.InferenceResult inferenceResult = new CRFInference.InferenceResult();
        inferenceResult.hiddenStates = ret;
        inferenceResult.bestScores = new double[numStates];
        System.arraycopy(this.bestScore, numStates * (len - 1), inferenceResult.bestScores, 0, numStates);
        return inferenceResult;
    }

    public int[] getBackPointers() {
        return this.backPointers;
    }

    public void setBackPointers(int[] backPointers) {
        this.backPointers = backPointers;
    }

    public double[] getBestScore() {
        return this.bestScore;
    }

    public void setBestScore(double[] bestScore) {
        this.bestScore = bestScore;
    }
}

