/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.loglinear.inference;

import edu.stanford.nlp.loglinear.inference.TableFactor;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import edu.stanford.nlp.loglinear.model.NDArrayDoubles;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.function.Supplier;

public class CliqueTree {
    private static final Redwood.RedwoodChannels log = Redwood.channels(CliqueTree.class);
    private GraphicalModel model;
    private ConcatVector weights;
    public static final String VARIABLE_OBSERVED_VALUE = "inference.CliqueTree.VARIABLE_OBSERVED_VALUE";
    private static final boolean CACHE_MESSAGES = true;
    private IdentityHashMap<GraphicalModel.Factor, CachedFactorWithObservations> cachedFactors = new IdentityHashMap();
    private TableFactor[] cachedCliqueList;
    private TableFactor[][] cachedMessages;
    private boolean[][] cachedBackwardPassedMessages;

    public CliqueTree(GraphicalModel model, ConcatVector weights) {
        this.model = model;
        this.weights = weights.deepClone();
    }

    public MarginalResult calculateMarginals() {
        return this.messagePassing(MarginalizationMethod.SUM, true);
    }

    public double[][] calculateMarginalsJustSingletons() {
        MarginalResult result = this.messagePassing(MarginalizationMethod.SUM, false);
        return result.marginals;
    }

    public int[] calculateMAP() {
        double[][] mapMarginals = this.messagePassing((MarginalizationMethod)MarginalizationMethod.MAX, (boolean)false).marginals;
        int[] result = new int[mapMarginals.length];
        for (int i = 0; i < result.length; ++i) {
            if (mapMarginals[i] != null) {
                for (int j = 0; j < mapMarginals[i].length; ++j) {
                    if (!(mapMarginals[i][j] > mapMarginals[i][result[i]])) continue;
                    result[i] = j;
                }
            }
            if (!this.model.getVariableMetaDataByReference(i).containsKey(VARIABLE_OBSERVED_VALUE)) continue;
            result[i] = Integer.parseInt(this.model.getVariableMetaDataByReference(i).get(VARIABLE_OBSERVED_VALUE));
        }
        return result;
    }

