/*
 * Decompiled with CFR 0.152.
 */
package org.apache.seatunnel.transform.nlpmodel.llm.remote.custom;

import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.groovy.util.Maps;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.api.table.type.SqlType;
import org.apache.seatunnel.shade.com.fasterxml.jackson.core.type.TypeReference;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.JsonNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ArrayNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.TextNode;
import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting;
import org.apache.seatunnel.transform.nlpmodel.CustomConfigPlaceholder;
import org.apache.seatunnel.transform.nlpmodel.llm.remote.AbstractModel;

public class CustomModel
extends AbstractModel {
    private final CloseableHttpClient client;
    private final String model;
    private final String apiPath;
    private final Map<String, String> header;
    private final Map<String, Object> body;
    private final String parse;

    public CustomModel(SeaTunnelRowType rowType, SqlType outputType, List<String> projectionColumns, String prompt, String model, String apiPath, Map<String, String> header, Map<String, Object> body, String parse) {
        super(rowType, outputType, projectionColumns, prompt);
        this.apiPath = apiPath;
        this.model = model;
        this.header = header;
        this.body = body;
        this.parse = parse;
        this.client = HttpClients.createDefault();
    }

    @Override
    protected List<String> chatWithModel(String promptWithLimit, String rowsJson) throws IOException {
        HttpPost post = new HttpPost(this.apiPath);
        for (Map.Entry<String, String> entry : this.header.entrySet()) {
            post.setHeader(entry.getKey(), entry.getValue());
        }
        post.setEntity(new StringEntity(OBJECT_MAPPER.writeValueAsString((Object)this.createJsonNodeFromData(promptWithLimit, rowsJson)), "UTF-8"));
        CloseableHttpResponse response = this.client.execute(post);
        String responseStr = EntityUtils.toString(response.getEntity());
        if (response.getStatusLine().getStatusCode() != 200) {
            throw new IOException("Failed to get vector from custom, response: " + responseStr);
        }
        try {
            return (List)OBJECT_MAPPER.convertValue(this.parseResponse(responseStr), (TypeReference)new TypeReference<List<String>>(){});
        }
        catch (Exception e) {
            String result = (String)OBJECT_MAPPER.convertValue(this.parseResponse(responseStr), (TypeReference)new TypeReference<String>(){});
            return Collections.singletonList(result);
        }
    }

    @VisibleForTesting
    public Object parseResponse(String responseStr) {
        return JsonPath.parse(responseStr).read(this.parse, new Predicate[0]);
    }

    @VisibleForTesting
    public ObjectNode createJsonNodeFromData(String prompt, String data) throws IOException {
        JsonNode jsonNode = OBJECT_MAPPER.readTree(OBJECT_MAPPER.writeValueAsString(this.body));
        Map<String, String> placeholderValues = Maps.of("input", data, "prompt", prompt, "model", this.model);
        return (ObjectNode)CustomModel.replacePlaceholders(jsonNode, placeholderValues);
    }

    private static JsonNode replacePlaceholders(JsonNode node, Map<String, String> placeholderValues) {
        if (node.isObject()) {
            ObjectNode objectNode = (ObjectNode)node;
            Iterator fields = objectNode.fields();
            while (fields.hasNext()) {
                Map.Entry field = (Map.Entry)fields.next();
                objectNode.set((String)field.getKey(), CustomModel.replacePlaceholders((JsonNode)field.getValue(), placeholderValues));
            }
        } else if (node.isArray()) {
            ArrayNode arrayNode = (ArrayNode)node;
            for (int i = 0; i < arrayNode.size(); ++i) {
                arrayNode.set(i, CustomModel.replacePlaceholders(arrayNode.get(i), placeholderValues));
            }
        } else if (node.isTextual()) {
            String textValue = node.asText();
            for (Map.Entry<String, String> entry : placeholderValues.entrySet()) {
                if (!CustomConfigPlaceholder.findPlaceholder(textValue, entry.getKey()).booleanValue()) continue;
                textValue = CustomConfigPlaceholder.replacePlaceholders(textValue, entry.getKey(), entry.getValue(), null);
            }
            return new TextNode(textValue);
        }
        return node;
    }

    @Override
    public void close() throws IOException {
        if (this.client != null) {
            this.client.close();
        }
    }
}

