/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.classify;

import edu.stanford.nlp.classify.KNNClassifier;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CollectionValuedMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;

public class KNNClassifierFactory<K, V> {
    private int k;
    private boolean weightedVotes;
    private boolean l2NormalizeVectors;

    public KNNClassifierFactory(int k, boolean weightedVotes, boolean l2NormalizeVectors) {
        this.k = k;
        this.weightedVotes = weightedVotes;
        this.l2NormalizeVectors = l2NormalizeVectors;
    }

    public KNNClassifier<K, V> train(Collection<RVFDatum<K, V>> instances) {
        KNNClassifier<K, V> classifier = new KNNClassifier<K, V>(this.k, this.weightedVotes, this.l2NormalizeVectors);
        classifier.addInstances(instances);
        return classifier;
    }

    public KNNClassifier<K, V> train(Collection<Counter<V>> vectors, Map<V, K> labelMap) {
        KNNClassifier classifier = new KNNClassifier(this.k, this.weightedVotes, this.l2NormalizeVectors);
        ArrayList instances = new ArrayList();
        for (Counter<V> vector : vectors) {
            K label = labelMap.get(vector);
            RVFDatum<K, V> datum = this.l2NormalizeVectors ? new RVFDatum<K, V>(Counters.L2Normalize(new ClassicCounter<V>(vector)), label) : new RVFDatum<K, V>(vector, label);
            instances.add(datum);
        }
        classifier.addInstances(instances);
        return classifier;
    }

    public KNNClassifier<K, V> train(CollectionValuedMap<K, Counter<V>> vecBag) {
        KNNClassifier classifier = new KNNClassifier(this.k, this.weightedVotes, this.l2NormalizeVectors);
        ArrayList instances = new ArrayList();
        for (K label : vecBag.keySet()) {
            Iterator iterator = vecBag.get(label).iterator();
            while (iterator.hasNext()) {
                Counter vector = (Counter)iterator.next();
                RVFDatum<K, Object> datum = this.l2NormalizeVectors ? new RVFDatum(Counters.L2Normalize(new ClassicCounter(vector)), label) : new RVFDatum(vector, label);
                instances.add(datum);
            }
        }
        classifier.addInstances(instances);
        return classifier;
    }
}