    private MarginalResult messagePassing(MarginalizationMethod marginalize, boolean includeJointMarginalsAndPartition) {
        int i;
        int i2;
        int j;
        int i32;
        boolean impossibleObservationMade = false;
        double partitionFunction = 1.0;
        if (includeJointMarginalsAndPartition) {
            block12: for (GraphicalModel.Factor f : this.model.factors) {
                for (int n : f.neigborIndices) {
                    if (!this.model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE)) continue block12;
                }
                int[] assignment = new int[f.neigborIndices.length];
                for (int i4 = 0; i4 < f.neigborIndices.length; ++i4) {
                    assignment[i4] = Integer.parseInt(this.model.getVariableMetaDataByReference(f.neigborIndices[i4]).get(VARIABLE_OBSERVED_VALUE));
                }
                double assignmentValue = ((ConcatVector)((Supplier)f.featuresTable.getAssignmentValue(assignment)).get()).dotProduct(this.weights);
                if (Double.isInfinite(assignmentValue)) {
                    impossibleObservationMade = true;
                    continue;
                }
                partitionFunction *= Math.exp(assignmentValue);
            }
        }
        ArrayList<int[]> cliquesList = new ArrayList<int[]>();
        HashMap<Integer, GraphicalModel.Factor> cliqueToFactor = new HashMap<Integer, GraphicalModel.Factor>();
        int numFactorsCached = 0;
        for (GraphicalModel.Factor f : this.model.factors) {
            boolean allObserved = true;
            int maxVar = 0;
            for (int n : f.neigborIndices) {
                if (!this.model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE)) {
                    allObserved = false;
                }
                if (n <= maxVar) continue;
                maxVar = n;
            }
            if (allObserved) continue;
            Object clique = null;
            if (this.cachedFactors.containsKey(f)) {
                CachedFactorWithObservations obs = this.cachedFactors.get(f);
                boolean allConsistent = true;
                for (i32 = 0; i32 < f.neigborIndices.length; ++i32) {
                    int n = f.neigborIndices[i32];
                    if (this.model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE) && (obs.observations[i32] == -1 || Integer.parseInt(this.model.getVariableMetaDataByReference(n).get(VARIABLE_OBSERVED_VALUE)) != obs.observations[i32])) {
                        allConsistent = false;
                        break;
                    }
                    if (this.model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE) || obs.observations[i32] == -1) continue;
                    allConsistent = false;
                    break;
                }
                if (allConsistent) {
                    clique = obs.cachedFactor;
                    ++numFactorsCached;
                    if (obs.impossibleObservation) {
                        impossibleObservationMade = true;
                    }
                }
            }
            if (clique == null) {
                int[] observations = new int[f.neigborIndices.length];
                for (int i5 = 0; i5 < observations.length; ++i5) {
                    int value;
                    Map<String, String> metadata = this.model.getVariableMetaDataByReference(f.neigborIndices[i5]);
                    observations[i5] = metadata.containsKey(VARIABLE_OBSERVED_VALUE) ? (value = Integer.parseInt(metadata.get(VARIABLE_OBSERVED_VALUE))) : -1;
                }
                clique = new TableFactor(this.weights, f, observations);
                CachedFactorWithObservations cache = new CachedFactorWithObservations();
                cache.cachedFactor = clique;
                cache.observations = observations;
                int nonZeroValue = 0;
                Iterator<int[]> value = ((NDArrayDoubles)clique).iterator();
                while (value.hasNext()) {
                    int[] assignment = value.next();
                    if (!(((TableFactor)clique).getAssignmentValue(assignment) > 0.0)) continue;
                    nonZeroValue = 1;
                    break;
                }
                if (nonZeroValue == 0) {
                    impossibleObservationMade = true;
                    cache.impossibleObservation = true;
                }
                this.cachedFactors.put(f, cache);
            }
            cliqueToFactor.put(cliquesList.size(), f);
            cliquesList.add((int[])clique);
        }
        TableFactor[] cliques = cliquesList.toArray(new TableFactor[cliquesList.size()]);
        if (impossibleObservationMade) {
            int maxVar = 0;
            for (TableFactor c : cliques) {
                for (int i6 : c.neighborIndices) {
                    if (i6 <= maxVar) continue;
                    maxVar = i6;
                }
            }
            double[][] result = new double[maxVar + 1][];
            for (TableFactor c : cliques) {
                for (i32 = 0; i32 < c.neighborIndices.length; ++i32) {
                    result[c.neighborIndices[i32]] = new double[c.getDimensions()[i32]];
                    for (j = 0; j < result[c.neighborIndices[i32]].length; ++j) {
                        result[c.neighborIndices[i32]][j] = 1.0 / (double)result[c.neighborIndices[i32]].length;
                    }
                }
            }
            IdentityHashMap<GraphicalModel.Factor, TableFactor> jointMarginals = new IdentityHashMap<GraphicalModel.Factor, TableFactor>();
            if (includeJointMarginalsAndPartition) {
                for (GraphicalModel.Factor f : this.model.factors) {
                    TableFactor uniformZero = new TableFactor(f.neigborIndices, f.featuresTable.getDimensions());
                    for (int[] assignment : uniformZero) {
                        uniformZero.setAssignmentValue(assignment, 0.0);
                    }
                    jointMarginals.put(f, uniformZero);
                }
            }
            return new MarginalResult(result, 1.0, jointMarginals);
        }
        int maxVar = 0;
        for (GraphicalModel.Factor fac : this.model.factors) {
            for (int i32 : fac.neigborIndices) {
                if (i32 <= maxVar) continue;
                maxVar = i32;
            }
        }
        TableFactor[][] messages = new TableFactor[cliques.length][cliques.length];
        boolean[][] backwardPassedMessages = new boolean[cliques.length][cliques.length];
        int forceRootForCachedMessagePassing = -1;
        int[] cachedCliquesBackPointers = null;
        if (numFactorsCached == cliques.length - 1 && numFactorsCached > 0) {
            cachedCliquesBackPointers = new int[cliques.length];
            boolean backPointersConsistent = true;
            for (i32 = 0; i32 < cliques.length; ++i32) {
                cachedCliquesBackPointers[i32] = -1;
                for (j = 0; j < this.cachedCliqueList.length; ++j) {
                    if (cliques[i32] != this.cachedCliqueList[j]) continue;
                    cachedCliquesBackPointers[i32] = j;
                    break;
                }
                if (cachedCliquesBackPointers[i32] != -1) continue;
                if (forceRootForCachedMessagePassing != -1) {
                    backPointersConsistent = false;
                    break;
                }
                forceRootForCachedMessagePassing = i32;
            }
            if (!backPointersConsistent) {
                forceRootForCachedMessagePassing = -1;
            }
        }
        boolean[] visited = new boolean[cliques.length];
        int numVisited = 0;
        int[] visitedOrder = new int[cliques.length];
        int[] parent = new int[cliques.length];
        for (int i7 = 0; i7 < parent.length; ++i7) {
            parent[i7] = -1;
        }
        int[] trees = new int[cliques.length];
        int treeIndex = -1;
        boolean[] seenVariable = new boolean[maxVar + 1];
        while (numVisited < cliques.length) {
            ++treeIndex;
            int root = -1;
            if (forceRootForCachedMessagePassing != -1 && !visited[forceRootForCachedMessagePassing]) {
                root = forceRootForCachedMessagePassing;
            } else {
                for (int i8 = 0; i8 < cliques.length; ++i8) {
                    if (visited[i8] || root != -1 && cliques[i8].neighborIndices.length <= cliques[root].neighborIndices.length) continue;
                    root = i8;
                }
            }
            assert (root != -1);
            ArrayDeque<Integer> toVisit = new ArrayDeque<Integer>();
            toVisit.add(root);
            boolean[] toVisitArray = new boolean[cliques.length];
            toVisitArray[root] = true;
            while (toVisit.size() > 0) {
                int cursor = (Integer)toVisit.poll();
                trees[cursor] = treeIndex;
                if (visited[cursor]) {
                    log.info("Visited contains: " + cursor);
                    log.info("Visited: " + Arrays.toString(visited));
                    log.info("To visit: " + toVisit);
                }
                assert (!visited[cursor]);
                visited[cursor] = true;
                visitedOrder[numVisited] = cursor;
                for (int i9 : cliques[cursor].neighborIndices) {
                    seenVariable[i9] = true;
                }
                ++numVisited;
                block36: for (i2 = 0; i2 < cliques.length; ++i2) {
                    if (i2 == cursor || i2 == parent[cursor] || !CliqueTree.domainsOverlap(cliques[cursor], cliques[i2])) continue;
                    block37: for (int child : cliques[i2].neighborIndices) {
                        if (!seenVariable[child]) continue;
                        for (int j2 : cliques[cursor].neighborIndices) {
                            if (j2 == child) continue block37;
                        }
                        continue block36;
                    }
                    if (parent[i2] != -1 || visited[i2]) continue;
                    if (!toVisitArray[i2]) {
                        toVisit.add(i2);
                        toVisitArray[i2] = true;
                        for (int j3 : cliques[i2].neighborIndices) {
                            seenVariable[j3] = true;
                        }
                    }
                    parent[i2] = cursor;
                }
            }
            assert (parent[root] == -1);
        }
        assert (numVisited == cliques.length);
        for (i = numVisited - 1; i >= 0; --i) {
            int k;
            int cursor = visitedOrder[i];
            if (parent[cursor] == -1) continue;
            backwardPassedMessages[cursor][parent[cursor]] = true;
            if (forceRootForCachedMessagePassing != -1 && cachedCliquesBackPointers[cursor] != -1 && cachedCliquesBackPointers[parent[cursor]] != -1 && this.cachedMessages[cachedCliquesBackPointers[cursor]][cachedCliquesBackPointers[parent[cursor]]] != null && this.cachedBackwardPassedMessages[cachedCliquesBackPointers[cursor]][cachedCliquesBackPointers[parent[cursor]]]) {
                messages[cursor][parent[cursor]] = this.cachedMessages[cachedCliquesBackPointers[cursor]][cachedCliquesBackPointers[parent[cursor]]];
                continue;
            }
            TableFactor message = cliques[cursor];
            for (k = 0; k < cliques.length; ++k) {
                if (k == parent[cursor] || messages[k][cursor] == null) continue;
                message = message.multiply(messages[k][cursor]);
            }
            messages[cursor][parent[cursor]] = CliqueTree.marginalizeMessage(message, cliques[parent[cursor]].neighborIndices, marginalize);
            if (forceRootForCachedMessagePassing == -1 || cachedCliquesBackPointers[parent[cursor]] == -1) continue;
            for (k = 0; k < this.cachedCliqueList.length; ++k) {
                this.cachedMessages[cachedCliquesBackPointers[parent[cursor]]][k] = null;
            }
        }
        for (i = 0; i < numVisited; ++i) {
            int cursor = visitedOrder[i];
            for (int j4 = 0; j4 < cliques.length; ++j4) {
                if (parent[j4] != cursor) continue;
                TableFactor message = cliques[cursor];
                for (int k = 0; k < cliques.length; ++k) {
                    if (k == j4 || messages[k][cursor] == null) continue;
                    message = message.multiply(messages[k][cursor]);
                }
                messages[cursor][j4] = CliqueTree.marginalizeMessage(message, cliques[j4].neighborIndices, marginalize);
            }
        }
        this.cachedCliqueList = cliques;
        this.cachedMessages = messages;
        this.cachedBackwardPassedMessages = backwardPassedMessages;
        double[][] marginals = new double[maxVar + 1][];
        for (GraphicalModel.Factor fac : this.model.factors) {
            for (int i10 = 0; i10 < fac.neigborIndices.length; ++i10) {
                int n = fac.neigborIndices[i10];
                if (!this.model.getVariableMetaDataByReference(n).containsKey(VARIABLE_OBSERVED_VALUE)) continue;
                double[] deterministic = new double[fac.featuresTable.getDimensions()[i10]];
                int assignment = Integer.parseInt(this.model.getVariableMetaDataByReference(n).get(VARIABLE_OBSERVED_VALUE));
                if (assignment > deterministic.length) {
                    throw new IllegalStateException("Variable " + n + ": Can't have as assignment (" + assignment + ") that is out of bounds for dimension size (" + deterministic.length + ")");
                }
                deterministic[assignment] = 1.0;
                marginals[n] = deterministic;
            }
        }
        IdentityHashMap<GraphicalModel.Factor, TableFactor> jointMarginals = new IdentityHashMap<GraphicalModel.Factor, TableFactor>();
        if (marginalize == MarginalizationMethod.SUM && includeJointMarginalsAndPartition) {
            boolean[] partitionIncludesTrees = new boolean[treeIndex + 1];
            double[] treePartitionFunctions = new double[treeIndex + 1];
            for (i2 = 0; i2 < cliques.length; ++i2) {
                double valueSum;
                TableFactor convergedClique = cliques[i2];
                for (int j5 = 0; j5 < cliques.length; ++j5) {
                    if (i2 == j5 || messages[j5][i2] == null) continue;
                    convergedClique = convergedClique.multiply(messages[j5][i2]);
                }
                if (!partitionIncludesTrees[trees[i2]]) {
                    partitionIncludesTrees[trees[i2]] = true;
                    treePartitionFunctions[trees[i2]] = convergedClique.valueSum();
                    partitionFunction *= treePartitionFunctions[trees[i2]];
                } else if (CliqueTree.assertsEnabled() && Double.isFinite(valueSum = convergedClique.valueSum()) && Double.isFinite(treePartitionFunctions[trees[i2]])) {
                    if (Math.abs(treePartitionFunctions[trees[i2]] - valueSum) >= 0.001 * treePartitionFunctions[trees[i2]]) {
                        log.info("Different partition functions for tree " + trees[i2] + ": ");
                        log.info("Pre-existing for tree: " + treePartitionFunctions[trees[i2]]);
                        log.info("This clique for tree: " + valueSum);
                    }
                    assert (Math.abs(treePartitionFunctions[trees[i2]] - valueSum) < 0.001 * treePartitionFunctions[trees[i2]]);
                }
                GraphicalModel.Factor f = (GraphicalModel.Factor)cliqueToFactor.get(i2);
                assert (f != null);
                if (!jointMarginals.containsKey(f)) {
                    int[] observedAssignments = this.getObservedAssignments(f);
                    int[] backPointers = new int[observedAssignments.length];
                    int cursor = 0;
                    for (int j6 = 0; j6 < observedAssignments.length; ++j6) {
                        backPointers[j6] = observedAssignments[j6] == -1 ? cursor++ : -1;
                    }
                    double sum = convergedClique.valueSum();
                    TableFactor jointMarginal = new TableFactor(f.neigborIndices, f.featuresTable.getDimensions());
                    Iterator<int[]> fastPassByReferenceIterator = convergedClique.fastPassByReferenceIterator();
                    int[] assignment = fastPassByReferenceIterator.next();
                    while (true) {
                        if (backPointers.length == assignment.length) {
                            jointMarginal.setAssignmentValue(assignment, convergedClique.getAssignmentValue(assignment) / sum);
                        } else {
                            int[] jointAssignment = new int[backPointers.length];
                            for (int j7 = 0; j7 < jointAssignment.length; ++j7) {
                                jointAssignment[j7] = observedAssignments[j7] != -1 ? observedAssignments[j7] : assignment[backPointers[j7]];
                            }
                            jointMarginal.setAssignmentValue(jointAssignment, convergedClique.getAssignmentValue(assignment) / sum);
                        }
                        if (!fastPassByReferenceIterator.hasNext()) break;
                        fastPassByReferenceIterator.next();
                    }
                    jointMarginals.put(f, jointMarginal);
                }
                boolean anyNull = false;
                for (int j8 = 0; j8 < convergedClique.neighborIndices.length; ++j8) {
                    int k = convergedClique.neighborIndices[j8];
                    if (marginals[k] != null) continue;
                    anyNull = true;
                }
                if (!anyNull) continue;
                double[][] cliqueMarginals = null;
                switch (marginalize) {
                    case SUM: {
                        cliqueMarginals = convergedClique.getSummedMarginals();
                        break;
                    }
                    case MAX: {
                        cliqueMarginals = convergedClique.getMaxedMarginals();
                    }
                }
                for (int j9 = 0; j9 < convergedClique.neighborIndices.length; ++j9) {
                    int k = convergedClique.neighborIndices[j9];
                    if (marginals[k] != null) continue;
                    marginals[k] = cliqueMarginals[j9];
                }
            }
        } else {
            int j10;
            int k2;
            int i11;
            for (i11 = 0; i11 < cliques.length; ++i11) {
                boolean allNull = true;
                for (int k2 : cliques[i11].neighborIndices) {
                    if (marginals[k2] == null) continue;
                    allNull = false;
                }
                if (!allNull) continue;
                TableFactor convergedClique = cliques[i11];
                for (int j11 = 0; j11 < cliques.length; ++j11) {
                    if (i11 == j11 || messages[j11][i11] == null) continue;
                    convergedClique = convergedClique.multiply(messages[j11][i11]);
                }
                double[][] cliqueMarginals = null;
                switch (marginalize) {
                    case SUM: {
                        cliqueMarginals = convergedClique.getSummedMarginals();
                        break;
                    }
                    case MAX: {
                        cliqueMarginals = convergedClique.getMaxedMarginals();
                    }
                }
                for (j10 = 0; j10 < convergedClique.neighborIndices.length; ++j10) {
                    k2 = convergedClique.neighborIndices[j10];
                    if (marginals[k2] != null) continue;
                    marginals[k2] = cliqueMarginals[j10];
                }
            }
            for (i11 = 0; i11 < cliques.length; ++i11) {
                boolean anyNull = false;
                for (int j12 = 0; j12 < cliques[i11].neighborIndices.length; ++j12) {
                    int k3 = cliques[i11].neighborIndices[j12];
                    if (marginals[k3] != null) continue;
                    anyNull = true;
                }
                if (!anyNull) continue;
                TableFactor convergedClique = cliques[i11];
                for (int j13 = 0; j13 < cliques.length; ++j13) {
                    if (i11 == j13 || messages[j13][i11] == null) continue;
                    convergedClique = convergedClique.multiply(messages[j13][i11]);
                }
                double[][] cliqueMarginals = null;
                switch (marginalize) {
                    case SUM: {
                        cliqueMarginals = convergedClique.getSummedMarginals();
                        break;
                    }
                    case MAX: {
                        cliqueMarginals = convergedClique.getMaxedMarginals();
                    }
                }
                for (j10 = 0; j10 < convergedClique.neighborIndices.length; ++j10) {
                    k2 = convergedClique.neighborIndices[j10];
                    if (marginals[k2] != null) continue;
                    marginals[k2] = cliqueMarginals[j10];
                }
            }
        }
        if (marginalize == MarginalizationMethod.SUM && includeJointMarginalsAndPartition) {
            for (GraphicalModel.Factor f : this.model.factors) {
                int[] observedAssignment;
                if (jointMarginals.containsKey(f)) continue;
                TableFactor deterministicJointMarginal = new TableFactor(f.neigborIndices, f.featuresTable.getDimensions());
                for (int i12 : observedAssignment = this.getObservedAssignments(f)) {
                    assert (i12 != -1);
                }
                deterministicJointMarginal.setAssignmentValue(observedAssignment, 1.0);
                jointMarginals.put(f, deterministicJointMarginal);
            }
        }
        return new MarginalResult(marginals, partitionFunction, jointMarginals);
    }

    private int[] getObservedAssignments(GraphicalModel.Factor f) {
        int[] observedAssignments = new int[f.neigborIndices.length];
        for (int i = 0; i < observedAssignments.length; ++i) {
            observedAssignments[i] = this.model.getVariableMetaDataByReference(f.neigborIndices[i]).containsKey(VARIABLE_OBSERVED_VALUE) ? Integer.parseInt(this.model.getVariableMetaDataByReference(f.neigborIndices[i]).get(VARIABLE_OBSERVED_VALUE)) : -1;
        }
        return observedAssignments;
    }

    private static TableFactor marginalizeMessage(TableFactor message, int[] relevant, MarginalizationMethod marginalize) {
        TableFactor result = message;
        block4: for (int i : message.neighborIndices) {
            boolean contains = false;
            for (int j : relevant) {
                if (i != j) continue;
                contains = true;
                break;
            }
            if (contains) continue;
            switch (marginalize) {
                case SUM: {
                    result = result.sumOut(i);
                    continue block4;
                }
                case MAX: {
                    result = result.maxOut(i);
                }
            }
        }
        return result;
    }

    private static boolean domainsOverlap(TableFactor f1, TableFactor f2) {
        for (int n1 : f1.neighborIndices) {
            for (int n2 : f2.neighborIndices) {
                if (n1 != n2) continue;
                return true;
            }
        }
        return false;
    }

    private static boolean assertsEnabled() {
        boolean assertsEnabled = false;
        if (!$assertionsDisabled) {
            assertsEnabled = true;
            if (!true) {
                throw new AssertionError();
            }
        }
        return assertsEnabled;
    }

    private static class CachedFactorWithObservations {
        TableFactor cachedFactor;
        int[] observations;
        boolean impossibleObservation;

        private CachedFactorWithObservations() {
        }
    }

    private static enum MarginalizationMethod {
        SUM,
        MAX;

    }

    public static class MarginalResult {
        public double[][] marginals;
        public double partitionFunction;
        public Map<GraphicalModel.Factor, TableFactor> jointMarginals;

        public MarginalResult(double[][] marginals, double partitionFunction, Map<GraphicalModel.Factor, TableFactor> jointMarginals) {
            this.marginals = marginals;
            this.partitionFunction = partitionFunction;
            this.jointMarginals = jointMarginals;
        }
    }
}

