/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.check;

import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.RecyclingBuffer;
import calhoun.analysis.crf.solver.check.FeatureCacheLength;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import calhoun.util.DenseIntMatrix2D;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class CachedSemiCRFGradient
implements CRFObjectiveFunctionGradient {
    private static final Log log = LogFactory.getLog(CachedSemiCRFGradient.class);
    boolean debug = log.isDebugEnabled();
    static final int NORM_FACTOR = 10;
    static final double NORM_MIN = Math.exp(-10.0);
    static final double NORM_MAX = Math.exp(10.0);
    short[] id;
    byte[] potentialIx;
    float[] val;
    int nTransitions;
    short[] transitionFrom;
    short[] transitionTo;
    short[] orderedPotentials;
    boolean[] invalidTransitions;
    DenseIntMatrix2D transitionIndex;
    double[] featureSums;
    int[] starts;
    int[] seqOffsets;
    int nSeqs;
    int nConstantFeatures;
    short[] maxStateLengths;
    short maxLookback;
    int[] lengthStarts;
    short[] lookbacks;
    byte[] lengthPotentials;
    short[] lengthIndexes;
    float[] lengthVals;
    CacheProcessor.StatePotentials[] statesWithLookback;
    CacheProcessor.StatePotentials[] statesWithoutLookback;
    List<? extends TrainingSequence<?>> data;
    ModelManager fm;
    int nFeatures;
    int nStates;
    int nPotentials;
    int iter = 0;
    double[][] alphas;
    int[] alphaNorms;
    double[] starterAlpha;
    RecyclingBuffer<LookbackBuffer> lookbackBuffer;
    LookbackBuffer nextBuffer;
    double[] lambda;
    double[] constMi;
    double logZ;
    int zNorm;
    double zInv;
    double[] expects;
    AlphaLengthFeatureProcessor alphaProcessor;
    BetaLengthFeatureProcessor betaProcessor;
    boolean allPaths;

    double exp(double argVal) {
        return Math.exp(argVal);
    }

    double log(double argVal) {
        return Math.log(argVal);
    }

    public CachedSemiCRFGradient(short[] maxStateLengths, boolean allPaths) {
        this.maxStateLengths = maxStateLengths;
        this.allPaths = allPaths;
    }

    @Override
    public void clean() {
    }

    @Override
    public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
        this.fm = fm;
        this.data = data;
        this.nFeatures = fm.getNumFeatures();
        this.nStates = fm.getNumStates();
        this.nSeqs = data.size();
        FeatureCacheLength cache = new FeatureCacheLength(fm, data, this.allPaths, this.maxStateLengths, new short[this.maxStateLengths.length], true);
        this.orderedPotentials = cache.orderedPotentials;
        this.id = cache.id;
        this.potentialIx = cache.potentialIx;
        this.val = cache.val;
        this.transitionFrom = cache.transitionFrom;
        this.transitionTo = cache.transitionTo;
        this.nTransitions = cache.nTransitions;
        this.transitionIndex = cache.transitionIndex;
        this.featureSums = cache.featureSums;
        this.starts = cache.starts;
        this.seqOffsets = cache.seqOffsets;
        this.invalidTransitions = cache.invalidTransitions;
        this.nConstantFeatures = cache.numConstantFeatures;
        this.nPotentials = cache.nPotentials;
        this.maxLookback = cache.maxLookback;
        this.lengthStarts = cache.lengthStarts;
        this.lookbacks = cache.lookbacks;
        this.lengthPotentials = cache.lengthPotentials;
        this.lengthIndexes = cache.lengthIndexes;
        this.lengthVals = cache.lengthVals;
        this.statesWithLookback = cache.statesWithLookback;
        this.statesWithoutLookback = cache.statesWithoutLookback;
        this.alphas = new double[cache.longestSeq][this.nStates];
        this.alphaNorms = new int[cache.longestSeq];
        this.constMi = new double[this.nTransitions];
        this.expects = new double[this.nFeatures];
        LookbackBuffer[] bufferContents = new LookbackBuffer[this.maxLookback + 3];
        for (int i = 0; i < this.maxLookback + 3; ++i) {
            bufferContents[i] = new LookbackBuffer();
        }
        this.lookbackBuffer = new RecyclingBuffer<LookbackBuffer>(bufferContents);
        this.nextBuffer = new LookbackBuffer();
        this.alphaProcessor = new AlphaLengthFeatureProcessor();
        this.betaProcessor = new BetaLengthFeatureProcessor();
        this.starterAlpha = new double[this.nStates];
    }

    @Override
    public double apply(double[] param, double[] grad) {
        this.lambda = param;
        Arrays.fill(grad, 0.0);
        double result = 0.0;
        Arrays.fill(this.constMi, 0.0);
        this.calcMi(this.constMi, -1, 0, this.starts[0], false);
        Arrays.fill(this.expects, 0.0);
        int seqStart = 0;
        for (int i = 0; i < this.nSeqs; ++i) {
            int len = this.seqOffsets[i + 1] - this.seqOffsets[i];
            this.alphaProcessor.computeAlpha(seqStart, len);
            double sum = 0.0;
            for (double localVal : this.alphas[len - 1]) {
                sum += localVal;
            }
            this.logZ = this.log(sum) + (double)(10 * this.alphaNorms[len - 1]);
            this.zNorm = (int)this.logZ / 10;
            this.zInv = this.exp((double)(this.zNorm * 10) - this.logZ);
            this.betaProcessor.computeBetasAndExpectations(seqStart, len);
            result -= this.logZ;
            seqStart += len;
        }
        for (int j = 0; j < this.nFeatures; ++j) {
            result += this.featureSums[j] * param[j];
            grad[j] = this.featureSums[j] - this.expects[j];
        }
        if (log.isInfoEnabled()) {
            log.info((Object)String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad: %s", this.iter, this.exp(result), result, ColtUtil.norm(grad), ColtUtil.format(this.featureSums), ColtUtil.format(this.expects), ColtUtil.format(param), ColtUtil.format(grad)));
        }
        ++this.iter;
        int totalPositions = this.seqOffsets[this.seqOffsets.length - 1];
        result /= (double)totalPositions;
        for (int i = 0; i < grad.length; ++i) {
            grad[i] = grad[i] / (double)totalPositions;
        }
        return result;
    }

    void logBuf() {
        int l = this.lookbackBuffer.length;
        String s = "";
        for (int i = 0; i < l; ++i) {
            s = s + this.lookbackBuffer.get((int)i).pos + " ";
        }
        log.info((Object)s);
    }

    void logBufBeta() {
        int l = this.lookbackBuffer.length;
        String s = "";
        for (int i = 0; i < l; ++i) {
            s = s + ColtUtil.format(this.lookbackBuffer.get((int)i).beta) + " ";
        }
        log.info((Object)s);
    }

    void cacheMi(double[] mi, double[] prevStable, double[] newStable, int seqStart, int miPos, int cacheStart, int cacheStop) {
        if (miPos < 0) {
            return;
        }
        int overallPosition = seqStart + miPos;
        this.calcMi(mi, overallPosition, cacheStart, cacheStop, false);
        for (int i = 0; i < this.nStates; ++i) {
            if (this.maxStateLengths[i] <= 1) continue;
            newStable[i] = mi[this.transitionIndex.getQuick(i, i)] + prevStable[i];
        }
    }

    void calcMi(double[] mi, int overallPosition, int current, int stop, boolean doExp) {
        byte cachedPotential = -1;
        double cachedVal = Double.NaN;
        if (current < stop) {
            cachedPotential = this.potentialIx[current];
            cachedVal = (double)this.val[current] * this.lambda[this.id[current]];
            ++current;
        }
        double nodeVal = Double.NaN;
        int invalidIndex = overallPosition * this.nPotentials;
        for (byte by : this.orderedPotentials) {
            double features;
            boolean invalid = overallPosition != -1 && this.invalidTransitions[invalidIndex + by];
            double d = features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
            while (cachedPotential == by) {
                features += cachedVal;
                if (current >= stop) break;
                cachedVal = (double)this.val[current] * this.lambda[this.id[current]];
                cachedPotential = this.potentialIx[current];
                ++current;
            }
            if (by < this.nStates) {
                nodeVal = features;
                continue;
            }
            int transition = by - this.nStates;
            double val1 = features + nodeVal + this.constMi[transition];
            if (doExp) {
                val1 = this.exp(val1);
            }
            mi[transition] = val1;
        }
        Assert.a(current == stop, "Pos: ", overallPosition, " Expected ", stop, " features only found ", current, " Pot: ", cachedPotential, " Val: ", cachedVal);
    }

    void renormalize(double[] vec, int currentNorm, int newNorm) {
        double factor = this.exp(10 * (currentNorm - newNorm));
        int len = vec.length;
        int i = 0;
        while (i < len) {
            int n = i++;
            vec[n] = vec[n] * factor;
        }
    }

    int normalize(double[] vec) {
        double sum = 0.0;
        int len = vec.length;
        for (int i = 0; i < len; ++i) {
            sum += vec[i];
        }
        if (sum == 0.0) {
            return 0;
        }
        Assert.a(!Double.isNaN(sum));
        if (sum > NORM_MIN && sum < NORM_MAX) {
            return 0;
        }
        double val1 = this.log(sum);
        int norm = (int)val1 / 10;
        val1 = this.exp(10 * norm);
        int i = 0;
        while (i < len) {
            int n = i++;
            vec[n] = vec[n] / val1;
        }
        return norm;
    }

    abstract class LengthFeatureProcessor {
        LengthFeatureProcessor() {
        }

        abstract void potential(short var1, byte var2, int var3, float var4, int var5, int var6, int var7, int var8);

        void lengthCache(int overallPosition) {
            int lengthCacheStart;
            short lastLookback = -1;
            byte lastPotential = -1;
            byte toNode = -1;
            float nodeValue = 0.0f;
            float potentialValue = 0.0f;
            int lengthCacheStop = CachedSemiCRFGradient.this.lengthStarts[overallPosition + 1];
            int nodeStart = -1;
            int nodeStop = -1;
            int edgeStart = -1;
            for (lengthCacheStart = CachedSemiCRFGradient.this.lengthStarts[overallPosition]; lengthCacheStart < lengthCacheStop; ++lengthCacheStart) {
                short lookback = CachedSemiCRFGradient.this.lookbacks[lengthCacheStart];
                short index = CachedSemiCRFGradient.this.lengthIndexes[lengthCacheStart];
                byte lengthPotential = CachedSemiCRFGradient.this.lengthPotentials[lengthCacheStart];
                if (lookback != lastLookback || lengthPotential != lastPotential) {
                    if (lengthPotential < CachedSemiCRFGradient.this.nStates) {
                        if (lastPotential != -1) {
                            if (lastPotential < CachedSemiCRFGradient.this.nStates) {
                                this.potential(lastLookback, lastPotential, -1, potentialValue, nodeStart, lengthCacheStart, -1, -1);
                            } else {
                                this.potential(lastLookback, toNode, lastPotential - CachedSemiCRFGradient.this.nStates, potentialValue, nodeStart, nodeStop, edgeStart, lengthCacheStart);
                            }
                        }
                        toNode = lengthPotential;
                        nodeValue = 0.0f;
                        potentialValue = 0.0f;
                        nodeStart = lengthCacheStart;
                        nodeStop = -1;
                    } else {
                        byte node = (byte)CachedSemiCRFGradient.this.transitionTo[lengthPotential - CachedSemiCRFGradient.this.nStates];
                        if (lastPotential == -1) {
                            toNode = node;
                        } else if (lastPotential < CachedSemiCRFGradient.this.nStates) {
                            if (lookback == lastLookback && node == toNode) {
                                nodeStop = lengthCacheStart;
                                nodeValue = potentialValue;
                            } else {
                                this.potential(lastLookback, toNode, -1, potentialValue, nodeStart, lengthCacheStart, -1, -1);
                                nodeValue = 0.0f;
                                nodeStart = -1;
                                nodeStop = -1;
                                toNode = node;
                            }
                        } else {
                            int lastTrans = lastPotential - CachedSemiCRFGradient.this.nStates;
                            byte lastNode = (byte)CachedSemiCRFGradient.this.transitionTo[lastPotential - CachedSemiCRFGradient.this.nStates];
                            this.potential(lastLookback, lastNode, lastTrans, potentialValue, nodeStart, nodeStop, edgeStart, lengthCacheStart);
                            if (lastNode != toNode || lookback != lastLookback) {
                                nodeValue = 0.0f;
                                nodeStart = -1;
                                nodeStop = -1;
                                toNode = node;
                            }
                        }
                        edgeStart = lengthCacheStart;
                        potentialValue = nodeValue;
                    }
                    lastLookback = lookback;
                    lastPotential = lengthPotential;
                }
                if (index == -1) continue;
                potentialValue = (float)((double)potentialValue + (double)CachedSemiCRFGradient.this.lengthVals[lengthCacheStart] * CachedSemiCRFGradient.this.lambda[index]);
            }
            if (lastPotential >= CachedSemiCRFGradient.this.nStates) {
                this.potential(lastLookback, toNode, lastPotential - CachedSemiCRFGradient.this.nStates, potentialValue, nodeStart, nodeStop, edgeStart, lengthCacheStart);
            } else if (lastPotential != -1) {
                this.potential(lastLookback, lastPotential, -1, potentialValue, nodeStart, lengthCacheStart, -1, -1);
            }
        }
    }

    class BetaLengthFeatureProcessor
    extends LengthFeatureProcessor {
        int lengthPos;
        double[] lengthStable;
        int miPos;
        double[] stableState;
        double[] beta;
        int betaNorm;
        LookbackBuffer posLookback;
        LookbackBuffer prevLookback;
        double[] nodeProb;
        double[] newNodeProb;
        double[] edgeProb;

        BetaLengthFeatureProcessor() {
            this.nodeProb = new double[CachedSemiCRFGradient.this.nStates];
            this.newNodeProb = new double[CachedSemiCRFGradient.this.nStates];
            this.edgeProb = new double[CachedSemiCRFGradient.this.nTransitions];
        }

        void computeBetasAndExpectations(int seqStart, int len) {
            int lastInitPos = len - 2 - CachedSemiCRFGradient.this.maxLookback;
            int cacheStop = CachedSemiCRFGradient.this.starts[seqStart + len];
            this.posLookback = CachedSemiCRFGradient.this.nextBuffer;
            this.prevLookback = null;
            Arrays.fill(this.nodeProb, 0.0);
            Arrays.fill(this.newNodeProb, 0.0);
            this.miPos = len - 1;
            CachedSemiCRFGradient.this.nextBuffer.clear();
            for (int pos = len - 1; pos >= 0; --pos) {
                while (this.miPos >= 0 && this.miPos >= lastInitPos) {
                    if (this.miPos == len - 1) {
                        Arrays.fill(CachedSemiCRFGradient.this.nextBuffer.stableState, 0.0);
                    } else {
                        int cacheStart = CachedSemiCRFGradient.this.starts[seqStart + this.miPos + 1];
                        CachedSemiCRFGradient.this.nextBuffer.clear();
                        CachedSemiCRFGradient.this.cacheMi(CachedSemiCRFGradient.this.nextBuffer.mi, this.stableState, CachedSemiCRFGradient.this.nextBuffer.stableState, seqStart, this.miPos + 1, cacheStart, cacheStop);
                        cacheStop = cacheStart;
                    }
                    CachedSemiCRFGradient.this.nextBuffer.pos = this.miPos--;
                    this.stableState = CachedSemiCRFGradient.this.nextBuffer.stableState;
                    CachedSemiCRFGradient.this.nextBuffer = CachedSemiCRFGradient.this.lookbackBuffer.addFirst(CachedSemiCRFGradient.this.nextBuffer);
                    CachedSemiCRFGradient.this.nextBuffer.clear();
                }
                this.posLookback = CachedSemiCRFGradient.this.lookbackBuffer.get(pos - this.miPos - 1);
                Assert.a(this.posLookback.pos == pos, "Wrong lookback buffer: was ", this.posLookback.pos, " should be ", pos);
                if (this.prevLookback == null) {
                    Assert.a(pos == len - 1);
                    Arrays.fill(this.posLookback.beta, 1.0);
                    this.posLookback.betaNorm = 0;
                    double nodeNorm = CachedSemiCRFGradient.this.exp((CachedSemiCRFGradient.this.alphaNorms[pos] - CachedSemiCRFGradient.this.zNorm) * 10) * CachedSemiCRFGradient.this.zInv;
                    for (int i = 0; i < CachedSemiCRFGradient.this.nStates; ++i) {
                        this.nodeProb[i] = nodeNorm * CachedSemiCRFGradient.this.alphas[pos][i];
                    }
                } else {
                    Assert.a(this.prevLookback.pos == pos + 1);
                    this.posLookback.betaNorm = this.regularBetaUpdate(pos + 1, this.posLookback.beta, this.posLookback.betaNorm, this.prevLookback.beta, this.prevLookback.betaNorm, this.prevLookback.transitionProb, this.posLookback.mi);
                    this.beta = this.prevLookback.beta;
                    this.betaNorm = this.prevLookback.betaNorm;
                    this.lengthStable = this.prevLookback.stableState;
                    this.lengthPos = pos + 1;
                    CachedSemiCRFGradient.this.betaProcessor.lengthCache(seqStart + this.lengthPos);
                    System.arraycopy(this.nodeProb, 0, this.newNodeProb, 0, CachedSemiCRFGradient.this.nStates);
                    for (CacheProcessor.StatePotentials lb : CachedSemiCRFGradient.this.statesWithLookback) {
                        byte state = lb.state;
                        int index = CachedSemiCRFGradient.this.transitionIndex.getQuick(state, state);
                        double transProb = this.nodeProb[state];
                        for (byte pot : lb.potentials) {
                            double lbTrans = this.posLookback.transitionProb[pot - CachedSemiCRFGradient.this.nStates];
                            transProb -= lbTrans;
                            byte by = state;
                            this.newNodeProb[by] = this.newNodeProb[by] - lbTrans;
                        }
                        Assert.a(this.posLookback.transitionProb[index] == 0.0);
                        this.posLookback.transitionProb[index] = transProb;
                    }
                }
                double sum = 0.0;
                for (double x : this.nodeProb) {
                    sum += x;
                }
                if (Math.abs(1.0 - sum) > 0.001) {
                    Assert.a(false, "Node marginals don't sum to 1: ", ColtUtil.format(this.nodeProb));
                }
                if (this.prevLookback != null) {
                    sum = 0.0;
                    for (double x : this.posLookback.transitionProb) {
                        sum += x;
                    }
                    for (double x : this.edgeProb) {
                        sum += x;
                    }
                    if (Math.abs(1.0 - sum) > 0.001) {
                        Assert.a(false, "Edge marginals don't sum to 1: ", ColtUtil.format(this.edgeProb), ColtUtil.format(this.posLookback.transitionProb));
                    }
                    this.updateExpectations(seqStart, pos + 1, this.posLookback.transitionProb);
                }
                double[] temp = this.nodeProb;
                this.nodeProb = this.newNodeProb;
                this.newNodeProb = temp;
                if (CachedSemiCRFGradient.this.debug && seqStart == 0 && (pos < 2 || pos >= len - 2)) {
                    log.debug((Object)String.format("Pos: %d expects: %s alphas: %s (norm %d) betas: %s (norm %d) MiPos: %d", pos, ColtUtil.format(CachedSemiCRFGradient.this.expects), ColtUtil.format(CachedSemiCRFGradient.this.alphas[pos]), CachedSemiCRFGradient.this.alphaNorms[pos], ColtUtil.format(this.posLookback.beta), this.posLookback.betaNorm, this.miPos + 1));
                }
                this.prevLookback = this.posLookback;
                --lastInitPos;
            }
            this.posLookback.betaNorm = this.regularBetaUpdate(0, null, this.posLookback.betaNorm, this.posLookback.beta, this.posLookback.betaNorm, null, null);
            this.beta = this.posLookback.beta;
            this.betaNorm = this.posLookback.betaNorm;
            this.lengthStable = this.posLookback.stableState;
            this.lengthPos = 0;
            CachedSemiCRFGradient.this.betaProcessor.lengthCache(seqStart);
            this.updateExpectations(seqStart, 0, this.posLookback.transitionProb);
        }

        private int regularBetaUpdate(int pos, double[] newBeta, int newNorm, double[] oldBeta, int oldNorm, double[] transitionProb, double[] mi) {
            int norm = newNorm;
            double normAdjust = 0.0;
            if (Math.abs(oldNorm) > Math.abs(newNorm)) {
                CachedSemiCRFGradient.this.renormalize(newBeta, newNorm, oldNorm);
                norm = oldNorm;
                newNorm = oldNorm;
            } else {
                normAdjust = (oldNorm - newNorm) * 10;
            }
            double[] nodeAlpha = CachedSemiCRFGradient.this.alphas[pos];
            double nodeNorm = CachedSemiCRFGradient.this.exp((CachedSemiCRFGradient.this.alphaNorms[pos] + oldNorm - CachedSemiCRFGradient.this.zNorm) * 10) * CachedSemiCRFGradient.this.zInv;
            double[] edgeAlpha = null;
            double edgeNorm = Double.NaN;
            if (pos > 0) {
                edgeAlpha = CachedSemiCRFGradient.this.alphas[pos - 1];
                edgeNorm = CachedSemiCRFGradient.this.exp((CachedSemiCRFGradient.this.alphaNorms[pos - 1] + newNorm - CachedSemiCRFGradient.this.zNorm) * 10) * CachedSemiCRFGradient.this.zInv;
            }
            for (CacheProcessor.StatePotentials potentials : CachedSemiCRFGradient.this.statesWithoutLookback) {
                byte node = potentials.state;
                double nodePotential = 0.0;
                double betaVal = oldBeta[node];
                this.nodeProb[node] = nodeAlpha[node] * betaVal * nodeNorm;
                if (pos <= 0) continue;
                byte[] arr$ = potentials.potentials;
                int len$ = arr$.length;
                for (int i$ = 0; i$ < len$; ++i$) {
                    short potential = arr$[i$];
                    int trans = potential - CachedSemiCRFGradient.this.nStates;
                    short from = CachedSemiCRFGradient.this.transitionFrom[trans];
                    double potentialValue = CachedSemiCRFGradient.this.exp(mi[trans] + normAdjust);
                    nodePotential += potentialValue;
                    short s = from;
                    newBeta[s] = newBeta[s] + potentialValue * betaVal;
                    this.edgeProb[trans] = edgeAlpha[from] * potentialValue * betaVal * edgeNorm;
                }
            }
            int ret = norm;
            if (newBeta != null) {
                ret += CachedSemiCRFGradient.this.normalize(newBeta);
            }
            return ret;
        }

        @Override
        void potential(short lookback, byte toNode, int trans, float potentialValue, int nodeStart, int nodeStop, int edgeStart, int edgeStop) {
            short index;
            int currentFeature;
            double transVal;
            int bufferPos = this.lengthPos - lookback - 1;
            int lbIndex = bufferPos - this.miPos - 1;
            LookbackBuffer buffer = trans == -1 ? null : CachedSemiCRFGradient.this.lookbackBuffer.get(lbIndex);
            LookbackBuffer stableBuffer = CachedSemiCRFGradient.this.lookbackBuffer.get(lbIndex + 1);
            int prevPos = this.lengthPos - lookback - 1;
            int fromNode = trans == -1 ? -1 : CachedSemiCRFGradient.this.transitionFrom[trans];
            double stableValue = stableBuffer.stableState[toNode] - this.lengthStable[toNode];
            double prevAlpha = 1.0;
            double prevAlphaNorm = 0.0;
            if (trans != -1) {
                Assert.a(this.prevLookback.pos == this.lengthPos, "Expected ", this.lengthPos, " was ", this.prevLookback.pos);
                Assert.a(buffer.pos == this.lengthPos - lookback - 1, "Expected ", this.lengthPos - lookback - 1, " was ", buffer.pos);
                transVal = buffer.mi[trans];
                prevAlpha = CachedSemiCRFGradient.this.alphas[prevPos][fromNode];
                prevAlphaNorm = CachedSemiCRFGradient.this.alphaNorms[prevPos];
            } else {
                transVal = CachedSemiCRFGradient.this.starterAlpha[toNode];
            }
            double expVal = (double)potentialValue + stableValue + transVal;
            int norm = (int)expVal / 10;
            expVal -= (double)(norm * 10);
            double prob = prevAlpha * this.beta[toNode] * CachedSemiCRFGradient.this.exp(expVal + 10.0 * (prevAlphaNorm + (double)(norm += this.betaNorm) - (double)CachedSemiCRFGradient.this.zNorm)) * CachedSemiCRFGradient.this.zInv;
            Assert.a(!Double.isNaN(prob));
            for (currentFeature = nodeStart; currentFeature < nodeStop; ++currentFeature) {
                index = CachedSemiCRFGradient.this.lengthIndexes[currentFeature];
                if (index == -1) continue;
                short s = index;
                CachedSemiCRFGradient.this.expects[s] = CachedSemiCRFGradient.this.expects[s] + prob * (double)CachedSemiCRFGradient.this.lengthVals[currentFeature];
            }
            for (currentFeature = edgeStart; currentFeature < edgeStop; ++currentFeature) {
                index = CachedSemiCRFGradient.this.lengthIndexes[currentFeature];
                if (index == -1) continue;
                short s = index;
                CachedSemiCRFGradient.this.expects[s] = CachedSemiCRFGradient.this.expects[s] + prob * (double)CachedSemiCRFGradient.this.lengthVals[currentFeature];
            }
            byte by = toNode;
            this.nodeProb[by] = this.nodeProb[by] + prob;
            if (trans != -1) {
                if (Math.abs(norm) > Math.abs(buffer.betaNorm)) {
                    CachedSemiCRFGradient.this.renormalize(buffer.beta, buffer.betaNorm, norm);
                    buffer.betaNorm = norm;
                } else if (Math.abs(norm) < Math.abs(buffer.betaNorm)) {
                    expVal -= (double)(10 * (buffer.betaNorm - norm));
                }
                double transPotential = CachedSemiCRFGradient.this.exp(expVal);
                int n = trans;
                buffer.transitionProb[n] = buffer.transitionProb[n] + prob;
                double update = transPotential * this.beta[toNode];
                int n2 = fromNode;
                buffer.beta[n2] = buffer.beta[n2] + update;
            }
        }

        void updateExpectations(int seqStart, int pos, double[] transitionProb) {
            int overallPosition = seqStart + pos;
            int posCurrent = CachedSemiCRFGradient.this.starts[overallPosition];
            int posStop = CachedSemiCRFGradient.this.starts[overallPosition + 1];
            int constCurrent = 0;
            int constId = -1;
            short constPotential = -1;
            double constVal = Double.NaN;
            int posId = -1;
            short posPotential = -1;
            double posVal = Double.NaN;
            boolean includeEdges = pos != 0;
            int invalidIndex = overallPosition * CachedSemiCRFGradient.this.nPotentials;
            boolean lengthNode = false;
            block0: for (short potential : CachedSemiCRFGradient.this.orderedPotentials) {
                boolean invalid = CachedSemiCRFGradient.this.invalidTransitions[invalidIndex + potential];
                double prob = Double.NaN;
                if (potential < CachedSemiCRFGradient.this.nStates) {
                    lengthNode = CachedSemiCRFGradient.this.maxStateLengths[potential] > 1;
                    prob = this.nodeProb[potential];
                } else {
                    prob = (lengthNode ? transitionProb : this.edgeProb)[potential - CachedSemiCRFGradient.this.nStates];
                }
                while (constPotential == -1 || constPotential == potential) {
                    if (constPotential != -1 && !invalid && (includeEdges || potential < CachedSemiCRFGradient.this.nStates)) {
                        int n = constId;
                        CachedSemiCRFGradient.this.expects[n] = CachedSemiCRFGradient.this.expects[n] + prob * constVal;
                    }
                    if (constCurrent >= CachedSemiCRFGradient.this.nConstantFeatures) break;
                    constId = CachedSemiCRFGradient.this.id[constCurrent];
                    constVal = CachedSemiCRFGradient.this.val[constCurrent];
                    constPotential = CachedSemiCRFGradient.this.potentialIx[constCurrent];
                    ++constCurrent;
                }
                if (invalid) continue;
                while (posId == -1 || posPotential == potential) {
                    if (posId != -1) {
                        int n = posId;
                        CachedSemiCRFGradient.this.expects[n] = CachedSemiCRFGradient.this.expects[n] + prob * posVal;
                    }
                    if (posCurrent >= posStop) continue block0;
                    posId = CachedSemiCRFGradient.this.id[posCurrent];
                    posVal = CachedSemiCRFGradient.this.val[posCurrent];
                    posPotential = CachedSemiCRFGradient.this.potentialIx[posCurrent];
                    ++posCurrent;
                }
            }
            Assert.a(constCurrent == CachedSemiCRFGradient.this.nConstantFeatures);
            Assert.a(posCurrent == posStop);
        }
    }

    class AlphaLengthFeatureProcessor
    extends LengthFeatureProcessor {
        int pos;
        double[] alpha;
        int alphaNorm;
        double[] stableState;

        AlphaLengthFeatureProcessor() {
        }

        void computeAlpha(int seqStart, int len) {
            Arrays.fill(CachedSemiCRFGradient.this.alphaNorms, 0);
            Arrays.fill(CachedSemiCRFGradient.this.starterAlpha, 0.0);
            int cacheStart = CachedSemiCRFGradient.this.starts[seqStart];
            double[] prevAlpha = null;
            this.pos = 0;
            while (this.pos < len) {
                int overallPosition = seqStart + this.pos;
                int cacheStop = CachedSemiCRFGradient.this.starts[overallPosition + 1];
                prevAlpha = this.alpha;
                this.alpha = CachedSemiCRFGradient.this.alphas[this.pos];
                Arrays.fill(this.alpha, 0.0);
                if (this.pos == 0) {
                    this.alphaNorm = 0;
                    this.calcStartAlpha(this.alpha, overallPosition, cacheStart, cacheStop);
                    Arrays.fill(CachedSemiCRFGradient.this.nextBuffer.stableState, 0.0);
                } else {
                    CachedSemiCRFGradient.this.cacheMi(CachedSemiCRFGradient.this.nextBuffer.mi, this.stableState, CachedSemiCRFGradient.this.nextBuffer.stableState, seqStart, this.pos, cacheStart, cacheStop);
                    this.regularAlphaUpdate(this.pos, CachedSemiCRFGradient.this.nextBuffer.mi, prevAlpha, this.alpha);
                }
                this.stableState = CachedSemiCRFGradient.this.nextBuffer.stableState;
                CachedSemiCRFGradient.this.nextBuffer = CachedSemiCRFGradient.this.lookbackBuffer.addFirst(CachedSemiCRFGradient.this.nextBuffer);
                CachedSemiCRFGradient.this.alphaProcessor.lengthCache(overallPosition);
                this.alphaNorm += CachedSemiCRFGradient.this.normalize(this.alpha);
                CachedSemiCRFGradient.this.alphaNorms[this.pos] = this.alphaNorm;
                cacheStart = cacheStop;
                ++this.pos;
            }
        }

        private void regularAlphaUpdate(int argPos, double[] mi, double[] lastAlpha, double[] newAlpha) {
            double nodeVal = 0.0;
            int lastState = -1;
            boolean lengthNode = false;
            for (int n : CachedSemiCRFGradient.this.orderedPotentials) {
                if (n < CachedSemiCRFGradient.this.nStates) {
                    if (lastState != -1) {
                        newAlpha[lastState] = nodeVal;
                    }
                    lastState = n;
                    nodeVal = 0.0;
                    lengthNode = CachedSemiCRFGradient.this.maxStateLengths[n] > 1;
                    continue;
                }
                if (lengthNode) continue;
                int trans = n - CachedSemiCRFGradient.this.nStates;
                short from = CachedSemiCRFGradient.this.transitionFrom[trans];
                nodeVal += lastAlpha[from] * CachedSemiCRFGradient.this.exp(mi[trans]);
            }
            newAlpha[lastState] = nodeVal;
        }

        @Override
        void potential(short lookback, byte toNode, int trans, float potentialValue, int nodeStart, int nodeStop, int edgeStart, int edgeStop) {
            int prevPos = this.pos - lookback - 1;
            LookbackBuffer buffer = CachedSemiCRFGradient.this.lookbackBuffer.get(lookback);
            double stableValue = this.stableState[toNode] - buffer.stableState[toNode];
            double expVal = (double)potentialValue + stableValue;
            double prevAlpha = 1.0;
            if (prevPos >= 0) {
                short fromNode = CachedSemiCRFGradient.this.transitionFrom[trans];
                expVal += buffer.mi[trans] + (double)(CachedSemiCRFGradient.this.alphaNorms[prevPos] * 10);
                prevAlpha = CachedSemiCRFGradient.this.alphas[prevPos][fromNode];
            } else {
                expVal += CachedSemiCRFGradient.this.starterAlpha[toNode];
            }
            int norm = (int)expVal / 10;
            expVal -= (double)(norm * 10);
            if (Math.abs(norm) > Math.abs(this.alphaNorm)) {
                CachedSemiCRFGradient.this.renormalize(this.alpha, this.alphaNorm, norm);
                this.alphaNorm = norm;
            } else if (Math.abs(norm) < Math.abs(this.alphaNorm)) {
                expVal += (double)(10 * (norm - this.alphaNorm));
            }
            double update = CachedSemiCRFGradient.this.exp(expVal) * prevAlpha;
            byte by = toNode;
            this.alpha[by] = this.alpha[by] + update;
        }

        void calcStartAlpha(double[] alpha1, int overallPosition, int posCurrent, int posStop) {
            int constCurrent = 0;
            short constPotential = -1;
            double constVal = Double.NaN;
            short posPotential = -1;
            double posVal = Double.NaN;
            int invalidIndex = overallPosition * CachedSemiCRFGradient.this.nPotentials;
            for (short potential : CachedSemiCRFGradient.this.orderedPotentials) {
                double features;
                boolean invalid = CachedSemiCRFGradient.this.invalidTransitions[invalidIndex + potential];
                double d = features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
                while (constPotential == -1 || constPotential == potential) {
                    if (constPotential != -1 && potential < CachedSemiCRFGradient.this.nStates) {
                        Assert.a(!Double.isNaN(features += constVal));
                    }
                    if (constCurrent >= CachedSemiCRFGradient.this.nConstantFeatures) break;
                    constVal = (double)CachedSemiCRFGradient.this.val[constCurrent] * CachedSemiCRFGradient.this.lambda[CachedSemiCRFGradient.this.id[constCurrent]];
                    constPotential = CachedSemiCRFGradient.this.potentialIx[constCurrent];
                    ++constCurrent;
                }
                if (potential >= CachedSemiCRFGradient.this.nStates) continue;
                while (posPotential == -1 || posPotential == potential) {
                    if (posPotential != -1) {
                        features += posVal;
                    }
                    if (posCurrent >= posStop) break;
                    posVal = (double)CachedSemiCRFGradient.this.val[posCurrent] * CachedSemiCRFGradient.this.lambda[CachedSemiCRFGradient.this.id[posCurrent]];
                    posPotential = CachedSemiCRFGradient.this.potentialIx[posCurrent];
                    ++posCurrent;
                }
                if (CachedSemiCRFGradient.this.maxStateLengths[potential] > 1) {
                    CachedSemiCRFGradient.this.starterAlpha[potential] = features;
                    continue;
                }
                alpha1[potential] = CachedSemiCRFGradient.this.exp(features);
            }
            Assert.a(constCurrent == CachedSemiCRFGradient.this.nConstantFeatures);
            Assert.a(posCurrent == posStop);
        }
    }

    class LookbackBuffer {
        int pos;
        double[] mi;
        double[] stableState;
        double[] beta;
        int betaNorm;
        double[] transitionProb;

        LookbackBuffer() {
            this.mi = new double[CachedSemiCRFGradient.this.nPotentials];
            this.stableState = new double[CachedSemiCRFGradient.this.nStates];
            this.beta = new double[CachedSemiCRFGradient.this.nStates];
            this.transitionProb = new double[CachedSemiCRFGradient.this.nTransitions];
        }

        void clear() {
            this.pos = -1;
            Arrays.fill(this.beta, 0.0);
            this.betaNorm = 0;
            Arrays.fill(this.transitionProb, 0.0);
        }
    }
}

