/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.WeightedDataset;
import edu.stanford.nlp.io.NumberRangesFileFilter;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.parser.lexparser.ChineseTreebankParserParams;
import edu.stanford.nlp.parser.lexparser.ChineseWordFeatureExtractor;
import edu.stanford.nlp.parser.lexparser.IntTaggedWord;
import edu.stanford.nlp.parser.lexparser.Lexicon;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TreeAnnotator;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.parser.lexparser.UnknownWordModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.Distribution;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.util.CollectionValuedMap;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.FileFilter;
import java.io.IOException;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.regex.Pattern;

public class ChineseMaxentLexicon
implements Lexicon {
    private static Redwood.RedwoodChannels log = Redwood.channels(ChineseMaxentLexicon.class);
    private static final long serialVersionUID = 238834703409896852L;
    private static final boolean verbose = true;
    public static final boolean seenTagsOnly = false;
    private ChineseWordFeatureExtractor featExtractor;
    public static final boolean fixUnkFunctionWords = false;
    private static final Pattern wordPattern = Pattern.compile(".*-W");
    private static final Pattern charPattern = Pattern.compile(".*-.C");
    private static final Pattern bigramPattern = Pattern.compile(".*-.B");
    private static final Pattern conjPattern = Pattern.compile(".*&&.*");
    private final Pair<Pattern, Integer> wordThreshold = new Pair<Pattern, Integer>(wordPattern, 0);
    private final Pair<Pattern, Integer> charThreshold = new Pair<Pattern, Integer>(charPattern, 2);
    private final Pair<Pattern, Integer> bigramThreshold = new Pair<Pattern, Integer>(bigramPattern, 3);
    private final Pair<Pattern, Integer> conjThreshold = new Pair<Pattern, Integer>(conjPattern, 3);
    private final List<Pair<Pattern, Integer>> featureThresholds = new ArrayList<Pair<Pattern, Integer>>();
    private final int universalThreshold = 0;
    private LinearClassifier scorer;
    private Map<String, String> functionWordTags = Generics.newHashMap();
    private Distribution<String> tagDist;
    private final Index<String> wordIndex;
    private final Index<String> tagIndex;
    private transient Counter<String> logProbs;
    private double iteratorCutoffFactor = 4.0;
    private transient int lastWord = -1;
    String initialWeightFile = null;
    boolean trainFloat = false;
    private static final String featureDir = "gbfeatures";
    private double tol = 1.0E-4;
    private double sigma = 0.4;
    static final boolean tuneSigma = false;
    static final int trainCountThreshold = 5;
    final int featureLevel;
    static final int DEFAULT_FEATURE_LEVEL = 2;
    private boolean trainOnLowCount = false;
    private boolean trainByType = false;
    private final TreebankLangParserParams tlpParams;
    private final TreebankLanguagePack ctlp;
    private final Options op;
    public CollectionValuedMap<String, String> tagsForWord = new CollectionValuedMap();
    transient IntCounter<TaggedWord> datumCounter;

    @Override
    public boolean isKnown(int word) {
        return this.isKnown(this.wordIndex.get(word));
    }

    @Override
    public boolean isKnown(String word) {
        return this.tagsForWord.containsKey(word);
    }

    @Override
    public Set<String> tagSet(Function<String, String> basicCategoryFunction) {
        HashSet<String> tagSet = new HashSet<String>();
        for (String tag : this.tagIndex.objectsList()) {
            tagSet.add(basicCategoryFunction.apply(tag));
        }
        return tagSet;
    }

    private void ensureProbs(int word) {
        this.ensureProbs(word, true);
    }

    private void ensureProbs(int word, boolean subtractTagScore) {
        if (word == this.lastWord) {
            return;
        }
        this.lastWord = word;
        if (this.functionWordTags.containsKey(this.wordIndex.get(word))) {
            this.logProbs = new ClassicCounter<String>();
            String trueTag = this.functionWordTags.get(this.wordIndex.get(word));
            for (String tag : this.tagIndex.objectsList()) {
                if (this.ctlp.basicCategory(tag).equals(trueTag)) {
                    this.logProbs.setCount(tag, 0.0);
                    continue;
                }
                this.logProbs.setCount(tag, Double.NEGATIVE_INFINITY);
            }
            return;
        }
        BasicDatum datum = new BasicDatum(this.featExtractor.makeFeatures(this.wordIndex.get(word)));
        this.logProbs = this.scorer.logProbabilityOf(datum);
        if (subtractTagScore) {
            Set<String> tagSet = this.logProbs.keySet();
            for (String tag : tagSet) {
                this.logProbs.incrementCount(tag, -Math.log(this.tagDist.probabilityOf(tag)));
            }
        }
    }

    @Override
    public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) {
        this.ensureProbs(word);
        ArrayList<IntTaggedWord> rules = new ArrayList<IntTaggedWord>();
        double max = Counters.max(this.logProbs);
        for (int tag = 0; tag < this.tagIndex.size(); ++tag) {
            IntTaggedWord iTW = new IntTaggedWord(word, tag);
            double score = this.logProbs.getCount(this.tagIndex.get(tag));
            if (!(score > max - this.iteratorCutoffFactor)) continue;
            rules.add(iTW);
        }
        return rules.iterator();
    }

    @Override
    public Iterator<IntTaggedWord> ruleIteratorByWord(String word, int loc, String featureSpec) {
        return this.ruleIteratorByWord(this.wordIndex.indexOf(word), loc, featureSpec);
    }

    @Override
    public int numRules() {
        int accumulated = 0;
        int tot = this.wordIndex.size();
        for (int w = 0; w < tot; ++w) {
            Iterator<IntTaggedWord> iter = this.ruleIteratorByWord(w, 0, null);
            while (iter.hasNext()) {
                iter.next();
                ++accumulated;
            }
        }
        return accumulated;
    }

    private String getTag(String word) {
        int iW = this.wordIndex.addToIndex(word);
        this.ensureProbs(iW, false);
        return Counters.argmax(this.logProbs);
    }

    private void verbose(String s) {
        log.info(s);
    }

    public ChineseMaxentLexicon(Options op, Index<String> wordIndex, Index<String> tagIndex, int featureLevel) {
        this.op = op;
        this.tlpParams = op.tlpParams;
        this.ctlp = op.tlpParams.treebankLanguagePack();
        this.wordIndex = wordIndex;
        this.tagIndex = tagIndex;
        this.featureLevel = featureLevel;
    }

    @Override
    public void initializeTraining(double numTrees) {
        this.verbose("Training ChineseMaxentLexicon.");
        this.verbose("trainOnLowCount = " + this.trainOnLowCount + ", trainByType = " + this.trainByType + ", featureLevel = " + this.featureLevel + ", tuneSigma = " + false);
        this.verbose("Making dataset...");
        if (this.featExtractor == null) {
            this.featExtractor = new ChineseWordFeatureExtractor(this.featureLevel);
        }
        this.datumCounter = new IntCounter();
    }

    @Override
    public final void train(Collection<Tree> trees) {
        this.train(trees, 1.0);
    }

    @Override
    public void train(Collection<Tree> trees, double weight) {
        for (Tree tree : trees) {
            this.train(tree, weight);
        }
    }

    @Override
    public void train(Tree tree, double weight) {
        this.train((List<TaggedWord>)tree.taggedYield(), weight);
    }

    @Override
    public void train(List<TaggedWord> sentence, double weight) {
        this.featExtractor.train(sentence, weight);
        for (TaggedWord word : sentence) {
            this.datumCounter.incrementCount(word, weight);
            this.tagsForWord.add(word.word(), word.tag());
        }
    }

    @Override
    public void trainUnannotated(List<TaggedWord> sentence, double weight) {
        throw new UnsupportedOperationException("This version of the parser does not support non-tree training data");
    }

    @Override
    public void incrementTreesRead(double weight) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void train(TaggedWord tw, int loc, double weight) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void finishTraining() {
        IntCounter<String> tagCounter = new IntCounter<String>();
        WeightedDataset<String, String> data = new WeightedDataset<String, String>(this.datumCounter.size());
        for (TaggedWord word : this.datumCounter.keySet()) {
            int count = this.datumCounter.getIntCount(word);
            if (this.trainOnLowCount && count > 5 || this.functionWordTags.containsKey(word.word())) continue;
            tagCounter.incrementCount(word.tag());
            if (this.trainByType) {
                count = 1;
            }
            data.add(new BasicDatum<String, String>(this.featExtractor.makeFeatures(word.word()), word.tag()), (float)count);
        }
        this.datumCounter = null;
        this.tagDist = Distribution.laplaceSmoothedDistribution(tagCounter, tagCounter.size(), 0.5);
        tagCounter = null;
        this.applyThresholds(data);
        this.verbose("Making classifier...");
        QNMinimizer minim = new QNMinimizer();
        LinearClassifierFactory factory = new LinearClassifierFactory(minim);
        factory.setTol(this.tol);
        factory.setSigma(this.sigma);
        this.scorer = factory.trainClassifier((GeneralDataset)data);
        this.verbose("Done training.");
    }

    private void applyThresholds(WeightedDataset data) {
        int numRemoved;
        if ((Integer)this.wordThreshold.second > 0) {
            this.featureThresholds.add(this.wordThreshold);
        }
        if (this.featExtractor.chars && (Integer)this.charThreshold.second > 0) {
            this.featureThresholds.add(this.charThreshold);
        }
        if (this.featExtractor.bigrams && (Integer)this.bigramThreshold.second > 0) {
            this.featureThresholds.add(this.bigramThreshold);
        }
        if ((this.featExtractor.conjunctions || this.featExtractor.mildConjunctions) && (Integer)this.conjThreshold.second > 0) {
            this.featureThresholds.add(this.conjThreshold);
        }
        int types = data.numFeatureTypes();
        if (this.featureThresholds.size() > 0) {
            data.applyFeatureCountThreshold(this.featureThresholds);
        }
        if ((numRemoved = types - data.numFeatureTypes()) > 0) {
            this.verbose("Thresholding removed " + numRemoved + " features.");
        }
    }

    public static void main(String[] args) {
        ChineseTreebankParserParams tlpParams = new ChineseTreebankParserParams();
        TreebankLanguagePack ctlp = tlpParams.treebankLanguagePack();
        Options op = new Options(tlpParams);
        TreeAnnotator ta = new TreeAnnotator(tlpParams.headFinder(), tlpParams, op);
        log.info("Reading Trees...");
        NumberRangesFileFilter trainFilter = new NumberRangesFileFilter(args[1], true);
        MemoryTreebank trainTreebank = tlpParams.memoryTreebank();
        trainTreebank.loadPath(args[0], (FileFilter)trainFilter);
        log.info("Annotating trees...");
        ArrayList<Tree> trainTrees = new ArrayList<Tree>();
        for (Tree tree : trainTreebank) {
            trainTrees.add(ta.transformTree(tree));
        }
        trainTreebank = null;
        log.info("Training lexicon...");
        HashIndex<String> wordIndex = new HashIndex<String>();
        HashIndex<String> tagIndex = new HashIndex<String>();
        int featureLevel = 2;
        if (args.length > 3) {
            featureLevel = Integer.parseInt(args[3]);
        }
        ChineseMaxentLexicon lex = new ChineseMaxentLexicon(op, wordIndex, tagIndex, featureLevel);
        lex.initializeTraining(trainTrees.size());
        lex.train(trainTrees);
        lex.finishTraining();
        log.info("Testing");
        NumberRangesFileFilter testFilter = new NumberRangesFileFilter(args[2], true);
        MemoryTreebank testTreebank = tlpParams.memoryTreebank();
        testTreebank.loadPath(args[0], (FileFilter)testFilter);
        ArrayList<TaggedWord> testWords = new ArrayList<TaggedWord>();
        for (Tree t : testTreebank) {
            for (TaggedWord tw : t.taggedYield()) {
                testWords.add(tw);
            }
        }
        int[] totalAndCorrect = lex.testOnTreebank(testWords);
        log.info("done.");
        System.out.println(totalAndCorrect[1] + " correct out of " + totalAndCorrect[0] + " -- ACC: " + (double)totalAndCorrect[1] / (double)totalAndCorrect[0]);
    }

    private int[] testOnTreebank(Collection<TaggedWord> testWords) {
        int[] totalAndCorrect = new int[]{0, 0};
        for (TaggedWord word : testWords) {
            String goldTag = word.tag();
            String guessTag = this.ctlp.basicCategory(this.getTag(word.word()));
            totalAndCorrect[0] = totalAndCorrect[0] + 1;
            if (!goldTag.equals(guessTag)) continue;
            totalAndCorrect[1] = totalAndCorrect[1] + 1;
        }
        return totalAndCorrect;
    }

    @Override
    public float score(IntTaggedWord iTW, int loc, String word, String featureSpec) {
        this.ensureProbs(iTW.word());
        double max = Counters.max(this.logProbs);
        double score = this.logProbs.getCount(iTW.tagString(this.tagIndex));
        if (score > max - this.iteratorCutoffFactor) {
            return (float)score;
        }
        return Float.NEGATIVE_INFINITY;
    }

    @Override
    public void writeData(Writer w) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public void readData(BufferedReader in) throws IOException {
        throw new UnsupportedOperationException();
    }

    @Override
    public UnknownWordModel getUnknownWordModel() {
        return null;
    }

    @Override
    public void setUnknownWordModel(UnknownWordModel uwm) {
    }

    @Override
    public void train(Collection<Tree> trees, Collection<Tree> rawTrees) {
        this.train(trees);
    }
}

