/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.Classifier;
import edu.stanford.nlp.classify.ClassifierFactory;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.ArrayMap;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.Collection;
import java.util.Map;

public class OneVsAllClassifier<L, F>
implements Classifier<L, F> {
    private static final long serialVersionUID = -743792054415242776L;
    private static final String POS_LABEL = "+1";
    private static final String NEG_LABEL = "-1";
    private static final Index<String> binaryIndex = new HashIndex<String>();
    private static final int posIndex;
    private Index<F> featureIndex;
    private Index<L> labelIndex;
    private Map<L, Classifier<String, F>> binaryClassifiers;
    private L defaultLabel;
    private static final Redwood.RedwoodChannels logger;

    public OneVsAllClassifier(Index<F> featureIndex, Index<L> labelIndex) {
        this(featureIndex, labelIndex, Generics.newHashMap(), null);
    }

    public OneVsAllClassifier(Index<F> featureIndex, Index<L> labelIndex, Map<L, Classifier<String, F>> binaryClassifiers) {
        this(featureIndex, labelIndex, binaryClassifiers, null);
    }

    public OneVsAllClassifier(Index<F> featureIndex, Index<L> labelIndex, Map<L, Classifier<String, F>> binaryClassifiers, L defaultLabel) {
        this.featureIndex = featureIndex;
        this.labelIndex = labelIndex;
        this.binaryClassifiers = binaryClassifiers;
        this.defaultLabel = defaultLabel;
    }

    public void addBinaryClassifier(L label, Classifier<String, F> classifier) {
        this.binaryClassifiers.put(label, classifier);
    }

    protected Classifier<String, F> getBinaryClassifier(L label) {
        return this.binaryClassifiers.get(label);
    }

    @Override
    public L classOf(Datum<L, F> example) {
        Counter<L> scores = this.scoresOf(example);
        if (scores != null) {
            return Counters.argmax(scores);
        }
        return this.defaultLabel;
    }

    @Override
    public Counter<L> scoresOf(Datum<L, F> example) {
        ClassicCounter scores = new ClassicCounter();
        for (Object label : this.labelIndex) {
            ArrayMap posLabelMap = new ArrayMap();
            posLabelMap.put(label, POS_LABEL);
            Datum<String, F> binDatum = GeneralDataset.mapDatum(example, posLabelMap, NEG_LABEL);
            Classifier<String, F> binaryClassifier = this.getBinaryClassifier(label);
            Counter<String> binScores = binaryClassifier.scoresOf(binDatum);
            double score = binScores.getCount(POS_LABEL);
            scores.setCount(label, score);
        }
        return scores;
    }

    @Override
    public Collection<L> labels() {
        return this.labelIndex.objectsList();
    }

    public static <L, F> OneVsAllClassifier<L, F> train(ClassifierFactory<String, F, Classifier<String, F>> classifierFactory, GeneralDataset<L, F> dataset) {
        Index<L> labelIndex = dataset.labelIndex();
        return OneVsAllClassifier.train(classifierFactory, dataset, labelIndex.objectsList());
    }

    public static <L, F> OneVsAllClassifier<L, F> train(ClassifierFactory<String, F, Classifier<String, F>> classifierFactory, GeneralDataset<L, F> dataset, Collection<L> trainLabels) {
        Index<L> labelIndex = dataset.labelIndex();
        Index<F> featureIndex = dataset.featureIndex();
        Map<L, Classifier<String, F>> classifiers = Generics.newHashMap();
        for (L label : trainLabels) {
            int i = labelIndex.indexOf(label);
            logger.info("Training " + label + " = " + i + ", posIndex = " + posIndex);
            ArrayMap<L, String> posLabelMap = new ArrayMap<L, String>();
            posLabelMap.put(label, POS_LABEL);
            GeneralDataset<String, F> binaryDataset = dataset.mapDataset(dataset, binaryIndex, posLabelMap, NEG_LABEL);
            Classifier<String, F> binaryClassifier = classifierFactory.trainClassifier(binaryDataset);
            classifiers.put(label, binaryClassifier);
        }
        OneVsAllClassifier<L, F> classifier = new OneVsAllClassifier<L, F>(featureIndex, labelIndex, classifiers);
        return classifier;
    }

    static {
        binaryIndex.add(POS_LABEL);
        binaryIndex.add(NEG_LABEL);
        posIndex = binaryIndex.indexOf(POS_LABEL);
        logger = Redwood.channels(OneVsAllClassifier.class);
    }
}

