/*
 * Decompiled with CFR 0.152.
 */
package calhoun.analysis.crf;

import calhoun.analysis.crf.CRFInference;
import calhoun.analysis.crf.CRFTraining;
import calhoun.analysis.crf.ModelManager;
import calhoun.analysis.crf.io.CompositeInput;
import calhoun.analysis.crf.io.InputHandler;
import calhoun.analysis.crf.io.InputSequence;
import calhoun.analysis.crf.io.OutputHandler;
import calhoun.analysis.crf.io.OutputHandlerGeneCallPredict;
import calhoun.analysis.crf.io.OutputHandlerGeneCallStats;
import calhoun.analysis.crf.io.TrainingSequence;
import calhoun.util.ColtUtil;
import calhoun.util.ErrorException;
import calhoun.util.FileUtil;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.xml.XmlBeanDefinitionReader;
import org.springframework.context.ApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.io.Resource;

public class Conrad
implements Serializable {
    private static final long serialVersionUID = -5964610632818921236L;
    private static final Log log = LogFactory.getLog(Conrad.class);
    byte[] configXml;
    ModelManager model;
    InputHandler inputHandler;
    OutputHandler outputHandler;
    transient CRFInference inference;
    transient CRFTraining optimizer;
    double[] weights = null;
    double trainingTime = 0.0;

    public static void main(String[] args) throws Exception {
        if (args.length != 4) {
            Conrad.usage();
        }
        if (args[0].startsWith("train")) {
            Conrad crf;
            if (args[0].equals("trainWeights")) {
                crf = Conrad.read(args[1]);
                crf.getOptimizer().setStarts(new double[]{0.563021, -1.299189, 1.156332, 0.3335547, 0.4300941, 0.3290597, 0.5924689, 0.3461513, 0.4968438, 0.6110913, 1.095909, 0.8485993, 0.6130822, 0.7956561, 0.6113225, 0.926442, 0.7921441, 0.1052026, 0.2956273, 0.2898136, 0.2074534, -0.01352267, 1.385965, 1.0, 1.922028, 1.0, -0.307992, 1.006709, 1.212544, 0.7807475, 0.7693724, -0.5903666, 0.05418292, -0.2186725, 0.02534417});
                crf.trainWeights(args[2]);
            } else {
                crf = new Conrad(args[1]);
                if (args[0].equals("trainFeatures")) {
                    crf.trainFeatures(args[2]);
                } else {
                    crf.train(args[2]);
                }
            }
            crf.write(args[3]);
        } else if (args[0].equals("test")) {
            Conrad crf = Conrad.read(args[1]);
            log.warn((Object)("Weights:" + ColtUtil.format(crf.getWeights())));
            crf.initSolver();
            crf.test(args[2], args[3]);
        } else if (args[0].equals("predict")) {
            Conrad crf = Conrad.read(args[1]);
            OutputHandlerGeneCallPredict predictOutputHandler = new OutputHandlerGeneCallPredict();
            predictOutputHandler.setWriteTrainingData(false);
            predictOutputHandler.setManager(crf.getModel());
            predictOutputHandler.setInputHandler(crf.getInputHandler());
            crf.setOutputHandler(predictOutputHandler);
            crf.initSolver();
            crf.testWithoutAnswers(args[2], args[3]);
        } else {
            Conrad.usage();
        }
    }

    public Conrad() {
    }

    public Conrad(String configFile) {
        try {
            this.configXml = FileUtil.readFileAsBytes(configFile);
        }
        catch (IOException ex) {
            throw new RuntimeException("Failed to load config file: " + configFile, ex);
        }
        ApplicationContext ctx = this.initSolver();
        this.model = (ModelManager)ctx.getBean("model");
        if (ctx.containsBean("inputFormat") && !ctx.containsBean("inputHandler")) {
            this.inputHandler = new CompositeInput.LegacyInputHandler(ctx.getBean("inputFormat"));
            if (ctx.containsBean("outputHandler")) {
                this.outputHandler = (OutputHandler)ctx.getBean("outputHandler");
            } else {
                OutputHandlerGeneCallStats legacyOutputHandler = new OutputHandlerGeneCallStats();
                legacyOutputHandler.setWriteTrainingData(true);
                legacyOutputHandler.setManager(this.model);
                legacyOutputHandler.setInputHandler(this.inputHandler);
                this.outputHandler = legacyOutputHandler;
            }
        } else {
            this.inputHandler = (InputHandler)ctx.getBean("inputHandler");
            this.outputHandler = (OutputHandler)ctx.getBean("outputHandler");
        }
    }

    public void write(String filename) throws IOException {
        FileUtil.writeObject(filename, this);
    }

    public static Conrad read(String filename) throws IOException {
        try {
            Conrad ret = (Conrad)FileUtil.readObject(filename);
            ret.initSolver();
            return ret;
        }
        catch (ClassNotFoundException ex) {
            throw new ErrorException(ex);
        }
    }

    public void train(String location) throws IOException {
        List<? extends TrainingSequence<?>> t = this.inputHandler.readTrainingData(location, false);
        this.train(t);
    }

    public void train(List<? extends TrainingSequence<?>> data) {
        this.trainFeatures(data);
        this.trainWeights(data);
    }

    public void trainFeatures(String location) throws IOException {
        List<? extends TrainingSequence<?>> t = this.inputHandler.readTrainingData(location, false);
        this.trainFeatures(t);
    }

    public void trainFeatures(List<? extends TrainingSequence<?>> data) {
        this.print("Training features");
        double timer = System.currentTimeMillis();
        this.model.train(0, this.model, data);
        if (log.isDebugEnabled()) {
            log.debug((Object)"Features:");
            for (int i = 0; i < this.model.getNumFeatures(); ++i) {
                log.debug((Object)this.model.getFeatureName(i));
            }
        }
        this.trainingTime = ((double)System.currentTimeMillis() - timer) / 1000.0;
        this.print("Trained in " + this.trainingTime + " seconds.");
    }

    public void trainWeights(String location) throws IOException {
        List<? extends TrainingSequence<?>> t = this.inputHandler.readTrainingData(location, false);
        this.trainWeights(t);
    }

    public void trainWeights(List<? extends TrainingSequence<?>> data) {
        this.print("Training weights");
        double timer = System.currentTimeMillis();
        this.weights = this.optimizer.optimize(this.model, data);
        timer = ((double)System.currentTimeMillis() - timer) / 1000.0;
        this.trainingTime += timer;
        this.print("Trained weights in " + timer + " seconds.  " + this.trainingTime + " total.");
    }

    public void test(String inputLocation) throws IOException {
        List<? extends TrainingSequence<?>> t = this.inputHandler.readTrainingData(inputLocation, false);
        this.test(t);
    }

    public void test(String inputLocation, String outputLocation) throws IOException {
        List<? extends TrainingSequence<?>> t = this.inputHandler.readTrainingData(inputLocation, false);
        this.test(t, outputLocation);
    }

    public void test(List<? extends TrainingSequence<?>> data) throws IOException {
        this.test(data, null);
    }

    public void testWithoutAnswers(String inputLocation, String outputLocation) throws IOException {
        List<? extends TrainingSequence<?>> t = this.inputHandler.readTrainingData(inputLocation, true);
        this.test(t, outputLocation);
    }

    public void test(List<? extends TrainingSequence<?>> data, String location) throws IOException {
        this.print("Beginning test");
        this.printWeights();
        this.outputHandler.setOutputLocation(location);
        for (TrainingSequence<?> dr : data) {
            CRFInference.InferenceResult predictedHiddenSequence = this.predict(dr);
            this.outputHandler.writeTestOutput(dr.getInputSequence(), dr.getY(), predictedHiddenSequence.hiddenStates);
        }
        this.print("Testing complete");
        this.outputHandler.outputComplete();
    }

    public CRFInference.InferenceResult predict(InputSequence data) {
        return this.inference.predict(this.model, data, this.weights);
    }

    public void setWeights(double[] weights) {
        this.weights = weights;
    }

    public String getFeatureName(int index) {
        return this.model.getFeatureName(index);
    }

    public double getTrainingTime() {
        return this.trainingTime;
    }

    public int getNumFeatures() {
        return this.model.getNumFeatures();
    }

    public int getNumStates() {
        return this.model.getNumStates();
    }

    public String getStateName(int state) {
        return this.model.getStateName(state);
    }

    public ModelManager getModel() {
        return this.model;
    }

    public CRFTraining getOptimizer() {
        return this.optimizer;
    }

    public double[] getWeights() {
        return this.weights;
    }

    public CRFInference getInference() {
        return this.inference;
    }

    public void setInference(CRFInference inference) {
        this.inference = inference;
    }

    public void setModel(ModelManager model) {
        this.model = model;
    }

    public void setOptimizer(CRFTraining optimizer) {
        this.optimizer = optimizer;
    }

    public String printWeights() {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.weights.length; ++i) {
            sb.append(String.format("%.5f\t%s\n", this.weights[i], this.getFeatureName(i)));
        }
        return sb.toString();
    }

    private void print(String msg) {
        System.out.println(msg);
    }

    private ApplicationContext initSolver() {
        GenericApplicationContext ctx = new GenericApplicationContext();
        XmlBeanDefinitionReader xmlReader = new XmlBeanDefinitionReader((BeanDefinitionRegistry)ctx);
        xmlReader.loadBeanDefinitions((Resource)new ByteArrayResource(this.configXml));
        ctx.refresh();
        this.inference = (CRFInference)ctx.getBean("inference");
        this.optimizer = (CRFTraining)ctx.getBean("optimizer");
        return ctx;
    }

    private static void usage() {
        System.out.println("       Conrad train(Features) configFile data modelFile");
        System.out.println(" or    Conrad trainWeights modelFileIn data modelFileOut");
        System.out.println(" or    Conrad test modelFile inputData outputData");
        System.out.println(" or    Conrad predict modelFile inputData outputData");
        System.exit(-1);
    }

    public InputHandler getInputHandler() {
        return this.inputHandler;
    }

    public void setInputHandler(InputHandler inputHandler) {
        this.inputHandler = inputHandler;
    }

    public OutputHandler getOutputHandler() {
        return this.outputHandler;
    }

    public void setOutputHandler(OutputHandler outputHandler) {
        this.outputHandler = outputHandler;
    }
}

