/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.loglinear.learning;

import edu.stanford.nlp.loglinear.inference.CliqueTree;
import edu.stanford.nlp.loglinear.learning.AbstractDifferentiableFunction;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.loglinear.model.GraphicalModel;
import java.util.Iterator;
import java.util.function.Supplier;

public class LogLikelihoodDifferentiableFunction
extends AbstractDifferentiableFunction<GraphicalModel> {
    public static final String VARIABLE_TRAINING_VALUE = "learning.LogLikelihoodDifferentiableFunction.VARIABLE_TRAINING_VALUE";

    @Override
    public double getSummaryForInstance(GraphicalModel model, ConcatVector weights, ConcatVector gradient) {
        double logLikelihood = 0.0;
        CliqueTree.MarginalResult result = new CliqueTree(model, weights).calculateMarginals();
        for (GraphicalModel.Factor factor : model.factors) {
            factor.featuresTable.cacheVectors();
        }
        if (Double.isInfinite(logLikelihood -= Math.log(result.partitionFunction))) {
            return 0.0;
        }
        for (GraphicalModel.Factor factor : model.factors) {
            int[] assignment = new int[factor.neigborIndices.length];
            for (int i = 0; i < assignment.length; ++i) {
                int trainingObservation;
                int deterministicValue = LogLikelihoodDifferentiableFunction.getDeterministicAssignment(result.marginals[factor.neigborIndices[i]]);
                assignment[i] = deterministicValue != -1 ? deterministicValue : (trainingObservation = Integer.parseInt(model.getVariableMetaDataByReference(factor.neigborIndices[i]).get(VARIABLE_TRAINING_VALUE)));
            }
            ConcatVector features = (ConcatVector)((Supplier)factor.featuresTable.getAssignmentValue(assignment)).get();
            logLikelihood += features.dotProduct(weights);
            gradient.addVectorInPlace(features, 1.0);
        }
        block3: for (GraphicalModel.Factor factor : model.factors) {
            Iterator<int[]> fastPassByReferenceIterator = factor.featuresTable.fastPassByReferenceIterator();
            int[] assignment = fastPassByReferenceIterator.next();
            while (true) {
                double assignmentProb;
                if ((assignmentProb = result.jointMarginals.get(factor).getAssignmentValue(assignment)) > 0.0) {
                    gradient.addVectorInPlace((ConcatVector)((Supplier)factor.featuresTable.getAssignmentValue(assignment)).get(), -assignmentProb);
                }
                if (!fastPassByReferenceIterator.hasNext()) continue block3;
                fastPassByReferenceIterator.next();
            }
        }
        for (GraphicalModel.Factor factor : model.factors) {
            factor.featuresTable.releaseCache();
        }
        return logLikelihood;
    }

    private static int getDeterministicAssignment(double[] distribution) {
        int assignment = -1;
        for (int i = 0; i < distribution.length; ++i) {
            if (distribution[i] == 1.0) {
                if (assignment == -1) {
                    assignment = i;
                    continue;
                }
                return -1;
            }
            if (distribution[i] == 0.0) continue;
            return -1;
        }
        return assignment;
    }
}

