/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.dvparser.DVModelReranker;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.parser.lexparser.RerankingParserQuery;
import edu.stanford.nlp.trees.DeepTree;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.Scored;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedWriter;
import java.io.FileFilter;
import java.io.FileWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class FindNearestNeighbors {
    private static Redwood.RedwoodChannels log = Redwood.channels(FindNearestNeighbors.class);
    static final int numNeighbors = 5;
    static final int maxLength = 8;

    public static void main(String[] args) throws Exception {
        String modelPath = null;
        String outputPath = null;
        String testTreebankPath = null;
        FileFilter testTreebankFilter = null;
        ArrayList<String> unusedArgs = new ArrayList<String>();
        int argIndex = 0;
        while (argIndex < args.length) {
            if (args[argIndex].equalsIgnoreCase("-model")) {
                modelPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-testTreebank")) {
                Pair<String, FileFilter> treebankDescription = ArgUtils.getTreebankDescription(args, argIndex, "-testTreebank");
                argIndex = argIndex + ArgUtils.numSubArgs(args, argIndex) + 1;
                testTreebankPath = treebankDescription.first();
                testTreebankFilter = treebankDescription.second();
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-output")) {
                outputPath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            unusedArgs.add(args[argIndex++]);
        }
        if (modelPath == null) {
            throw new IllegalArgumentException("Need to specify -model");
        }
        if (testTreebankPath == null) {
            throw new IllegalArgumentException("Need to specify -testTreebank");
        }
        if (outputPath == null) {
            throw new IllegalArgumentException("Need to specify -output");
        }
        String[] newArgs = unusedArgs.toArray(new String[unusedArgs.size()]);
        LexicalizedParser lexparser = LexicalizedParser.loadModel(modelPath, newArgs);
        MemoryTreebank testTreebank = null;
        if (testTreebankPath != null) {
            log.info("Reading in trees from " + testTreebankPath);
            if (testTreebankFilter != null) {
                log.info("Filtering on " + testTreebankFilter);
            }
            testTreebank = lexparser.getOp().tlpParams.memoryTreebank();
            testTreebank.loadPath(testTreebankPath, testTreebankFilter);
            log.info("Read in " + ((Treebank)testTreebank).size() + " trees for testing");
        }
        FileWriter out2 = new FileWriter(outputPath);
        BufferedWriter bout = new BufferedWriter(out2);
        log.info("Parsing " + ((Treebank)testTreebank).size() + " trees");
        int count = 0;
        ArrayList<ParseRecord> records = Generics.newArrayList();
        for (Object goldTree : testTreebank) {
            ArrayList<Word> tokens = ((Tree)goldTree).yieldWords();
            ParserQuery parserQuery = lexparser.parserQuery();
            if (!parserQuery.parse(tokens)) {
                throw new AssertionError((Object)("Could not parse: " + tokens));
            }
            if (!(parserQuery instanceof RerankingParserQuery)) {
                throw new IllegalArgumentException("Expected a LexicalizedParser with a Reranker attached");
            }
            RerankingParserQuery rerankingParserQuery = (RerankingParserQuery)parserQuery;
            if (!(rerankingParserQuery.rerankerQuery() instanceof DVModelReranker.Query)) {
                throw new IllegalArgumentException("Expected a LexicalizedParser with a DVModel attached");
            }
            DeepTree tree = ((DVModelReranker.Query)rerankingParserQuery.rerankerQuery()).getDeepTrees().get(0);
            SimpleMatrix rootVector = null;
            for (Map.Entry<Tree, SimpleMatrix> entry : tree.getVectors().entrySet()) {
                if (!entry.getKey().label().value().equals("ROOT")) continue;
                rootVector = entry.getValue();
                break;
            }
            if (rootVector == null) {
                throw new AssertionError((Object)"Could not find root nodevector");
            }
            out2.write(tokens + "\n");
            out2.write(tree.getTree() + "\n");
            for (int i = 0; i < rootVector.getNumElements(); ++i) {
                out2.write("  " + rootVector.get(i));
            }
            out2.write("\n\n\n");
            if (++count % 10 == 0) {
                log.info("  " + count);
            }
            records.add(new ParseRecord(tokens, (Tree)goldTree, tree.getTree(), rootVector, tree.getVectors()));
        }
        log.info("  done parsing");
        ArrayList<Pair<Tree, SimpleMatrix>> subtrees = Generics.newArrayList();
        for (ParseRecord record : records) {
            for (Map.Entry entry : record.nodeVectors.entrySet()) {
                if (((Tree)entry.getKey()).getLeaves().size() > 8) continue;
                subtrees.add(Pair.makePair((Tree)entry.getKey(), (SimpleMatrix)entry.getValue()));
            }
        }
        log.info("There are " + subtrees.size() + " subtrees in the set of trees");
        PriorityQueue<Scored> bestmatches = new PriorityQueue<Scored>(101, ScoredComparator.DESCENDING_COMPARATOR);
        for (int i = 0; i < subtrees.size(); ++i) {
            log.info(((Tree)((Pair)subtrees.get(i)).first()).yieldWords());
            log.info(((Pair)subtrees.get(i)).first());
            for (int j = 0; j < subtrees.size(); ++j) {
                if (i == j) continue;
                double d = ((SimpleMatrix)((SimpleMatrix)((Pair)subtrees.get(i)).second()).minus((SimpleBase)((SimpleMatrix)((Pair)subtrees.get(j)).second()))).normF();
                bestmatches.add(new ScoredObject<Pair<Tree, Tree>>(Pair.makePair((Tree)((Pair)subtrees.get(i)).first(), (Tree)((Pair)subtrees.get(j)).first()), d));
                if (bestmatches.size() <= 100) continue;
                bestmatches.poll();
            }
            ArrayList<ScoredObject> ordered = Generics.newArrayList();
            while (bestmatches.size() > 0) {
                ordered.add((ScoredObject)bestmatches.poll());
            }
            Collections.reverse(ordered);
            for (ScoredObject pair : ordered) {
                log.info(" MATCHED " + ((Tree)((Pair)pair.object()).second).yieldWords() + " ... " + ((Pair)pair.object()).second() + " with a score of " + pair.score());
            }
            log.info(new Object[0]);
            log.info(new Object[0]);
            bestmatches.clear();
        }
        bout.flush();
        out2.flush();
        out2.close();
    }

    public static class ParseRecord {
        final List<Word> sentence;
        final Tree goldTree;
        final Tree parse;
        final SimpleMatrix rootVector;
        final IdentityHashMap<Tree, SimpleMatrix> nodeVectors;

        public ParseRecord(List<Word> sentence, Tree goldTree, Tree parse, SimpleMatrix rootVector, IdentityHashMap<Tree, SimpleMatrix> nodeVectors) {
            this.sentence = sentence;
            this.goldTree = goldTree;
            this.parse = parse;
            this.rootVector = rootVector;
            this.nodeVectors = nodeVectors;
        }
    }
}

