/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.parser.dvparser.CacheParseHypotheses;
import edu.stanford.nlp.parser.lexparser.BinaryGrammar;
import edu.stanford.nlp.parser.lexparser.BinaryRule;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.UnaryGrammar;
import edu.stanford.nlp.parser.lexparser.UnaryRule;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.TwoDimensionalSet;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.PrintStream;
import java.io.Serializable;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.regex.Pattern;
import org.ejml.data.DenseMatrix64F;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class DVModel
implements Serializable {
    private static Redwood.RedwoodChannels log = Redwood.channels(DVModel.class);
    public TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform;
    public Map<String, SimpleMatrix> unaryTransform;
    public TwoDimensionalMap<String, String, SimpleMatrix> binaryScore;
    public Map<String, SimpleMatrix> unaryScore;
    public Map<String, SimpleMatrix> wordVectors;
    int numBinaryMatrices;
    int numUnaryMatrices;
    int binaryTransformSize;
    int unaryTransformSize;
    int binaryScoreSize;
    int unaryScoreSize;
    Options op;
    final int numCols;
    final int numRows;
    transient SimpleMatrix identity;
    Random rand;
    static final String UNKNOWN_WORD = "*UNK*";
    static final String UNKNOWN_NUMBER = "*NUM*";
    static final String UNKNOWN_CAPS = "*CAPS*";
    static final String UNKNOWN_CHINESE_YEAR = "*ZH_YEAR*";
    static final String UNKNOWN_CHINESE_NUMBER = "*ZH_NUM*";
    static final String UNKNOWN_CHINESE_PERCENT = "*ZH_PERCENT*";
    static final String START_WORD = "*START*";
    static final String END_WORD = "*END*";
    private static final Function<SimpleMatrix, DenseMatrix64F> convertSimpleMatrix = matrix -> matrix.getMatrix();
    private static final Function<DenseMatrix64F, SimpleMatrix> convertDenseMatrix = matrix -> SimpleMatrix.wrap((DenseMatrix64F)matrix);
    static final Pattern NUMBER_PATTERN = Pattern.compile("-?[0-9][-0-9,.:]*");
    static final Pattern CAPS_PATTERN = Pattern.compile("[a-zA-Z]*[A-Z][a-zA-Z]*");
    static final Pattern CHINESE_YEAR_PATTERN = Pattern.compile("[\u3007\u96f6\u4e00\u4e8c\u4e09\u56db\u4e94\u516d\u4e03\u516b\u4e5d\uff10\uff11\uff12\uff13\uff14\uff15\uff16\uff17\uff18\uff19]{4}+\u5e74");
    static final Pattern CHINESE_NUMBER_PATTERN = Pattern.compile("(?:[\u3007\uff10\u96f6\u4e00\u4e8c\u4e09\u56db\u4e94\u516d\u4e03\u516b\u4e5d\uff10\uff11\uff12\uff13\uff14\uff15\uff16\uff17\uff18\uff19\u5341\u767e\u4e07\u5343\u4ebf]+[\u70b9\u591a]?)+");
    static final Pattern CHINESE_PERCENT_PATTERN = Pattern.compile("\u767e\u5206\u4e4b[\u3007\uff10\u96f6\u4e00\u4e8c\u4e09\u56db\u4e94\u516d\u4e03\u516b\u4e5d\uff10\uff11\uff12\uff13\uff14\uff15\uff16\uff17\uff18\uff19\u5341\u70b9]+");
    static final Pattern DG_PATTERN = Pattern.compile(".*DG.*");
    private static final long serialVersionUID = 1L;

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.identity = SimpleMatrix.identity((int)this.numRows);
    }

    public DVModel(Options op, Index<String> stateIndex, UnaryGrammar unaryGrammar, BinaryGrammar binaryGrammar) {
        this.op = op;
        this.rand = new Random(op.trainOptions.randomSeed);
        this.readWordVectors();
        this.numRows = op.lexOptions.numHid;
        this.numCols = op.lexOptions.numHid;
        this.binaryTransform = TwoDimensionalMap.treeMap();
        this.unaryTransform = Generics.newTreeMap();
        this.binaryScore = TwoDimensionalMap.treeMap();
        this.unaryScore = Generics.newTreeMap();
        this.numBinaryMatrices = 0;
        this.numUnaryMatrices = 0;
        this.binaryTransformSize = this.numRows * (this.numCols * 2 + 1);
        this.unaryTransformSize = this.numRows * (this.numCols + 1);
        this.binaryScoreSize = this.numCols;
        this.unaryScoreSize = this.numCols;
        if (op.trainOptions.useContextWords) {
            this.binaryTransformSize += this.numRows * this.numCols * 2;
            this.unaryTransformSize += this.numRows * this.numCols * 2;
        }
        this.identity = SimpleMatrix.identity((int)this.numRows);
        for (UnaryRule unaryRule : unaryGrammar) {
            String childState = stateIndex.get(unaryRule.child);
            String childBasic = this.basicCategory(childState);
            this.addRandomUnaryMatrix(childBasic);
        }
        for (BinaryRule binaryRule : binaryGrammar) {
            String leftState = stateIndex.get(binaryRule.leftChild);
            String leftBasic = this.basicCategory(leftState);
            String rightState = stateIndex.get(binaryRule.rightChild);
            String rightBasic = this.basicCategory(rightState);
            this.addRandomBinaryMatrix(leftBasic, rightBasic);
        }
    }

    public DVModel(TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform, Map<String, SimpleMatrix> unaryTransform, TwoDimensionalMap<String, String, SimpleMatrix> binaryScore, Map<String, SimpleMatrix> unaryScore, Map<String, SimpleMatrix> wordVectors, Options op) {
        this.op = op;
        this.binaryTransform = binaryTransform;
        this.unaryTransform = unaryTransform;
        this.binaryScore = binaryScore;
        this.unaryScore = unaryScore;
        this.wordVectors = wordVectors;
        this.numBinaryMatrices = binaryTransform.size();
        this.numUnaryMatrices = unaryTransform.size();
        if (this.numBinaryMatrices > 0) {
            this.binaryTransformSize = binaryTransform.iterator().next().getValue().getNumElements();
            this.binaryScoreSize = binaryScore.iterator().next().getValue().getNumElements();
        } else {
            this.binaryTransformSize = 0;
            this.binaryScoreSize = 0;
        }
        if (this.numUnaryMatrices > 0) {
            this.unaryTransformSize = unaryTransform.values().iterator().next().getNumElements();
            this.unaryScoreSize = unaryScore.values().iterator().next().getNumElements();
        } else {
            this.unaryTransformSize = 0;
            this.unaryScoreSize = 0;
        }
        this.numRows = op.lexOptions.numHid;
        this.numCols = op.lexOptions.numHid;
        this.identity = SimpleMatrix.identity((int)this.numRows);
        this.rand = new Random(op.trainOptions.randomSeed);
    }

    private SimpleMatrix randomContextMatrix() {
        SimpleMatrix matrix = new SimpleMatrix(this.numRows, this.numCols * 2);
        matrix.insertIntoThis(0, 0, (SimpleBase)((SimpleMatrix)this.identity.scale(this.op.trainOptions.scalingForInit * 0.1)));
        matrix.insertIntoThis(0, this.numCols, (SimpleBase)((SimpleMatrix)this.identity.scale(this.op.trainOptions.scalingForInit * 0.1)));
        matrix = (SimpleMatrix)matrix.plus((SimpleBase)SimpleMatrix.random((int)this.numRows, (int)(this.numCols * 2), (double)(-1.0 / Math.sqrt((double)this.numCols * 100.0)), (double)(1.0 / Math.sqrt((double)this.numCols * 100.0)), (Random)this.rand));
        return matrix;
    }

    private SimpleMatrix randomTransformMatrix() {
        SimpleMatrix matrix;
        switch (this.op.trainOptions.transformMatrixType) {
            case DIAGONAL: {
                matrix = (SimpleMatrix)SimpleMatrix.random((int)this.numRows, (int)this.numCols, (double)(-1.0 / Math.sqrt((double)this.numCols * 100.0)), (double)(1.0 / Math.sqrt((double)this.numCols * 100.0)), (Random)this.rand).plus((SimpleBase)this.identity);
                break;
            }
            case RANDOM: {
                matrix = SimpleMatrix.random((int)this.numRows, (int)this.numCols, (double)(-1.0 / Math.sqrt(this.numCols)), (double)(1.0 / Math.sqrt(this.numCols)), (Random)this.rand);
                break;
            }
            case OFF_DIAGONAL: {
                matrix = (SimpleMatrix)SimpleMatrix.random((int)this.numRows, (int)this.numCols, (double)(-1.0 / Math.sqrt((double)this.numCols * 100.0)), (double)(1.0 / Math.sqrt((double)this.numCols * 100.0)), (Random)this.rand).plus((SimpleBase)this.identity);
                for (int i = 0; i < this.numCols; ++i) {
                    int x = this.rand.nextInt(this.numCols);
                    int y = this.rand.nextInt(this.numCols);
                    int scale = this.rand.nextInt(3) - 1;
                    matrix.set(x, y, matrix.get(x, y) + (double)scale);
                }
                break;
            }
            case RANDOM_ZEROS: {
                matrix = (SimpleMatrix)SimpleMatrix.random((int)this.numRows, (int)this.numCols, (double)(-1.0 / Math.sqrt((double)this.numCols * 100.0)), (double)(1.0 / Math.sqrt((double)this.numCols * 100.0)), (Random)this.rand).plus((SimpleBase)this.identity);
                for (int i = 0; i < this.numCols; ++i) {
                    int x = this.rand.nextInt(this.numCols);
                    int y = this.rand.nextInt(this.numCols);
                    matrix.set(x, y, 0.0);
                }
                break;
            }
            default: {
                throw new IllegalArgumentException("Unexpected matrix initialization type " + (Object)((Object)this.op.trainOptions.transformMatrixType));
            }
        }
        return matrix;
    }

    public void addRandomUnaryMatrix(String childBasic) {
        SimpleMatrix transform;
        if (this.unaryTransform.get(childBasic) != null) {
            return;
        }
        ++this.numUnaryMatrices;
        SimpleMatrix score = SimpleMatrix.random((int)1, (int)this.numCols, (double)(-1.0 / Math.sqrt(this.numCols)), (double)(1.0 / Math.sqrt(this.numCols)), (Random)this.rand);
        this.unaryScore.put(childBasic, (SimpleMatrix)score.scale(this.op.trainOptions.scalingForInit));
        if (this.op.trainOptions.useContextWords) {
            transform = new SimpleMatrix(this.numRows, this.numCols * 3 + 1);
            transform.insertIntoThis(0, this.numCols + 1, (SimpleBase)this.randomContextMatrix());
        } else {
            transform = new SimpleMatrix(this.numRows, this.numCols + 1);
        }
        SimpleMatrix unary = this.randomTransformMatrix();
        transform.insertIntoThis(0, 0, (SimpleBase)unary);
        this.unaryTransform.put(childBasic, (SimpleMatrix)transform.scale(this.op.trainOptions.scalingForInit));
    }

    public void addRandomBinaryMatrix(String leftBasic, String rightBasic) {
        SimpleMatrix binary;
        if (this.binaryTransform.get(leftBasic, rightBasic) != null) {
            return;
        }
        ++this.numBinaryMatrices;
        SimpleMatrix score = SimpleMatrix.random((int)1, (int)this.numCols, (double)(-1.0 / Math.sqrt(this.numCols)), (double)(1.0 / Math.sqrt(this.numCols)), (Random)this.rand);
        this.binaryScore.put(leftBasic, rightBasic, (SimpleMatrix)score.scale(this.op.trainOptions.scalingForInit));
        if (this.op.trainOptions.useContextWords) {
            binary = new SimpleMatrix(this.numRows, this.numCols * 4 + 1);
            binary.insertIntoThis(0, this.numCols * 2 + 1, (SimpleBase)this.randomContextMatrix());
        } else {
            binary = new SimpleMatrix(this.numRows, this.numCols * 2 + 1);
        }
        SimpleMatrix left = this.randomTransformMatrix();
        SimpleMatrix right = this.randomTransformMatrix();
        binary.insertIntoThis(0, 0, (SimpleBase)left);
        binary.insertIntoThis(0, this.numCols, (SimpleBase)right);
        this.binaryTransform.put(leftBasic, rightBasic, (SimpleMatrix)binary.scale(this.op.trainOptions.scalingForInit));
    }

    public void setRulesForTrainingSet(List<Tree> sentences, Map<Tree, byte[]> compressedTrees) {
        TwoDimensionalSet<String, String> binaryRules = TwoDimensionalSet.treeSet();
        HashSet<String> unaryRules = new HashSet<String>();
        HashSet<String> words = new HashSet<String>();
        for (Tree tree : sentences) {
            this.searchRulesForBatch(binaryRules, unaryRules, words, tree);
            for (Tree hypothesis : CacheParseHypotheses.convertToTrees(compressedTrees.get(tree))) {
                this.searchRulesForBatch(binaryRules, unaryRules, words, hypothesis);
            }
        }
        for (Pair pair : binaryRules) {
            this.addRandomBinaryMatrix((String)pair.first, (String)pair.second);
        }
        for (String string : unaryRules) {
            this.addRandomUnaryMatrix(string);
        }
        this.filterRulesForBatch(binaryRules, unaryRules, words);
    }

    public void filterRulesForBatch(Collection<Tree> trees) {
        TwoDimensionalSet<String, String> binaryRules = TwoDimensionalSet.treeSet();
        HashSet<String> unaryRules = new HashSet<String>();
        HashSet<String> words = new HashSet<String>();
        for (Tree tree : trees) {
            this.searchRulesForBatch(binaryRules, unaryRules, words, tree);
        }
        this.filterRulesForBatch(binaryRules, unaryRules, words);
    }

    public void filterRulesForBatch(Map<Tree, byte[]> compressedTrees) {
        TwoDimensionalSet<String, String> binaryRules = TwoDimensionalSet.treeSet();
        HashSet<String> unaryRules = new HashSet<String>();
        HashSet<String> words = new HashSet<String>();
        for (Map.Entry<Tree, byte[]> entry : compressedTrees.entrySet()) {
            this.searchRulesForBatch(binaryRules, unaryRules, words, entry.getKey());
            for (Tree hypothesis : CacheParseHypotheses.convertToTrees(entry.getValue())) {
                this.searchRulesForBatch(binaryRules, unaryRules, words, hypothesis);
            }
        }
        this.filterRulesForBatch(binaryRules, unaryRules, words);
    }

    public void filterRulesForBatch(TwoDimensionalSet<String, String> binaryRules, Set<String> unaryRules, Set<String> words) {
        TwoDimensionalMap<String, String, Object> newBinaryTransforms = TwoDimensionalMap.treeMap();
        TwoDimensionalMap<String, String, SimpleMatrix> newBinaryScores = TwoDimensionalMap.treeMap();
        for (Pair<String, String> pair : binaryRules) {
            SimpleMatrix score;
            SimpleMatrix transform = this.binaryTransform.get(pair.first(), pair.second());
            if (transform != null) {
                newBinaryTransforms.put(pair.first(), pair.second(), transform);
            }
            if ((score = this.binaryScore.get(pair.first(), pair.second())) != null) {
                newBinaryScores.put(pair.first(), pair.second(), score);
            }
            if (transform == null && score != null || transform != null && score == null) {
                throw new AssertionError();
            }
        }
        this.binaryTransform = newBinaryTransforms;
        this.binaryScore = newBinaryScores;
        this.numBinaryMatrices = this.binaryTransform.size();
        TreeMap<String, SimpleMatrix> newUnaryTransforms = Generics.newTreeMap();
        TreeMap<String, SimpleMatrix> treeMap = Generics.newTreeMap();
        for (String unaryRule : unaryRules) {
            SimpleMatrix score;
            SimpleMatrix transform = this.unaryTransform.get(unaryRule);
            if (transform != null) {
                newUnaryTransforms.put(unaryRule, transform);
            }
            if ((score = this.unaryScore.get(unaryRule)) != null) {
                treeMap.put(unaryRule, score);
            }
            if (transform == null && score != null || transform != null && score == null) {
                throw new AssertionError();
            }
        }
        this.unaryTransform = newUnaryTransforms;
        this.unaryScore = treeMap;
        this.numUnaryMatrices = this.unaryTransform.size();
        TreeMap<String, SimpleMatrix> newWordVectors = Generics.newTreeMap();
        for (String word : words) {
            SimpleMatrix wordVector = this.wordVectors.get(word);
            if (wordVector == null) continue;
            newWordVectors.put(word, wordVector);
        }
        this.wordVectors = newWordVectors;
    }

    private void searchRulesForBatch(TwoDimensionalSet<String, String> binaryRules, Set<String> unaryRules, Set<String> words, Tree tree) {
        if (tree.isLeaf()) {
            return;
        }
        if (tree.isPreTerminal()) {
            words.add(this.getVocabWord(tree.children()[0].value()));
            return;
        }
        Tree[] children = tree.children();
        if (children.length == 1) {
            unaryRules.add(this.basicCategory(children[0].value()));
            this.searchRulesForBatch(binaryRules, unaryRules, words, children[0]);
        } else if (children.length == 2) {
            binaryRules.add(this.basicCategory(children[0].value()), this.basicCategory(children[1].value()));
            this.searchRulesForBatch(binaryRules, unaryRules, words, children[0]);
            this.searchRulesForBatch(binaryRules, unaryRules, words, children[1]);
        } else {
            throw new AssertionError((Object)"Expected a binarized tree");
        }
    }

    public String basicCategory(String category) {
        if (this.op.trainOptions.dvSimplifiedModel) {
            return "";
        }
        String basic = this.op.langpack().basicCategory(category);
        if (basic.length() > 0 && basic.charAt(0) == '@') {
            basic = basic.substring(1);
        }
        return basic;
    }

    public void readWordVectors() {
        SimpleMatrix unknownNumberVector = null;
        SimpleMatrix unknownCapsVector = null;
        SimpleMatrix unknownChineseYearVector = null;
        SimpleMatrix unknownChineseNumberVector = null;
        SimpleMatrix unknownChinesePercentVector = null;
        this.wordVectors = Generics.newTreeMap();
        int numberCount = 0;
        int capsCount = 0;
        int chineseYearCount = 0;
        int chineseNumberCount = 0;
        int chinesePercentCount = 0;
        Embedding rawWordVectors = new Embedding(this.op.lexOptions.wordVectorFile, this.op.lexOptions.numHid);
        for (String word : rawWordVectors.keySet()) {
            SimpleMatrix vector = rawWordVectors.get(word);
            if (this.op.wordFunction != null) {
                word = this.op.wordFunction.apply(word);
            }
            this.wordVectors.put(word, vector);
            if (this.op.lexOptions.numHid <= 0) {
                this.op.lexOptions.numHid = vector.getNumElements();
            }
            if (this.op.trainOptions.unknownNumberVector && (NUMBER_PATTERN.matcher(word).matches() || DG_PATTERN.matcher(word).matches())) {
                ++numberCount;
                unknownNumberVector = unknownNumberVector == null ? new SimpleMatrix(vector) : (SimpleMatrix)unknownNumberVector.plus((SimpleBase)vector);
            }
            if (this.op.trainOptions.unknownCapsVector && CAPS_PATTERN.matcher(word).matches()) {
                ++capsCount;
                unknownCapsVector = unknownCapsVector == null ? new SimpleMatrix(vector) : (SimpleMatrix)unknownCapsVector.plus((SimpleBase)vector);
            }
            if (this.op.trainOptions.unknownChineseYearVector && CHINESE_YEAR_PATTERN.matcher(word).matches()) {
                ++chineseYearCount;
                unknownChineseYearVector = unknownChineseYearVector == null ? new SimpleMatrix(vector) : (SimpleMatrix)unknownChineseYearVector.plus((SimpleBase)vector);
            }
            if (this.op.trainOptions.unknownChineseNumberVector && (CHINESE_NUMBER_PATTERN.matcher(word).matches() || DG_PATTERN.matcher(word).matches())) {
                ++chineseNumberCount;
                unknownChineseNumberVector = unknownChineseNumberVector == null ? new SimpleMatrix(vector) : (SimpleMatrix)unknownChineseNumberVector.plus((SimpleBase)vector);
            }
            if (!this.op.trainOptions.unknownChinesePercentVector || !CHINESE_PERCENT_PATTERN.matcher(word).matches()) continue;
            ++chinesePercentCount;
            if (unknownChinesePercentVector == null) {
                unknownChinesePercentVector = new SimpleMatrix(vector);
                continue;
            }
            unknownChinesePercentVector = (SimpleMatrix)unknownChinesePercentVector.plus((SimpleBase)vector);
        }
        String unkWord = this.op.trainOptions.unkWord;
        if (this.op.wordFunction != null) {
            unkWord = this.op.wordFunction.apply(unkWord);
        }
        SimpleMatrix unknownWordVector = this.wordVectors.get(unkWord);
        this.wordVectors.put(UNKNOWN_WORD, unknownWordVector);
        if (unknownWordVector == null) {
            throw new RuntimeException("Unknown word vector not specified in the word vector file");
        }
        if (this.op.trainOptions.unknownNumberVector) {
            unknownNumberVector = numberCount > 0 ? (SimpleMatrix)unknownNumberVector.divide((double)numberCount) : new SimpleMatrix(unknownWordVector);
            this.wordVectors.put(UNKNOWN_NUMBER, unknownNumberVector);
        }
        if (this.op.trainOptions.unknownCapsVector) {
            unknownCapsVector = capsCount > 0 ? (SimpleMatrix)unknownCapsVector.divide((double)capsCount) : new SimpleMatrix(unknownWordVector);
            this.wordVectors.put(UNKNOWN_CAPS, unknownCapsVector);
        }
        if (this.op.trainOptions.unknownChineseYearVector) {
            log.info("Matched " + chineseYearCount + " chinese year vectors");
            unknownChineseYearVector = chineseYearCount > 0 ? (SimpleMatrix)unknownChineseYearVector.divide((double)chineseYearCount) : new SimpleMatrix(unknownWordVector);
            this.wordVectors.put(UNKNOWN_CHINESE_YEAR, unknownChineseYearVector);
        }
        if (this.op.trainOptions.unknownChineseNumberVector) {
            log.info("Matched " + chineseNumberCount + " chinese number vectors");
            unknownChineseNumberVector = chineseNumberCount > 0 ? (SimpleMatrix)unknownChineseNumberVector.divide((double)chineseNumberCount) : new SimpleMatrix(unknownWordVector);
            this.wordVectors.put(UNKNOWN_CHINESE_NUMBER, unknownChineseNumberVector);
        }
        if (this.op.trainOptions.unknownChinesePercentVector) {
            log.info("Matched " + chinesePercentCount + " chinese percent vectors");
            unknownChinesePercentVector = chinesePercentCount > 0 ? (SimpleMatrix)unknownChinesePercentVector.divide((double)chinesePercentCount) : new SimpleMatrix(unknownWordVector);
            this.wordVectors.put(UNKNOWN_CHINESE_PERCENT, unknownChinesePercentVector);
        }
        if (this.op.trainOptions.useContextWords) {
            SimpleMatrix start = SimpleMatrix.random((int)this.op.lexOptions.numHid, (int)1, (double)-0.5, (double)0.5, (Random)this.rand);
            SimpleMatrix end = SimpleMatrix.random((int)this.op.lexOptions.numHid, (int)1, (double)-0.5, (double)0.5, (Random)this.rand);
            this.wordVectors.put(START_WORD, start);
            this.wordVectors.put(END_WORD, end);
        }
    }

    public int totalParamSize() {
        int totalSize = 0;
        totalSize += this.numBinaryMatrices * (this.binaryTransformSize + this.binaryScoreSize);
        totalSize += this.numUnaryMatrices * (this.unaryTransformSize + this.unaryScoreSize);
        if (this.op.trainOptions.trainWordVectors) {
            totalSize += this.wordVectors.size() * this.op.lexOptions.numHid;
        }
        return totalSize;
    }

    public double[] paramsToVector(double scale) {
        int totalSize = this.totalParamSize();
        if (this.op.trainOptions.trainWordVectors) {
            return NeuralUtils.paramsToVector(scale, totalSize, this.binaryTransform.valueIterator(), this.unaryTransform.values().iterator(), this.binaryScore.valueIterator(), this.unaryScore.values().iterator(), this.wordVectors.values().iterator());
        }
        return NeuralUtils.paramsToVector(scale, totalSize, this.binaryTransform.valueIterator(), this.unaryTransform.values().iterator(), this.binaryScore.valueIterator(), this.unaryScore.values().iterator());
    }

    public double[] paramsToVector() {
        int totalSize = this.totalParamSize();
        if (this.op.trainOptions.trainWordVectors) {
            return NeuralUtils.paramsToVector(totalSize, this.binaryTransform.valueIterator(), this.unaryTransform.values().iterator(), this.binaryScore.valueIterator(), this.unaryScore.values().iterator(), this.wordVectors.values().iterator());
        }
        return NeuralUtils.paramsToVector(totalSize, this.binaryTransform.valueIterator(), this.unaryTransform.values().iterator(), this.binaryScore.valueIterator(), this.unaryScore.values().iterator());
    }

    public void vectorToParams(double[] theta) {
        if (this.op.trainOptions.trainWordVectors) {
            NeuralUtils.vectorToParams(theta, this.binaryTransform.valueIterator(), this.unaryTransform.values().iterator(), this.binaryScore.valueIterator(), this.unaryScore.values().iterator(), this.wordVectors.values().iterator());
        } else {
            NeuralUtils.vectorToParams(theta, this.binaryTransform.valueIterator(), this.unaryTransform.values().iterator(), this.binaryScore.valueIterator(), this.unaryScore.values().iterator());
        }
    }

    public SimpleMatrix getWForNode(Tree node) {
        if (node.children().length == 1) {
            String childLabel = node.children()[0].value();
            String childBasic = this.basicCategory(childLabel);
            return this.unaryTransform.get(childBasic);
        }
        if (node.children().length == 2) {
            String leftLabel = node.children()[0].value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = node.children()[1].value();
            String rightBasic = this.basicCategory(rightLabel);
            return this.binaryTransform.get(leftBasic, rightBasic);
        }
        throw new AssertionError((Object)"Should only have unary or binary nodes");
    }

    public SimpleMatrix getScoreWForNode(Tree node) {
        if (node.children().length == 1) {
            String childLabel = node.children()[0].value();
            String childBasic = this.basicCategory(childLabel);
            return this.unaryScore.get(childBasic);
        }
        if (node.children().length == 2) {
            String leftLabel = node.children()[0].value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = node.children()[1].value();
            String rightBasic = this.basicCategory(rightLabel);
            return this.binaryScore.get(leftBasic, rightBasic);
        }
        throw new AssertionError((Object)"Should only have unary or binary nodes");
    }

    public SimpleMatrix getStartWordVector() {
        return this.wordVectors.get(START_WORD);
    }

    public SimpleMatrix getEndWordVector() {
        return this.wordVectors.get(END_WORD);
    }

    public SimpleMatrix getWordVector(String word) {
        return this.wordVectors.get(this.getVocabWord(word));
    }

    public String getVocabWord(String word) {
        String lastPiece;
        String wv;
        int index;
        if (this.op.wordFunction != null) {
            word = this.op.wordFunction.apply(word);
        }
        if (this.op.trainOptions.lowercaseWordVectors) {
            word = word.toLowerCase();
        }
        if (this.wordVectors.containsKey(word)) {
            return word;
        }
        if (this.op.trainOptions.unknownNumberVector && NUMBER_PATTERN.matcher(word).matches()) {
            return UNKNOWN_NUMBER;
        }
        if (this.op.trainOptions.unknownCapsVector && CAPS_PATTERN.matcher(word).matches()) {
            return UNKNOWN_CAPS;
        }
        if (this.op.trainOptions.unknownChineseYearVector && CHINESE_YEAR_PATTERN.matcher(word).matches()) {
            return UNKNOWN_CHINESE_YEAR;
        }
        if (this.op.trainOptions.unknownChineseNumberVector && CHINESE_NUMBER_PATTERN.matcher(word).matches()) {
            return UNKNOWN_CHINESE_NUMBER;
        }
        if (this.op.trainOptions.unknownChinesePercentVector && CHINESE_PERCENT_PATTERN.matcher(word).matches()) {
            return UNKNOWN_CHINESE_PERCENT;
        }
        if (this.op.trainOptions.unknownDashedWordVectors && (index = word.lastIndexOf(45)) >= 0 && index < word.length() && (wv = this.getVocabWord(lastPiece = word.substring(index + 1))) != null) {
            return wv;
        }
        return UNKNOWN_WORD;
    }

    public SimpleMatrix getUnknownWordVector() {
        return this.wordVectors.get(UNKNOWN_WORD);
    }

    public void printMatrixNames(PrintStream out2) {
        out2.println("Binary matrices:");
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryTransform) {
            out2.println("  " + entry.getFirstKey() + ":" + entry.getSecondKey());
        }
        out2.println("Unary matrices:");
        for (String string : this.unaryTransform.keySet()) {
            out2.println("  " + string);
        }
    }

    public void printMatrixStats(PrintStream out2) {
        log.info("Model loaded with " + this.numUnaryMatrices + " unary and " + this.numBinaryMatrices + " binary");
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryTransform) {
            out2.println("Binary transform " + entry.getFirstKey() + ":" + entry.getSecondKey());
            double normf = entry.getValue().normF();
            out2.println("  Total norm " + normf * normf);
            normf = ((SimpleMatrix)entry.getValue().extractMatrix(0, this.op.lexOptions.numHid, 0, this.op.lexOptions.numHid)).normF();
            out2.println("  Left norm (" + entry.getFirstKey() + ") " + normf * normf);
            normf = ((SimpleMatrix)entry.getValue().extractMatrix(0, this.op.lexOptions.numHid, this.op.lexOptions.numHid, this.op.lexOptions.numHid * 2)).normF();
            out2.println("  Right norm (" + entry.getSecondKey() + ") " + normf * normf);
        }
    }

    public void printAllMatrices(PrintStream out2) {
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryTransform) {
            out2.println("Binary transform " + entry.getFirstKey() + ":" + entry.getSecondKey());
            out2.println(entry.getValue());
        }
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryScore) {
            out2.println("Binary score " + entry.getFirstKey() + ":" + entry.getSecondKey());
            out2.println(entry.getValue());
        }
        for (Map.Entry entry : this.unaryTransform.entrySet()) {
            out2.println("Unary transform " + (String)entry.getKey());
            out2.println(entry.getValue());
        }
        for (Map.Entry entry : this.unaryScore.entrySet()) {
            out2.println("Unary score " + (String)entry.getKey());
            out2.println(entry.getValue());
        }
    }

    public int binaryTransformIndex(String leftChild, String rightChild) {
        int pos = 0;
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryTransform) {
            if (entry.getFirstKey().equals(leftChild) && entry.getSecondKey().equals(rightChild)) {
                return pos;
            }
            pos += entry.getValue().getNumElements();
        }
        return -1;
    }

    public int unaryTransformIndex(String child) {
        int pos = this.binaryTransformSize * this.numBinaryMatrices;
        for (Map.Entry<String, SimpleMatrix> unary : this.unaryTransform.entrySet()) {
            if (unary.getKey().equals(child)) {
                return pos;
            }
            pos += unary.getValue().getNumElements();
        }
        return -1;
    }

    public int binaryScoreIndex(String leftChild, String rightChild) {
        int pos = this.binaryTransformSize * this.numBinaryMatrices + this.unaryTransformSize * this.numUnaryMatrices;
        for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryScore) {
            if (entry.getFirstKey().equals(leftChild) && entry.getSecondKey().equals(rightChild)) {
                return pos;
            }
            pos += entry.getValue().getNumElements();
        }
        return -1;
    }

    public int unaryScoreIndex(String child) {
        int pos = (this.binaryTransformSize + this.binaryScoreSize) * this.numBinaryMatrices + this.unaryTransformSize * this.numUnaryMatrices;
        for (Map.Entry<String, SimpleMatrix> unary : this.unaryScore.entrySet()) {
            if (unary.getKey().equals(child)) {
                return pos;
            }
            pos += unary.getValue().getNumElements();
        }
        return -1;
    }

    public Pair<String, String> indexToBinaryTransform(int pos) {
        if (pos < this.numBinaryMatrices * this.binaryTransformSize) {
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryTransform) {
                if (this.binaryTransformSize < pos) {
                    pos -= this.binaryTransformSize;
                    continue;
                }
                return Pair.makePair(entry.getFirstKey(), entry.getSecondKey());
            }
        }
        return null;
    }

    public String indexToUnaryTransform(int pos) {
        if ((pos -= this.numBinaryMatrices * this.binaryTransformSize) < this.numUnaryMatrices * this.unaryTransformSize && pos >= 0) {
            for (Map.Entry<String, SimpleMatrix> entry : this.unaryTransform.entrySet()) {
                if (this.unaryTransformSize < pos) {
                    pos -= this.unaryTransformSize;
                    continue;
                }
                return entry.getKey();
            }
        }
        return null;
    }

    public Pair<String, String> indexToBinaryScore(int pos) {
        if ((pos -= this.numBinaryMatrices * this.binaryTransformSize + this.numUnaryMatrices * this.unaryTransformSize) < this.numBinaryMatrices * this.binaryScoreSize && pos >= 0) {
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : this.binaryScore) {
                if (this.binaryScoreSize < pos) {
                    pos -= this.binaryScoreSize;
                    continue;
                }
                return Pair.makePair(entry.getFirstKey(), entry.getSecondKey());
            }
        }
        return null;
    }

    public String indexToUnaryScore(int pos) {
        if ((pos -= this.numBinaryMatrices * (this.binaryTransformSize + this.binaryScoreSize) + this.numUnaryMatrices * this.unaryTransformSize) < this.numUnaryMatrices * this.unaryScoreSize && pos >= 0) {
            for (Map.Entry<String, SimpleMatrix> entry : this.unaryScore.entrySet()) {
                if (this.unaryScoreSize < pos) {
                    pos -= this.unaryScoreSize;
                    continue;
                }
                return entry.getKey();
            }
        }
        return null;
    }

    public void printParameterType(int pos, PrintStream out2) {
        int originalPos = pos;
        Pair<String, String> binary = this.indexToBinaryTransform(pos);
        if (binary != null) {
            out2.println("Entry " + originalPos + " is entry " + (pos %= this.binaryTransformSize) + " of binary transform " + binary.first() + ":" + binary.second());
            return;
        }
        String unary = this.indexToUnaryTransform(pos);
        if (unary != null) {
            pos = (pos - this.numBinaryMatrices * this.binaryTransformSize) % this.unaryTransformSize;
            out2.println("Entry " + originalPos + " is entry " + pos + " of unary transform " + unary);
            return;
        }
        binary = this.indexToBinaryScore(pos);
        if (binary != null) {
            pos = (pos - this.numBinaryMatrices * this.binaryTransformSize - this.numUnaryMatrices * this.unaryTransformSize) % this.binaryScoreSize;
            out2.println("Entry " + originalPos + " is entry " + pos + " of binary score " + binary.first() + ":" + binary.second());
            return;
        }
        unary = this.indexToUnaryScore(pos);
        if (unary != null) {
            pos = (pos - this.numBinaryMatrices * (this.binaryTransformSize + this.binaryScoreSize) - this.numUnaryMatrices * this.unaryTransformSize) % this.unaryScoreSize;
            out2.println("Entry " + originalPos + " is entry " + pos + " of unary score " + unary);
            return;
        }
        out2.println("Index " + originalPos + " unknown");
    }
}

