/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.grmm.learning.extract;

import cc.mallet.extract.Extraction;
import cc.mallet.extract.TokenizationFilter;
import cc.mallet.grmm.inference.Inferencer;
import cc.mallet.grmm.learning.ACRF;
import cc.mallet.grmm.learning.ACRFEvaluator;
import cc.mallet.grmm.learning.ACRFTrainer;
import cc.mallet.grmm.learning.AcrfSerialEvaluator;
import cc.mallet.grmm.learning.DefaultAcrfTrainer;
import cc.mallet.grmm.learning.extract.ACRFExtractor;
import cc.mallet.grmm.util.PipedIterator;
import cc.mallet.grmm.util.RememberTokenizationPipe;
import cc.mallet.pipe.Pipe;
import cc.mallet.pipe.PipeUtils;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CollectionUtils;
import cc.mallet.util.FileUtils;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Timing;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.logging.Logger;

public class ACRFExtractorTrainer {
    private static final Logger logger = MalletLogger.getLogger(ACRFExtractorTrainer.class.getName());
    private int numIter = 99999;
    protected ACRF.Template[] tmpls;
    protected InstanceList training;
    protected InstanceList testing;
    private Iterator<Instance> testIterator;
    private Iterator<Instance> trainIterator;
    ACRFTrainer trainer = new DefaultAcrfTrainer();
    protected Pipe featurePipe;
    protected Pipe tokPipe;
    protected ACRFEvaluator evaluator = new DefaultAcrfTrainer.LogEvaluator();
    TokenizationFilter filter;
    private Inferencer inferencer;
    private Inferencer viterbiInferencer;
    private int numCheckpointIterations = -1;
    private File checkpointDirectory = null;
    private boolean usePerTemplateTrain = false;
    private int perTemplateIterations = 100;
    private boolean cacheUnrolledGraphs;
    private Random r;
    private double trainingPct = -1.0;
    private double testingPct = -1.0;

    public ACRFExtractorTrainer setTemplates(ACRF.Template[] tmpls) {
        this.tmpls = tmpls;
        return this;
    }

    public ACRFExtractorTrainer setDataSource(Iterator<Instance> trainIterator, Iterator<Instance> testIterator) {
        this.trainIterator = trainIterator;
        this.testIterator = testIterator;
        return this;
    }

    public ACRFExtractorTrainer setData(InstanceList training, InstanceList testing) {
        this.training = training;
        this.testing = testing;
        return this;
    }

    public ACRFExtractorTrainer setNumIterations(int numIter) {
        this.numIter = numIter;
        return this;
    }

    public int getNumIter() {
        return this.numIter;
    }

    public ACRFExtractorTrainer setPipes(Pipe tokPipe, Pipe featurePipe) {
        RememberTokenizationPipe rtp = new RememberTokenizationPipe();
        this.featurePipe = PipeUtils.concatenatePipes(rtp, featurePipe);
        this.tokPipe = tokPipe;
        return this;
    }

    public ACRFExtractorTrainer setEvaluator(ACRFEvaluator evaluator) {
        this.evaluator = evaluator;
        return this;
    }

    public ACRFExtractorTrainer setTrainingMethod(ACRFTrainer acrfTrainer) {
        this.trainer = acrfTrainer;
        return this;
    }

    public ACRFExtractorTrainer setTokenizatioFilter(TokenizationFilter filter) {
        this.filter = filter;
        return this;
    }

    public ACRFExtractorTrainer setCacheUnrolledGraphs(boolean cacheUnrolledGraphs) {
        this.cacheUnrolledGraphs = cacheUnrolledGraphs;
        return this;
    }

    public ACRFExtractorTrainer setNumCheckpointIterations(int numCheckpointIterations) {
        this.numCheckpointIterations = numCheckpointIterations;
        return this;
    }

    public ACRFExtractorTrainer setCheckpointDirectory(File checkpointDirectory) {
        this.checkpointDirectory = checkpointDirectory;
        return this;
    }

    public ACRFExtractorTrainer setUsePerTemplateTrain(boolean usePerTemplateTrain) {
        this.usePerTemplateTrain = usePerTemplateTrain;
        return this;
    }

    public ACRFExtractorTrainer setPerTemplateIterations(int numIter) {
        this.perTemplateIterations = numIter;
        return this;
    }

    public ACRFTrainer getTrainer() {
        return this.trainer;
    }

    public TokenizationFilter getFilter() {
        return this.filter;
    }

    public ACRFExtractor trainExtractor() {
        ACRF acrf = this.usePerTemplateTrain ? this.perTemplateTrain() : this.trainAcrf();
        ACRFExtractor extor = new ACRFExtractor(acrf, this.tokPipe, this.featurePipe);
        if (this.filter != null) {
            extor.setTokenizationFilter(this.filter);
        }
        return extor;
    }

