/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.NaiveBayesClassifierFactory;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.optimization.AbstractCachingDiffFunction;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.IntTriple;
import edu.stanford.nlp.util.IntTuple;
import edu.stanford.nlp.util.IntUni;
import java.util.Arrays;

public class LogConditionalEqConstraintFunction
extends AbstractCachingDiffFunction {
    public static final int NO_PRIOR = 0;
    public static final int QUADRATIC_PRIOR = 1;
    public static final int HUBER_PRIOR = 2;
    public static final int QUARTIC_PRIOR = 3;
    protected int numFeatures = 0;
    protected int numClasses = 0;
    protected int[][] data = null;
    protected int[] labels = null;
    protected int[] numValues = null;
    private int prior;
    private double sigma = 1.0;
    private double epsilon;
    private Index<IntTuple> featureIndex;

    @Override
    public int domainDimension() {
        return this.featureIndex.size();
    }

    int classOf(int index) {
        IntTuple i = this.featureIndex.get(index);
        return i.get(0);
    }

    int featureOf(int index) {
        IntTuple i = this.featureIndex.get(index);
        if (i.length() == 1) {
            return -1;
        }
        return i.get(1);
    }

    protected int indexOf(int c) {
        return this.featureIndex.indexOf(new IntUni(c));
    }

    protected int indexOf(int f, int c, int val) {
        return this.featureIndex.indexOf(new IntTriple(c, f, val));
    }

    protected Index<IntTuple> createIndex() {
        HashIndex<IntTuple> index = new HashIndex<IntTuple>();
        for (int c = 0; c < this.numClasses; ++c) {
            index.add(new IntUni(c));
            for (int f = 0; f < this.numFeatures; ++f) {
                for (int val = 0; val < this.numValues[f]; ++val) {
                    index.add(new IntTriple(c, f, val));
                }
            }
        }
        return index;
    }

    public double[][][] to3D(double[] x1) {
        double[] x = this.normalize(x1);
        double[][][] x2 = new double[this.numClasses][this.numFeatures][];
        for (int c = 0; c < this.numClasses; ++c) {
            for (int f = 0; f < this.numFeatures; ++f) {
                x2[c][f] = new double[this.numValues[f]];
                for (int val = 0; val < this.numValues[f]; ++val) {
                    x2[c][f][val] = x[this.indexOf(f, c, val)];
                }
            }
        }
        return x2;
    }

    public double[] priors(double[] x1) {
        double[] x = this.normalize(x1);
        double[] x2 = new double[this.numClasses];
        for (int c = 0; c < this.numClasses; ++c) {
            x2[c] = x[this.indexOf(c)];
        }
        return x2;
    }

    private double[] normalize(double[] x) {
        int c;
        double[] x1 = new double[x.length];
        LogConditionalEqConstraintFunction.copy(x1, x);
        double[] sums = new double[this.numClasses];
        int c2 = 0;
        while (c2 < this.numClasses) {
            int priorc = this.indexOf(c2);
            int n = c2++;
            sums[n] = sums[n] + x[priorc];
        }
        double total = ArrayMath.logSum(sums);
        for (c = 0; c < this.numClasses; ++c) {
            int priorc;
            int n = priorc = this.indexOf(c);
            x1[n] = x1[n] - total;
        }
        for (c = 0; c < this.numClasses; ++c) {
            for (int f = 0; f < this.numFeatures; ++f) {
                int index;
                int val;
                double[] vals = new double[this.numValues[f]];
                for (val = 0; val < this.numValues[f]; ++val) {
                    index = this.indexOf(f, c, val);
                    vals[val] = x[index];
                }
                total = ArrayMath.logSum(vals);
                for (val = 0; val < this.numValues[f]; ++val) {
                    int n = index = this.indexOf(f, c, val);
                    x1[n] = x1[n] - total;
                }
            }
        }
        return x1;
    }

    @Override
    protected void calculate(double[] x1) {
        block14: {
            double[] x;
            block15: {
                block13: {
                    x = this.normalize(x1);
                    double[] xExp = new double[x.length];
                    for (int i = 0; i < x.length; ++i) {
                        xExp[i] = Math.exp(x[i]);
                    }
                    this.value = 0.0;
                    Arrays.fill(this.derivative, 0.0);
                    double[] sums = new double[this.numClasses];
                    double[] probs = new double[this.numClasses];
                    for (int d = 0; d < this.data.length; ++d) {
                        int priorc;
                        int i;
                        int f;
                        int[] features = this.data[d];
                        Arrays.fill(sums, 0.0);
                        for (int c = 0; c < this.numClasses; ++c) {
                            int priorc2 = this.indexOf(c);
                            int n = c;
                            sums[n] = sums[n] + x[priorc2];
                            for (f = 0; f < features.length; ++f) {
                                i = this.indexOf(f, c, features[f]);
                                int n2 = c;
                                sums[n2] = sums[n2] + x[i];
                            }
                        }
                        double total = ArrayMath.logSum(sums);
                        for (int c = 0; c < this.numClasses; ++c) {
                            int priorc3;
                            probs[c] = Math.exp(sums[c] - total);
                            int n = priorc3 = this.indexOf(c);
                            this.derivative[n] = this.derivative[n] + probs[c];
                            for (int f2 = 0; f2 < features.length; ++f2) {
                                for (int val = 0; val < this.numValues[f2]; ++val) {
                                    int i2 = this.indexOf(f2, c, val);
                                    double thetha = xExp[i2];
                                    int n3 = i2;
                                    this.derivative[n3] = this.derivative[n3] - probs[c] * thetha;
                                    if (this.labels[d] != c) continue;
                                    int n4 = i2;
                                    this.derivative[n4] = this.derivative[n4] + thetha;
                                }
                            }
                        }
                        for (f = 0; f < features.length; ++f) {
                            int n = i = this.indexOf(f, this.labels[d], features[f]);
                            this.derivative[n] = this.derivative[n] - 1.0;
                            for (int c = 0; c < this.numClasses; ++c) {
                                int i1;
                                int n5 = i1 = this.indexOf(f, c, features[f]);
                                this.derivative[n5] = this.derivative[n5] + probs[c];
                            }
                        }
                        this.value -= sums[this.labels[d]] - total;
                        int n = priorc = this.indexOf(this.labels[d]);
                        this.derivative[n] = this.derivative[n] - 1.0;
                    }
                    if (this.prior != 1) break block13;
                    double sigmaSq = this.sigma * this.sigma;
                    int i = 0;
                    while (i < x1.length) {
                        double k = 1.0;
                        double w = x1[i];
                        this.value += k * w * w / 2.0 / sigmaSq;
                        int n = i++;
                        this.derivative[n] = this.derivative[n] + k * w / sigmaSq;
                    }
                    break block14;
                }
                if (this.prior != 2) break block15;
                double sigmaSq = this.sigma * this.sigma;
                for (int i = 0; i < x1.length; ++i) {
                    double w = x1[i];
                    double wabs = Math.abs(w);
                    if (wabs < this.epsilon) {
                        this.value += w * w / 2.0 / this.epsilon / sigmaSq;
                        int n = i;
                        this.derivative[n] = this.derivative[n] + w / this.epsilon / sigmaSq;
                        continue;
                    }
                    this.value += (wabs - this.epsilon / 2.0) / sigmaSq;
                    int n = i;
                    this.derivative[n] = this.derivative[n] + (w < 0.0 ? -1.0 : 1.0) / sigmaSq;
                }
                break block14;
            }
            if (this.prior != 3) break block14;
            double sigmaQu = this.sigma * this.sigma * this.sigma * this.sigma;
            int i = 0;
            while (i < x.length) {
                double k = 1.0;
                double w = x1[i];
                this.value += k * w * w * w * w / 2.0 / sigmaQu;
                int n = i++;
                this.derivative[n] = this.derivative[n] + k * w / sigmaQu;
            }
        }
    }

    public LogConditionalEqConstraintFunction(int numFeatures, int numClasses, int[][] data, int[] labels) {
        this(numFeatures, numClasses, data, labels, 1.0);
    }

    public LogConditionalEqConstraintFunction(int numFeatures, int numClasses, int[][] data, int[] labels, double sigma) {
        this(numFeatures, numClasses, data, labels, 1, sigma, 0.0);
    }

    public LogConditionalEqConstraintFunction(int numFeatures, int numClasses, int[][] data, int[] labels, int prior, double sigma, double epsilon) {
        this.numFeatures = numFeatures;
        this.numClasses = numClasses;
        this.data = data;
        this.labels = labels;
        if (prior < 0 || prior > 3) {
            throw new IllegalArgumentException("Invalid prior: " + prior);
        }
        this.prior = prior;
        this.epsilon = epsilon;
        this.sigma = sigma;
        this.numValues = NaiveBayesClassifierFactory.numberValues(data, numFeatures);
        for (int i = 0; i < this.numValues.length; ++i) {
            System.out.println("numValues " + i + " " + this.numValues[i]);
        }
        this.featureIndex = this.createIndex();
    }

    @Override
    public double[] initial() {
        double[] initial = new double[this.domainDimension()];
        for (int i = 0; i < initial.length; ++i) {
            double r = Math.random();
            initial[i] = r -= 0.5;
        }
        return initial;
    }
}

