#include "kg_utils.h"

uint32_t hmac_sha256(uint8_t *key, uint32_t key_len, uint8_t *message_data, 
    uint32_t message_data_len, uint8_t *digest, uint32_t *digest_len) {
    uint32_t ret = KG_SUCCESS;
    uint8_t kx[KG_BUF_LEN];
    TEE_MemFill(kx, 0, KG_BUF_LEN);

    uint8_t *key_hash = NULL;
    key_hash = TEE_Malloc(SHA256_DIGEST_LENGTH, 0);
    if (NULL == key_hash) {
        ret = KG_ALLOC_BUFFER_FAIL;
        goto exit;
    }

    uint32_t key_hash_len = SHA256_DIGEST_LENGTH;
    if (NULL == key || 0 == key_len) {
        KG_LOG("KG hmac_sha256, invalid input key file");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

    if (NULL == message_data || 0 == message_data_len) {
        KG_LOG("KG hmac_sha256, invalid message input");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

    if (NULL == digest || *digest_len < SHA256_DIGEST_LENGTH) {
        KG_LOG("KG hmac_sha256, invalid digest buffer");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

    if (key_len > KG_HMAC_BLOCK) {
        if (TZ_digest_SHA256(key, key_len, key_hash, &key_hash_len) != TZ_API_OK) {
            ret = KG_TZ_API_FAIL;
            goto exit;
        }
        key_len = key_hash_len;
        TEE_MemMove(key, (void *)key_hash, key_len);
    }

    uint32_t i;

    for (i = 0; i < key_len; i++)
        kx[i] = I_PAD ^ key[i];

    for (i = key_len; i < KG_HMAC_BLOCK; i++)
        kx[i] = I_PAD ^ 0;

    for (i = KG_HMAC_BLOCK; i < KG_HMAC_BLOCK + message_data_len; i++)
        kx[i] = message_data[i - KG_HMAC_BLOCK];

    if (TZ_digest_SHA256(kx, KG_HMAC_BLOCK + message_data_len, digest, digest_len) != TZ_API_OK) {
        ret = KG_TZ_API_FAIL;
        goto exit;
    }

    for (i = 0; i < key_len; i++)
        kx[i] = O_PAD ^ key[i];
    
    for (i = key_len; i < KG_HMAC_BLOCK; i++)
        kx[i] = O_PAD ^ 0;

    for (i = KG_HMAC_BLOCK; i < KG_HMAC_BLOCK + *digest_len; i++)
        kx[i] = digest[i - KG_HMAC_BLOCK];

    if (TZ_digest_SHA256(kx, KG_HMAC_BLOCK + *digest_len, digest, digest_len) != TZ_API_OK) {
        ret = KG_TZ_API_FAIL;
        goto exit;
    }

exit:
    TEE_MemFill(kx, 0x0, KG_BUF_LEN);
    if (key_hash != NULL) {
        TEE_Free(key_hash);
        key_hash = NULL;
    }
    return ret;
}

uint32_t reverse(char str[], int length)
{
    if (length <= 0 || length >= KG_BUF_LEN) {
        KG_LOG("reverse Invalid buffer length!");
        return KG_BUFFER_SIZE_FAIL;
    }
    int start = 0;
    int end = length -1;
    char tmp;
    while (start < end)
    {
        tmp = *(str+start);
        *(str + start) = *(str + end);
        *(str + end) = tmp;
        start++;
        end--;
    }
    return KG_SUCCESS;
}

void kg_dump(
    char *label,
    uint8_t *data,
    uint32_t data_len
)
{
    uint32_t i;
    char buf[50] = { 0 };
    KG_LOG(">>KG_TA_DUMP: %s(%d)", label, data_len);

    for (i = 0; i < data_len; i++) {
        sprintf(buf + 3 * (i % 16), "%02X ", *(data + i));
        if (i % 16 == 15) {
            KG_LOG("%s", buf);
            TEE_MemFill(buf, 0, sizeof(buf));
        }
    }
    KG_LOG("%s", buf);
}

int KG_gcm_encrypt(unsigned char *plaintext, int plaintext_len,
                unsigned char *key,
                unsigned char *iv, int iv_len,
                unsigned char *ciphertext,
                int *cipher_len,
                unsigned char *tag)
{
    EVP_CIPHER_CTX *ctx;

    int ret = KG_CRYPTO_AES_GCM_FAIL;
    int len;
    int ciphertext_len;

    if(NULL == plaintext || 0 == plaintext_len){
        KG_LOG("Error: KG AES GCM plaintext invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == key){
        KG_LOG("Error: KG AES GCM key invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == iv || 0 == iv_len || iv_len > 16){
        KG_LOG("Error: KG AES GCM iv invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }        
    if(NULL == ciphertext || NULL == cipher_len){
        KG_LOG("Error: KG AES GCM ciphertext invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == tag){
        KG_LOG("Error: KG AES GCM tag invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    
    
    /* Create and initialise the context */
    if(!(ctx = EVP_CIPHER_CTX_new())){
        goto err;
    }

    /* Initialise the encryption operation. */
    if(1 != EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)){
        goto err; 
    }
    
    /*
     * Set IV length if default 12 bytes (96 bits) is not appropriate
     */
    if(1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL)){
        goto err; 
    }
     
    /* Initialise key and IV */
    if(1 != EVP_EncryptInit_ex(ctx, NULL, NULL, key, iv)){
        goto err; 
    }

    /*
     * Provide the message to be encrypted, and obtain the encrypted output.
     * EVP_EncryptUpdate can be called multiple times if necessary
     */
    if(1 != EVP_EncryptUpdate(ctx, ciphertext, &len, plaintext, plaintext_len)){
        goto err; 
    }
    
    ciphertext_len = len;

    /*
     * Finalise the encryption. Normally ciphertext bytes may be written at
     * this stage, but this does not occur in GCM mode
     */
    if(1 != EVP_EncryptFinal_ex(ctx, ciphertext + len, &len)){
        goto err; 
    }
    ciphertext_len += len;

    /* Get the tag */
    if(1 != EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, 16, tag)){
        goto err; 
    }

    *cipher_len = ciphertext_len;
    
    /* Clean up */
    EVP_CIPHER_CTX_free(ctx);

    
    return KG_SUCCESS;

err: 
    if(NULL != ctx){
        EVP_CIPHER_CTX_free(ctx);
        ctx = NULL; 
    }
    *cipher_len = 0;
    return ret;
        
}


int KG_gcm_decrypt(unsigned char *ciphertext, int ciphertext_len,
                unsigned char *tag,
                unsigned char *key,
                unsigned char *iv, int iv_len,
                unsigned char *plaintext,
                int *plain_len)
{
    EVP_CIPHER_CTX *ctx;
    int len;
    int plaintext_len;
    int ret;


    if(NULL == plaintext || NULL == plain_len){
        KG_LOG("Error: KG AES GCM plaintext invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == key){
        KG_LOG("Error: KG AES GCM key invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == iv || 0 == iv_len || iv_len > 16){
        KG_LOG("Error: KG AES GCM iv invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }        
    if(NULL == ciphertext || 0 == ciphertext_len){
        KG_LOG("Error: KG AES GCM ciphertext invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == tag){
        KG_LOG("Error: KG AES GCM tag invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    

    /* Create and initialise the context */
    if(!(ctx = EVP_CIPHER_CTX_new())){
        goto err; 
    }
     
    /* Initialise the decryption operation. */
    if(!EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)){
        goto err; 
    }
     
    /* Set IV length. Not necessary if this is 12 bytes (96 bits) */
    if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, iv_len, NULL)){
        goto err; 
    }
     
    /* Initialise key and IV */
    if(!EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv)){
        goto err; 
    }

    /*
     * Provide the message to be decrypted, and obtain the plaintext output.
     * EVP_DecryptUpdate can be called multiple times if necessary
     */
    if(!EVP_DecryptUpdate(ctx, plaintext, &len, ciphertext, ciphertext_len)){
        goto err; 
    }
    plaintext_len = len;

    /* Set expected tag value. Works in OpenSSL 1.0.1d and later */
    if(!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, 16, tag)){
        goto err; 
    }

    /*
     * Finalise the decryption. A positive return value indicates success,
     * anything else is a failure - the plaintext is not trustworthy.
     */
    ret = EVP_DecryptFinal_ex(ctx, plaintext + len, &len);

    if(ret <= 0){
        
        goto err; 
    }
    plaintext_len += len;
    /* Success */
    /* Clean up */
    EVP_CIPHER_CTX_free(ctx);
    
    *plain_len = plaintext_len;
    return KG_SUCCESS;

err: 
    if(NULL != ctx){
        EVP_CIPHER_CTX_free(ctx);
        ctx = NULL; 
    }
    *plain_len = 0; 
    return KG_CRYPTO_AES_GCM_FAIL;
    
}

int kg_public_encrypt(unsigned char * data,int data_len, RSA* key, unsigned char *out, int *out_len){
    if(NULL == data || data_len <= 0){
        KG_LOG("Error: KG RSA  plaintext invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == key ){
        KG_LOG("Error: KG RSA  key invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == out){
        KG_LOG("Error: KG encrypt buffer invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    int ret = RSA_public_encrypt(data_len, data, out, key, RSA_PKCS1_OAEP_PADDING);
    if(-1 == ret){
        KG_LOG("Error: KG RSA encrypt failed!");
        return KG_CRYPTO_RSA_ENC_FAIL;
    }

    *out_len = ret;
    
    return KG_SUCCESS;    
}

int KG_private_decrypt(unsigned char * data, int data_len, RSA* key, unsigned char *out, int *out_len){
    if(NULL == data || data_len <= 0){
        KG_LOG("Error: KG RSA  plaintext invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == key ){
        KG_LOG("Error: KG RSA  key invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    if(NULL == out || NULL == out_len){
        KG_LOG("Error: KG encrypt buffer invalid!"); 
        return KG_CRYPTO_INVALID_PARAM;
    }
    int ret = RSA_private_decrypt(data_len, data ,out, key, RSA_PKCS1_OAEP_PADDING);
    
    if(-1 == ret){
        KG_LOG("Error: KG RSA decrypt failed!");
        return KG_CRYPTO_RSA_DEC_FAIL;
    }
    
    *out_len = ret;

    return KG_SUCCESS;
}

char *decimal_to_binary(uint32_t n)
{
    int c, d, count;
    char *pointer;

    count = 0;
    pointer = (char*)TEE_Malloc(33, 0);
    if (pointer == NULL)
        return NULL;

    for (c = 31 ; c >= 0 ; c--)
    {
        d = n >> c;
        if (d & 1)
            *(pointer+count) = '1';
        else
            *(pointer+count) = '0';

        count++;
    }
    *(pointer+count) = '\0';
    return  pointer;
}
