/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.loglinear.benchmarks;

import edu.stanford.nlp.loglinear.benchmarks.CoNLLBenchmark;
import edu.stanford.nlp.loglinear.inference.CliqueTree;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.ConcatVectorNamespace;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Stack;

public class GamePlayerBenchmark {
    private static Redwood.RedwoodChannels log = Redwood.channels(GamePlayerBenchmark.class);
    static final String DATA_PATH = "/u/nlp/data/ner/conll/";

    public static void main(String[] args) throws IOException, ClassNotFoundException {
        int i;
        double[] dense;
        CoNLLBenchmark coNLL = new CoNLLBenchmark();
        List<CoNLLBenchmark.CoNLLSentence> train = coNLL.getSentences("/u/nlp/data/ner/conll/conll.iob.4class.train");
        List<CoNLLBenchmark.CoNLLSentence> testA = coNLL.getSentences("/u/nlp/data/ner/conll/conll.iob.4class.testa");
        List<CoNLLBenchmark.CoNLLSentence> testB = coNLL.getSentences("/u/nlp/data/ner/conll/conll.iob.4class.testb");
        ArrayList<CoNLLBenchmark.CoNLLSentence> allData = new ArrayList<CoNLLBenchmark.CoNLLSentence>();
        allData.addAll(train);
        allData.addAll(testA);
        allData.addAll(testB);
        HashSet<String> tagsSet = new HashSet<String>();
        for (CoNLLBenchmark.CoNLLSentence sentence : allData) {
            for (String nerTag : sentence.ner) {
                tagsSet.add(nerTag);
            }
        }
        ArrayList<String> tags = new ArrayList<String>();
        tags.addAll(tagsSet);
        coNLL.embeddings = coNLL.getEmbeddings("/u/nlp/data/ner/conll/google-300-trimmed.ser.gz", allData);
        log.info("Making the training set...");
        ConcatVectorNamespace namespace = new ConcatVectorNamespace();
        int trainSize = train.size();
        GraphicalModel[] trainingSet = new GraphicalModel[trainSize];
        for (int i2 = 0; i2 < trainSize; ++i2) {
            if (i2 % 10 == 0) {
                log.info(i2 + "/" + trainSize);
            }
            trainingSet[i2] = coNLL.generateSentenceModel(namespace, train.get(i2), tags);
        }
        Random r = new Random(10L);
        int numFeatures = 5;
        int featureLength = 30;
        ConcatVector[] humanFeatureVectors = new ConcatVector[1000];
        for (int i3 = 0; i3 < humanFeatureVectors.length; ++i3) {
            humanFeatureVectors[i3] = new ConcatVector(numFeatures);
            for (int j = 0; j < numFeatures; ++j) {
                if (r.nextBoolean()) {
                    humanFeatureVectors[i3].setSparseComponent(j, r.nextInt(featureLength), r.nextDouble());
                    continue;
                }
                dense = new double[featureLength];
                for (int k = 0; k < dense.length; ++k) {
                    dense[k] = r.nextDouble();
                }
                humanFeatureVectors[i3].setDenseComponent(j, dense);
            }
        }
        ConcatVector weights = new ConcatVector(numFeatures);
        for (i = 0; i < numFeatures; ++i) {
            dense = new double[featureLength];
            for (int j = 0; j < dense.length; ++j) {
                dense[j] = r.nextDouble();
            }
            weights.setDenseComponent(i, dense);
        }
        log.info("Warming up the JIT...");
        for (i = 0; i < 10; ++i) {
            log.info(i);
            GamePlayerBenchmark.gameplay(r, trainingSet[i], weights, humanFeatureVectors);
        }
        log.info("Timing actual run...");
        long start = System.currentTimeMillis();
        for (int i4 = 0; i4 < 10; ++i4) {
            log.info(i4);
            GamePlayerBenchmark.gameplay(r, trainingSet[i4], weights, humanFeatureVectors);
        }
        long duration = System.currentTimeMillis() - start;
        log.info("Duration: " + duration);
    }

