/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver;

import calhoun.analysis.crf.CRFInference;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.SemiMarkovSetup;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.solver.SemiMarkovViterbi;
import calhoun.analysis.crf.solver.check.FeatureCalculator;
import calhoun.analysis.crf.solver.check.TransitionInfo;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import calhoun.util.DenseIntMatrix2D;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class SemiMarkovViterbiNoCache
implements CRFInference {
    private static final Log log = LogFactory.getLog(SemiMarkovViterbi.class);
    boolean debug = log.isDebugEnabled();
    private boolean allPaths;
    private double[] bestScore;
    private int[] backPointers;
    DenseIntMatrix2D backLengths;
    short[] maxStateLengths;
    boolean ignoreSemiMarkovSelfTransitions;
    int maxLookback = 1;
    int nStates;
    TransitionInfo transitions;
    int[] selfTransitions;

    public boolean isAllPaths() {
        return this.allPaths;
    }

    public void setAllPaths(boolean allPaths) {
        this.allPaths = allPaths;
    }

    public void setSemiMarkovSetup(SemiMarkovSetup setup) {
        this.maxStateLengths = setup.getMaxLengths();
        this.ignoreSemiMarkovSelfTransitions = setup.isIgnoreSemiMarkovSelfTransitions();
    }

    @Override
    public CRFInference.InferenceResult predict(ModelManager fm, InputSequence<?> seq, double[] lambda) {
        int state;
        this.nStates = fm.getNumStates();
        int len = seq.length();
        if (this.maxStateLengths == null) {
            log.info((Object)"No state lengths set - standard viterbi search.");
            this.maxStateLengths = new short[this.nStates];
            Arrays.fill(this.maxStateLengths, (short)1);
        }
        for (short i : this.maxStateLengths) {
            this.maxLookback = Math.max(i, this.maxLookback);
        }
        this.transitions = new LengthTransitionInfo(fm, this.allPaths, this.maxStateLengths, this.ignoreSemiMarkovSelfTransitions);
        int[][] statePotentials = this.getStatePotentials(this.transitions);
        this.selfTransitions = new int[this.nStates];
        for (int i = 0; i < this.nStates; ++i) {
            this.selfTransitions[i] = this.transitions.transitionIndex.getQuick(i, i);
        }
        FeatureCalculator calc = new FeatureCalculator(fm, lambda, this.transitions);
        LinkedList<double[]> mis = new LinkedList<double[]>();
        LinkedList<double[]> stableStates = new LinkedList<double[]>();
        double[] nextMi = new double[this.transitions.nTransitions];
        double[] Ri = new double[this.nStates];
        this.bestScore = new double[seq.length() * this.nStates];
        this.backPointers = new int[seq.length() * this.nStates];
        int[] localBackLengths = new int[seq.length() * this.nStates];
        for (int pos = 0; pos < len; ++pos) {
            calc.computeSparseMi(seq, pos, nextMi, Ri);
            if (pos == 0) {
                stableStates.add(Ri);
                Ri = new double[this.nStates];
            } else {
                this.updateStableBuffer(stableStates, nextMi);
                nextMi = this.updateMiBuffer(mis, nextMi);
            }
            double[] latestStable = stableStates.getFirst();
            for (state = 0; state < this.nStates; ++state) {
                int[] transitionPotentials = statePotentials[state];
                int lookbackSize = this.maxStateLengths[state];
                double max = Double.NEGATIVE_INFINITY;
                int bestLookback = -1;
                int bestPrevState = -2;
                Iterator miIter = mis.iterator();
                Iterator stableIter = stableStates.iterator();
                for (int lookback = 0; lookback < lookbackSize; ++lookback) {
                    int startPos = pos - lookback;
                    if (startPos == 0) {
                        Assert.a(mis.size() == lookback, "More Mi matrices in history than there are previous positions in the sequence.");
                        double current = latestStable[state];
                        calc.result.evaluateNodeLength(seq, pos, lookback + 1, state);
                        double lengthCost = calc.calcRet(false);
                        current += lengthCost;
                        if (!(current > max)) break;
                        max = current;
                        bestLookback = lookback;
                        bestPrevState = -1;
                        break;
                    }
                    double nodeLengthCost = Double.NaN;
                    double[] lookbackMi = (double[])miIter.next();
                    double[] lookbackStable = (double[])stableIter.next();
                    for (int transition : transitionPotentials) {
                        double current;
                        double transitionCost;
                        int prevState = this.transitions.transitionFrom[transition];
                        if (prevState == state && lookbackSize > 1 || Double.isInfinite(transitionCost = lookbackMi[transition])) continue;
                        if (Double.isNaN(nodeLengthCost)) {
                            nodeLengthCost = calc.calcNodeLengthValue(seq, pos, lookback + 1, state);
                        }
                        if ((current = this.bestScore[this.nStates * (pos - (lookback + 1)) + prevState]) == Double.NEGATIVE_INFINITY) continue;
                        double stable = latestStable[state] - lookbackStable[state];
                        calc.result.evaluateEdgeLength(seq, pos, lookback + 1, prevState, state);
                        double lengthCost = nodeLengthCost + calc.calcRet(false);
                        current += transitionCost + stable + lengthCost;
                        if (!(current > max)) continue;
                        max = current;
                        bestLookback = lookback;
                        bestPrevState = prevState;
                    }
                }
                int index = pos * this.nStates + state;
                this.bestScore[index] = max;
                this.backPointers[index] = bestPrevState;
                localBackLengths[index] = bestLookback + 1;
            }
        }
        int[] ret = new int[len];
        int pos = len - 1;
        state = ColtUtil.maxInColumn(this.bestScore, this.nStates, len - 1);
        Assert.a(state != -2, "No valid paths");
        while (pos >= 0) {
            int stateLen = localBackLengths[pos * this.nStates + state];
            int prevState = this.backPointers[pos * this.nStates + state];
            for (int i = 0; i < stateLen; ++i) {
                ret[pos] = state;
                --pos;
            }
            state = prevState;
        }
        Assert.a(pos == -1);
        CRFInference.InferenceResult inferenceResult = new CRFInference.InferenceResult();
        inferenceResult.hiddenStates = ret;
        inferenceResult.bestScores = new double[this.nStates];
        System.arraycopy(this.bestScore, this.nStates * (len - 1), inferenceResult.bestScores, 0, this.nStates);
        return inferenceResult;
    }

    void updateStableBuffer(LinkedList<double[]> stableStates, double[] nextMi) {
        double[] stableState = stableStates.size() > this.maxLookback ? stableStates.removeLast() : new double[this.nStates];
        double[] prevState = stableStates.getFirst();
        for (int ix = 0; ix < this.nStates; ++ix) {
            int trans;
            if (this.maxStateLengths[ix] <= 1 || (trans = this.selfTransitions[ix]) == -1) continue;
            stableState[ix] = Double.isInfinite(nextMi[trans]) ? prevState[ix] : prevState[ix] + nextMi[trans];
        }
        stableStates.addFirst(stableState);
    }

    double[] updateMiBuffer(LinkedList<double[]> mis, double[] nextMi) {
        mis.addFirst(nextMi);
        nextMi = mis.size() > this.maxLookback ? mis.removeLast() : new double[this.transitions.nTransitions];
        return nextMi;
    }

    int[][] getStatePotentials(TransitionInfo transitions) {
        int[][] statePotentials = new int[this.nStates][];
        int currentState = -1;
        ArrayList<Integer> currentList = null;
        for (int n : transitions.orderedPotentials) {
            if (n < this.nStates) {
                if (currentState != -1) {
                    statePotentials[currentState] = this.toIntArray((List<Integer>)currentList);
                }
                currentState = n;
                currentList = new ArrayList<Integer>();
                continue;
            }
            currentList.add(n - this.nStates);
        }
        statePotentials[currentState] = this.toIntArray(currentList);
        return statePotentials;
    }

    int[] toIntArray(List<Integer> list) {
        int[] ret = new int[list.size()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = list.get(i);
        }
        return ret;
    }

    class LengthTransitionInfo
    extends TransitionInfo {
        short[] localMaxStateLengths;

        LengthTransitionInfo(ModelManager fm, boolean allPaths, short[] maxStateLengths, boolean ignoreSemiMarkovSelfTransitions) {
            this.localMaxStateLengths = maxStateLengths;
            this.ignoreSemiMarkovSelf = ignoreSemiMarkovSelfTransitions;
            this.initTrans(fm, allPaths);
        }

        @Override
        protected boolean allowSelf(int state) {
            return this.localMaxStateLengths[state] > 1;
        }
    }
}

