package com.kms.katalon.core.webservice.aws;

import org.apache.commons.io.IOUtils;
import software.amazon.awssdk.utils.BinaryUtils;
import java.security.MessageDigest;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.signer.Aws4Signer;
import software.amazon.awssdk.auth.signer.AwsSignerExecutionAttribute;
import software.amazon.awssdk.core.interceptor.ExecutionAttributes;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.regions.Region;
import java.io.IOException;
import java.io.InputStream;
import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException;
import java.util.Objects;
import java.util.Optional;

import org.apache.http.HttpEntity;
import org.apache.http.HttpEntityEnclosingRequest;
import org.apache.http.client.methods.HttpRequestWrapper;
import org.apache.http.client.methods.HttpUriRequest;
import com.kms.katalon.core.testobject.authorization.AwsSignatureAuthorization;
import com.kms.katalon.core.testobject.authorization.AwsSignatureLocation;
import com.kms.katalon.core.testobject.authorization.RequestAuthorization;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import com.kms.katalon.core.webservice.exception.WebServiceException;

public class Authenticator {
    private static String AWS_CONTENT_HASH_HEADER_NAME = "x-amz-content-sha256";
    
    private Aws4Signer aws4Signer = Aws4Signer.create();
    
    public HttpUriRequest sign(HttpUriRequest request, RequestAuthorization authenticationSetting)
            throws IOException, GeneralSecurityException, WebServiceException {
        AwsSignatureAuthorization awsSetting = AwsSignatureAuthorization.adapt(authenticationSetting);
        boolean signedQueryParam = awsSetting.getSignatureLocation() == AwsSignatureLocation.QUERY_STRING;
        SdkHttpFullRequest signedRequest = signInternal(request, awsSetting, signedQueryParam);
        var outputSignedRequest = HttpRequestWrapper.wrap(request);

        if (signedQueryParam) {
            outputSignedRequest.setURI(signedRequest.getUri());
        }

        // Add headers which could be added for the signed AWS request
        signedRequest.headers().entrySet().stream().forEach((signedHeader) -> {
            var key = signedHeader.getKey();
            if (key.toLowerCase().equals("authorization") && !signedQueryParam) {
                // The Authorization header contains the signature token so it will
                // be override on the output request
                outputSignedRequest.removeHeaders(key);
                outputSignedRequest.addHeader(key, String.join(",", signedHeader.getValue()));
            } else if (!outputSignedRequest.containsHeader(key)) {
                outputSignedRequest.addHeader(key, String.join(",", signedHeader.getValue()));
            }
        });

        return outputSignedRequest;
    }

    private SdkHttpFullRequest convertToAwsRequest(HttpUriRequest request, AwsSignatureAuthorization awsSetting) throws WebServiceException {
        final SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder()
                .method(SdkHttpMethod.fromValue(request.getMethod()))
                .uri(request.getURI());
        boolean requestPayloadAvailable = false;
        for (var h : request.getAllHeaders()) {
            builder.appendHeader(h.getName(), h.getValue());
        }

        if (request instanceof HttpEntityEnclosingRequest entityRequest) {
            HttpEntity entity = entityRequest.getEntity();
            if (Objects.nonNull(entity)) {
                try {
                    final InputStream is = convert(entity);
                    builder.contentStreamProvider(new ContentStreamProvider() {
                        @Override
                        public InputStream newStream() {
                            return is;
                        }
                    });
                } catch (IOException e) {
                    throw new WebServiceException("Converting HttpUriRequest to SdkHttpFullRequest fails on copying request body from HttpUriRequest", e);
                }
                
                try {
                    builder.appendHeader(AWS_CONTENT_HASH_HEADER_NAME, hash(Optional.of(entity)));
                    requestPayloadAvailable = true;
                } catch (NoSuchAlgorithmException | IOException e) {
                    throw new WebServiceException(String.format("Fails to hash the request body from HttpUriRequest for the header %s", AWS_CONTENT_HASH_HEADER_NAME), e);
                }
            }
        }
        
        // To comply with S3 service :(
        // https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
        if (!requestPayloadAvailable) {
            try {
                builder.appendHeader(AWS_CONTENT_HASH_HEADER_NAME, hash(Optional.empty()));
            } catch (NoSuchAlgorithmException | IOException e) {
                throw new WebServiceException(String.format("Fails to create hash for the empty payload for the header %s", AWS_CONTENT_HASH_HEADER_NAME), e);
            }
        }

        return builder.build();
    }
    
    private SdkHttpFullRequest signInternal(HttpUriRequest request, AwsSignatureAuthorization awsSetting,
            boolean presign) throws IOException, GeneralSecurityException, WebServiceException {
        SdkHttpFullRequest awsRequest = convertToAwsRequest(request, awsSetting);
        AwsCredentials awsCredentials = AwsBasicCredentials
                .create(awsSetting.getAwsAccessKey().get(), awsSetting.getAwsSecretKey().get());
        if (awsSetting.getAwsSessionToken().isPresent()) {
            awsCredentials = AwsSessionCredentials
                    .create(awsSetting.getAwsAccessKey().get(), awsSetting.getAwsSecretKey().get(), awsSetting.getAwsSessionToken().get());
        }
        
        var executionAttributes = new ExecutionAttributes()
                .putAttribute(AwsSignerExecutionAttribute.SERVICE_SIGNING_NAME, awsSetting.getAwsServiceName().get())
                .putAttribute(AwsSignerExecutionAttribute.SIGNING_REGION, Region.of(awsSetting.getAwsRegion()))
                .putAttribute(AwsSignerExecutionAttribute.AWS_CREDENTIALS, awsCredentials);

        // See https://docs.aws.amazon.com/IAM/latest/UserGuide/aws-signing-authentication-methods.html
        return presign ? aws4Signer.presign(awsRequest, executionAttributes)
                : aws4Signer.sign(awsRequest, executionAttributes);
    }
    
    private InputStream convert(HttpEntity httpPayload) throws IOException {
        var pis = new PipedInputStream((int)httpPayload.getContentLength());
        try (var pos = new PipedOutputStream(pis)) {
            httpPayload.writeTo(pos);
        }
        
        return pis;
    }
    
    private String hash(Optional<HttpEntity> httpPayload) throws NoSuchAlgorithmException, IOException {
        MessageDigest digest = MessageDigest.getInstance("SHA-256");
        byte[] rawPayload = new byte[0];
        
        if (httpPayload.isPresent()) {
            rawPayload = IOUtils.toByteArray(convert(httpPayload.get()));
        }
        
        byte[] hashedBytes = digest.digest(rawPayload);
        return BinaryUtils.toHex(hashedBytes);
    }
}
