/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.loglinear.learning;

import edu.stanford.nlp.loglinear.learning.AbstractBatchOptimizer;
import edu.stanford.nlp.loglinear.model.ConcatVector;
import edu.stanford.nlp.util.logging.Redwood;

public class BacktrackingAdaGradOptimizer
extends AbstractBatchOptimizer {
    private static Redwood.RedwoodChannels log = Redwood.channels(BacktrackingAdaGradOptimizer.class);
    static final double alpha = 0.1;

    @Override
    public boolean updateWeights(ConcatVector weights, ConcatVector gradient, double logLikelihood, AbstractBatchOptimizer.OptimizationState optimizationState, boolean quiet) {
        AdaGradOptimizationState s = (AdaGradOptimizationState)optimizationState;
        double logLikelihoodChange = logLikelihood - s.lastLogLikelihood;
        if (logLikelihoodChange == 0.0) {
            if (!quiet) {
                log.info("\tlogLikelihood improvement = 0: quitting");
            }
            return true;
        }
        if (logLikelihoodChange < 0.0) {
            s.lastDerivative.mapInPlace(d -> d / 2.0);
            weights.addVectorInPlace(s.lastDerivative, -1.0);
            if (!quiet) {
                log.info("\tBACKTRACK...");
            }
            if (s.lastDerivative.dotProduct(s.lastDerivative) < 1.0E-10) {
                if (!quiet) {
                    log.info("\tBacktracking derivative norm " + s.lastDerivative.dotProduct(s.lastDerivative) + " < 1.0e-9: quitting");
                }
                return true;
            }
        } else {
            ConcatVector squared = gradient.deepClone();
            squared.mapInPlace(d -> d * d);
            s.adagradAccumulator.addVectorInPlace(squared, 1.0);
            ConcatVector sqrt = s.adagradAccumulator.deepClone();
            sqrt.mapInPlace(d -> {
                if (d == 0.0) {
                    return 0.1;
                }
                return 0.1 / Math.sqrt(d);
            });
            gradient.elementwiseProductInPlace(sqrt);
            weights.addVectorInPlace(gradient, 1.0);
            s.lastDerivative = gradient;
            s.lastLogLikelihood = logLikelihood;
            if (!quiet) {
                log.info("\tLL: " + logLikelihood);
            }
        }
        return false;
    }

    @Override
    protected AbstractBatchOptimizer.OptimizationState getFreshOptimizationState(ConcatVector initialWeights) {
        return new AdaGradOptimizationState();
    }

    protected class AdaGradOptimizationState
    extends AbstractBatchOptimizer.OptimizationState {
        ConcatVector lastDerivative = new ConcatVector(0);
        ConcatVector adagradAccumulator = new ConcatVector(0);
        double lastLogLikelihood = Double.NEGATIVE_INFINITY;

        protected AdaGradOptimizationState() {
        }
    }
}

