/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.solver.semimarkov;

import calhoun.analysis.crf.CRFObjectiveFunctionGradient;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.solver.CacheProcessor;
import calhoun.analysis.crf.solver.LogFiles;
import calhoun.analysis.crf.solver.LookbackBuffer;
import calhoun.analysis.crf.solver.RecyclingBuffer;
import calhoun.analysis.crf.solver.semimarkov.AlphaLengthFeatureProcessor;
import calhoun.analysis.crf.solver.semimarkov.BetaLengthFeatureProcessor;
import calhoun.util.Assert;
import calhoun.util.ColtUtil;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class CleanMaximumLikelihoodSemiMarkovGradient
implements CRFObjectiveFunctionGradient {
    static final Log log = LogFactory.getLog(CleanMaximumLikelihoodSemiMarkovGradient.class);
    public static final boolean debug = log.isDebugEnabled();
    public static final double ASSERTION_TOLERANCE = 1.0E-4;
    public static final int NORM_FACTOR = 50;
    public static final double NORM_MIN = Math.exp(-50.0);
    public static final double NORM_MAX = Math.exp(50.0);
    final LogFiles logs = new LogFiles();
    CacheProcessor.SolverSetup modelInfo;
    CacheProcessor cacheProcessor;
    CacheProcessor.FeatureEvaluation[] evals;
    CacheProcessor.LengthFeatureEvaluation[][] lengthEvals;
    boolean[] invalidTransitions;
    short maxLookback;
    CacheProcessor.StatePotentials[] statesWithLookback;
    CacheProcessor.StatePotentials[] statesWithoutLookback;
    int iter = 0;
    double[][] alphas;
    int[] alphaNorms;
    double[] starterAlpha;
    int nSemiMarkovStates;
    RecyclingBuffer<LookbackBuffer> lookbackBuffer;
    LookbackBuffer nextBuffer;
    double[] lambda;
    double logZ;
    int zNorm;
    double zInv;
    double[] expects;
    AlphaLengthFeatureProcessor alphaProcessor;
    BetaLengthFeatureProcessor betaProcessor;
    private double[] featureSums;

    @Override
    public void setTrainingData(ModelManager fm, List<? extends TrainingSequence<?>> data) {
        this.cacheProcessor.setTrainingData(fm, data);
        this.modelInfo = this.cacheProcessor.getSolverSetup();
        Assert.a(this.modelInfo.maxStateLengths != null, "Maximum state lengths not set.");
        Assert.a(this.modelInfo.maxStateLengths.length == this.modelInfo.nStates, "Maximum state lengths array was length (" + this.modelInfo.maxStateLengths.length + ").  Must have one entry for each state " + this.modelInfo.nStates + ")");
        this.evals = this.cacheProcessor.getFeatureEvaluations();
        this.lengthEvals = this.cacheProcessor.getLengthFeatureEvaluations();
        this.invalidTransitions = this.cacheProcessor.getInvalidTransitions();
        this.maxLookback = this.modelInfo.maxLookback;
        this.statesWithLookback = this.modelInfo.statesWithLookback;
        this.statesWithoutLookback = this.modelInfo.statesWithoutLookback;
        this.nSemiMarkovStates = this.modelInfo.statesWithLookback.length;
        this.alphas = new double[this.modelInfo.longestSeq][this.modelInfo.nStates];
        this.alphaNorms = new int[this.modelInfo.longestSeq];
        this.expects = new double[this.modelInfo.nFeatures];
        LookbackBuffer[] bufferContents = new LookbackBuffer[this.maxLookback + 3];
        for (int i = 0; i < this.maxLookback + 3; ++i) {
            bufferContents[i] = new LookbackBuffer(this.modelInfo.nStates, this.modelInfo.nTransitions);
        }
        this.lookbackBuffer = new RecyclingBuffer<LookbackBuffer>(bufferContents);
        this.nextBuffer = new LookbackBuffer(this.modelInfo.nStates, this.modelInfo.nTransitions);
        this.alphaProcessor = new AlphaLengthFeatureProcessor(this);
        this.betaProcessor = new BetaLengthFeatureProcessor(this);
        this.starterAlpha = new double[this.modelInfo.nStates];
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public double apply(double[] param, double[] grad) {
        log.debug((Object)String.format("Beginning It: %d Weights: %s", this.iter, ColtUtil.format(param)));
        this.logs.open();
        this.lambda = param;
        Arrays.fill(grad, 0.0);
        double totalZ = 0.0;
        double result = 0.0;
        try {
            Arrays.fill(this.expects, 0.0);
            for (int i = 0; i < this.modelInfo.nSeqs; ++i) {
                int len = this.modelInfo.seqOffsets[i + 1] - this.modelInfo.seqOffsets[i];
                this.alphaAndBetaPass(i, len);
                totalZ += this.logZ;
            }
            double[] featureSums = this.cacheProcessor.getFeatureSums();
            this.featureSums = featureSums;
            for (int j = 0; j < this.modelInfo.nFeatures; ++j) {
                result += featureSums[j] * param[j];
                grad[j] = featureSums[j] - this.expects[j];
            }
            log.debug((Object)("Path Value: " + result + " Norm: " + totalZ));
            result -= totalZ;
            if (log.isInfoEnabled()) {
                log.info((Object)String.format("It: %d L=%e, LL=%f, norm(grad): %f Sums: %s Expects: %s Weights: %s Grad: %s", this.iter, CleanMaximumLikelihoodSemiMarkovGradient.exp(result), result, ColtUtil.norm(grad), ColtUtil.format(featureSums), ColtUtil.format(this.expects), ColtUtil.format(param), ColtUtil.format(grad)));
            }
            Assert.a(CleanMaximumLikelihoodSemiMarkovGradient.exp(result) <= 1.0, "Likelihood is greater than 1.");
            result /= (double)this.modelInfo.totalPositions;
            for (int i = 0; i < grad.length; ++i) {
                grad[i] = grad[i] / (double)this.modelInfo.totalPositions;
            }
            ++this.iter;
        }
        finally {
            this.logs.close();
        }
        return result;
    }

    @Override
    public void clean() {
    }

    void alphaAndBetaPass(int i, int len) {
        this.alphaProcessor.computeAlpha(i, len);
        double sum = 0.0;
        for (double val : this.alphas[len - 1]) {
            sum += val;
        }
        this.logZ = CleanMaximumLikelihoodSemiMarkovGradient.log(sum) + (double)(50 * this.alphaNorms[len - 1]);
        this.zNorm = (int)this.logZ / 50;
        this.zInv = CleanMaximumLikelihoodSemiMarkovGradient.exp((double)(this.zNorm * 50) - this.logZ);
        log.debug((Object)("Seq: " + i + " Z: " + CleanMaximumLikelihoodSemiMarkovGradient.printNorm(1.0 / this.zInv, this.zNorm)));
        this.betaProcessor.computeBetasAndExpectations(i, len);
    }

    void logBuf() {
        int l = this.lookbackBuffer.length;
        String s = "";
        for (int i = 0; i < l; ++i) {
            s = s + this.lookbackBuffer.get((int)i).pos + " ";
        }
        log.info((Object)s);
    }

    void logBufBeta() {
        int l = this.lookbackBuffer.length;
        String s = "";
        for (int i = 0; i < l; ++i) {
            s = s + ColtUtil.format(this.lookbackBuffer.get((int)i).beta) + " ";
        }
        log.info((Object)s);
    }

    void cacheMi(int seqNum, double[] mi, double[] prevStable, double[] newStable, int miPos) {
        if (miPos < 0) {
            return;
        }
        this.calcMi(mi, seqNum, miPos, false);
        for (int i = 0; i < this.modelInfo.nStates; ++i) {
            if (this.modelInfo.maxStateLengths[i] <= 1) continue;
            newStable[i] = prevStable[i];
            double trans = mi[this.modelInfo.selfTransitions[i]];
            if (Double.isInfinite(trans)) continue;
            int n = i;
            newStable[n] = newStable[n] + trans;
        }
    }

    void calcMi(double[] mi, int seq, int pos, boolean doExp) {
        this.cacheProcessor.evaluatePosition(seq, pos);
        double nodeVal = Double.NaN;
        int overallPosition = this.modelInfo.seqOffsets[seq] + pos;
        int invalidIndex = overallPosition * this.modelInfo.nPotentials;
        for (short potential : this.modelInfo.orderedPotentials) {
            boolean invalid = this.invalidTransitions[invalidIndex + potential];
            double features = invalid ? Double.NEGATIVE_INFINITY : 0.0;
            CacheProcessor.FeatureEvaluation potEvals = this.evals[potential];
            short[] indices = potEvals.index;
            float[] vals = potEvals.value;
            int i = 0;
            short index = indices[i];
            while (index >= 0) {
                features += (double)vals[i] * this.lambda[index];
                index = indices[++i];
            }
            if (index == Short.MIN_VALUE) {
                features = Double.NEGATIVE_INFINITY;
            }
            if (potential < this.modelInfo.nStates) {
                nodeVal = features;
                continue;
            }
            int transition = potential - this.modelInfo.nStates;
            double val = features + nodeVal;
            if (doExp) {
                val = CleanMaximumLikelihoodSemiMarkovGradient.exp(val);
            }
            mi[transition] = val;
        }
    }

    static final void renormalize(double[] vec, int currentNorm, int newNorm) {
        double factor = CleanMaximumLikelihoodSemiMarkovGradient.exp(50 * (currentNorm - newNorm));
        int len = vec.length;
        for (int i = 0; i < len; ++i) {
            if (vec[i] == 0.0) continue;
            int n = i;
            vec[n] = vec[n] * factor;
        }
    }

    static final int normalize(double[] vec) {
        double sum = 0.0;
        for (double val : vec) {
            sum += val;
        }
        if (sum == 0.0 || sum > NORM_MIN && sum < NORM_MAX) {
            return 0;
        }
        if (debug) {
            Assert.a(!Double.isNaN(sum));
        }
        double val = CleanMaximumLikelihoodSemiMarkovGradient.log(sum);
        int norm = (int)val / 50;
        val = CleanMaximumLikelihoodSemiMarkovGradient.exp(50 * norm);
        int len = vec.length;
        int i = 0;
        while (i < len) {
            int n = i++;
            vec[n] = vec[n] / val;
        }
        return norm;
    }

    static final double exp(double val) {
        return Math.exp(val);
    }

    static final double log(double val) {
        return Math.log(val);
    }

    final void logFeatureSums(int seqNum) {
        double[][] seqFeatureSums = this.cacheProcessor.getSequenceFeatureSums();
        if (seqFeatureSums != null) {
            double seqResult = 0.0;
            for (int j = 0; j < this.modelInfo.nFeatures; ++j) {
                seqResult += seqFeatureSums[seqNum][j] * this.lambda[j];
            }
            log.debug((Object)String.format("Seq: %d L: %g LL: %f Training path: %f Z: %f", seqNum, CleanMaximumLikelihoodSemiMarkovGradient.exp(seqResult - this.logZ), seqResult - this.logZ, seqResult, this.logZ));
            Assert.a(CleanMaximumLikelihoodSemiMarkovGradient.exp(seqResult - this.logZ) < 1.0);
        }
    }

    public static final String printNorm(double value, int norm) {
        if (value == 0.0) {
            return "0 (" + norm + ")";
        }
        if (Double.isNaN(value)) {
            return "NaN (" + norm + ")";
        }
        int exponent = (int)CleanMaximumLikelihoodSemiMarkovGradient.log(value);
        double eValue = value / CleanMaximumLikelihoodSemiMarkovGradient.exp(exponent);
        if (Double.isNaN(eValue)) {
            return String.format("NaN(%e n:%d)", value, norm);
        }
        return String.format("%fe%d", eValue, exponent + norm * 50);
    }

    public CacheProcessor getCacheProcessor() {
        return this.cacheProcessor;
    }

    public void setCacheProcessor(CacheProcessor cacheProcessor) {
        this.cacheProcessor = cacheProcessor;
    }

    public String getAlphaLengthFile() {
        return this.logs.alphaLengthFile;
    }

    public void setAlphaLengthFile(String alphaLengthFile) {
        this.logs.alphaLengthFile = alphaLengthFile;
    }

    public String getAlphaFile() {
        return this.logs.alphaFile;
    }

    public void setAlphaFile(String alphaFile) {
        this.logs.alphaFile = alphaFile;
    }

    public String getExpectFile() {
        return this.logs.expectFile;
    }

    public void setExpectFile(String expectFile) {
        this.logs.expectFile = expectFile;
    }

    public String getExpectLengthFile() {
        return this.logs.expectLengthFile;
    }

    public void setExpectLengthFile(String expectLengthFile) {
        this.logs.expectLengthFile = expectLengthFile;
    }

    public String getNodeMarginalFile() {
        return this.logs.nodeMarginalFile;
    }

    public void setNodeMarginalFile(String nodeMarginalFile) {
        this.logs.nodeMarginalFile = nodeMarginalFile;
    }

    public String getBetaLengthFile() {
        return this.logs.betaLengthFile;
    }

    public void setBetaLengthFile(String betaLengthFile) {
        this.logs.betaLengthFile = betaLengthFile;
    }
}

