/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.check;

import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.check.ArrayFeatureList;
import calhoun.analysis.crf.solver.check.FeatureCalculator;
import calhoun.analysis.crf.solver.check.TransitionInfo;
import calhoun.util.ColtUtil;
import calhoun.util.Util;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Blas;
import cern.colt.matrix.linalg.SeqBlas;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class BasicCRFGradient
implements CRFObjectiveFunctionGradient {
    private static final Log log = LogFactory.getLog(BasicCRFGradient.class);
    boolean debug = log.isDebugEnabled();
    List<? extends TrainingSequence<?>> data;
    ModelManager fm;
    int nFeatures;
    int nStates;
    double[] beta = null;
    Blas blas = SeqBlas.seqBlas;
    int iter = 0;
    DoubleMatrix2D mi;
    DoubleMatrix1D ri;
    DoubleMatrix1D temp;
    DoubleMatrix1D prevAlpha;
    DoubleMatrix1D alpha;
    DoubleMatrix1D[] betas;
    DoubleMatrix1D expects;
    boolean allPaths = false;

    @Override
    public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
        this.fm = fm;
        this.data = data;
        this.nFeatures = fm.getNumFeatures();
        this.nStates = fm.getNumStates();
        this.expects = new DenseDoubleMatrix1D(this.nFeatures);
        this.mi = new DenseDoubleMatrix2D(this.nStates, this.nStates);
        this.prevAlpha = new DenseDoubleMatrix1D(this.nStates);
        this.alpha = new DenseDoubleMatrix1D(this.nStates);
        this.betas = new DenseDoubleMatrix1D[0];
    }

    @Override
    public void clean() {
    }

    @Override
    public double apply(double[] param, double[] grad) {
        TransitionInfo t = new TransitionInfo(this.fm, this.allPaths);
        FeatureCalculator calc = new FeatureCalculator(this.fm, param, t);
        Arrays.fill(grad, 0.0);
        int seqIx = 0;
        double result = 0.0;
        for (TrainingSequence<?> seq : this.data) {
            int i;
            int len = seq.length();
            this.expects.assign(0.0);
            if (this.betas.length < len) {
                this.betas = new DenseDoubleMatrix1D[len];
                for (i = 0; i < this.betas.length; ++i) {
                    this.betas[i] = new DenseDoubleMatrix1D(this.nStates);
                }
            }
            this.betas[len - 1].assign(1.0);
            for (i = len - 1; i > 0; --i) {
                this.mi.assign(0.0);
                calc.computeMi(seq, i, this.mi, null);
                this.mi.assign(ColtUtil.exp);
                this.mi.zMult(this.betas[i], this.betas[i - 1], 1.0, 0.0, false);
            }
            calc.resetFeatureSums();
            for (int pos = 0; pos < len; ++pos) {
                if (pos == 0) {
                    for (int state = 0; state < this.nStates; ++state) {
                        double val = calc.calcNodeValue(seq, pos, state);
                        this.alpha.setQuick(state, val);
                    }
                    this.alpha.assign(ColtUtil.exp);
                } else {
                    this.mi.assign(0.0);
                    calc.computeMi(seq, pos, this.mi, this.alpha);
                    this.mi.assign(ColtUtil.exp);
                    this.mi.zMult(this.prevAlpha, this.alpha, 1.0, 0.0, true);
                }
                ArrayFeatureList results = new ArrayFeatureList(this.fm);
                for (int state = 0; state < this.nStates; ++state) {
                    results.evaluateNode(seq, pos, state);
                    double mult = this.alpha.getQuick(state) * this.betas[pos].getQuick(state);
                    results.updateExpectations(this.expects, mult);
                    if (pos <= 0) continue;
                    for (int prevState = 0; prevState < this.nStates; ++prevState) {
                        if (!calc.isValidTransition(prevState, state)) continue;
                        results.evaluateEdge(seq, pos, prevState, state);
                        mult = this.prevAlpha.getQuick(prevState) * this.mi.getQuick(prevState, state) * this.betas[pos].getQuick(state);
                        results.updateExpectations(this.expects, mult);
                    }
                }
                DoubleMatrix1D swap = this.prevAlpha;
                this.prevAlpha = this.alpha;
                this.alpha = swap;
            }
            double z = this.prevAlpha.zSum();
            double seqLogPotential = calc.getWeightedFeatureSum();
            double[] featureSums = calc.getFeatureSums();
            for (int i2 = 0; i2 < this.nFeatures; ++i2) {
                int n = i2;
                grad[n] = grad[n] + (featureSums[i2] - this.expects.getQuick(i2) / z);
            }
            result += seqLogPotential - Math.log(z);
            if (this.debug) {
                log.debug((Object)String.format("Seq: %d L=%e, LL=%f, Feats=%s, Sum=%e (log=%f), Z=%e (log=%f) Expects=%s, Grad=%s", seqIx, Math.exp(seqLogPotential - Math.log(z)), seqLogPotential - Math.log(z), StringUtils.join(Util.convertDoubleArray(featureSums).iterator(), (char)','), Math.exp(seqLogPotential), seqLogPotential, z, Math.log(z), ColtUtil.format(this.expects.toArray()), ColtUtil.format(grad)));
            }
            ++seqIx;
        }
        if (log.isInfoEnabled()) {
            log.info((Object)String.format("It: %d L=%e, LL=%f, norm(grad): %f Weights: %s Grad: %s", this.iter, Math.exp(result), result, ColtUtil.norm(grad), ColtUtil.format(param), ColtUtil.format(grad)));
        }
        ++this.iter;
        int totalPositions = 0;
        for (TrainingSequence<?> i : this.data) {
            totalPositions += i.length();
        }
        result /= (double)totalPositions;
        for (int i = 0; i < grad.length; ++i) {
            grad[i] = grad[i] / (double)totalPositions;
        }
        return result;
    }

    public boolean isAllPaths() {
        return this.allPaths;
    }

    public void setAllPaths(boolean allPaths) {
        this.allPaths = allPaths;
    }
}