    private ACRF perTemplateTrain() {
        Timing timing = new Timing();
        boolean hasConverged = false;
        ACRF miniAcrf = null;
        if (this.training == null) {
            this.setupData();
        }
        for (int ti = 0; ti < this.tmpls.length; ++ti) {
            ACRF.Template[] theseTmpls = new ACRF.Template[ti + 1];
            System.arraycopy(this.tmpls, 0, theseTmpls, 0, theseTmpls.length);
            logger.info("***PerTemplateTrain: Round " + ti + "\n  Templates: " + CollectionUtils.dumpToString(Arrays.asList(theseTmpls), " "));
            miniAcrf = new ACRF(this.featurePipe, theseTmpls);
            this.setupAcrf(miniAcrf);
            ACRFEvaluator eval = this.setupEvaluator("tmpl" + ti);
            hasConverged = this.trainer.train(miniAcrf, this.training, null, this.testing, eval, this.perTemplateIterations);
            timing.tick("PerTemplateTrain round " + ti);
        }
        ACRFEvaluator eval = this.setupEvaluator("full");
        if (!hasConverged) {
            this.trainer.train(miniAcrf, this.training, null, this.testing, eval, this.numIter);
        }
        return miniAcrf;
    }

    public ACRF trainAcrf() {
        if (this.training == null) {
            this.setupData();
        }
        ACRF acrf = new ACRF(this.featurePipe, this.tmpls);
        this.setupAcrf(acrf);
        ACRFEvaluator eval = this.setupEvaluator("");
        this.trainer.train(acrf, this.training, null, this.testing, eval, this.numIter);
        return acrf;
    }

    private void setupAcrf(ACRF acrf) {
        if (this.cacheUnrolledGraphs) {
            acrf.setCacheUnrolledGraphs(true);
        }
        if (this.inferencer != null) {
            acrf.setInferencer(this.inferencer);
        }
        if (this.viterbiInferencer != null) {
            acrf.setViterbiInferencer(this.viterbiInferencer);
        }
    }

    private ACRFEvaluator setupEvaluator(String checkpointPrefix) {
        ACRFEvaluator eval = this.evaluator;
        if (this.numCheckpointIterations > 0) {
            ArrayList<ACRFEvaluator> evals = new ArrayList<ACRFEvaluator>();
            evals.add(this.evaluator);
            evals.add(new CheckpointingEvaluator(this.checkpointDirectory, this.numCheckpointIterations, this.tokPipe, this.featurePipe));
            eval = new AcrfSerialEvaluator(evals);
        }
        return eval;
    }

    protected void setupData() {
        Timing timing = new Timing();
        this.training = new InstanceList(this.featurePipe);
        this.training.addThruPipe(new PipedIterator(this.trainIterator, this.tokPipe));
        if (this.trainingPct > 0.0) {
            this.training = this.subsetData(this.training, this.trainingPct);
        }
        if (this.testIterator != null) {
            this.testing = new InstanceList(this.featurePipe);
            this.testing.addThruPipe(new PipedIterator(this.testIterator, this.tokPipe));
            if (this.testingPct > 0.0) {
                this.testing = this.subsetData(this.testing, this.trainingPct);
            }
        }
        timing.tick("Data loading");
    }

    private InstanceList subsetData(InstanceList data, double pct) {
        InstanceList[] lsts = data.split(this.r, new double[]{pct, 1.0 - pct});
        return lsts[0];
    }

    public InstanceList getTrainingData() {
        if (this.training == null) {
            this.setupData();
        }
        return this.training;
    }

    public InstanceList getTestingData() {
        if (this.testing == null) {
            this.setupData();
        }
        return this.testing;
    }

    public Extraction extractOnTestData(ACRFExtractor extor) {
        return extor.extract(this.testing);
    }

    public ACRFExtractorTrainer setInferencer(Inferencer inferencer) {
        this.inferencer = inferencer;
        return this;
    }

    public ACRFExtractorTrainer setViterbiInferencer(Inferencer viterbiInferencer) {
        this.viterbiInferencer = viterbiInferencer;
        return this;
    }

    public ACRFExtractorTrainer setDataSubsets(Random random, double trainingPct, double testingPct) {
        this.r = random;
        this.trainingPct = trainingPct;
        this.testingPct = testingPct;
        return this;
    }

    private static class CheckpointingEvaluator
    extends ACRFEvaluator {
        private File directory;
        private int interval;
        private Pipe tokPipe;
        private Pipe featurePipe;

        public CheckpointingEvaluator(File directory, int interval, Pipe tokPipe, Pipe featurePipe) {
            this.directory = directory;
            this.interval = interval;
            this.tokPipe = tokPipe;
            this.featurePipe = featurePipe;
        }

        @Override
        public boolean evaluate(ACRF acrf, int iter, InstanceList training, InstanceList validation, InstanceList testing) {
            if (iter > 0 && iter % this.interval == 0) {
                ACRFExtractor extor = new ACRFExtractor(acrf, this.tokPipe, this.featurePipe);
                FileUtils.writeGzippedObject(new File(this.directory, "extor." + iter + ".ser.gz"), extor);
            }
            return true;
        }

        @Override
        public void test(InstanceList gold, List returned, String description) {
        }
    }
}

