/*
 * Decompiled with CFR 0.152.
 */
package com.kms.katalon.ai.services;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jdk8.Jdk8Module;
import com.kms.katalon.ai.core.model.config.AwsBedrockConfig;
import com.kms.katalon.ai.core.model.exception.StudioAssistException;
import com.kms.katalon.ai.core.model.exception.StudioAssistLlmApiAuthenticationException;
import com.kms.katalon.ai.core.model.exception.StudioAssistLlmApiClientException;
import com.kms.katalon.ai.core.model.exception.StudioAssistLlmApiCostQuotaExceededException;
import com.kms.katalon.ai.core.model.exception.StudioAssistLlmApiServerContentViolatedException;
import com.kms.katalon.ai.core.model.exception.StudioAssistLlmApiServerException;
import com.kms.katalon.ai.core.model.exception.StudioAssistLlmApiServerNoAnswerException;
import com.kms.katalon.ai.core.model.exception.StudioAssistLlmApiServerTokenExceededException;
import com.kms.katalon.ai.core.model.llm.AssistantMessage;
import com.kms.katalon.ai.core.model.llm.CompletionOptions;
import com.kms.katalon.ai.core.model.llm.LlmMessage;
import com.kms.katalon.ai.core.model.llm.SystemMessage;
import com.kms.katalon.ai.core.model.llm.ToolCall;
import com.kms.katalon.ai.core.model.llm.UserMessage;
import com.kms.katalon.ai.core.services.ILlmService;
import com.kms.katalon.ai.services.aws.AwsBedrockChatCompleteOptions;
import io.modelcontextprotocol.spec.McpSchema;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.core.SdkNumber;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.http.apache.ApacheHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.AccessDeniedException;
import software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeException;
import software.amazon.awssdk.services.bedrockruntime.model.ConflictException;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ConversationRole;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.InternalServerException;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.ModelErrorException;
import software.amazon.awssdk.services.bedrockruntime.model.ModelNotReadyException;
import software.amazon.awssdk.services.bedrockruntime.model.ModelTimeoutException;
import software.amazon.awssdk.services.bedrockruntime.model.ResourceNotFoundException;
import software.amazon.awssdk.services.bedrockruntime.model.ServiceQuotaExceededException;
import software.amazon.awssdk.services.bedrockruntime.model.ServiceUnavailableException;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ThrottlingException;
import software.amazon.awssdk.services.bedrockruntime.model.Tool;
import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ValidationException;

