/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.dvparser;

import edu.stanford.nlp.parser.dvparser.DVModel;
import edu.stanford.nlp.parser.dvparser.DVParser;
import edu.stanford.nlp.parser.lexparser.LexicalizedParser;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.TwoDimensionalMap;
import edu.stanford.nlp.util.TwoDimensionalSet;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.ejml.simple.SimpleBase;
import org.ejml.simple.SimpleMatrix;

public class AverageDVModels {
    private static Redwood.RedwoodChannels log = Redwood.channels(AverageDVModels.class);

    public static TwoDimensionalSet<String, String> getBinaryMatrixNames(List<TwoDimensionalMap<String, String, SimpleMatrix>> maps) {
        TwoDimensionalSet<String, String> matrixNames = new TwoDimensionalSet<String, String>();
        for (TwoDimensionalMap<String, String, SimpleMatrix> map : maps) {
            for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : map) {
                matrixNames.add(entry.getFirstKey(), entry.getSecondKey());
            }
        }
        return matrixNames;
    }

    public static Set<String> getUnaryMatrixNames(List<Map<String, SimpleMatrix>> maps) {
        Set<String> matrixNames = Generics.newHashSet();
        for (Map<String, SimpleMatrix> map : maps) {
            for (Map.Entry<String, SimpleMatrix> entry : map.entrySet()) {
                matrixNames.add(entry.getKey());
            }
        }
        return matrixNames;
    }

    public static TwoDimensionalMap<String, String, SimpleMatrix> averageBinaryMatrices(List<TwoDimensionalMap<String, String, SimpleMatrix>> maps) {
        TwoDimensionalMap<String, String, SimpleMatrix> averages = TwoDimensionalMap.treeMap();
        for (Pair<String, String> pair : AverageDVModels.getBinaryMatrixNames(maps)) {
            int count = 0;
            SimpleMatrix matrix = null;
            for (TwoDimensionalMap<String, String, SimpleMatrix> map : maps) {
                if (!map.contains(pair.first(), pair.second())) continue;
                SimpleMatrix original = map.get(pair.first(), pair.second());
                ++count;
                if (matrix == null) {
                    matrix = original;
                    continue;
                }
                matrix = (SimpleMatrix)matrix.plus((SimpleBase)original);
            }
            matrix = (SimpleMatrix)matrix.divide((double)count);
            averages.put(pair.first(), pair.second(), matrix);
        }
        return averages;
    }

    public static Map<String, SimpleMatrix> averageUnaryMatrices(List<Map<String, SimpleMatrix>> maps) {
        TreeMap<String, SimpleMatrix> averages = Generics.newTreeMap();
        for (String name : AverageDVModels.getUnaryMatrixNames(maps)) {
            int count = 0;
            SimpleMatrix matrix = null;
            for (Map<String, SimpleMatrix> map : maps) {
                if (!map.containsKey(name)) continue;
                SimpleMatrix original = map.get(name);
                ++count;
                if (matrix == null) {
                    matrix = original;
                    continue;
                }
                matrix = (SimpleMatrix)matrix.plus((SimpleBase)original);
            }
            matrix = (SimpleMatrix)matrix.divide((double)count);
            averages.put(name, matrix);
        }
        return averages;
    }

    public static void main(String[] args) {
        String outputModelFilename = null;
        ArrayList<String> inputModelFilenames = Generics.newArrayList();
        int argIndex = 0;
        while (argIndex < args.length) {
            if (args[argIndex].equalsIgnoreCase("-output")) {
                outputModelFilename = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-input")) {
                ++argIndex;
                while (argIndex < args.length && !args[argIndex].startsWith("-")) {
                    inputModelFilenames.addAll(Arrays.asList(args[argIndex].split(",")));
                    ++argIndex;
                }
                continue;
            }
            throw new RuntimeException("Unknown argument " + args[argIndex]);
        }
        if (outputModelFilename == null) {
            log.info("Need to specify output model name with -output");
            System.exit(2);
        }
        if (inputModelFilenames.size() == 0) {
            log.info("Need to specify input model names with -input");
            System.exit(2);
        }
        log.info("Averaging " + inputModelFilenames);
        log.info("Outputting result to " + outputModelFilename);
        LexicalizedParser lexparser = null;
        ArrayList<DVModel> models = Generics.newArrayList();
        for (String filename : inputModelFilenames) {
            LexicalizedParser parser = LexicalizedParser.loadModel(filename, new String[0]);
            if (lexparser == null) {
                lexparser = parser;
            }
            models.add(DVParser.getModelFromLexicalizedParser(parser));
        }
        List<TwoDimensionalMap<String, String, SimpleMatrix>> binaryTransformMaps = CollectionUtils.transformAsList(models, model -> model.binaryTransform);
        List<TwoDimensionalMap<String, String, SimpleMatrix>> binaryScoreMaps = CollectionUtils.transformAsList(models, model -> model.binaryScore);
        List<Map<String, SimpleMatrix>> unaryTransformMaps = CollectionUtils.transformAsList(models, model -> model.unaryTransform);
        List<Map<String, SimpleMatrix>> unaryScoreMaps = CollectionUtils.transformAsList(models, model -> model.unaryScore);
        List<Map<String, SimpleMatrix>> wordMaps = CollectionUtils.transformAsList(models, model -> model.wordVectors);
        TwoDimensionalMap<String, String, SimpleMatrix> binaryTransformAverages = AverageDVModels.averageBinaryMatrices(binaryTransformMaps);
        TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreAverages = AverageDVModels.averageBinaryMatrices(binaryScoreMaps);
        Map<String, SimpleMatrix> unaryTransformAverages = AverageDVModels.averageUnaryMatrices(unaryTransformMaps);
        Map<String, SimpleMatrix> unaryScoreAverages = AverageDVModels.averageUnaryMatrices(unaryScoreMaps);
        Map<String, SimpleMatrix> wordAverages = AverageDVModels.averageUnaryMatrices(wordMaps);
        DVModel newModel = new DVModel(binaryTransformAverages, unaryTransformAverages, binaryScoreAverages, unaryScoreAverages, wordAverages, lexparser.getOp());
        DVParser newParser = new DVParser(newModel, lexparser);
        newParser.saveModel(outputModelFilename);
    }
}

