/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf.test;

import calhoun.analysis.crf.CRFInference;
import calhoun.analysis.crf.Conrad;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.util.AbstractTestCase;
import calhoun.util.Assert;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class SemiMarkovTest
extends AbstractTestCase {
    private static final Log log = LogFactory.getLog(SemiMarkovTest.class);

    public void testSemiMarkovBadTraining() throws Exception {
        String failureMessage = null;
        try {
            Conrad.main(new String[]{"train", "test/input/zeroOrderLBFGSCachedSemiMarkov.xml", "test/input/zeroOrderTest.txt", "test/working/zeroLBGFSModelCachedSemiMarkov.ser"});
        }
        catch (Exception ex) {
            failureMessage = ex.getMessage();
        }
        SemiMarkovTest.assertEquals((String)"Seq #0 Pos 150 Training segment 127 is longer than allowed length 20", (String)failureMessage);
    }

    public void testSemiCRFViterbiCompareWithBaseClass() throws Exception {
        Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengths.xml");
        r.trainFeatures("test/input/zeroOrderTrivial.txt");
        r.setWeights(new double[]{1.0, 1.0, 1.0});
        Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
        s.trainFeatures("test/input/zeroOrderTrivial.txt");
        s.setWeights(new double[]{1.0, 1.0, 1.0});
        this.doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
        this.doViterbiComparison("test/input/zeroOrderTest.txt", r, s);
    }

    public void testSemiCRFViterbiCompareWithBaseClassWithLengths() throws Exception {
        Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthFeatures.xml");
        r.trainFeatures("test/input/zeroOrderTrivial.txt");
        r.setWeights(new double[]{1.0, 1.0, 1.0});
        Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
        s.trainFeatures("test/input/zeroOrderTrivial.txt");
        s.setWeights(new double[]{1.0, 1.0, 1.0});
        this.doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
    }

    public void testSemiCRFViterbiCompareWithBaseClassWithFeatures() throws Exception {
        Conrad r = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
        r.trainFeatures("test/input/zeroOrderTrivial.txt");
        r.setWeights(new double[]{1.0, 1.0, 1.0});
        Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
        s.trainFeatures("test/input/zeroOrderTrivial.txt");
        s.setWeights(new double[]{1.0, 1.0, 1.0});
        this.doViterbiComparison("test/input/zeroOrderTrivial.txt", r, s);
    }

    public void testSemiCRFCompareWithBaseClass() throws Exception {
        Conrad s = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthsUseBaseClass.xml");
        s.train("test/input/zeroOrderTrivial.txt");
        Conrad r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengths.xml");
        r.train("test/input/zeroOrderTrivial.txt");
        SemiMarkovTest.assertEquals((double)r.getWeights()[0], (double)s.getWeights()[0], (double)1.0E-4);
        r = new Conrad("test/input/semiMarkovTestModelNoExplicitLengthFeatures.xml");
        r.train("test/input/zeroOrderTrivial.txt");
        SemiMarkovTest.assertEquals((double)r.getWeights()[0], (double)s.getWeights()[0], (double)1.0E-4);
        r = new Conrad("test/input/semiMarkovTestModelHalfAndHalf.xml");
        r.train("test/input/zeroOrderTrivial.txt");
        SemiMarkovTest.assertEquals((double)r.getWeights()[0], (double)s.getWeights()[0], (double)0.001);
    }

    void doViterbiComparison(String file, Conrad a, Conrad b) throws Exception {
        List<TrainingSequence<?>> train = a.getInputHandler().readTrainingData(file);
        for (TrainingSequence<?> seq : train) {
            CRFInference.InferenceResult r1 = a.predict(seq.getInputSequence());
            CRFInference.InferenceResult r2 = b.predict(seq.getInputSequence());
            double[] M = r1.bestScores;
            double[] N = r2.bestScores;
            Assert.a(M.length == a.getNumStates());
            for (int r = 0; r < M.length; ++r) {
                SemiMarkovTest.assertEquals((String)("State: " + r + " "), (double)N[r], (double)M[r], (double)1.0E-4);
            }
        }
    }

    public void testColt() {
        DenseDoubleMatrix2D f = new DenseDoubleMatrix2D(2, 2);
        DenseDoubleMatrix2D g = new DenseDoubleMatrix2D(2, 2);
        f.setQuick(1, 1, 3.0);
        g.setQuick(1, 1, 3.0);
        SemiMarkovTest.assertEquals((Object)f, (Object)g);
    }
}

