/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.classifier;

import com.google.common.base.Preconditions;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;

public abstract class AbstractVectorClassifier {
    public static final double MIN_LOG_LIKELIHOOD = -100.0;

    public abstract int numCategories();

    public abstract Vector classify(Vector var1);

    public Vector classifyNoLink(Vector features) {
        throw new UnsupportedOperationException(this.getClass().getName() + " doesn't support classification without a link");
    }

    public abstract double classifyScalar(Vector var1);

    public Vector classifyFull(Vector instance) {
        return this.classifyFull(new DenseVector(this.numCategories()), instance);
    }

    public Vector classifyFull(Vector r, Vector instance) {
        r.viewPart(1, this.numCategories() - 1).assign(this.classify(instance));
        r.setQuick(0, 1.0 - r.zSum());
        return r;
    }

    public Matrix classify(Matrix data) {
        DenseMatrix r = new DenseMatrix(data.numRows(), this.numCategories() - 1);
        for (int row = 0; row < data.numRows(); ++row) {
            r.assignRow(row, this.classify(data.viewRow(row)));
        }
        return r;
    }

    public Matrix classifyFull(Matrix data) {
        DenseMatrix r = new DenseMatrix(data.numRows(), this.numCategories());
        for (int row = 0; row < data.numRows(); ++row) {
            this.classifyFull(r.viewRow(row), data.viewRow(row));
        }
        return r;
    }

    public Vector classifyScalar(Matrix data) {
        Preconditions.checkArgument(this.numCategories() == 2, "Can only call classifyScalar with two categories");
        DenseVector r = new DenseVector(data.numRows());
        for (int row = 0; row < data.numRows(); ++row) {
            r.set(row, this.classifyScalar(data.viewRow(row)));
        }
        return r;
    }

    public double logLikelihood(int actual, Vector data) {
        if (this.numCategories() == 2) {
            double p = this.classifyScalar(data);
            if (actual > 0) {
                return Math.max(-100.0, Math.log(p));
            }
            return Math.max(-100.0, Math.log1p(-p));
        }
        Vector p = this.classify(data);
        if (actual > 0) {
            return Math.max(-100.0, Math.log(p.get(actual - 1)));
        }
        return Math.max(-100.0, Math.log1p(-p.zSum()));
    }
}

