/*
 * Decompiled with CFR 0.152.
 */
package jdk.test.lib.thread;

import java.lang.foreign.Arena;
import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.Linker;
import java.lang.foreign.MemoryLayout;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SymbolLookup;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.file.Path;
import jdk.test.lib.thread.VThreadRunner;

public class VThreadPinner {
    private static final Path JAVA_LIBRARY_PATH = Path.of(System.getProperty("java.library.path"), new String[0]);
    private static final Path LIB_PATH = JAVA_LIBRARY_PATH.resolve(System.mapLibraryName("VThreadPinner"));
    private static final MethodHandle INVOKER = VThreadPinner.invoker();
    private static final MemorySegment UPCALL_STUB = VThreadPinner.upcallStub();
    private static final ThreadLocal<TaskRunner> TASK_RUNNER = new ThreadLocal();

    private static void callback() {
        TASK_RUNNER.get().run();
    }

    public static <X extends Throwable> void runPinned(VThreadRunner.ThrowingRunnable<X> task) throws X {
        if (!Thread.currentThread().isVirtual()) {
            VThreadRunner.run(() -> VThreadPinner.runPinned(task));
            return;
        }
        TaskRunner runner = new TaskRunner(task);
        TASK_RUNNER.set(runner);
        try {
            INVOKER.invoke(UPCALL_STUB);
        }
        catch (Throwable e) {
            throw new RuntimeException(e);
        }
        finally {
            TASK_RUNNER.remove();
        }
        Throwable ex = runner.exception();
        if (ex != null) {
            if (ex instanceof RuntimeException) {
                RuntimeException e = (RuntimeException)ex;
                throw e;
            }
            if (ex instanceof Error) {
                Error e = (Error)ex;
                throw e;
            }
            throw ex;
        }
    }

    private static MethodHandle invoker() {
        Linker abi = Linker.nativeLinker();
        try {
            SymbolLookup lib = SymbolLookup.libraryLookup(LIB_PATH, Arena.global());
            MemorySegment symbol = lib.find("call").orElseThrow();
            FunctionDescriptor desc = FunctionDescriptor.ofVoid(ValueLayout.ADDRESS);
            return abi.downcallHandle(symbol, desc, new Linker.Option[0]);
        }
        catch (Throwable e) {
            throw new RuntimeException(e);
        }
    }

    private static MemorySegment upcallStub() {
        Linker abi = Linker.nativeLinker();
        try {
            MethodHandle callback = MethodHandles.lookup().findStatic(VThreadPinner.class, "callback", MethodType.methodType(Void.TYPE));
            return abi.upcallStub(callback, FunctionDescriptor.ofVoid(new MemoryLayout[0]), Arena.global(), new Linker.Option[0]);
        }
        catch (Throwable e) {
            throw new RuntimeException(e);
        }
    }

    private static class TaskRunner
    implements Runnable {
        private final VThreadRunner.ThrowingRunnable<?> task;
        private Throwable throwable;

        TaskRunner(VThreadRunner.ThrowingRunnable<?> task) {
            this.task = task;
        }

        @Override
        public void run() {
            try {
                this.task.run();
            }
            catch (Throwable ex) {
                this.throwable = ex;
            }
        }

        Throwable exception() {
            return this.throwable;
        }
    }
}

