/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sentiment;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.neural.NeuralUtils;
import edu.stanford.nlp.neural.SimpleTensor;
import edu.stanford.nlp.sentiment.RNNOptions;
import edu.stanford.nlp.sentiment.SentimentModel;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import org.ejml.simple.SimpleMatrix;

public class ConvertMatlabModel {
    private static Redwood.RedwoodChannels log = Redwood.channels(ConvertMatlabModel.class);

    private ConvertMatlabModel() {
    }

    public static void copyWordVector(Map<String, SimpleMatrix> wordVectors, String source, String target) {
        if (wordVectors.containsKey(target) || !wordVectors.containsKey(source)) {
            return;
        }
        log.info("Using wordVector " + source + " for " + target);
        wordVectors.put(target, new SimpleMatrix(wordVectors.get(source)));
    }

    public static void replaceWordVector(Map<String, SimpleMatrix> wordVectors, String source, String target) {
        if (!wordVectors.containsKey(source)) {
            return;
        }
        wordVectors.put(target, new SimpleMatrix(wordVectors.get(source)));
    }

    public static SimpleMatrix loadMatrix(String binaryName, String textName) throws IOException {
        File matrixFile = new File(binaryName);
        if (matrixFile.exists()) {
            return SimpleMatrix.loadBinary((String)matrixFile.getPath());
        }
        matrixFile = new File(textName);
        if (matrixFile.exists()) {
            return NeuralUtils.loadTextMatrix(matrixFile);
        }
        throw new RuntimeException("Could not find either " + binaryName + " or " + textName);
    }

    public static void main(String[] args) throws IOException {
        String basePath = "/user/socherr/scr/projects/semComp/RNTN/src/params/";
        int numSlices = 25;
        boolean useEscapedParens = false;
        int argIndex = 0;
        while (argIndex < args.length) {
            if (args[argIndex].equalsIgnoreCase("-slices")) {
                numSlices = Integer.parseInt(args[argIndex + 1]);
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-path")) {
                basePath = args[argIndex + 1];
                argIndex += 2;
                continue;
            }
            if (args[argIndex].equalsIgnoreCase("-useEscapedParens")) {
                useEscapedParens = true;
                ++argIndex;
                continue;
            }
            log.info("Unknown argument " + args[argIndex]);
            System.exit(2);
        }
        SimpleMatrix[] slices = new SimpleMatrix[numSlices];
        for (int i = 0; i < numSlices; ++i) {
            slices[i] = ConvertMatlabModel.loadMatrix(basePath + "bin/Wt_" + (i + 1) + ".bin", basePath + "Wt_" + (i + 1) + ".txt");
        }
        SimpleTensor tensor = new SimpleTensor(slices);
        log.info("W tensor size: " + tensor.numRows() + "x" + tensor.numCols() + "x" + tensor.numSlices());
        SimpleMatrix W = ConvertMatlabModel.loadMatrix(basePath + "bin/W.bin", basePath + "W.txt");
        log.info("W matrix size: " + W.numRows() + "x" + W.numCols());
        SimpleMatrix Wcat = ConvertMatlabModel.loadMatrix(basePath + "bin/Wcat.bin", basePath + "Wcat.txt");
        log.info("W cat size: " + Wcat.numRows() + "x" + Wcat.numCols());
        SimpleMatrix combinedWV = ConvertMatlabModel.loadMatrix(basePath + "bin/Wv.bin", basePath + "Wv.txt");
        log.info("Word matrix size: " + combinedWV.numRows() + "x" + combinedWV.numCols());
        File vocabFile = new File(basePath + "vocab_1.txt");
        if (!vocabFile.exists()) {
            vocabFile = new File(basePath + "words.txt");
        }
        ArrayList<String> lines = Generics.newArrayList();
        for (String line : IOUtils.readLines(vocabFile)) {
            lines.add(line.trim());
        }
        log.info("Lines in vocab file: " + lines.size());
        TreeMap<String, SimpleMatrix> wordVectors = Generics.newTreeMap();
        for (int i = 0; i < lines.size() && i < combinedWV.numCols(); ++i) {
            String[] pieces = ((String)lines.get(i)).split(" +");
            if (pieces.length == 0 || pieces.length > 1) continue;
            wordVectors.put(pieces[0], (SimpleMatrix)combinedWV.extractMatrix(0, numSlices, i, i + 1));
            if (!pieces[0].equals("UNK")) continue;
            wordVectors.put("*UNK*", (SimpleMatrix)wordVectors.get("UNK"));
        }
        ConvertMatlabModel.copyWordVector(wordVectors, "&#44", ",");
        ConvertMatlabModel.copyWordVector(wordVectors, ".", ",");
        ConvertMatlabModel.copyWordVector(wordVectors, "&#59", ";");
        ConvertMatlabModel.copyWordVector(wordVectors, ".", ";");
        ConvertMatlabModel.copyWordVector(wordVectors, "&#96&#96", "``");
        ConvertMatlabModel.copyWordVector(wordVectors, "''", "``");
        if (useEscapedParens) {
            ConvertMatlabModel.replaceWordVector(wordVectors, "(", "-LRB-");
            ConvertMatlabModel.replaceWordVector(wordVectors, ")", "-RRB-");
        }
        RNNOptions op = new RNNOptions();
        op.numHid = numSlices;
        op.lowercaseWordVectors = false;
        if (Wcat.numRows() == 2) {
            op.classNames = new String[]{"Negative", "Positive"};
            op.equivalenceClasses = new int[][]{{0}, {1}};
            op.numClasses = 2;
        }
        if (!wordVectors.containsKey("*UNK*")) {
            wordVectors.put("*UNK*", SimpleMatrix.random((int)numSlices, (int)1, (double)-1.0E-5, (double)1.0E-5, (Random)new Random()));
        }
        SentimentModel model = SentimentModel.modelFromMatrices(W, Wcat, tensor, wordVectors, op);
        model.saveSerialized("matlab.ser.gz");
    }
}