public class AwsBedrockService
implements ILlmService {
    private static final String JSON_MARKDOWN_PREFIX = "```json";
    private static final String JSON_MARKDOWN_SUFFIX = "```";
    private final Logger logger = LoggerFactory.getLogger(AwsBedrockService.class);
    private ObjectMapper objectMapper = new ObjectMapper();
    protected AwsBedrockConfig config;
    private BedrockRuntimeClient bedrockClient;

    public AwsBedrockService(AwsBedrockConfig config) {
        this.config = config;
        this.objectMapper.setSerializationInclusion(JsonInclude.Include.NON_ABSENT);
        this.objectMapper.registerModule((Module)new Jdk8Module());
    }

    public AssistantMessage getChatCompletions(List<LlmMessage> messages, CompletionOptions options) {
        try {
            AwsBedrockChatCompleteOptions chatCompletesOptions = this.buildRequest(messages, options);
            ConverseResponse response = this.executeConverse(chatCompletesOptions);
            return this.handleConverseResponse(response);
        }
        catch (Exception e) {
            throw new StudioAssistLlmApiClientException((Throwable)e);
        }
    }

    private ConverseResponse executeConverse(AwsBedrockChatCompleteOptions options) throws StudioAssistException {
        try {
            ConverseResponse response;
            if (this.bedrockClient == null) {
                this.initializeBedrockClient();
            }
            ConverseRequest.Builder requestBuilder = ConverseRequest.builder().modelId(options.getModelId()).messages(options.getMessages());
            if (CollectionUtils.isNotEmpty(options.getSystemMessages())) {
                requestBuilder.system(options.getSystemMessages());
            }
            if (options.getInferenceConfig() != null) {
                requestBuilder.inferenceConfig(options.getInferenceConfig());
            }
            if (options.getToolConfig() != null) {
                requestBuilder.toolConfig(options.getToolConfig());
            }
            ConverseRequest request = (ConverseRequest)requestBuilder.build();
            ConverseResponse converseResponse = response = this.bedrockClient.converse(request);
            return converseResponse;
        }
        catch (AccessDeniedException e) {
            throw new StudioAssistLlmApiAuthenticationException("Access denied to AWS Bedrock: " + e.getMessage());
        }
        catch (ValidationException e) {
            throw new StudioAssistLlmApiClientException("Invalid request to AWS Bedrock: " + e.getMessage());
        }
        catch (ThrottlingException e) {
            throw new StudioAssistLlmApiClientException("Request throttled by AWS Bedrock: " + e.getMessage());
        }
        catch (ServiceQuotaExceededException e) {
            throw new StudioAssistLlmApiCostQuotaExceededException("Service quota exceeded for AWS Bedrock: " + e.getMessage());
        }
        catch (ModelNotReadyException e) {
            throw new StudioAssistLlmApiClientException("Model not ready in AWS Bedrock: " + e.getMessage());
        }
        catch (ModelErrorException e) {
            throw new StudioAssistLlmApiClientException("Model error in AWS Bedrock: " + e.getMessage());
        }
        catch (ModelTimeoutException e) {
            throw new StudioAssistLlmApiClientException("Model timeout in AWS Bedrock: " + e.getMessage());
        }
        catch (ResourceNotFoundException e) {
            throw new StudioAssistLlmApiClientException("Resource not found in AWS Bedrock: " + e.getMessage());
        }
        catch (InternalServerException e) {
            throw new StudioAssistLlmApiServerException("Internal server error in AWS Bedrock: " + e.getMessage());
        }
        catch (ServiceUnavailableException e) {
            throw new StudioAssistLlmApiServerException("Service unavailable in AWS Bedrock: " + e.getMessage());
        }
        catch (ConflictException e) {
            throw new StudioAssistLlmApiClientException("Conflict in AWS Bedrock request: " + e.getMessage());
        }
        catch (BedrockRuntimeException e) {
            throw new StudioAssistLlmApiClientException("AWS Bedrock runtime error: " + e.getMessage());
        }
        catch (Exception e) {
            throw new StudioAssistLlmApiClientException("Unexpected error in AWS Bedrock: " + e.getMessage());
        }
        finally {
            if (this.bedrockClient != null) {
                this.bedrockClient.close();
            }
        }
    }

    private AwsBedrockChatCompleteOptions buildRequest(List<LlmMessage> messages, CompletionOptions options) {
        List<Message> userMessages = messages.stream().filter(message -> !(message instanceof SystemMessage)).flatMap(message -> this.mapToRequestMessages((LlmMessage)message).stream()).toList();
        List<SystemContentBlock> systemMessages = messages.stream().filter(message -> message instanceof SystemMessage).flatMap(message -> this.mapToSystemRequestMessages((LlmMessage)message).stream()).toList();
        InferenceConfiguration.Builder inferenceConfigBuilder = InferenceConfiguration.builder().maxTokens(Integer.valueOf(this.config.getMaxToken())).temperature(Float.valueOf(0.0f));
        ToolConfiguration toolConfig = null;
        if (CollectionUtils.isNotEmpty((Collection)options.getTools())) {
            ArrayList<Tool> bedrockTools = new ArrayList<Tool>();
            for (McpSchema.Tool tool : options.getTools()) {
                try {
                    Document parametersDocument = this.convertJsonSchemaToDocument(tool.inputSchema());
                    ToolInputSchema inputSchema = (ToolInputSchema)ToolInputSchema.builder().json(parametersDocument).build();
                    ToolSpecification toolSpec = (ToolSpecification)ToolSpecification.builder().name(tool.name()).description(StringUtils.defaultString((String)tool.description())).inputSchema(inputSchema).build();
                    Tool bedrockTool = (Tool)Tool.builder().toolSpec(toolSpec).build();
                    bedrockTools.add(bedrockTool);
                }
                catch (Exception e) {
                    this.logger.warn("Error while building tool specification for tool: " + tool.name(), (Throwable)e);
                }
            }
            if (!bedrockTools.isEmpty()) {
                toolConfig = (ToolConfiguration)ToolConfiguration.builder().tools(bedrockTools).build();
            }
        }
        return new AwsBedrockChatCompleteOptions(this.config.getModel(), userMessages, systemMessages, (InferenceConfiguration)inferenceConfigBuilder.build(), toolConfig);
    }

    private List<Message> mapToRequestMessages(LlmMessage message) {
        if (message instanceof UserMessage) {
            UserMessage userMessage = (UserMessage)message;
            if (StringUtils.isBlank((CharSequence)userMessage.getContent())) {
                return List.of();
            }
            ContentBlock contentBlock = (ContentBlock)ContentBlock.builder().text(userMessage.getContent()).build();
            return List.of((Message)Message.builder().role(ConversationRole.USER).content(new ContentBlock[]{contentBlock}).build());
        }
        if (message instanceof AssistantMessage) {
            AssistantMessage assistantMessage = (AssistantMessage)message;
            if (StringUtils.isNotBlank((CharSequence)assistantMessage.getContent())) {
                ContentBlock contentBlock = (ContentBlock)ContentBlock.builder().text(assistantMessage.getContent()).build();
                return List.of((Message)Message.builder().role(ConversationRole.ASSISTANT).content(new ContentBlock[]{contentBlock}).build());
            }
            ArrayList<Message> toolMessages = new ArrayList<Message>();
            for (ToolCall toolCall : assistantMessage.getToolCalls()) {
                if (toolCall.getInput() != null) {
                    Document inputDocument = this.convertJsonToDocument(toolCall.getInput());
                    ToolUseBlock toolUseBlock = (ToolUseBlock)ToolUseBlock.builder().toolUseId(toolCall.getCallId()).name(toolCall.getName()).input(inputDocument).build();
                    ContentBlock toolUseContent = (ContentBlock)ContentBlock.builder().toolUse(toolUseBlock).build();
                    toolMessages.add((Message)Message.builder().role(ConversationRole.ASSISTANT).content(new ContentBlock[]{toolUseContent}).build());
                }
                if (toolCall.getOutput() == null) continue;
                ToolResultContentBlock resultContent = (ToolResultContentBlock)ToolResultContentBlock.builder().text(toolCall.getOutput()).build();
                ToolResultBlock toolResultBlock = (ToolResultBlock)ToolResultBlock.builder().toolUseId(toolCall.getCallId()).content(new ToolResultContentBlock[]{resultContent}).build();
                ContentBlock toolResultContent = (ContentBlock)ContentBlock.builder().toolResult(toolResultBlock).build();
                toolMessages.add((Message)Message.builder().role(ConversationRole.USER).content(new ContentBlock[]{toolResultContent}).build());
            }
            return toolMessages;
        }
        return List.of();
    }

    private List<SystemContentBlock> mapToSystemRequestMessages(LlmMessage message) {
        if (message instanceof SystemMessage) {
            SystemMessage systemMessage = (SystemMessage)message;
            if (StringUtils.isBlank((CharSequence)systemMessage.getContent())) {
                return List.of();
            }
            SystemContentBlock contentBlock = (SystemContentBlock)SystemContentBlock.builder().text(systemMessage.getContent()).build();
            return List.of(contentBlock);
        }
        return List.of();
    }

    private AssistantMessage handleConverseResponse(ConverseResponse response) {
        StopReason stopReason;
        if (response == null || response.output() == null) {
            throw new StudioAssistLlmApiServerNoAnswerException("Empty response from AWS Bedrock");
        }
        StringBuilder result = new StringBuilder();
        boolean hasToolCall = false;
        ArrayList<ToolUseBlock> toolUses = new ArrayList<ToolUseBlock>();
        if (response.output().message() != null && CollectionUtils.isNotEmpty((Collection)response.output().message().content())) {
            for (ContentBlock contentBlock : response.output().message().content()) {
                if (contentBlock.text() != null) {
                    result.append(contentBlock.text());
                    continue;
                }
                if (contentBlock.toolUse() == null) continue;
                hasToolCall = true;
                toolUses.add(contentBlock.toolUse());
            }
        }
        if (StopReason.MAX_TOKENS.equals((Object)(stopReason = response.stopReason()))) {
            throw new StudioAssistLlmApiServerTokenExceededException("Token limit exceeded in AWS Bedrock");
        }
        if (StopReason.CONTENT_FILTERED.equals((Object)stopReason)) {
            throw new StudioAssistLlmApiServerContentViolatedException("Content filtered by AWS Bedrock");
        }
        if (StringUtils.isBlank((CharSequence)result) && !hasToolCall) {
            throw new StudioAssistLlmApiServerNoAnswerException("Empty answer from AWS Bedrock");
        }
        if (hasToolCall) {
            List<ToolCall> toolCalls = toolUses.stream().map(this::convertToolUseBlockToToolCall).toList();
            return AssistantMessage.of(toolCalls);
        }
        String cleanedText = this.removeMarkdownCodeBlocks(result.toString());
        return AssistantMessage.of((String)cleanedText);
    }

    private String removeMarkdownCodeBlocks(String text) {
        if (StringUtils.isBlank((CharSequence)text)) {
            return text;
        }
        String trimmedText = text.trim();
        if (trimmedText.startsWith(JSON_MARKDOWN_PREFIX) && trimmedText.endsWith(JSON_MARKDOWN_SUFFIX)) {
            String cleanedText = trimmedText.substring(JSON_MARKDOWN_PREFIX.length());
            if (cleanedText.endsWith(JSON_MARKDOWN_SUFFIX)) {
                cleanedText = cleanedText.substring(0, cleanedText.length() - JSON_MARKDOWN_SUFFIX.length());
            }
            return cleanedText.trim();
        }
        return text;
    }

    private void initializeBedrockClient() {
        try {
            AwsSessionCredentials credentials = AwsSessionCredentials.create((String)this.config.getAwsAccessKey(), (String)this.config.getAwsSecretKey(), (String)this.config.getAwsSessionToken());
            this.bedrockClient = (BedrockRuntimeClient)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)BedrockRuntimeClient.builder().region(Region.of((String)this.config.getAwsRegion()))).credentialsProvider((AwsCredentialsProvider)StaticCredentialsProvider.create((AwsCredentials)credentials))).httpClient(ApacheHttpClient.builder().build())).build();
        }
        catch (Exception e) {
            throw new StudioAssistLlmApiClientException("Failed to initialize AWS Bedrock client: " + e.getMessage());
        }
    }

    private ToolCall convertToolUseBlockToToolCall(ToolUseBlock toolUseBlock) {
        try {
            String toolName = toolUseBlock.name();
            String toolId = toolUseBlock.toolUseId();
            Document inputDocument = toolUseBlock.input();
            ToolCall tool = new ToolCall();
            tool.setCallId(toolId);
            tool.setName(toolName);
            String inputJson = this.convertDocumentToJson(inputDocument);
            tool.setInput(inputJson);
            return tool;
        }
        catch (Exception e) {
            this.logger.warn("Error converting ToolUseBlock to ToolCall", (Throwable)e);
            return null;
        }
    }

    private String convertDocumentToJson(Document document) {
        String jsonStringDefault = "{}";
        try {
            if (document == null) {
                return jsonStringDefault;
            }
            Object documentValue = document.unwrap();
            return this.objectMapper.writeValueAsString(documentValue);
        }
        catch (Exception e) {
            this.logger.error("Error converting AWS SDK Document to JSON", (Throwable)e);
            return jsonStringDefault;
        }
    }

    private Document convertJsonToDocument(String jsonInput) {
        try {
            if (StringUtils.isBlank((CharSequence)jsonInput)) {
                return Document.fromMap(Collections.emptyMap());
            }
            Object jsonObject = this.objectMapper.readValue(jsonInput, Object.class);
            return this.convertObjectToDocument(jsonObject);
        }
        catch (JsonProcessingException e) {
            this.logger.error("Failed to convert JSON to Document, using empty object", (Throwable)e);
            return Document.fromMap(Collections.emptyMap());
        }
    }

    private Document convertJsonSchemaToDocument(McpSchema.JsonSchema schema) {
        if (schema == null) {
            return Document.fromMap(Collections.emptyMap());
        }
        try {
            String schemaJson = this.objectMapper.writeValueAsString((Object)schema);
            Map schemaMap = (Map)this.objectMapper.readValue(schemaJson, Map.class);
            return this.convertMapToDocument(schemaMap);
        }
        catch (JsonProcessingException e) {
            this.logger.warn("Failed to convert JsonSchema to Document: " + e.getMessage(), (Throwable)e);
            return Document.fromMap(Collections.emptyMap());
        }
    }

    private Document convertObjectToDocument(Object obj) {
        if (obj == null) {
            return Document.fromNull();
        }
        if (obj instanceof String) {
            return Document.fromString((String)((String)obj));
        }
        if (obj instanceof Boolean) {
            return Document.fromBoolean((boolean)((Boolean)obj));
        }
        if (obj instanceof Number) {
            if (obj instanceof Integer || obj instanceof Long || obj instanceof Short || obj instanceof Byte) {
                return Document.fromNumber((SdkNumber)SdkNumber.fromInteger((int)((Number)obj).intValue()));
            }
            if (obj instanceof Float || obj instanceof Double) {
                return Document.fromNumber((SdkNumber)SdkNumber.fromDouble((double)((Number)obj).doubleValue()));
            }
            if (obj instanceof BigInteger) {
                return Document.fromNumber((SdkNumber)SdkNumber.fromBigInteger((BigInteger)((BigInteger)obj)));
            }
            if (obj instanceof BigDecimal) {
                return Document.fromNumber((SdkNumber)SdkNumber.fromBigDecimal((BigDecimal)((BigDecimal)obj)));
            }
            return Document.fromNumber((SdkNumber)SdkNumber.fromDouble((double)((Number)obj).doubleValue()));
        }
        if (obj instanceof Map) {
            return this.convertMapToDocument((Map)obj);
        }
        if (obj instanceof List) {
            return this.convertListToDocument((List)obj);
        }
        if (obj instanceof Object[]) {
            return this.convertListToDocument(Arrays.asList((Object[])obj));
        }
        return Document.fromString((String)obj.toString());
    }

    private Document convertMapToDocument(Map<String, Object> map) {
        Document.MapBuilder mapBuilder = Document.mapBuilder();
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            String key = entry.getKey();
            Object value = entry.getValue();
            if (value == null) {
                mapBuilder.putNull(key);
                continue;
            }
            if (value instanceof String) {
                mapBuilder.putString(key, (String)value);
                continue;
            }
            if (value instanceof Boolean) {
                mapBuilder.putBoolean(key, ((Boolean)value).booleanValue());
                continue;
            }
            if (value instanceof Integer) {
                mapBuilder.putNumber(key, ((Integer)value).intValue());
                continue;
            }
            if (value instanceof Long) {
                mapBuilder.putNumber(key, ((Long)value).longValue());
                continue;
            }
            if (value instanceof Double) {
                mapBuilder.putNumber(key, ((Double)value).doubleValue());
                continue;
            }
            if (value instanceof Float) {
                mapBuilder.putNumber(key, ((Float)value).floatValue());
                continue;
            }
            if (value instanceof BigDecimal) {
                mapBuilder.putNumber(key, (BigDecimal)value);
                continue;
            }
            if (value instanceof BigInteger) {
                mapBuilder.putNumber(key, (BigInteger)value);
                continue;
            }
            if (value instanceof Map) {
                mapBuilder.putDocument(key, this.convertMapToDocument((Map)value));
                continue;
            }
            if (value instanceof List) {
                mapBuilder.putDocument(key, this.convertListToDocument((List)value));
                continue;
            }
            mapBuilder.putString(key, String.valueOf(value));
        }
        return mapBuilder.build();
    }

    private Document convertListToDocument(List<Object> list) {
        Document.ListBuilder listBuilder = Document.listBuilder();
        for (Object item : list) {
            if (item == null) {
                listBuilder.addNull();
                continue;
            }
            if (item instanceof String) {
                listBuilder.addString((String)item);
                continue;
            }
            if (item instanceof Boolean) {
                listBuilder.addBoolean(((Boolean)item).booleanValue());
                continue;
            }
            if (item instanceof Integer) {
                listBuilder.addNumber(((Integer)item).intValue());
                continue;
            }
            if (item instanceof Long) {
                listBuilder.addNumber(((Long)item).longValue());
                continue;
            }
            if (item instanceof Double) {
                listBuilder.addNumber(((Double)item).doubleValue());
                continue;
            }
            if (item instanceof Float) {
                listBuilder.addNumber(((Float)item).floatValue());
                continue;
            }
            if (item instanceof BigDecimal) {
                listBuilder.addNumber((BigDecimal)item);
                continue;
            }
            if (item instanceof BigInteger) {
                listBuilder.addNumber((BigInteger)item);
                continue;
            }
            if (item instanceof Map) {
                listBuilder.addDocument(this.convertMapToDocument((Map)item));
                continue;
            }
            if (item instanceof List) {
                listBuilder.addDocument(this.convertListToDocument((List)item));
                continue;
            }
            listBuilder.addString(String.valueOf(item));
        }
        return listBuilder.build();
    }
}

