#include "cipher.h"
#include "openssl/aes/aes.h"
#include "openssl/modes/modes_lcl.h"

#if defined(USE_BF)
#include "tee_internal_api.h"
#endif

int32_t aes_gcm_encrypt( const uint8_t* key, uint32_t key_size,
                                 const uint8_t* iv,  uint32_t iv_size,
				 const uint8_t* aad, uint32_t aad_size,
                                 uint8_t* data, uint32_t input_size,
                                 uint32_t* output_size)
{
    int32_t         status;
    GCM128_CONTEXT  ctx;
    AES_KEY         aes_key;
    // guarding against integer overflow
    if((input_size > (0xffffffff - AES_GCM_TAG_SIZE)) || ( *output_size < (input_size + AES_GCM_TAG_SIZE)))
    {
        CC_LOGE( "tima_aes_gcm_encrypt :: output_size is too small" );
        return BUFFER_TOO_SMALL;
    }

    status = AES_set_encrypt_key( (const unsigned char *) key, key_size*8, &aes_key );
    if( 0 != status )
    {
        CC_LOGE( "tima_aes_gcm_encrypt :: AES_set_encrypt_key failed with error : 0x%08X", status );
        status = CRYPTO_FAILED;
        goto cleanup;
    }

    CRYPTO_gcm128_init( &ctx, &aes_key, (block128_f)AES_encrypt );
    CRYPTO_gcm128_setiv( &ctx, iv, iv_size );

    if (aad != NULL && aad_size > 0)
    {
        status = CRYPTO_gcm128_aad(&ctx, aad, aad_size);

        if (0 != status)
        {
	    CC_LOGE( "aes_gcm_encrypt ::  CRYPTO_gcm128_aad failed with error: 0x%08X", status );
	    status = CRYPTO_FAILED;
	    goto cleanup;
        }
    }

    status = CRYPTO_gcm128_encrypt( &ctx, data, data, input_size );
    if( 0 != status )
    {
        CC_LOGE( "tima_aes_gcm_encrypt :: CRYPTO_gcm128_encrypt failed with error: 0x%08X", status );
        status = CRYPTO_FAILED;
        goto cleanup;
    }

    CRYPTO_gcm128_tag( &ctx, data + input_size, AES_GCM_TAG_SIZE );
    *output_size = input_size + AES_GCM_TAG_SIZE;

    status = SUCCESS;

cleanup:
#if defined(USE_BF)
    TEE_MemFill( &ctx, 0xAA, sizeof(ctx) );
    TEE_MemFill( &aes_key, 0xAA, sizeof(aes_key) );
#else
    memset( &ctx, 0xAA, sizeof(ctx) );
    memset( &aes_key, 0xAA, sizeof(aes_key) );
#endif
    return status;
}


int32_t aes_gcm_decrypt( const uint8_t* key, uint32_t key_size,
                                 const uint8_t* iv,  uint32_t iv_size,
				 const uint8_t* aad, uint32_t aad_size,
                                 uint8_t* data, uint32_t input_size,
                                 uint32_t* output_size)
{
    int32_t         status;
    GCM128_CONTEXT  ctx;
    AES_KEY         aes_key;
    uint8_t         tag[ AES_GCM_TAG_SIZE ];

    if( input_size < AES_GCM_TAG_SIZE )
    {
        CC_LOGE( "tima_aes_gcm_decrypt :: input_size is too small" );
        return INVALID_ARGUMENT;
    }

    status = AES_set_encrypt_key( (const unsigned char *) key, key_size*8, &aes_key );
    if( 0 != status )
    {
        CC_LOGE( "tima_aes_gcm_decrypt :: AES_set_encrypt_key failed with error : 0x%08X", status );
        status = CRYPTO_FAILED;
        goto cleanup;
    }

    CRYPTO_gcm128_init( &ctx, &aes_key, (block128_f)AES_encrypt );
    CRYPTO_gcm128_setiv( &ctx, iv, iv_size );

    if (aad != NULL && aad_size > 0)
    {
        status = CRYPTO_gcm128_aad(&ctx, aad, aad_size);

        if (0 != status)
        {
	    CC_LOGE( "aes_gcm_decrypt ::  CRYPTO_gcm128_aad failed with error: 0x%08X", status );
	    status = CRYPTO_FAILED;
	    goto cleanup;
        }
    }

    status = CRYPTO_gcm128_decrypt( &ctx, data, data, input_size - AES_GCM_TAG_SIZE );
    if( 0 != status )
    {
        CC_LOGE( "tima_aes_gcm_decrypt :: CRYPTO_gcm128_decrypt failed with error: 0x%08X", status );
        status = CRYPTO_FAILED;
        goto cleanup;
    }

    CRYPTO_gcm128_tag( &ctx, tag, AES_GCM_TAG_SIZE );
#if defined(USE_BF)
    if( TEE_MemCompare( tag, data + input_size - AES_GCM_TAG_SIZE, AES_GCM_TAG_SIZE ) != 0 )
#else
    if( memcmp( tag, data + input_size - AES_GCM_TAG_SIZE, AES_GCM_TAG_SIZE ) != 0 )
#endif
    {
        CC_LOGE( "tima_aes_gcm_decrypt :: wrong MAC value" );
        status = WRONG_TAG;
        goto cleanup;
    }

    *output_size = input_size - AES_GCM_TAG_SIZE;

    status = SUCCESS;

cleanup:
#if defined(USE_BF)
    TEE_MemFill( &ctx, 0xAA, sizeof(ctx) );
    TEE_MemFill( &aes_key, 0xAA, sizeof(aes_key) );
#else
    memset( &ctx, 0xAA, sizeof(ctx) );
    memset( &aes_key, 0xAA, sizeof(aes_key) );
#endif
    return status;
}
