/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.inference.external.amazonbedrock;

import java.util.ArrayDeque;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Strings;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;

class AmazonBedrockStreamingChatProcessor
implements Flow.Processor<ConverseStreamOutput, StreamingChatCompletionResults.Results> {
    private static final Logger logger = LogManager.getLogger(AmazonBedrockStreamingChatProcessor.class);
    private final AtomicReference<Throwable> error = new AtomicReference<Object>(null);
    private final AtomicLong demand = new AtomicLong(0L);
    private final AtomicBoolean isDone = new AtomicBoolean(false);
    private final AtomicBoolean onCompleteCalled = new AtomicBoolean(false);
    private final AtomicBoolean onErrorCalled = new AtomicBoolean(false);
    private final ThreadPool threadPool;
    private volatile Flow.Subscriber<? super StreamingChatCompletionResults.Results> downstream;
    private volatile Flow.Subscription upstream;

    AmazonBedrockStreamingChatProcessor(ThreadPool threadPool) {
        this.threadPool = threadPool;
    }

    @Override
    public void subscribe(Flow.Subscriber<? super StreamingChatCompletionResults.Results> subscriber) {
        if (this.downstream == null) {
            this.downstream = subscriber;
            this.downstream.onSubscribe(new StreamSubscription());
        } else {
            subscriber.onError(new IllegalStateException("Subscriber already set."));
        }
    }

    @Override
    public void onSubscribe(Flow.Subscription subscription) {
        if (this.upstream == null) {
            this.upstream = subscription;
            long currentRequestCount = this.demand.getAndUpdate(i -> 0L);
            if (currentRequestCount > 0L) {
                this.upstream.request(currentRequestCount);
            }
        } else {
            subscription.cancel();
        }
    }

    @Override
    public void onNext(ConverseStreamOutput item) {
        if (item.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA) {
            this.demand.set(0L);
            item.accept(ConverseStreamResponseHandler.Visitor.builder().onContentBlockDelta(this::sendDownstreamOnAnotherThread).build());
        } else {
            this.upstream.request(1L);
        }
    }

    private void sendDownstreamOnAnotherThread(ContentBlockDeltaEvent event) {
        this.runOnUtilityThreadPool(() -> {
            String text = event.delta().text();
            ArrayDeque<StreamingChatCompletionResults.Result> result = new ArrayDeque<StreamingChatCompletionResults.Result>(1);
            result.offer(new StreamingChatCompletionResults.Result(text));
            StreamingChatCompletionResults.Results results = new StreamingChatCompletionResults.Results(result);
            this.downstream.onNext((StreamingChatCompletionResults.Results)results);
        });
    }

    @Override
    public void onError(Throwable amazonBedrockRuntimeException) {
        ExceptionsHelper.maybeDieOnAnotherThread((Throwable)amazonBedrockRuntimeException);
        this.error.set(new ElasticsearchException(Strings.format((String)"AmazonBedrock StreamingChatProcessor failure: [%s]", (Object[])new Object[]{amazonBedrockRuntimeException.getMessage()}), amazonBedrockRuntimeException, new Object[0]));
        if (this.isDone.compareAndSet(false, true) && this.checkAndResetDemand() && this.onErrorCalled.compareAndSet(false, true)) {
            this.runOnUtilityThreadPool(() -> this.downstream.onError(amazonBedrockRuntimeException));
        }
    }

    private boolean checkAndResetDemand() {
        return this.demand.getAndUpdate(i -> 0L) > 0L;
    }

    @Override
    public void onComplete() {
        if (this.isDone.compareAndSet(false, true) && this.checkAndResetDemand() && this.onCompleteCalled.compareAndSet(false, true)) {
            this.downstream.onComplete();
        }
    }

    private void runOnUtilityThreadPool(Runnable runnable) {
        try {
            this.threadPool.executor("inference_utility").execute(runnable);
        }
        catch (Exception e) {
            logger.error(Strings.format((String)"failed to fork [%s] to utility thread pool", (Object[])new Object[]{runnable}), (Throwable)e);
        }
    }

    private class StreamSubscription
    implements Flow.Subscription {
        private StreamSubscription() {
        }

        @Override
        public void request(long n) {
            if (n > 0L) {
                AmazonBedrockStreamingChatProcessor.this.demand.updateAndGet(i -> {
                    long sum = i + n;
                    return sum >= 0L ? sum : Long.MAX_VALUE;
                });
                if (AmazonBedrockStreamingChatProcessor.this.upstream == null) {
                    return;
                }
                if (this.upstreamIsRunning()) {
                    this.requestOnMlThread(n);
                } else if (AmazonBedrockStreamingChatProcessor.this.error.get() != null && AmazonBedrockStreamingChatProcessor.this.onErrorCalled.compareAndSet(false, true)) {
                    AmazonBedrockStreamingChatProcessor.this.downstream.onError(AmazonBedrockStreamingChatProcessor.this.error.get());
                } else if (AmazonBedrockStreamingChatProcessor.this.onCompleteCalled.compareAndSet(false, true)) {
                    AmazonBedrockStreamingChatProcessor.this.downstream.onComplete();
                }
            } else {
                this.cancel();
                AmazonBedrockStreamingChatProcessor.this.downstream.onError(new IllegalStateException("Cannot request a negative number."));
            }
        }

        private boolean upstreamIsRunning() {
            return !AmazonBedrockStreamingChatProcessor.this.isDone.get() && AmazonBedrockStreamingChatProcessor.this.error.get() == null;
        }

        private void requestOnMlThread(long n) {
            String currentThreadPool = EsExecutors.executorName((String)Thread.currentThread().getName());
            if ("inference_utility".equalsIgnoreCase(currentThreadPool)) {
                AmazonBedrockStreamingChatProcessor.this.upstream.request(n);
            } else {
                AmazonBedrockStreamingChatProcessor.this.runOnUtilityThreadPool(() -> AmazonBedrockStreamingChatProcessor.this.upstream.request(n));
            }
        }

        @Override
        public void cancel() {
            if (AmazonBedrockStreamingChatProcessor.this.upstream != null && this.upstreamIsRunning()) {
                AmazonBedrockStreamingChatProcessor.this.upstream.cancel();
            }
        }
    }
}

