/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types;

import cc.mallet.classify.Classification;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelVector;
import cc.mallet.types.Labeling;
import cc.mallet.types.RankedFeatureVector;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

public class GradientGain
extends RankedFeatureVector {
    private static double[] calcGradientGains(InstanceList ilist, LabelVector[] classifications) {
        int numInstances = ilist.size();
        int numClasses = ilist.getTargetAlphabet().size();
        int numFeatures = ilist.getDataAlphabet().size();
        double[] gradientgains = new double[numFeatures];
        for (int i = 0; i < ilist.size(); ++i) {
            assert (classifications[i].getLabelAlphabet() == ilist.getTargetAlphabet());
            Instance inst = (Instance)ilist.get(i);
            Labeling labeling = inst.getLabeling();
            FeatureVector fv = (FeatureVector)inst.getData();
            double instanceWeight = ilist.getInstanceWeight(i);
            double labelWeightSum = 0.0;
            for (int ll = 0; ll < labeling.numLocations(); ++ll) {
                int li = labeling.indexAtLocation(ll);
                double labelWeight = labeling.value(li);
                labelWeightSum += labelWeight;
                double labelWeightDiff = Math.abs(labelWeight - classifications[i].value(li));
                for (int fl = 0; fl < fv.numLocations(); ++fl) {
                    int fli;
                    int n = fli = fv.indexAtLocation(fl);
                    gradientgains[n] = gradientgains[n] + fv.valueAtLocation(fl) * labelWeightDiff * instanceWeight;
                }
            }
            assert (Math.abs(labelWeightSum - 1.0) < 1.0E-4);
        }
        return gradientgains;
    }

    public GradientGain(InstanceList ilist, LabelVector[] classifications) {
        super(ilist.getDataAlphabet(), GradientGain.calcGradientGains(ilist, classifications));
    }

    private static LabelVector[] getLabelVectorsFromClassifications(Classification[] c) {
        LabelVector[] ret = new LabelVector[c.length];
        for (int i = 0; i < c.length; ++i) {
            ret[i] = c[i].getLabelVector();
        }
        return ret;
    }

    public GradientGain(InstanceList ilist, Classification[] classifications) {
        super(ilist.getDataAlphabet(), GradientGain.calcGradientGains(ilist, GradientGain.getLabelVectorsFromClassifications(classifications)));
    }

    public static class Factory
    implements RankedFeatureVector.Factory {
        LabelVector[] classifications;
        private static final long serialVersionUID = 1L;
        private static final int CURRENT_SERIAL_VERSION = 0;

        public Factory(LabelVector[] classifications) {
            this.classifications = classifications;
        }

        public RankedFeatureVector newRankedFeatureVector(InstanceList ilist) {
            return new GradientGain(ilist, this.classifications);
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.writeInt(0);
            out.writeInt(this.classifications.length);
            for (int i = 0; i < this.classifications.length; ++i) {
                out.writeObject(this.classifications[i]);
            }
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            int version = in.readInt();
            int n = in.readInt();
            this.classifications = new LabelVector[n];
            for (int i = 0; i < n; ++i) {
                this.classifications[i] = (LabelVector)in.readObject();
            }
        }
    }
}

