/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.coref.statistical;

import edu.stanford.nlp.coref.statistical.Clusterer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Pair;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

public class EvalUtils {
    public static double getCombinedF1(double mucWeight, List<List<Integer>> gold, List<Clusterer.Cluster> clusters, Map<Integer, List<Integer>> mentionToGold, Map<Integer, Clusterer.Cluster> mentionToSystem) {
        CombinedEvaluator combined = new CombinedEvaluator(mucWeight);
        combined.update(gold, clusters, mentionToGold, mentionToSystem);
        return combined.getF1();
    }

    public static double f1(double pNum, double pDen, double rNum, double rDen) {
        double p = pNum == 0.0 ? 0.0 : pNum / pDen;
        double r = rNum == 0.0 ? 0.0 : rNum / rDen;
        return p == 0.0 ? 0.0 : 2.0 * p * r / (p + r);
    }

    public static class MUCEvaluator
    extends AbstractEvaluator {
        @Override
        public Pair<Double, Double> getScore(List<List<Integer>> clusters, Map<Integer, List<Integer>> mentionToGold) {
            int tp = 0;
            int predictedPositive = 0;
            for (List<Integer> c : clusters) {
                predictedPositive += c.size() - 1;
                tp += c.size();
                HashSet<List<Integer>> linked = new HashSet<List<Integer>>();
                for (int m : c) {
                    List<Integer> g = mentionToGold.get(m);
                    if (g == null) {
                        --tp;
                        continue;
                    }
                    linked.add(g);
                }
                tp -= linked.size();
            }
            return new Pair<Double, Double>(Double.valueOf(tp), Double.valueOf(predictedPositive));
        }
    }

    public static class B3Evaluator
    extends AbstractEvaluator {
        @Override
        public Pair<Double, Double> getScore(List<List<Integer>> clusters, Map<Integer, List<Integer>> mentionToGold) {
            double num = 0.0;
            int dem = 0;
            for (List<Integer> c : clusters) {
                if (c.size() == 1) continue;
                ClassicCounter<List<Integer>> goldCounts = new ClassicCounter<List<Integer>>();
                double correct = 0.0;
                Iterator<Object> iterator = c.iterator();
                while (iterator.hasNext()) {
                    int n = iterator.next();
                    List<Integer> goldCluster = mentionToGold.get(n);
                    if (goldCluster == null) continue;
                    goldCounts.incrementCount(goldCluster);
                }
                for (Map.Entry entry : goldCounts.entrySet()) {
                    if (((List)entry.getKey()).size() == 1) continue;
                    correct += (Double)entry.getValue() * (Double)entry.getValue();
                }
                num += correct / (double)c.size();
                dem += c.size();
            }
            return new Pair<Double, Double>(num, Double.valueOf(dem));
        }
    }

    public static abstract class AbstractEvaluator
    implements Evaluator {
        public double pNum;
        public double pDen;
        public double rNum;
        public double rDen;

        @Override
        public void update(List<List<Integer>> gold, List<Clusterer.Cluster> clusters, Map<Integer, List<Integer>> mentionToGold, Map<Integer, Clusterer.Cluster> mentionToSystem) {
            List<List<Integer>> clustersAsList = clusters.stream().map(c -> c.mentions).collect(Collectors.toList());
            Map<Integer, List<Integer>> mentionToSystemLists = mentionToSystem.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((Clusterer.Cluster)e.getValue()).mentions));
            Pair<Double, Double> prec = this.getScore(clustersAsList, mentionToGold);
            Pair<Double, Double> rec = this.getScore(gold, mentionToSystemLists);
            this.pNum += ((Double)prec.first).doubleValue();
            this.pDen += ((Double)prec.second).doubleValue();
            this.rNum += ((Double)rec.first).doubleValue();
            this.rDen += ((Double)rec.second).doubleValue();
        }

        @Override
        public double getF1() {
            return EvalUtils.f1(this.pNum, this.pDen, this.rNum, this.rDen);
        }

        public double getRecall() {
            return this.pNum == 0.0 ? 0.0 : this.pNum / this.pDen;
        }

        public double getPrecision() {
            return this.rNum == 0.0 ? 0.0 : this.rNum / this.rDen;
        }

        public abstract Pair<Double, Double> getScore(List<List<Integer>> var1, Map<Integer, List<Integer>> var2);
    }

    public static class CombinedEvaluator
    implements Evaluator {
        private final B3Evaluator b3Evaluator = new B3Evaluator();
        private final MUCEvaluator mucEvaluator = new MUCEvaluator();
        private final double mucWeight;

        public CombinedEvaluator(double mucWeight) {
            this.mucWeight = mucWeight;
        }

        @Override
        public void update(List<List<Integer>> gold, List<Clusterer.Cluster> clusters, Map<Integer, List<Integer>> mentionToGold, Map<Integer, Clusterer.Cluster> mentionToSystem) {
            if (this.mucWeight != 1.0) {
                this.b3Evaluator.update(gold, clusters, mentionToGold, mentionToSystem);
            }
            if (this.mucWeight != 0.0) {
                this.mucEvaluator.update(gold, clusters, mentionToGold, mentionToSystem);
            }
        }

        @Override
        public double getF1() {
            return (this.mucWeight == 0.0 ? 0.0 : this.mucWeight * this.mucEvaluator.getF1()) + (this.mucWeight == 1.0 ? 0.0 : (1.0 - this.mucWeight) * this.b3Evaluator.getF1());
        }
    }

    public static interface Evaluator {
        public void update(List<List<Integer>> var1, List<Clusterer.Cluster> var2, Map<Integer, List<Integer>> var3, Map<Integer, Clusterer.Cluster> var4);

        public double getF1();
    }
}

