/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.test;

import calhoun.analysis.crf.AbstractFeatureManager;
import calhoun.analysis.crf.Conrad;
import calhoun.analysis.crf.LocalPathSimilarityScore;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.SemiMarkovSetup;
import calhoun.analysis.crf.io.IntInput;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.analysis.crf.scoring.SimScoreMaxStateAgreement;
import calhoun.analysis.crf.solver.CacheProcessorDeluxe;
import calhoun.analysis.crf.solver.MaximumExpectedAccuracySemiMarkovGradient;
import calhoun.analysis.crf.solver.StandardOptimizer;
import calhoun.analysis.crf.solver.check.CachedAOFGradient;
import calhoun.analysis.crf.test.TestFeatureManager;
import calhoun.analysis.crf.test.TestFeatureManager2;
import calhoun.analysis.crf.test.TestFeatureManager3;
import calhoun.util.AbstractTestCase;
import calhoun.util.ColtUtil;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class MaximumExpectedAccuracyTest
extends AbstractTestCase {
    private static final Log log = LogFactory.getLog(MaximumExpectedAccuracyTest.class);

    public void testLittle() throws Exception {
        String config = "test/input/aofreg_test/delta_conservation_aofreg.xml";
        String input = "test/input/aofreg_test/testSeq.txt";
        Conrad conrad = new Conrad(config);
        conrad.train(input);
        double[] regularWeights = conrad.getWeights();
        conrad.test(input);
        config = "test/input/aofreg_test/delta_conservation_aofreg_semi.xml";
        conrad = new Conrad(config);
        conrad.train(input);
        double[] semiMarkovWeights = conrad.getWeights();
        conrad.test(input);
        this.assertArrayEquals(regularWeights, semiMarkovWeights, 0.001);
    }

    public void testGradEvals() throws Exception {
        this.doFuncEvalsSkewed(0, IntInput.prepareData("0\n0"));
        this.doFuncEvalsSkewed(0, IntInput.prepareData("00\n00"));
        this.doFuncEvalsSkewed(0, IntInput.prepareData("000\n000"));
        this.doFuncEvalsSkewed(1, IntInput.prepareData("1\n1"));
        this.doFuncEvalsSkewed(3, IntInput.prepareData("11\n11"));
        this.doFuncEvalsSkewed(1, IntInput.prepareData("10\n10"));
        this.doFuncEvalsSkewed(1, IntInput.prepareData("111\n111"));
        this.doFuncEvalsSkewed(3, IntInput.prepareData("101\n101"));
        this.doFuncEvalsSkewed(3, IntInput.prepareData("011\n011"));
        this.doFuncEvalsSkewed(1, IntInput.prepareData("1111\n1111"));
        this.doFuncEvalsSkewed(1, IntInput.prepareData("001111\n001111"));
        this.doFuncEvalsSkewed(2, IntInput.prepareData("00\n00"));
        this.doFuncEvalsSkewed(2, IntInput.prepareData("00\n00\n00\n00"));
        this.doFuncEvalsSkewed(1, IntInput.prepareData("00\n00\n0\n0\n"));
        this.doFuncEvalsSkewed(1, IntInput.prepareData("0\n0\n1111\n1111\n"));
        this.doFuncEvalsSkewed(2, IntInput.prepareData("00001010100100111000\n00001010100100111000\n00001010100100111001\n00001010100100111001\n"));
    }

    void doFuncEvalsSkewed(int mmNum, List<? extends TrainingSequence<?>> data) throws Exception {
        if (mmNum == 1 || mmNum == 3) {
            this.doFuncEvals(mmNum, true, true, data);
        }
    }

    void doFuncEvals(int mmNum, boolean skewedWeights, boolean fm3, List<? extends TrainingSequence<?>> data) throws Exception {
        double[] dArray;
        if (skewedWeights) {
            double[] dArray2 = new double[3];
            dArray2[0] = 2.0;
            dArray2[1] = 0.5;
            dArray = dArray2;
            dArray2[2] = 1.0;
        } else {
            double[] dArray3 = new double[3];
            dArray3[0] = 1.0;
            dArray3[1] = 1.0;
            dArray = dArray3;
            dArray3[2] = 1.0;
        }
        double[] weights = dArray;
        short[] max = new short[2];
        Arrays.fill(max, (short)1);
        AbstractFeatureManager m = fm3 ? new TestFeatureManager3(mmNum, false) : new TestFeatureManager2(mmNum, false);
        m.train(0, (ModelManager)((Object)m), data);
        MaximumExpectedAccuracySemiMarkovGradient gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
        CacheProcessorDeluxe cacheProcessor = new CacheProcessorDeluxe();
        cacheProcessor.setAllPaths(false);
        cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
        gradFunc.setCacheProcessor(cacheProcessor);
        gradFunc.setTrainingData((ModelManager)((Object)m), data);
        double[] grad1 = new double[m.getNumFeatures()];
        double val1 = gradFunc.apply(weights, grad1);
        log.info((Object)("Grad1: " + ColtUtil.format(grad1)));
        Arrays.fill(max, (short)20);
        m = fm3 ? new TestFeatureManager3(mmNum, false) : new TestFeatureManager2(mmNum, false);
        m.train(0, (ModelManager)((Object)m), data);
        gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
        cacheProcessor = new CacheProcessorDeluxe();
        cacheProcessor.setAllPaths(false);
        cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
        gradFunc.setCacheProcessor(cacheProcessor);
        gradFunc.setTrainingData((ModelManager)((Object)m), data);
        double[] grad2 = new double[m.getNumFeatures()];
        double val2 = gradFunc.apply(weights, grad2);
        log.info((Object)("Grad2: " + ColtUtil.format(grad2)));
        MaximumExpectedAccuracyTest.assertEquals((double)val1, (double)val2, (double)0.001);
        this.assertArrayEquals(grad1, grad2, 0.001);
        m = fm3 ? new TestFeatureManager3(mmNum, true) : new TestFeatureManager2(mmNum, true);
        m.train(0, (ModelManager)((Object)m), data);
        gradFunc = new MaximumExpectedAccuracySemiMarkovGradient();
        cacheProcessor = new CacheProcessorDeluxe();
        cacheProcessor.setAllPaths(false);
        cacheProcessor.setSemiMarkovSetup(new SemiMarkovSetup(max, true));
        gradFunc.setCacheProcessor(cacheProcessor);
        gradFunc.setTrainingData((ModelManager)((Object)m), data);
        double[] grad3 = new double[m.getNumFeatures()];
        double val3 = gradFunc.apply(weights, grad3);
        log.info((Object)("Grad3: " + ColtUtil.format(grad3)));
        MaximumExpectedAccuracyTest.assertEquals((double)val1, (double)val3, (double)0.001);
        this.assertArrayEquals(grad1, grad3, 0.001);
    }

    public void testGeneCallerLocalScore() throws Exception {
        Conrad nodeOnly = new Conrad("test/input/geneCallerLocal/baseline_aof.xml");
        Conrad length = new Conrad("test/input/geneCallerLocal/baseline_aof_length.xml");
        nodeOnly.train("test/input/geneCallerLocal");
        length.train("test/input/geneCallerLocal");
        this.assertArrayEquals(nodeOnly.getWeights(), length.getWeights(), 1.0E-4);
    }

    public void testAlternateObjectiveFunction() throws Exception {
        SimScoreMaxStateAgreement s = new SimScoreMaxStateAgreement();
        TestFeatureManager m2 = new TestFeatureManager(2);
        List<? extends TrainingSequence<?>> data = IntInput.prepareData("00\n00");
        this.doAlternateObjectiveFunctionTest(m2, data, s, 0.1428571, -0.0919481);
        List<? extends TrainingSequence<?>> data2 = IntInput.prepareData("00\n11\n00\n00");
        this.doAlternateObjectiveFunctionTest(m2, data2, s, 0.1428571325, -0.091948);
        TestFeatureManager m0 = new TestFeatureManager(0);
        List<? extends TrainingSequence<?>> data3 = IntInput.prepareData("00\n00");
        double t3 = -0.1540327;
        this.doAlternateObjectiveFunctionTest(m0, data3, s, 0.166666665, t3 / 2.0);
        List<? extends TrainingSequence<?>> data4 = IntInput.prepareData("0000\n0000");
        this.doAlternateObjectiveFunctionTest(m0, data4, s, 0.25, 3.0 * t3 / 4.0);
        TestFeatureManager m1 = new TestFeatureManager(1);
        this.doAlternateObjectiveFunctionTest(m1, data3, s, 0.25, 0.0);
        this.doAlternateObjectiveFunctionTest(m1, data4, s, 0.375, 0.0);
    }

    public void testSemiRealExample() throws Exception {
        String input = "test/input/cryptoAOFUnittest/Tiny_1_Train_Test.txt";
        String config1 = "test/input/cryptoAOFUnittest/delta_aof_Model.xml";
        String config2 = "test/input/cryptoAOFUnittest/delta_aof_Model_semi_nolen.xml";
        String config3 = "test/input/cryptoAOFUnittest/delta_aof_Model_semi.xml";
        this.doSameWeightsTest(config1, config2, config3, input);
    }

    void doSameWeightsTest(String config1, String config2, String config3, String input) throws Exception {
        Conrad conrad = new Conrad(config1);
        conrad.train(input);
        double[] regularWeights = conrad.getWeights();
        conrad.test(input);
        conrad = new Conrad(config2);
        conrad.train(input);
        double[] noLenWeights = conrad.getWeights();
        conrad.test(input);
        this.assertArrayEquals(regularWeights, noLenWeights, 0.001);
        if (config3 != null) {
            conrad = new Conrad(config3);
            conrad.train(input);
            double[] semiMarkovWeights = conrad.getWeights();
            conrad.test(input);
            this.assertArrayEquals(noLenWeights, semiMarkovWeights, 0.001);
        }
    }

    public void testCachedAOFGradient() throws Exception {
        List<? extends TrainingSequence<?>> data = IntInput.prepareData("001111\n001111\n001111\n001111\n001111\n001111\n001111\n001111\n");
        TestFeatureManager m = new TestFeatureManager(2);
        StandardOptimizer opt = new StandardOptimizer();
        opt.setStarts(new double[]{0.1, 0.2});
        opt.setRequireConvergence(true);
        opt.setEpsForConvergence(5.0E-7);
        opt.setObjectiveFunction(new CachedAOFGradient());
        opt.optimize(m, data);
        MaximumExpectedAccuracySemiMarkovGradient semiAof = new MaximumExpectedAccuracySemiMarkovGradient();
        semiAof.setCacheProcessor(new CacheProcessorDeluxe());
        opt.setObjectiveFunction(semiAof);
        opt.optimize(m, data);
    }

    void doAlternateObjectiveFunctionTest(ModelManager m, List<? extends TrainingSequence<?>> data, LocalPathSimilarityScore s, double hand_val, double hand_grad0) {
        double[] weights = new double[m.getNumFeatures()];
        Arrays.fill(weights, 1.0);
        CachedAOFGradient gradFunc = new CachedAOFGradient();
        gradFunc.setAllPaths(true);
        gradFunc.setScoreAlphaFile("scoreAlphaOld.txt");
        gradFunc.setExpectedProductFile("expectedProductOld.txt");
        gradFunc.setLocalPathSimilarityScore(s);
        gradFunc.setTrainingData(m, data);
        double[] grad = new double[m.getNumFeatures()];
        double val = gradFunc.apply(weights, grad);
        log.info((Object)("  val = " + val + "   grad[0] = " + grad[0]));
        log.info((Object)("Grad(Cache,Valid Paths): " + ColtUtil.format(grad)));
        MaximumExpectedAccuracyTest.assertEquals((double)hand_val, (double)val, (double)0.001);
        MaximumExpectedAccuracyTest.assertEquals((double)hand_grad0, (double)grad[0], (double)0.001);
        MaximumExpectedAccuracySemiMarkovGradient semiAof = new MaximumExpectedAccuracySemiMarkovGradient();
        semiAof.setCacheProcessor(new CacheProcessorDeluxe());
        semiAof.setTrainingData(m, data);
        semiAof.setMarginalsFile("marginals.txt");
        semiAof.setScoreAlphaFile("scoreAlpha.txt");
        semiAof.setExpectedProductFile("expectedProduct.txt");
        val = semiAof.apply(weights, grad);
        log.info((Object)("  val = " + val + "   grad[0] = " + grad[0]));
        log.info((Object)("Grad(Cache,Valid Paths): " + ColtUtil.format(grad)));
        MaximumExpectedAccuracyTest.assertEquals((double)hand_val, (double)val, (double)0.001);
        MaximumExpectedAccuracyTest.assertEquals((double)hand_grad0, (double)grad[0], (double)0.001);
    }
}

