/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.rest;

import com.google.gson.JsonElement;
import com.google.gson.JsonParser;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.get.GetRequest;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.common.xcontent.support.XContentHttpChunk;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.http.HttpChunk;
import org.opensearch.ml.action.execute.TransportExecuteStreamTaskAction;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agui.AGUIInputConverter;
import org.opensearch.ml.common.agui.RunErrorEvent;
import org.opensearch.ml.common.agui.RunStartedEvent;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.StreamingRestChannel;
import org.opensearch.transport.StreamTransportResponseHandler;
import org.opensearch.transport.StreamTransportService;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.client.Client;
import org.opensearch.transport.client.node.NodeClient;
import org.opensearch.transport.stream.StreamTransportResponse;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

public class RestMLExecuteStreamAction
extends BaseRestHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(RestMLExecuteStreamAction.class);
    private static final String ML_EXECUTE_STREAM_ACTION = "ml_execute_stream_action";
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private ClusterService clusterService;
    private MLModelManager mlModelManager;

    public RestMLExecuteStreamAction(MLModelManager mlModelManager, MLFeatureEnabledSetting mlFeatureEnabledSetting, ClusterService clusterService) {
        this.mlModelManager = mlModelManager;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.clusterService = clusterService;
    }

    public String getName() {
        return ML_EXECUTE_STREAM_ACTION;
    }

    public List<RestHandler.Route> routes() {
        return ImmutableList.of((Object)new RestHandler.Route(RestRequest.Method.POST, String.format(Locale.ROOT, "%s/agents/{%s}/_execute/stream", "/_plugins/_ml", "agent_id")));
    }

    public boolean supportsContentStream() {
        return true;
    }

    public boolean supportsStreaming() {
        return true;
    }

    public boolean allowsUnsafeBuffers() {
        return true;
    }

    public BaseRestHandler.RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
        if (RestActionUtils.hasMcpHeaders(request) && !this.mlFeatureEnabledSetting.isMcpHeaderPassthroughEnabled()) {
            throw new IllegalArgumentException("MCP header passthrough is not enabled. To enable, please update the setting: plugins.ml_commons.mcp_header_passthrough_enabled");
        }
        RestActionUtils.putMcpRequestHeaders(request, (Client)client);
        if (!this.mlFeatureEnabledSetting.isStreamEnabled()) {
            throw new IllegalStateException("Streaming is currently disabled. To enable it, update the setting \"plugins.ml_commons.stream_enabled\" to true.");
        }
        String agentId = request.param("agent_id");
        MLAgent agent = this.validateAndGetAgent(agentId, client);
        if (agent.getLlm() != null && agent.getLlm().getModelId() != null && !this.isModelValid(agent.getLlm().getModelId(), request, client)) {
            throw new OpenSearchStatusException("Failed to find model", RestStatus.NOT_FOUND, new Object[0]);
        }
        BaseRestHandler.StreamingRestChannelConsumer consumer = channel -> {
            Supplier supplier = client.threadPool().getThreadContext().newRestorableContext(true);
            Map<String, List<String>> headers = Map.of("Content-Type", List.of("text/event-stream"), "Cache-Control", List.of("no-cache"), "Connection", List.of("keep-alive"));
            channel.prepareResponse(RestStatus.OK, headers);
            Flux.from((Publisher)channel).ofType(HttpChunk.class).collectList().flatMap(arg_0 -> this.lambda$prepareRequest$0((Supplier)supplier, agentId, request, channel, client, arg_0)).doOnNext(arg_0 -> ((StreamingRestChannel)channel).sendChunk(arg_0)).onErrorResume(ex -> {
                log.error("Error occurred", ex);
                try {
                    String errorMessage = ex instanceof IOException ? "Failed to parse request: " + ex.getMessage() : "Error processing request: " + ex.getMessage();
                    HttpChunk errorChunk = this.createHttpChunk("data: {\"error\": \"" + errorMessage.replace("\"", "\\\"") + "\"}\n\n", true);
                    channel.sendChunk(errorChunk);
                }
                catch (Exception e) {
                    log.error("Failed to send error chunk", (Throwable)e);
                }
                return Mono.empty();
            }).subscribe();
        };
        return channel -> {
            if (channel instanceof StreamingRestChannel) {
                consumer.accept((Object)((StreamingRestChannel)channel));
            } else {
                ActionRequestValidationException validationError = new ActionRequestValidationException();
                validationError.addValidationError("Unable to initiate request / response streaming over non-streaming channel");
                channel.sendResponse((RestResponse)new BytesRestResponse(channel, (Exception)validationError));
            }
        };
    }

    @VisibleForTesting
    MLAgent validateAndGetAgent(String agentId, NodeClient client) {
        try {
            CompletableFuture future = new CompletableFuture();
            try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext();){
                client.get(new GetRequest(".plugins-ml-agent", agentId), ActionListener.runBefore((ActionListener)ActionListener.wrap(response -> {
                    if (response.isExists()) {
                        try {
                            XContentParser parser = JsonXContent.jsonXContent.createParser(null, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, response.getSourceAsString());
                            XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                            future.complete(MLAgent.parse((XContentParser)parser));
                        }
                        catch (Exception e) {
                            future.completeExceptionally(e);
                        }
                    } else {
                        future.completeExceptionally((Throwable)new OpenSearchStatusException("Agent not found", RestStatus.NOT_FOUND, new Object[0]));
                    }
                }, future::completeExceptionally), () -> ((ThreadContext.StoredContext)context).restore()));
            }
            return (MLAgent)future.get(5L, TimeUnit.SECONDS);
        }
        catch (Exception e) {
            log.error("Failed to validate agent {}", (Object)agentId, (Object)e);
            throw new OpenSearchStatusException("Failed to find agent with the provided agent id: " + agentId, RestStatus.NOT_FOUND, new Object[0]);
        }
    }

    @VisibleForTesting
    boolean isModelValid(String modelId, RestRequest request, NodeClient client) throws IOException {
        try {
            CompletableFuture future = new CompletableFuture();
            try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext();){
                this.mlModelManager.getModel(modelId, TenantAwareHelper.getTenantID(this.mlFeatureEnabledSetting.isMultiTenancyEnabled(), request), (ActionListener<MLModel>)ActionListener.runBefore((ActionListener)ActionListener.wrap(future::complete, future::completeExceptionally), () -> ((ThreadContext.StoredContext)context).restore()));
            }
            future.get(5L, TimeUnit.SECONDS);
            return true;
        }
        catch (Exception e) {
            log.error("Failed to validate model {}", (Object)e.getMessage());
            return false;
        }
    }

    @VisibleForTesting
    MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesReference content) throws IOException {
        RemoteInferenceInputDataSet inputDataSet;
        MLInput input;
        XContentParser parser = request.getMediaType().xContent().createParser(request.getXContentRegistry(), (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, (InputStream)content.streamInput());
        boolean async = RestActionUtils.isAsync(request);
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
        if (!this.mlFeatureEnabledSetting.isAgentFrameworkEnabled()) {
            throw new IllegalStateException("Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.");
        }
        String tenantId = TenantAwareHelper.getTenantID(this.mlFeatureEnabledSetting.isMultiTenancyEnabled(), request);
        FunctionName functionName = FunctionName.AGENT;
        String requestBodyJson = content.utf8ToString();
        if (AGUIInputConverter.isAGUIInput((String)requestBodyJson)) {
            if (!this.mlFeatureEnabledSetting.isAGUIEnabled()) {
                throw new IllegalStateException(MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE);
            }
            log.debug("AG-UI: Detected AG-UI input format for streaming agent: {}", (Object)agentId);
            input = AGUIInputConverter.convertFromAGUIInput((String)requestBodyJson, (String)agentId, (String)tenantId, (boolean)async);
        } else {
            input = MLInput.parse((XContentParser)parser, (String)functionName.name());
            AgentMLInput agentInput = (AgentMLInput)input;
            agentInput.setAgentId(agentId);
            agentInput.setTenantId(tenantId);
            agentInput.setIsAsync(Boolean.valueOf(async));
        }
        MLInputDataset mLInputDataset = ((AgentMLInput)input).getInputDataset();
        if (mLInputDataset instanceof RemoteInferenceInputDataSet) {
            String memoryConfig;
            inputDataSet = (RemoteInferenceInputDataSet)mLInputDataset;
            if (!this.mlFeatureEnabledSetting.isRemoteAgenticMemoryEnabled() && inputDataSet.getParameters() != null && !Strings.isNullOrEmpty((String)(memoryConfig = (String)inputDataSet.getParameters().get("memory_configuration")))) {
                throw new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_REMOTE_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN, new Object[0]);
            }
        } else {
            throw new IllegalArgumentException("Expected RemoteInferenceInputDataSet for agent execution");
        }
        inputDataSet.getParameters().put("stream", String.valueOf(true));
        return new MLExecuteTaskRequest(functionName, (Input)input);
    }

    private boolean isAGUIAgent(MLExecuteTaskRequest request) {
        Input input = request.getInput();
        if (input instanceof AgentMLInput) {
            AgentMLInput agentInput = (AgentMLInput)input;
            RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)agentInput.getInputDataset();
            return inputDataSet.getParameters().containsKey("agui_thread_id") || inputDataSet.getParameters().containsKey("agui_run_id");
        }
        return false;
    }

    private String extractThreadId(MLExecuteTaskRequest request) {
        if (request.getInput() instanceof AgentMLInput) {
            AgentMLInput agentInput = (AgentMLInput)request.getInput();
            RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)agentInput.getInputDataset();
            String threadId = (String)inputDataSet.getParameters().get("agui_thread_id");
            return threadId != null ? threadId : "thread_" + System.currentTimeMillis();
        }
        return "thread_" + System.currentTimeMillis();
    }

    private String extractRunId(MLExecuteTaskRequest request) {
        if (request.getInput() instanceof AgentMLInput) {
            AgentMLInput agentInput = (AgentMLInput)request.getInput();
            RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet)agentInput.getInputDataset();
            String runId = (String)inputDataSet.getParameters().get("agui_run_id");
            return runId != null ? runId : "run_" + System.currentTimeMillis();
        }
        return "run_" + System.currentTimeMillis();
    }

    private HttpChunk convertToHttpChunk(MLTaskResponse response, boolean isAGUIAgent) throws IOException {
        String memoryId = "";
        String parentInteractionId = "";
        String content = "";
        boolean isLast = false;
        try {
            Map<String, ?> dataMap = this.extractDataMap(response);
            if (dataMap.containsKey("error")) {
                content = (String)dataMap.get("error");
                isLast = true;
            } else {
                memoryId = this.extractTensorResult(response, "memory_id");
                parentInteractionId = this.extractTensorResult(response, "parent_interaction_id");
                content = dataMap.containsKey("content") ? (String)dataMap.get("content") : "";
                isLast = dataMap.containsKey("is_last") && Boolean.TRUE.equals(dataMap.get("is_last"));
            }
        }
        catch (Exception e) {
            log.error("Failed to process response", (Throwable)e);
            content = "Processing failed";
            isLast = true;
        }
        final String finalContent = content;
        final boolean finalIsLast = isLast;
        if (isAGUIAgent) {
            return this.convertToAGUIEvent(content, isLast);
        }
        List<ModelTensor> orderedTensors = List.of(ModelTensor.builder().name("memory_id").result(memoryId).build(), ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(), ModelTensor.builder().name("response").dataAsMap((Map)new LinkedHashMap<String, Object>(){
            {
                this.put("content", finalContent);
                this.put("is_last", finalIsLast);
            }
        }).build());
        ModelTensors tensors = ModelTensors.builder().mlModelTensors(orderedTensors).build();
        ModelTensorOutput tensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build();
        XContentBuilder builder = XContentFactory.jsonBuilder();
        tensorOutput.toXContent(builder, ToXContent.EMPTY_PARAMS);
        String jsonData = builder.toString();
        String sseData = "data: " + jsonData + "\n\n";
        return this.createHttpChunk(sseData, isLast);
    }

    private String extractTensorResult(MLTaskResponse response, String tensorName) {
        ModelTensorOutput output = (ModelTensorOutput)response.getOutput();
        if (output != null && !output.getMlModelOutputs().isEmpty()) {
            ModelTensors tensors = (ModelTensors)output.getMlModelOutputs().get(0);
            for (ModelTensor tensor : tensors.getMlModelTensors()) {
                if (!tensorName.equals(tensor.getName()) || tensor.getResult() == null) continue;
                return tensor.getResult();
            }
        }
        return "";
    }

    private Map<String, ?> extractDataMap(MLTaskResponse response) {
        ModelTensorOutput output = (ModelTensorOutput)response.getOutput();
        if (output != null && !output.getMlModelOutputs().isEmpty()) {
            ModelTensors tensors = (ModelTensors)output.getMlModelOutputs().get(0);
            for (ModelTensor tensor : tensors.getMlModelTensors()) {
                Map dataMap;
                String name = tensor.getName();
                if (!"error".equals(name) && !"llm_response".equals(name) && !"response".equals(name) || (dataMap = tensor.getDataAsMap()) == null) continue;
                return dataMap;
            }
        }
        return Map.of();
    }

    private HttpChunk convertToAGUIEvent(String content, boolean isLast) {
        StringBuilder sseResponse;
        block5: {
            log.debug("RestMLExecuteStreamAction: convertToAGUIEvent() called - contentLength={}, isLast={}", content != null ? Integer.valueOf(content.length()) : "null", (Object)isLast);
            sseResponse = new StringBuilder();
            if (content != null && !content.isEmpty()) {
                log.debug("RestMLExecuteStreamAction: Processing content: '{}'", (Object)content);
                try {
                    if (StringUtils.isJson((String)content)) {
                        JsonElement element = JsonParser.parseString((String)content);
                        sseResponse.append("data: ").append(element).append("\n\n");
                        log.debug("RestMLExecuteStreamAction: Processing json element: '{}'", (Object)element);
                        break block5;
                    }
                    log.warn("Unexpected content received - not valid JSON: {}", (Object)content);
                    RunErrorEvent runErrorEvent = new RunErrorEvent("Unexpected chunk: " + content, null);
                    sseResponse.append("data: ").append(runErrorEvent.toJsonString()).append("\n\n");
                    isLast = true;
                }
                catch (Exception e) {
                    log.error("Failed to process AG-UI events chunk content {}", (Object)content, (Object)e);
                    RunErrorEvent runErrorEvent = new RunErrorEvent("Unexpected error: " + e.getMessage(), null);
                    sseResponse.append("data: ").append(runErrorEvent.toJsonString()).append("\n\n");
                    isLast = true;
                }
            } else {
                log.warn("Received null or empty AG-UI content chunk");
            }
        }
        String finalSse = sseResponse.toString();
        log.debug("RestMLExecuteStreamAction: Returning chunk - length={}", (Object)finalSse.length());
        return this.createHttpChunk(finalSse, isLast);
    }

    @VisibleForTesting
    BytesReference combineChunks(List<HttpChunk> chunks) {
        try {
            ByteArrayOutputStream buffer = new ByteArrayOutputStream();
            for (HttpChunk chunk : chunks) {
                chunk.content().writeTo((OutputStream)buffer);
            }
            return BytesReference.fromByteBuffer((ByteBuffer)ByteBuffer.wrap(buffer.toByteArray()));
        }
        catch (IOException e) {
            log.error("Failed to combine chunks", (Throwable)e);
            throw new OpenSearchStatusException("Failed to combine request chunks", RestStatus.INTERNAL_SERVER_ERROR, (Throwable)e, new Object[0]);
        }
    }

    private HttpChunk createHttpChunk(String sseData, final boolean isLast) {
        final BytesReference bytesRef = BytesReference.fromByteBuffer((ByteBuffer)ByteBuffer.wrap(sseData.getBytes()));
        return new HttpChunk(){

            public void close() {
                if (bytesRef instanceof Releasable) {
                    ((Releasable)bytesRef).close();
                }
            }

            public boolean isLast() {
                return isLast;
            }

            public BytesReference content() {
                return bytesRef;
            }
        };
    }

    private /* synthetic */ Mono lambda$prepareRequest$0(Supplier supplier, String agentId, RestRequest request, final StreamingRestChannel channel, final NodeClient client, List chunks) {
        Mono mono;
        block9: {
            ThreadContext.StoredContext context = (ThreadContext.StoredContext)supplier.get();
            try {
                BytesReference completeContent = this.combineChunks(chunks);
                MLExecuteTaskRequest mlExecuteTaskRequest = this.getRequest(agentId, request, completeContent);
                final boolean isAGUI = this.isAGUIAgent(mlExecuteTaskRequest);
                if (isAGUI) {
                    String threadId = this.extractThreadId(mlExecuteTaskRequest);
                    String runId = this.extractRunId(mlExecuteTaskRequest);
                    RunStartedEvent runStartedEvent = new RunStartedEvent(threadId, runId);
                    HttpChunk startChunk = this.createHttpChunk("data: " + runStartedEvent.toJsonString() + "\n\n", false);
                    channel.sendChunk(startChunk);
                    log.debug("AG-UI: RestMLExecuteStreamAction: Sent RUN_STARTED event - threadId={}, runId={}", (Object)threadId, (Object)runId);
                }
                final CompletableFuture future = new CompletableFuture();
                StreamTransportResponseHandler<MLTaskResponse> handler = new StreamTransportResponseHandler<MLTaskResponse>(){

                    public void handleStreamResponse(StreamTransportResponse<MLTaskResponse> streamResponse) {
                        try {
                            MLTaskResponse response = (MLTaskResponse)streamResponse.nextResponse();
                            if (response != null) {
                                HttpChunk responseChunk = RestMLExecuteStreamAction.this.convertToHttpChunk(response, isAGUI);
                                channel.sendChunk(responseChunk);
                                client.threadPool().executor("opensearch_ml_execute_stream").execute(() -> this.handleStreamResponse(streamResponse));
                            } else {
                                log.info("No more responses, closing stream");
                                future.complete(XContentHttpChunk.last());
                                streamResponse.close();
                            }
                        }
                        catch (Exception e) {
                            future.completeExceptionally(e);
                            log.error("Error in stream handling", (Throwable)e);
                        }
                    }

                    public void handleException(TransportException exp) {
                        future.completeExceptionally((Throwable)exp);
                    }

                    public String executor() {
                        return "opensearch_ml_execute_stream";
                    }

                    public MLTaskResponse read(StreamInput in) throws IOException {
                        return new MLTaskResponse(in);
                    }
                };
                StreamTransportService streamTransportService = TransportExecuteStreamTaskAction.getStreamTransportService();
                streamTransportService.sendRequest(this.clusterService.localNode(), "cluster:admin/opensearch/ml/execute/stream", (TransportRequest)mlExecuteTaskRequest, TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), (TransportResponseHandler)handler);
                mono = Mono.fromCompletionStage(future);
                if (context == null) break block9;
            }
            catch (Throwable throwable) {
                try {
                    if (context != null) {
                        try {
                            context.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception e) {
                    log.error("Failed to parse or process request", (Throwable)e);
                    return Mono.error((Throwable)e);
                }
            }
            context.close();
        }
        return mono;
    }
}

