/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.searchrelevance.ml;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.searchrelevance.common.MLConstants;
import org.opensearch.searchrelevance.ml.TokenizerUtil;

public class MLInputOutputTransformer {
    @Generated
    private static final Logger log = LogManager.getLogger(MLInputOutputTransformer.class);

    public List<MLInput> createMLInputs(int tokenLimit, String searchText, String reference, Map<String, String> hits) {
        ArrayList<MLInput> mlInputs = new ArrayList<MLInput>();
        HashMap<String, String> currentChunk = new HashMap<String, String>();
        for (Map.Entry<String, String> entry : hits.entrySet()) {
            HashMap<String, String> tempChunk = new HashMap<String, String>(currentChunk);
            tempChunk.put(entry.getKey(), entry.getValue());
            String messages = this.formatMessages(searchText, reference, tempChunk);
            int totalTokens = TokenizerUtil.countTokens(messages);
            if (totalTokens > tokenLimit) {
                if (currentChunk.isEmpty()) {
                    mlInputs.add(this.handleOversizedEntry(entry, searchText, reference, tokenLimit));
                    continue;
                }
                mlInputs.add(this.createMLInput(searchText, reference, currentChunk));
                currentChunk = new HashMap();
                currentChunk.put(entry.getKey(), entry.getValue());
                continue;
            }
            currentChunk.put(entry.getKey(), entry.getValue());
        }
        if (!currentChunk.isEmpty()) {
            mlInputs.add(this.createMLInput(searchText, reference, currentChunk));
        }
        return mlInputs;
    }

    private MLInput handleOversizedEntry(Map.Entry<String, String> entry, String searchText, String reference, int tokenLimit) {
        log.warn("Entry with key {} causes total tokens to exceed limit of {}", (Object)entry.getKey(), (Object)tokenLimit);
        Map<String, String> testChunk = Map.of(entry.getKey(), entry.getValue());
        String testMessages = this.formatMessages(searchText, reference, testChunk);
        int excessTokens = TokenizerUtil.countTokens(testMessages) - tokenLimit;
        int currentTokens = TokenizerUtil.countTokens(entry.getValue());
        String truncatedValue = TokenizerUtil.truncateString(entry.getValue(), Math.max(1, currentTokens - excessTokens));
        Map<String, String> singleEntryChunk = Map.of(entry.getKey(), truncatedValue);
        return this.createMLInput(searchText, reference, singleEntryChunk);
    }

    public MLInput createMLInput(String searchText, String reference, Map<String, String> hits) {
        HashMap<String, String> parameters = new HashMap<String, String>();
        parameters.put("messages", this.formatMessages(searchText, reference, hits));
        return MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)new RemoteInferenceInputDataSet(parameters)).build();
    }

    public String formatMessages(String searchText, String reference, Map<String, String> hits) {
        try {
            String hitsJson = this.buildHitsJson(hits);
            String userContent = this.buildUserContent(searchText, reference, hitsJson);
            return String.format(Locale.ROOT, "[{\"role\":\"system\",\"content\":\"%s\"},{\"role\":\"user\",\"content\":\"%s\"}]", MLConstants.PROMPT_SEARCH_RELEVANCE, MLConstants.escapeJson(userContent));
        }
        catch (IOException e) {
            log.error("Error converting hits to JSON string", (Throwable)e);
            throw new IllegalArgumentException("Failed to process hits", e);
        }
    }

    private String buildHitsJson(Map<String, String> hits) throws IOException {
        try (XContentBuilder builder = XContentFactory.jsonBuilder();){
            builder.startArray();
            for (Map.Entry<String, String> hit : hits.entrySet()) {
                builder.startObject();
                builder.field("id", hit.getKey());
                builder.field("source", hit.getValue());
                builder.endObject();
            }
            builder.endArray();
            String string = builder.toString();
            return string;
        }
    }

    private String buildUserContent(String searchText, String reference, String hitsJson) {
        if (Objects.isNull(reference) || reference.isEmpty()) {
            return String.format(Locale.ROOT, "SearchText - %s; Hits - %s", searchText, hitsJson);
        }
        return String.format(Locale.ROOT, "SearchText: %s; Reference: %s; Hits: %s", searchText, reference, hitsJson);
    }

    public String extractResponseContent(MLOutput mlOutput) {
        if (!(mlOutput instanceof ModelTensorOutput)) {
            throw new IllegalArgumentException("Expected ModelTensorOutput, but got " + mlOutput.getClass().getSimpleName());
        }
        ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlOutput;
        List tensorOutputList = modelTensorOutput.getMlModelOutputs();
        if (CollectionUtils.isEmpty((Collection)tensorOutputList) || CollectionUtils.isEmpty((Collection)((ModelTensors)tensorOutputList.get(0)).getMlModelTensors())) {
            throw new IllegalStateException("Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]");
        }
        ModelTensor tensor = (ModelTensor)((ModelTensors)tensorOutputList.get(0)).getMlModelTensors().get(0);
        Map dataMap = tensor.getDataAsMap();
        Map choices = (Map)((List)dataMap.get("choices")).get(0);
        Map message = (Map)choices.get("message");
        return (String)message.get("content");
    }
}

