/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.check;

import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.check.ArrayFeatureList;
import calhoun.analysis.crf.solver.check.FeatureCalculator;
import calhoun.analysis.crf.solver.check.TransitionInfo;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import calhoun.util.Util;
import cern.colt.function.DoubleFunction;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Blas;
import cern.colt.matrix.linalg.SeqBlas;
import cern.jet.math.Mult;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class NormalizedCRFGradient
implements CRFObjectiveFunctionGradient {
    private static final Log log = LogFactory.getLog(NormalizedCRFGradient.class);
    boolean debug = log.isDebugEnabled();
    List<? extends TrainingSequence<?>> data;
    ModelManager fm;
    int nFeatures;
    int nStates;
    double[] beta = null;
    Blas blas = SeqBlas.seqBlas;
    int iter = 0;
    DoubleMatrix2D mi;
    DoubleMatrix1D ri;
    DoubleMatrix1D temp;
    DoubleMatrix1D prevAlpha;
    DoubleMatrix1D alpha;
    DoubleMatrix1D[] betas;
    double[] betaNorm;
    DoubleMatrix1D expects;
    Mult normalizer = Mult.mult((double)0.0);
    boolean allPaths = false;

    @Override
    public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
        this.fm = fm;
        this.data = data;
        this.nFeatures = fm.getNumFeatures();
        this.nStates = fm.getNumStates();
        this.expects = new DenseDoubleMatrix1D(this.nFeatures);
        this.mi = new DenseDoubleMatrix2D(this.nStates, this.nStates);
        this.prevAlpha = new DenseDoubleMatrix1D(this.nStates);
        this.alpha = new DenseDoubleMatrix1D(this.nStates);
        this.betas = new DenseDoubleMatrix1D[0];
        this.betaNorm = new double[0];
    }

    double normalizePotential(DoubleMatrix1D vec) {
        double norm = vec.zSum();
        this.normalizer.multiplicator = 1.0 / norm;
        vec.assign((DoubleFunction)this.normalizer);
        return norm;
    }

    @Override
    public void clean() {
    }

    @Override
    public double apply(double[] param, double[] grad) {
        TransitionInfo t = new TransitionInfo(this.fm, this.allPaths);
        FeatureCalculator calc = new FeatureCalculator(this.fm, param, t);
        Arrays.fill(grad, 0.0);
        int seqIx = 0;
        double result = 0.0;
        double[] totalFeatureSums = new double[this.nFeatures];
        double[] totalExpects = new double[this.nFeatures];
        for (TrainingSequence<?> seq : this.data) {
            int i;
            int len = seq.length();
            if (this.betas.length < len) {
                this.betas = new DenseDoubleMatrix1D[len];
                this.betaNorm = new double[len];
                for (i = 0; i < this.betas.length; ++i) {
                    this.betas[i] = new DenseDoubleMatrix1D(this.nStates);
                }
            }
            this.betas[len - 1].assign(1.0);
            this.betaNorm[len - 1] = 0.0;
            for (i = len - 1; i > 0; --i) {
                this.mi.assign(0.0);
                calc.computeMi(seq, i, this.mi, null);
                this.mi.assign(ColtUtil.exp);
                this.mi.zMult(this.betas[i], this.betas[i - 1], 1.0, 0.0, false);
                double n = this.normalizePotential(this.betas[i - 1]);
                this.betaNorm[i - 1] = this.betaNorm[i] + Math.log(n);
            }
            calc.resetFeatureSums();
            this.expects.assign(0.0);
            double logZ = Double.NEGATIVE_INFINITY;
            double alphaNorm = 0.0;
            double prevAlphaNorm = 0.0;
            for (int pos = 0; pos < len; ++pos) {
                if (pos == 0) {
                    for (int state = 0; state < this.nStates; ++state) {
                        double val = calc.calcNodeValue(seq, pos, state);
                        this.alpha.setQuick(state, val);
                    }
                    this.alpha.assign(ColtUtil.exp);
                    alphaNorm = Math.log(this.normalizePotential(this.alpha));
                    logZ = Math.log(this.alpha.zDotProduct(this.betas[0])) + this.betaNorm[0] + alphaNorm;
                } else {
                    this.mi.assign(0.0);
                    calc.computeMi(seq, pos, this.mi, this.alpha);
                    this.mi.assign(ColtUtil.exp);
                    this.mi.zMult(this.prevAlpha, this.alpha, 1.0, 0.0, true);
                    alphaNorm = prevAlphaNorm + Math.log(this.normalizePotential(this.alpha));
                    double newZ = Math.log(this.alpha.zDotProduct(this.betas[pos])) + this.betaNorm[pos] + alphaNorm;
                    Assert.a(Math.abs(newZ - logZ) < 1.0E-7 * Math.abs(logZ), "New Z:", newZ, " Old was: ", logZ);
                }
                ArrayFeatureList results = new ArrayFeatureList(this.fm);
                double nodeNorm = Math.exp(alphaNorm + this.betaNorm[pos] - logZ);
                double edgeNorm = Math.exp(prevAlphaNorm + this.betaNorm[pos] - logZ);
                for (int state = 0; state < this.nStates; ++state) {
                    results.evaluateNode(seq, pos, state);
                    double mult = this.alpha.getQuick(state) * this.betas[pos].getQuick(state) * nodeNorm;
                    results.updateExpectations(this.expects, mult);
                    if (pos <= 0) continue;
                    for (int prevState = 0; prevState < this.nStates; ++prevState) {
                        if (!calc.isValidTransition(prevState, state)) continue;
                        results.evaluateEdge(seq, pos, prevState, state);
                        mult = this.prevAlpha.getQuick(prevState) * this.mi.getQuick(prevState, state) * this.betas[pos].getQuick(state) * edgeNorm;
                        results.updateExpectations(this.expects, mult);
                    }
                }
                if (!(!this.debug || seqIx >= 2 && seqIx != this.data.size() - 1 || pos >= 2 && pos < len - 2)) {
                    log.debug((Object)String.format("Pos: %d expects: %s alphas: %s betas: %s", pos, ColtUtil.format(this.expects.toArray()), ColtUtil.format(this.alpha.toArray()), ColtUtil.format(this.betas[pos].toArray())));
                }
                DoubleMatrix1D swap = this.prevAlpha;
                this.prevAlpha = this.alpha;
                this.alpha = swap;
                prevAlphaNorm = alphaNorm;
            }
            double seqLogLikelihood = calc.getWeightedFeatureSum();
            double[] featureSums = calc.getFeatureSums();
            for (int i2 = 0; i2 < this.nFeatures; ++i2) {
                int n = i2;
                grad[n] = grad[n] + (featureSums[i2] - this.expects.getQuick(i2));
                int n2 = i2;
                totalFeatureSums[n2] = totalFeatureSums[n2] + featureSums[i2];
                int n3 = i2;
                totalExpects[n3] = totalExpects[n3] + this.expects.getQuick(i2);
            }
            result += seqLogLikelihood - logZ;
            if (this.debug) {
                log.debug((Object)String.format("Seq: %d L=%e, LL=%f, Feats=%s, Sum=%e (log=%f), Z=%e (log=%f) Expects=%s, Grad=%s", seqIx, Math.exp(seqLogLikelihood - logZ), seqLogLikelihood - logZ, StringUtils.join(Util.convertDoubleArray(featureSums).iterator(), (char)','), Math.exp(seqLogLikelihood), seqLogLikelihood, Math.exp(logZ), logZ, ColtUtil.format(this.expects.toArray()), ColtUtil.format(grad)));
            }
            ++seqIx;
        }
        if (log.isInfoEnabled()) {
            log.info((Object)String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad: %s", this.iter, Math.exp(result), result, ColtUtil.norm(grad), ColtUtil.format(totalFeatureSums), ColtUtil.format(totalExpects), ColtUtil.format(param), ColtUtil.format(grad)));
        }
        ++this.iter;
        int totalPositions = 0;
        for (TrainingSequence<?> i : this.data) {
            totalPositions += i.length();
        }
        result /= (double)totalPositions;
        for (int i = 0; i < grad.length; ++i) {
            grad[i] = grad[i] / (double)totalPositions;
        }
        return result;
    }

    public void setCacheProcessor(CacheProcessor cacheProcessor) {
        throw new UnsupportedOperationException("This objective function predates cacheProcessors");
    }

    public boolean isAllPaths() {
        return this.allPaths;
    }

    public void setAllPaths(boolean allPaths) {
        this.allPaths = allPaths;
    }
}

