/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.ClassifierFactory;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.LogisticUtils;
import edu.stanford.nlp.classify.MultinomialLogisticClassifier;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.classify.ShiftParamsLogisticObjectiveFunction;
import edu.stanford.nlp.optimization.QNMinimizer;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Arrays;

public class ShiftParamsLogisticClassifierFactory<L, F>
implements ClassifierFactory<L, F, MultinomialLogisticClassifier<L, F>> {
    private static final long serialVersionUID = -8977510677251295037L;
    private int[][] data;
    private double[][] dataValues;
    private int[] labels;
    private int numClasses;
    private int numFeatures;
    private LogPrior prior;
    private double lambda;

    public ShiftParamsLogisticClassifierFactory() {
        this(new LogPrior(LogPrior.LogPriorType.NULL), 0.1);
    }

    public ShiftParamsLogisticClassifierFactory(double lambda) {
        this(new LogPrior(LogPrior.LogPriorType.NULL), lambda);
    }

    public ShiftParamsLogisticClassifierFactory(LogPrior prior, double lambda) {
        this.prior = prior;
        this.lambda = lambda;
    }

    @Override
    public MultinomialLogisticClassifier<L, F> trainClassifier(GeneralDataset<L, F> dataset) {
        this.numClasses = dataset.numClasses();
        this.numFeatures = dataset.numFeatures();
        this.data = dataset.getDataArray();
        this.dataValues = dataset instanceof RVFDataset ? dataset.getValuesArray() : LogisticUtils.initializeDataValues(this.data);
        this.augmentFeatureMatrix(this.data, this.dataValues);
        this.labels = dataset.getLabelsArray();
        return new MultinomialLogisticClassifier(this.trainWeights(), dataset.featureIndex, dataset.labelIndex);
    }

    private double[][] trainWeights() {
        QNMinimizer minimizer = new QNMinimizer(15, true);
        minimizer.useOWLQN(true, this.lambda);
        ShiftParamsLogisticObjectiveFunction objective = new ShiftParamsLogisticObjectiveFunction(this.data, this.dataValues, this.convertLabels(this.labels), this.numClasses, this.numFeatures + this.data.length, this.numFeatures, this.prior);
        double[] augmentedThetas = new double[(this.numClasses - 1) * (this.numFeatures + this.data.length)];
        augmentedThetas = minimizer.minimize(objective, 1.0E-4, augmentedThetas);
        int count = 0;
        for (int j = this.numFeatures; j < augmentedThetas.length; ++j) {
            if (augmentedThetas[j] == 0.0) continue;
            ++count;
        }
        Redwood.log("NUM NONZERO PARAMETERS: " + count);
        double[][] thetas = new double[this.numClasses - 1][this.numFeatures];
        LogisticUtils.unflatten(augmentedThetas, thetas);
        return thetas;
    }

    private void augmentFeatureMatrix(int[][] data, double[][] dataValues) {
        for (int i = 0; i < data.length; ++i) {
            int newLength = data[i].length + 1;
            data[i] = Arrays.copyOf(data[i], newLength);
            data[i][newLength - 1] = i + this.numFeatures;
            dataValues[i] = Arrays.copyOf(dataValues[i], newLength);
            dataValues[i][newLength - 1] = 1.0;
        }
    }

    private int[][] convertLabels(int[] labels) {
        int[][] result = new int[labels.length][this.numClasses];
        for (int i = 0; i < labels.length; ++i) {
            result[i][labels[i]] = 1;
        }
        return result;
    }
}

