/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.check;

import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.check.FeatureCache;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class OldCachedCRFGradient
implements CRFObjectiveFunctionGradient {
    private static final Log log = LogFactory.getLog(OldCachedCRFGradient.class);
    boolean debug = log.isDebugEnabled();
    boolean allPaths;
    short[] id;
    byte[] potentialIx;
    float[] val;
    int miLength;
    double[] constMi;
    double[] mi;
    short[] transitionFrom;
    short[] transitionTo;
    short[] orderedPotentials;
    boolean[] invalidTransitions;
    int totalPositions;
    double[] featureSums;
    int[] starts;
    int[] seqOffsets;
    int nSeqs;
    int nConstantFeatures;
    List<? extends TrainingSequence<?>> data;
    ModelManager fm;
    int nFeatures;
    int nStates;
    int nPotentials;
    int iter = 0;
    double[] prevAlpha;
    double[] alpha;
    double[][] betas;
    double[] betaNorms;
    double[] expects;

    double exp(double val1) {
        return Math.exp(val1);
    }

    double log(double val1) {
        return Math.log(val1);
    }

    public OldCachedCRFGradient(boolean allPaths) {
        this.allPaths = allPaths;
    }

    @Override
    public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
        this.fm = fm;
        this.data = data;
        this.nFeatures = fm.getNumFeatures();
        this.nStates = fm.getNumStates();
        this.nSeqs = data.size();
        this.expects = new double[this.nFeatures];
        this.prevAlpha = new double[this.nStates];
        this.alpha = new double[this.nStates];
        FeatureCache cache = new FeatureCache(fm, data, this.allPaths);
        this.orderedPotentials = cache.orderedPotentials;
        this.id = cache.id;
        this.potentialIx = cache.potentialIx;
        this.val = cache.val;
        this.transitionFrom = cache.transitionFrom;
        this.transitionTo = cache.transitionTo;
        this.miLength = cache.nTransitions;
        this.constMi = new double[this.miLength];
        this.mi = new double[this.miLength];
        this.featureSums = cache.featureSums;
        this.starts = cache.starts;
        this.seqOffsets = cache.seqOffsets;
        this.invalidTransitions = cache.invalidTransitions;
        this.nConstantFeatures = cache.numConstantFeatures;
        this.nPotentials = cache.nPotentials;
        this.totalPositions = cache.totalPositions;
        this.betas = new double[cache.longestSeq][this.nStates];
        this.betaNorms = new double[cache.longestSeq];
    }

    @Override
    public void clean() {
    }

    @Override
    public double apply(double[] param, double[] grad) {
        int i;
        Arrays.fill(grad, 0.0);
        double result = 0.0;
        Arrays.fill(this.constMi, 0.0);
        this.calcMi(-1, 0, this.starts[0], param);
        for (int i2 = 0; i2 < this.miLength; ++i2) {
            this.constMi[i2] = this.log(this.mi[i2]);
        }
        Arrays.fill(this.expects, 0.0);
        int seqStart = 0;
        for (i = 0; i < this.nSeqs; ++i) {
            int len = this.seqOffsets[i + 1] - this.seqOffsets[i];
            Arrays.fill(this.betas[len - 1], 1.0);
            this.betaNorms[len - 1] = 0.0;
            int cacheStop = this.starts[seqStart + len];
            for (int pos = len - 1; pos > 0; --pos) {
                int overallPosition = seqStart + pos;
                int cacheStart = this.starts[overallPosition];
                this.calcMi(overallPosition, cacheStart, cacheStop, param);
                cacheStop = cacheStart;
                this.quickBetaUpdate(this.betas[pos], this.betas[pos - 1]);
                double n = this.normalizePotential(this.betas[pos - 1]);
                this.betaNorms[pos - 1] = this.betaNorms[pos] + this.log(n);
            }
            double logZ = Double.NEGATIVE_INFINITY;
            double alphaNorm = 0.0;
            double prevAlphaNorm = 0.0;
            int cacheStart = this.starts[seqStart];
            for (int pos = 0; pos < len; ++pos) {
                int overallPosition = seqStart + pos;
                double[] beta = this.betas[pos];
                double betaNorm = this.betaNorms[pos];
                cacheStop = this.starts[overallPosition + 1];
                if (pos == 0) {
                    this.calcStartAlpha(overallPosition, cacheStart, cacheStop, param);
                    alphaNorm = this.log(this.normalizePotential(this.alpha));
                    logZ = this.log(ColtUtil.dotProduct(this.alpha, beta)) + betaNorm + alphaNorm;
                } else {
                    this.calcMi(overallPosition, cacheStart, cacheStop, param);
                    this.quickAlphaUpdate(this.prevAlpha, this.alpha);
                    alphaNorm = prevAlphaNorm + this.log(this.normalizePotential(this.alpha));
                    double newZ = this.log(ColtUtil.dotProduct(this.alpha, beta)) + betaNorm + alphaNorm;
                    Assert.a(Math.abs(newZ - logZ) < 1.0E-7 * Math.abs(logZ), "New Z:", newZ, " Old was: ", logZ);
                }
                double nodeNorm = this.exp(alphaNorm + betaNorm - logZ);
                double edgeNorm = this.exp(prevAlphaNorm + betaNorm - logZ);
                this.updateExpectations(overallPosition, pos != 0, cacheStart, cacheStop, nodeNorm, edgeNorm, beta);
                if (!(!this.debug || i >= 2 && i != this.nSeqs - 1 || pos >= 2 && pos < len - 2)) {
                    log.debug((Object)String.format("Pos: %d expects: %s alphas: %s (norm %f) betas: %s (norm %f)", pos, ColtUtil.format(this.expects), ColtUtil.format(this.alpha), alphaNorm, ColtUtil.format(beta), betaNorm));
                }
                double[] swap = this.prevAlpha;
                this.prevAlpha = this.alpha;
                this.alpha = swap;
                prevAlphaNorm = alphaNorm;
                cacheStart = cacheStop;
            }
            result -= logZ;
            seqStart += len;
        }
        for (int j = 0; j < this.nFeatures; ++j) {
            result += this.featureSums[j] * param[j];
            grad[j] = this.featureSums[j] - this.expects[j];
        }
        if (log.isInfoEnabled()) {
            log.info((Object)String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad (unnorm): %s", this.iter, this.exp(result / (double)this.totalPositions), result / (double)this.totalPositions, ColtUtil.norm(grad) / (double)this.totalPositions, ColtUtil.format(this.featureSums), ColtUtil.format(this.expects), ColtUtil.format(param), ColtUtil.format(grad)));
        }
        ++this.iter;
        result /= (double)this.totalPositions;
        for (i = 0; i < grad.length; ++i) {
            grad[i] = grad[i] / (double)this.totalPositions;
        }
        return result;
    }

    void updateExpectations(int overallPos, boolean includeEdges, int posCurrent, int posStop, double nodeNorm, double edgeNorm, double[] beta) {
        int constCurrent = 0;
        int constId = -1;
        int constPotential = -1;
        double constVal = Double.NaN;
        if (constCurrent < this.nConstantFeatures) {
            constId = this.id[constCurrent];
            constPotential = this.potentialIx[constCurrent];
            constVal = this.val[constCurrent];
            ++constCurrent;
        }
        int posId = -1;
        int posPotential = -1;
        double posVal = Double.NaN;
        if (posCurrent < posStop) {
            posId = this.id[posCurrent];
            posPotential = this.potentialIx[posCurrent];
            posVal = this.val[posCurrent];
            ++posCurrent;
        }
        int currentNode = -1;
        double currentBeta = 0.0;
        int invalidIndex = overallPos * this.nPotentials;
        block0: for (int n : this.orderedPotentials) {
            boolean invalid = this.invalidTransitions[invalidIndex + n];
            double prob = 0.0;
            if (n < this.nStates) {
                currentNode = n;
                currentBeta = beta[currentNode];
                if (!invalid) {
                    prob = this.alpha[currentNode] * currentBeta * nodeNorm;
                }
            } else {
                int trans = n - this.nStates;
                short yprev = this.transitionFrom[trans];
                if (!invalid) {
                    prob = this.prevAlpha[yprev] * this.mi[trans] * currentBeta * edgeNorm;
                }
            }
            while (constPotential == n) {
                if (!invalid && (includeEdges || n < this.nStates)) {
                    int n2 = constId;
                    this.expects[n2] = this.expects[n2] + prob * constVal;
                }
                if (constCurrent >= this.nConstantFeatures) break;
                constId = this.id[constCurrent];
                constVal = this.val[constCurrent];
                constPotential = this.potentialIx[constCurrent];
                ++constCurrent;
            }
            if (invalid) continue;
            while (posPotential == n) {
                int n3 = posId;
                this.expects[n3] = this.expects[n3] + prob * posVal;
                if (posCurrent >= posStop) continue block0;
                posId = this.id[posCurrent];
                posVal = this.val[posCurrent];
                posPotential = this.potentialIx[posCurrent];
                ++posCurrent;
            }
        }
        Assert.a(constCurrent == this.nConstantFeatures);
        Assert.a(posCurrent == posStop);
    }

    void calcMi(int overallPosition, int current, int stop, double[] lambda) {
        short cachedPotential = -1;
        double cachedVal = Double.NaN;
        if (current < stop) {
            cachedPotential = this.potentialIx[current];
            cachedVal = (double)this.val[current] * lambda[this.id[current]];
            ++current;
        }
        double nodeVal = Double.NaN;
        int invalidIndex = overallPosition * this.nPotentials;
        for (short potential : this.orderedPotentials) {
            double features;
            boolean invalid = overallPosition != -1 && this.invalidTransitions[invalidIndex + potential];
            double d = features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
            while (cachedPotential == potential) {
                features += cachedVal;
                if (current >= stop) break;
                cachedVal = (double)this.val[current] * lambda[this.id[current]];
                cachedPotential = this.potentialIx[current];
                ++current;
            }
            if (potential < this.nStates) {
                nodeVal = features;
                continue;
            }
            int transition = potential - this.nStates;
            this.mi[transition] = this.exp(features + nodeVal + this.constMi[transition]);
        }
        Assert.a(current == stop, "Pos: ", overallPosition, " Expected ", stop, " features only found ", current);
    }

    void calcStartAlpha(int overallPosition, int posCurrent, int posStop, double[] lambda) {
        int constCurrent = 0;
        short constPotential = -1;
        double constVal = Double.NaN;
        if (constCurrent < this.nConstantFeatures) {
            constPotential = this.potentialIx[constCurrent];
            constVal = (double)this.val[constCurrent] * lambda[this.id[constCurrent]];
            ++constCurrent;
        }
        short posPotential = -1;
        double posVal = Double.NaN;
        if (posCurrent < posStop) {
            posPotential = this.potentialIx[posCurrent];
            posVal = (double)this.val[posCurrent] * lambda[this.id[posCurrent]];
            ++posCurrent;
        }
        int invalidIndex = overallPosition * this.nPotentials;
        for (short potential : this.orderedPotentials) {
            double features;
            boolean invalid = this.invalidTransitions[invalidIndex + potential];
            double d = features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
            while (constPotential == potential) {
                if (potential < this.nStates) {
                    Assert.a(!Double.isNaN(features += constVal));
                }
                if (constCurrent >= this.nConstantFeatures) break;
                constVal = (double)this.val[constCurrent] * lambda[this.id[constCurrent]];
                constPotential = this.potentialIx[constCurrent];
                ++constCurrent;
            }
            if (potential >= this.nStates) continue;
            while (posPotential == potential) {
                features += posVal;
                if (posCurrent >= posStop) break;
                posVal = (double)this.val[posCurrent] * lambda[this.id[posCurrent]];
                posPotential = this.potentialIx[posCurrent];
                ++posCurrent;
            }
            this.alpha[potential] = this.exp(features);
        }
        Assert.a(constCurrent == this.nConstantFeatures);
        Assert.a(posCurrent == posStop);
    }

    private double normalizePotential(double[] vec) {
        double norm = 0.0;
        int len = vec.length;
        for (int i = 0; i < len; ++i) {
            norm += vec[i];
        }
        double mult = 1.0 / norm;
        int i = 0;
        while (i < len) {
            int n = i++;
            vec[n] = vec[n] * mult;
        }
        return norm;
    }

    private void quickBetaUpdate(double[] lastBeta, double[] newBeta) {
        Arrays.fill(newBeta, 0.0);
        double nodeVal = 0.0;
        for (short potential : this.orderedPotentials) {
            short from;
            if (potential < this.nStates) {
                nodeVal = lastBeta[potential];
                continue;
            }
            int trans = potential - this.nStates;
            short s = from = this.transitionFrom[trans];
            newBeta[s] = newBeta[s] + this.mi[trans] * nodeVal;
        }
    }

    private void quickAlphaUpdate(double[] lastAlpha, double[] newAlpha) {
        double nodeVal = 0.0;
        int lastState = -1;
        for (int n : this.orderedPotentials) {
            if (n < this.nStates) {
                if (lastState != -1) {
                    newAlpha[lastState] = nodeVal;
                }
                lastState = n;
                nodeVal = 0.0;
                continue;
            }
            int trans = n - this.nStates;
            short from = this.transitionFrom[trans];
            nodeVal += lastAlpha[from] * this.mi[trans];
        }
        newAlpha[lastState] = nodeVal;
    }

    public void setCacheProcessor(CacheProcessor cacheProcessor) {
        throw new UnsupportedOperationException("This objective function predates cacheProcessors");
    }
}