    private static void gameplay(Random r, GraphicalModel model, ConcatVector weights, ConcatVector[] humanFeatureVectors) {
        ArrayList<Integer> variablesList = new ArrayList<Integer>();
        ArrayList<Integer> variableSizesList = new ArrayList<Integer>();
        for (GraphicalModel.Factor f : model.factors) {
            for (int i2 = 0; i2 < f.neigborIndices.length; ++i2) {
                int j = f.neigborIndices[i2];
                if (variablesList.contains(j)) continue;
                variablesList.add(j);
                variableSizesList.add(f.featuresTable.getDimensions()[i2]);
            }
        }
        int[] variables = variablesList.stream().mapToInt(i -> i).toArray();
        int[] variableSizes = variableSizesList.stream().mapToInt(i -> i).toArray();
        ArrayList<SampleState> childrenOfRoot = new ArrayList<SampleState>();
        CliqueTree tree = new CliqueTree(model, weights);
        int initialFactors = model.factors.size();
        long start = System.currentTimeMillis();
        long marginalsTime = 0L;
        for (int i3 = 0; i3 < 1000; ++i3) {
            log.info("\tTaking sample " + i3);
            Stack<SampleState> stack = new Stack<SampleState>();
            SampleState state = GamePlayerBenchmark.selectOrCreateChildAtRandom(r, model, variables, variableSizes, childrenOfRoot, humanFeatureVectors);
            long localMarginalsTime = 0L;
            for (int j = 0; j < 10; ++j) {
                state.push(model);
                assert (model.factors.size() == initialFactors + j + 1);
                if (state.cachedMarginal == null) {
                    long s = System.currentTimeMillis();
                    state.cachedMarginal = tree.calculateMarginalsJustSingletons();
                    localMarginalsTime += System.currentTimeMillis() - s;
                }
                stack.push(state);
                state = GamePlayerBenchmark.selectOrCreateChildAtRandom(r, model, variables, variableSizes, state.children, humanFeatureVectors);
            }
            log.info("\t\t" + localMarginalsTime + " ms");
            marginalsTime += localMarginalsTime;
            while (!stack.empty()) {
                ((SampleState)stack.pop()).pop(model);
            }
            assert (model.factors.size() == initialFactors);
        }
        log.info("Marginals time: " + marginalsTime + " ms");
        log.info("Avg time per marginal: " + marginalsTime / 200L + " ms");
        log.info("Total time: " + (System.currentTimeMillis() - start));
    }

    private static SampleState selectOrCreateChildAtRandom(Random r, GraphicalModel model, int[] variables, int[] variableSizes, List<SampleState> children, ConcatVector[] humanFeatureVectors) {
        int i = r.nextInt(variables.length);
        int variable = variables[i];
        int observation = r.nextInt(variableSizes[i]);
        for (SampleState sampleState : children) {
            if (sampleState.variable != variable || sampleState.observation != observation) continue;
            return sampleState;
        }
        int humanObservationVariable = 0;
        for (GraphicalModel.Factor f : model.factors) {
            for (int j : f.neigborIndices) {
                if (j < humanObservationVariable) continue;
                humanObservationVariable = j + 1;
            }
        }
        GraphicalModel.Factor factor = model.addFactor(new int[]{variable, humanObservationVariable}, new int[]{variableSizes[i], variableSizes[i]}, assn -> {
            int j = assn[0] * variableSizes[i] + assn[1];
            return humanFeatureVectors[j];
        });
        model.factors.remove(factor);
        SampleState newState = new SampleState(factor, variable, observation);
        children.add(newState);
        return newState;
    }

    public static class SampleState {
        public GraphicalModel.Factor addedFactor;
        public int variable;
        public int observation;
        public List<SampleState> children = new ArrayList<SampleState>();
        public double[][] cachedMarginal = null;

        public SampleState(GraphicalModel.Factor addedFactor, int variable, int observation) {
            this.addedFactor = addedFactor;
            this.variable = variable;
            this.observation = observation;
        }

        public void push(GraphicalModel model) {
            assert (!model.factors.contains(this.addedFactor));
            model.factors.add(this.addedFactor);
            model.getVariableMetaDataByReference(this.variable).put("inference.CliqueTree.VARIABLE_OBSERVED_VALUE", "" + this.observation);
        }

        public void pop(GraphicalModel model) {
            assert (model.factors.contains(this.addedFactor));
            model.factors.remove(this.addedFactor);
            model.getVariableMetaDataByReference(this.variable).remove("inference.CliqueTree.VARIABLE_OBSERVED_VALUE");
        }
    }
}

