/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.sequences;

import edu.stanford.nlp.ie.EmpiricalNERPriorBIO;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.sequences.BestSequenceFinder;
import edu.stanford.nlp.sequences.CoolingSchedule;
import edu.stanford.nlp.sequences.SequenceListener;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.Set;

public class SequenceGibbsSampler
implements BestSequenceFinder {
    private static Redwood.RedwoodChannels log = Redwood.channels(SequenceGibbsSampler.class);
    private static Random random = new Random(Integer.MAX_VALUE);
    public static int verbose = 0;
    private List document;
    private int numSamples;
    private int sampleInterval;
    private int speedUpThreshold = -1;
    private SequenceListener listener;
    private static final int RANDOM_SAMPLING = 0;
    private static final int SEQUENTIAL_SAMPLING = 1;
    private static final int CHROMATIC_SAMPLING = 2;
    EmpiricalNERPriorBIO priorEn;
    EmpiricalNERPriorBIO priorCh = null;
    public boolean returnLastFoundSequence = false;
    private int samplingStyle;
    private int chromaticSize;
    private List<List<Integer>> partition;

    public static int[] copy(int[] a) {
        int[] result = new int[a.length];
        System.arraycopy(a, 0, result, 0, a.length);
        return result;
    }

    public static int[] getRandomSequence(SequenceModel model) {
        int[] result = new int[model.length()];
        for (int i = 0; i < result.length; ++i) {
            int[] classes = model.getPossibleValues(i);
            result[i] = classes[random.nextInt(classes.length)];
        }
        return result;
    }

    @Override
    public int[] bestSequence(SequenceModel model) {
        int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
        return this.findBestUsingSampling(model, this.numSamples, this.sampleInterval, initialSequence);
    }

    public int[] findBestUsingSampling(SequenceModel model, int numSamples, int sampleInterval, int[] initialSequence) {
        List<int[]> samples = this.collectSamples(model, numSamples, sampleInterval, initialSequence);
        int[] best = null;
        double bestScore = Double.NEGATIVE_INFINITY;
        for (int[] sample : samples) {
            int[] sequence = sample;
            double score = model.scoreOf(sequence);
            if (!(score > bestScore)) continue;
            best = sequence;
            bestScore = score;
            log.info("found new best (" + bestScore + ")");
            log.info(ArrayMath.toString(best));
        }
        return best;
    }

    public int[] findBestUsingAnnealing(SequenceModel model, CoolingSchedule schedule) {
        int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
        return this.findBestUsingAnnealing(model, schedule, initialSequence);
    }

    public int[] findBestUsingAnnealing(SequenceModel model, CoolingSchedule schedule, int[] initialSequence) {
        if (verbose > 0) {
            log.info("Doing annealing");
        }
        this.listener.setInitialSequence(initialSequence);
        ArrayList<int[]> result = new ArrayList<int[]>();
        int[] sequence = SequenceGibbsSampler.copy(initialSequence);
        int[] best = null;
        double bestScore = Double.NEGATIVE_INFINITY;
        double score = Double.NEGATIVE_INFINITY;
        Set<Integer> positionsChanged = null;
        if (this.speedUpThreshold > 0) {
            positionsChanged = Generics.newHashSet();
        }
        for (int i = 0; i < schedule.numIterations(); ++i) {
            if (Thread.interrupted()) {
                throw new RuntimeInterruptedException();
            }
            double temperature = schedule.getTemperature(i);
            if (this.speedUpThreshold <= 0) {
                score = this.sampleSequenceForward(model, sequence, temperature, null);
            } else if (i < this.speedUpThreshold) {
                score = this.sampleSequenceForward(model, sequence, temperature, null);
                for (int j = 0; j < sequence.length; ++j) {
                    if (sequence[j] == initialSequence[j]) continue;
                    positionsChanged.add(j);
                }
            } else {
                score = this.sampleSequenceForward(model, sequence, temperature, positionsChanged);
            }
            result.add(sequence);
            if (this.returnLastFoundSequence) {
                best = sequence;
            } else if (score > bestScore) {
                best = sequence;
                bestScore = score;
            }
            if (i % 50 == 0 && verbose > 1) {
                log.info("itr " + i + ": " + bestScore + "\t");
            }
            if (verbose <= 0) continue;
            log.info(".");
        }
        if (verbose > 1) {
            log.info(new Object[0]);
            this.printSamples(result, System.err);
        }
        if (verbose > 0) {
            log.info("done.");
        }
        return best;
    }

    public List<int[]> collectSamples(SequenceModel model, int numSamples, int sampleInterval) {
        int[] initialSequence = SequenceGibbsSampler.getRandomSequence(model);
        return this.collectSamples(model, numSamples, sampleInterval, initialSequence);
    }

    public List<int[]> collectSamples(SequenceModel model, int numSamples, int sampleInterval, int[] initialSequence) {
        if (verbose > 0) {
            log.info("Collecting samples");
        }
        this.listener.setInitialSequence(initialSequence);
        ArrayList<int[]> result = new ArrayList<int[]>();
        int[] sequence = initialSequence;
        for (int i = 0; i < numSamples; ++i) {
            sequence = SequenceGibbsSampler.copy(sequence);
            this.sampleSequenceRepeatedly(model, sequence, sampleInterval);
            result.add(sequence);
            if (verbose > 0) {
                log.info(".");
            }
            System.err.flush();
        }
        if (verbose > 1) {
            log.info(new Object[0]);
            this.printSamples(result, System.err);
        }
        if (verbose > 0) {
            log.info("done.");
        }
        return result;
    }

    public double sampleSequenceRepeatedly(SequenceModel model, int[] sequence, int numSamples) {
        sequence = SequenceGibbsSampler.copy(sequence);
        this.listener.setInitialSequence(sequence);
        double returnScore = Double.NEGATIVE_INFINITY;
        for (int iter = 0; iter < numSamples; ++iter) {
            returnScore = this.sampleSequenceForward(model, sequence);
        }
        return returnScore;
    }

    public double sampleSequenceRepeatedly(SequenceModel model, int numSamples) {
        int[] sequence = SequenceGibbsSampler.getRandomSequence(model);
        return this.sampleSequenceRepeatedly(model, sequence, numSamples);
    }

    public double sampleSequenceForward(SequenceModel model, int[] sequence) {
        return this.sampleSequenceForward(model, sequence, 1.0, null);
    }

    public double sampleSequenceForward(final SequenceModel model, final int[] sequence, final double temperature, Set<Integer> onlySampleThesePositions) {
        double returnScore = Double.NEGATIVE_INFINITY;
        if (onlySampleThesePositions != null) {
            for (int pos : onlySampleThesePositions) {
                returnScore = this.samplePosition(model, sequence, pos, temperature);
            }
        } else if (this.samplingStyle == 1) {
            for (int pos = 0; pos < sequence.length; ++pos) {
                returnScore = this.samplePosition(model, sequence, pos, temperature);
            }
        } else if (this.samplingStyle == 0) {
            for (int aSequence : sequence) {
                int pos = random.nextInt(sequence.length);
                returnScore = this.samplePosition(model, sequence, pos, temperature);
            }
        } else if (this.samplingStyle == 2) {
            ArrayList results = new ArrayList();
            for (List<Integer> indieList : this.partition) {
                if (indieList.size() <= this.chromaticSize) {
                    for (int pos : indieList) {
                        Pair<Integer, Double> newPosProb = this.samplePositionHelper(model, sequence, pos, temperature);
                        sequence[pos] = newPosProb.first();
                    }
                    continue;
                }
                MulticoreWrapper<List<Integer>, List<Pair<Integer, Integer>>> wrapper = new MulticoreWrapper<List<Integer>, List<Pair<Integer, Integer>>>(this.chromaticSize, new ThreadsafeProcessor<List<Integer>, List<Pair<Integer, Integer>>>(){

                    @Override
                    public List<Pair<Integer, Integer>> process(List<Integer> posList) {
                        ArrayList<Pair<Integer, Integer>> allPos = new ArrayList<Pair<Integer, Integer>>(posList.size());
                        Pair newPosProb = null;
                        for (int pos : posList) {
                            newPosProb = SequenceGibbsSampler.this.samplePositionHelper(model, sequence, pos, temperature);
                            allPos.add(new Pair<Integer, Integer>(pos, (Integer)newPosProb.first()));
                        }
                        return allPos;
                    }

                    @Override
                    public ThreadsafeProcessor<List<Integer>, List<Pair<Integer, Integer>>> newInstance() {
                        return this;
                    }
                });
                results.clear();
                int interval = Math.max(1, indieList.size() / this.chromaticSize);
                int begin = 0;
                int end = 0;
                int indieListSize = indieList.size();
                while (end < indieListSize) {
                    end = Math.min(begin + interval, indieListSize);
                    wrapper.put(indieList.subList(begin, end));
                    while (wrapper.peek()) {
                        results.addAll(wrapper.poll());
                    }
                    begin += interval;
                }
                wrapper.join();
                while (wrapper.peek()) {
                    results.addAll(wrapper.poll());
                }
                for (Pair posVal : results) {
                    sequence[((Integer)posVal.first()).intValue()] = (Integer)posVal.second();
                }
            }
            returnScore = model.scoreOf(sequence);
        }
        return returnScore;
    }

    public double sampleSequenceBackward(SequenceModel model, int[] sequence) {
        return this.sampleSequenceBackward(model, sequence, 1.0);
    }

    public double sampleSequenceBackward(SequenceModel model, int[] sequence, double temperature) {
        double returnScore = Double.NEGATIVE_INFINITY;
        for (int pos = sequence.length - 1; pos >= 0; --pos) {
            returnScore = this.samplePosition(model, sequence, pos, temperature);
        }
        return returnScore;
    }

    public double samplePosition(SequenceModel model, int[] sequence, int pos) {
        return this.samplePosition(model, sequence, pos, 1.0);
    }

    private Pair<Integer, Double> samplePositionHelper(SequenceModel model, int[] sequence, int pos, double temperature) {
        double[] distribution = model.scoresOf(sequence, pos);
        if (temperature != 1.0) {
            if (temperature == 0.0) {
                int argmax = ArrayMath.argmax(distribution);
                Arrays.fill(distribution, Double.NEGATIVE_INFINITY);
                distribution[argmax] = 0.0;
            } else {
                ArrayMath.multiplyInPlace(distribution, 1.0 / temperature);
            }
        }
        ArrayMath.logNormalize(distribution);
        ArrayMath.expInPlace(distribution);
        int newTag = ArrayMath.sampleFromDistribution(distribution, random);
        double newProb = distribution[newTag];
        return new Pair<Integer, Double>(newTag, newProb);
    }

    public double samplePosition(SequenceModel model, int[] sequence, int pos, double temperature) {
        int newTag;
        int oldTag = sequence[pos];
        Pair<Integer, Double> newPosProb = this.samplePositionHelper(model, sequence, pos, temperature);
        sequence[pos] = newTag = newPosProb.first().intValue();
        this.listener.updateSequenceElement(sequence, pos, oldTag);
        return newPosProb.second();
    }

    public void printSamples(List samples, PrintStream out2) {
        for (int i = 0; i < this.document.size(); ++i) {
            HasWord word = (HasWord)this.document.get(i);
            String s = "null";
            if (word != null) {
                s = word.word();
            }
            out2.print(StringUtils.padOrTrim(s, 10));
            for (Object sample : samples) {
                int[] sequence = (int[])sample;
                out2.print(" " + StringUtils.padLeft(sequence[i], 2));
            }
            out2.println();
        }
    }

    public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, List document, boolean returnLastFoundSequence, int samplingStyle, int chromaticSize, List<List<Integer>> partition, int speedUpThreshold, EmpiricalNERPriorBIO priorEn, EmpiricalNERPriorBIO priorCh) {
        this.numSamples = numSamples;
        this.sampleInterval = sampleInterval;
        this.listener = listener;
        this.document = document;
        this.returnLastFoundSequence = returnLastFoundSequence;
        this.samplingStyle = samplingStyle;
        if (verbose > 0) {
            if (samplingStyle == 0) {
                log.info("Using random sampling");
            } else if (samplingStyle == 2) {
                log.info("Using chromatic sampling with " + chromaticSize + " threads");
            } else if (samplingStyle == 1) {
                log.info("Using sequential sampling");
            }
        }
        this.chromaticSize = chromaticSize;
        this.partition = partition;
        this.speedUpThreshold = speedUpThreshold;
        this.priorEn = priorEn;
        this.priorCh = priorCh;
    }

    public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, List document) {
        this(numSamples, sampleInterval, listener, document, false, 1, 0, null, -1, null, null);
    }

    public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener) {
        this(numSamples, sampleInterval, listener, null);
    }

    public SequenceGibbsSampler(int numSamples, int sampleInterval, SequenceListener listener, int samplingStyle, int chromaticSize, List<List<Integer>> partition, int speedUpThreshold, EmpiricalNERPriorBIO priorEn, EmpiricalNERPriorBIO priorCh) {
        this(numSamples, sampleInterval, listener, null, false, samplingStyle, chromaticSize, partition, speedUpThreshold, priorEn, priorCh);
    }
}

