#include "knoxai_io_datatypes.h"
#include "knoxai_logger.h"

#include <qsee_log.h> // qsee_log (get_alt_rot_distname)
#include <qsee_message.h> // qsee_decapsulate_inter_app_message
#include <qsee_cfg_prop.h> // qsee_cfg_getpropval
#include <qsee_kdf.h> // qsee_kdf
#include <qsee_prng.h> // qsee_prng_getdata
#include <qsee_cipher.h> // qsee_cipher_*
#include <qsee_hmac.h> // qsee_hmac
#include <string.h>

__inline static uint32_t knoxai_q_unwrap(char *salt, uint8_t *blob_ptr, uint32_t blob_len, uint8_t *dest_ptr, uint32_t* dest_len){
    tz_knoxai_return_type   ret = KNOXAI_FAILURE;
    int                     unwrap_ret;
    char                    key_context[] = {"HW Crypto AES key"};
    uint32_t                decKey_len = KNOXAI_FAC_KEY_LEN;
    uint8_t                 decKey[KNOXAI_FAC_KEY_LEN];
    uint32_t                auth_len = KNOXAI_FAC_KEY_LEN;
    uint8_t                 auth_data[KNOXAI_FAC_KEY_LEN];
    uint8_t                 *auth_ptr;
    uint32_t                nonce_len = MAX_NONCE_LEN/8; // 128bit
    uint8_t                 nonce[MAX_NONCE_LEN/8] = {0};
	qsee_cipher_ctx         *ctx = 0;
	QSEE_CIPHER_MODE_ET     mode = QSEE_CIPHER_MODE_CTR;
	QSEE_CIPHER_PAD_ET      padMode = QSEE_CIPHER_PAD_PKCS7;
    uint8_t                 *decrypted_object = NULL;
    uint32_t                decrypted_object_len = blob_len + nonce_len + KNOXAI_FAC_KEY_LEN*2;//iv+data+sha256
    uint8_t so_buffer[sizeof(drk_parsed_object_t)+SO_LOCAL_HEADERSIZE+16];
    uint32_t so_buffer_len = sizeof(drk_parsed_object_t)+SO_LOCAL_HEADERSIZE+16;

    if ( blob_len < nonce_len + auth_len ) {
        KNOXAI_LOG("q_unwrap: out buffer is small in[%d] ", blob_len);
        goto exit;
    }    

    TEE_MemMove(nonce   , blob_ptr, nonce_len);
    auth_ptr = (uint8_t*) (blob_ptr + nonce_len);
    decrypted_object = (uint8_t*) (blob_ptr + nonce_len + auth_len);
    decrypted_object_len = blob_len - nonce_len - auth_len;

    unwrap_ret = qsee_kdf(NULL, KNOXAI_FAC_KEY_LEN, (void*)salt, strlen(salt), (void*)key_context, strlen(key_context), decKey, decKey_len);
    if ( unwrap_ret != QSEE_KDF_SUCCESS ) {
        KNOXAI_LOG("q_unwrap: Key Bigger than space");
        goto exit;
    }
    if (qsee_cipher_init(QSEE_CIPHER_ALGO_AES_256, &ctx) < 0) {
        KNOXAI_LOG("q_unwrap: cipher init err");              goto exit;
    } else if (qsee_cipher_set_param(ctx, QSEE_CIPHER_PARAM_KEY, decKey, decKey_len) < 0) {
        KNOXAI_LOG("q_unwrap: cipher setparam key err"  );    goto exit;
    } else if (qsee_cipher_set_param(ctx, QSEE_CIPHER_PARAM_MODE, &mode, sizeof(mode)) < 0) {
        KNOXAI_LOG("q_unwrap: cipher setparam mode err" );    goto exit;
    } else if (qsee_cipher_set_param(ctx, QSEE_CIPHER_PARAM_PAD, &padMode, sizeof(padMode)) < 0) {
        KNOXAI_LOG("q_unwrap: cipher setparam pad err"  );    goto exit;
    } else if (qsee_cipher_set_param(ctx, QSEE_CIPHER_PARAM_IV, nonce, nonce_len) < 0) {
        KNOXAI_LOG("q_unwrap: cipher setparam nonce err");    goto exit;
    } else if (qsee_cipher_decrypt(ctx, decrypted_object, decrypted_object_len, so_buffer, &so_buffer_len) < 0) {
        KNOXAI_LOG("q_unwrap: cipher enc err");               goto exit;
    }

    KNOXAI_DEBUG_LOG("q_unwrap: enc successed dest_len = %d, so_buffer_len = %d", *dest_len, so_buffer_len);
    
    *dest_len = so_buffer_len;
    TEE_MemMove(dest_ptr, so_buffer, so_buffer_len);

    unwrap_ret = qsee_hmac(QSEE_HMAC_SHA256, dest_ptr, *dest_len, nonce, nonce_len, auth_data);
    if ( unwrap_ret != 0 ) {
        KNOXAI_LOG("q_unwrap: hash error[%d]", unwrap_ret);
        goto exit;
    }
    if ( TEE_MemCompare((void*)auth_ptr, (void*)auth_data, auth_len) != 0  ) {
        KNOXAI_LOG("q_unwrap: auth error 0x[%x], 0x[%x], len[%d]", auth_ptr[0], auth_data[0], auth_len);
        goto exit;
    }
    ret = KNOXAI_SUCCESS;
exit:
    TEE_MemFill(nonce, 0x00, sizeof(nonce));
    TEE_MemFill(decKey, 0x00, sizeof(decKey));
    TEE_MemFill(auth_data, 0x00, sizeof(auth_data));
    TEE_MemFill(so_buffer, 0x00, sizeof(so_buffer));
    if (ctx) {
        if (qsee_cipher_free_ctx(ctx) < 0) {
            KNOXAI_LOG("q_wrap: cipher free error");
        }
    }
    return ret;
}

__inline static uint32_t knoxai_q_wrap(char *salt, uint8_t *blob_ptr, uint32_t blob_len, uint8_t *dest_ptr, uint32_t* dest_len){
    tz_knoxai_return_type   ret = KNOXAI_FAILURE;
    int                     wrap_ret;
    char                    key_context[] = {"HW Crypto AES key"};
    uint32_t                encKey_len = KNOXAI_FAC_KEY_LEN;
    uint8_t                 encKey[KNOXAI_FAC_KEY_LEN];
    uint32_t                auth_len = KNOXAI_FAC_KEY_LEN;
    uint8_t                 auth_ptr[KNOXAI_FAC_KEY_LEN];
    uint32_t                nonce_len = MAX_NONCE_LEN/8; // 128bit
    uint8_t                 nonce[MAX_NONCE_LEN/8] = {0};
	qsee_cipher_ctx         *ctx = 0;
	QSEE_CIPHER_MODE_ET     mode = QSEE_CIPHER_MODE_CTR;
	QSEE_CIPHER_PAD_ET      padMode = QSEE_CIPHER_PAD_PKCS7;
    uint8_t                 *encrypted_object = NULL;
    uint32_t                encrypted_object_len = blob_len + nonce_len + KNOXAI_FAC_KEY_LEN*2;//iv+data+sha256

    encrypted_object = TEE_Malloc(encrypted_object_len, 0);
    if ( encrypted_object == NULL ) {
        KNOXAI_LOG("q_wrap: malloc error");
        goto exit;
    }
    nonce_len = qsee_prng_getdata(nonce, nonce_len);
    if ( nonce_len != (uint32_t)(MAX_NONCE_LEN/8) ) {
        KNOXAI_LOG("q_wrap: prng erro %d", nonce_len);
        goto exit;
    }
    wrap_ret = qsee_kdf(NULL, KNOXAI_FAC_KEY_LEN, (void*)salt, strlen(salt), (void*)key_context, strlen(key_context), encKey, encKey_len);
    if ( wrap_ret != QSEE_KDF_SUCCESS ) {
        KNOXAI_LOG("q_wrap: Key Bigger than space");
        goto exit;
    }
    // AES make new one?
    if (qsee_cipher_init(QSEE_CIPHER_ALGO_AES_256, &ctx) < 0) {
        KNOXAI_LOG("q_wrap: cipher init err");              goto exit;
    } else if (qsee_cipher_set_param(ctx, QSEE_CIPHER_PARAM_KEY, encKey, encKey_len) < 0) {
        KNOXAI_LOG("q_wrap: cipher setparam key err"  );    goto exit;
    } else if (qsee_cipher_set_param(ctx, QSEE_CIPHER_PARAM_MODE, &mode, sizeof(mode)) < 0) {
        KNOXAI_LOG("q_wrap: cipher setparam mode err" );    goto exit;
    } else if (qsee_cipher_set_param(ctx, QSEE_CIPHER_PARAM_PAD, &padMode, sizeof(padMode)) < 0) {
        KNOXAI_LOG("q_wrap: cipher setparam pad err"  );    goto exit;
    } else if (qsee_cipher_set_param(ctx, QSEE_CIPHER_PARAM_IV, nonce, nonce_len) < 0) {
        KNOXAI_LOG("q_wrap: cipher setparam nonce err");    goto exit;
    } else if (qsee_cipher_encrypt(ctx, blob_ptr, blob_len, encrypted_object, &encrypted_object_len) < 0) {
        KNOXAI_LOG("q_wrap: cipher enc err");               goto exit;
    }
    KNOXAI_DEBUG_LOG("q_wrap: enc successed %d", encrypted_object_len);
    wrap_ret = qsee_hmac(QSEE_HMAC_SHA256, blob_ptr, blob_len, nonce, nonce_len, auth_ptr);
    if ( wrap_ret != 0 ) {
        KNOXAI_LOG("q_wrap: hash error[%d]", wrap_ret);
        goto exit;
    }
    if ( *dest_len < encrypted_object_len + nonce_len + auth_len) {
        KNOXAI_LOG("q_wrap: out buffer is small out[%d] new[%d]", *dest_len, encrypted_object_len + nonce_len + auth_len);
        goto exit;
    }
    KNOXAI_DEBUG_LOG("q_wrap: len %d %d %d", nonce_len, auth_len, *dest_len);
    TEE_MemMove(dest_ptr, nonce, nonce_len);
    TEE_MemMove(dest_ptr + nonce_len, auth_ptr, auth_len);
    TEE_MemMove(dest_ptr + nonce_len + auth_len, encrypted_object, encrypted_object_len);
    *dest_len = nonce_len + auth_len + encrypted_object_len;
    ret = KNOXAI_SUCCESS;

exit:
    TEE_MemFill(nonce, 0x00, sizeof(nonce));
    TEE_MemFill(encKey, 0x00, sizeof(encKey));
    TEE_MemFill(auth_ptr, 0x00, sizeof(auth_ptr));
    if ( encrypted_object != NULL) {
        TEE_MemFill(encrypted_object, 0, blob_len + nonce_len + KNOXAI_FAC_KEY_LEN*2);
        TEE_Free(encrypted_object);
        encrypted_object = NULL;
    }
    if (ctx) {
        if (qsee_cipher_free_ctx(ctx) < 0) {
            KNOXAI_LOG("q_wrap: cipher free error");
        }
    }
    return ret;
}

void get_alt_rot_distname(char *i_appname, char *o_distname)
{
    const char *prop_name = "alt_rot_domain_name_dot";
    uint32_t ret = 0;
    uint32_t ret_size = 0;
    size_t len = 0;
    qsee_cfg_propvar_t *ptr = NULL;
    uint32_t prop[2 + (MAX_DISTNAME_PREFIX_SZ / sizeof(uint32_t))] = {0};
    char distname_prefix[MAX_DISTNAME_PREFIX_SZ + 1] = {0};

    ret = qsee_cfg_getpropval(prop_name,
                              strlen(prop_name) + 1, 0,
                              (qsee_cfg_propvar_t *)&prop,
                              sizeof(prop), &ret_size);
    if (QSEE_CFG_SUCCESS != ret) {
        KNOXAI_LOG("'alt_rot_domain_name_dot' read failed %d, using legacy appname", ret);
        ret_size = strlcpy(o_distname, i_appname, MAX_TANAME_SZ);
        return;
    }
    ptr = (qsee_cfg_propvar_t *)prop;
    /* len = ret_size - sizeof(qsee_cfg_propvar_t) + padding */
    len = ret_size - sizeof(*ptr) + 2 * sizeof(ptr->val) + 1;
    if (len > MAX_DISTNAME_PREFIX_SZ) {
        KNOXAI_LOG("'alt_rot_domain_name_dot' len invalid, using legacy appname");
        strlcpy(o_distname, i_appname, MAX_TANAME_SZ);
        return;
    }
    /* remove the quotes only when read from devcfg */
    TEE_MemMove(distname_prefix, &ptr->val[1], len - 1);
    distname_prefix[len] = '\0';
    /* finalize fully qualified distname */
    strlcpy(o_distname, distname_prefix, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
}
void add_rot_alt_name(char *i_appname, char *o_distname) {
    // const int ALT_HASH_SIZE = 3 + 1 + 64 + 1;
#if defined(SDM845) // RoT hash value for QSEE_SDM845_TA (sha256)
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.2945FB3C624A03E83D9E7D892DF938559E7D0F56B1475660E0887BAE4D01DA77.";    
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#elif defined(SM8250) // RoT hash value of QSEE_SM8250_TA (sha384) - first 64 bytes only
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.288717EFA81760C347CDB3A0CA23723C92AC2EE97AD36A7BB9EE3EEE76678BEA.";    
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#elif defined(SM6150) || defined(SM7150)// RoT hash value of QSEE_SM6150_TA (sha384) - first 64 bytes only
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.FEF4EFBC5D6689C2939E7C410088053F102F5FB4D7B7319C3144BD79E5464654.";    
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#elif defined(SM8150_FUSION) // RoT hash value of QSEE_SM8150_FUSION_TA (sha384) - first 64 bytes only
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.984426BE79B6C60F9265EDF6ECDD4CFF1827E9CB0B439D1CBFDCE378B76001DC.";    
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#elif defined(SM8150) // RoT hash value of QSEE_SM8150_TA (sha384) - first 64 bytes only
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.9361A53CBE05BDDEEEF7DDCD9E0D2AD472FF0A1AEE5F3EE4C7E416162273D237."; // 4E836F360B516BE1022A16C369F62E5A.";    
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#elif defined(SDM710) // RoT hash value for QSEE_SDM710_TA (sha256) //kis
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.C6A1B1F001AA41325C471DEFBB767B962D381CF0FE7224E73EC4305C0C848736.";    
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#elif defined(SM8350) // RoT hash value for QSEE_SM8350_TA (sha384) - first 64 bytes only
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.30F6F65747F73446B6BB31DB855F712EBBD6CE7493A40A6F674035047299C892.";   
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#elif defined(SM8450) // RoT hash value for QSEE_SM8450_TA (sha384) - first 64 bytes only
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.D16E188A7DDD76A0A409E541F6D2B9FD3FD0684E2B742383AF639324BE86561B.";   
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#else // default sdm710
    char destination[MAX_DISTNAME_PREFIX_SZ] = "alt.C6A1B1F001AA41325C471DEFBB767B962D381CF0FE7224E73EC4305C0C848736.";    
    strlcpy(o_distname, destination, MAX_DISTNAME_PREFIX_SZ);
    strlcat(o_distname, i_appname, MAX_FULLNAME_SZ);
#endif
}

char* getNameFromUuid(TEE_UUID target_uuid) {
    char                  *TA_APP_NAME = NULL;
    if ( TEE_MemCompare((void*)&target_uuid, (void*) &((TEE_UUID)PROV_UUID), sizeof(TEE_UUID) ) == 0 ) {
        TA_APP_NAME = PROV_NAME;
    } else if ( TEE_MemCompare((void*)&target_uuid, (void*) &((TEE_UUID)SELF_UUID), sizeof(TEE_UUID) ) == 0 ) {
        TA_APP_NAME = SELF_NAME;
    } else {
        KNOXAI_LOG("knoxai_wrap not support uuid ");
        TA_APP_NAME = SELF_NAME;
    }
    KNOXAI_DEBUG_LOG("getNameFromUuid %s ", TA_APP_NAME);
    return TA_APP_NAME;
}

uint32_t knoxai_unwrap(TEE_UUID target_uuid, uint8_t *blob_ptr, uint32_t blob_len, uint8_t *dest_ptr, uint32_t* dest_len){
    tz_knoxai_return_type  ret = KNOXAI_FAILURE;
    TEE_Result             unwrap_ret;
    char                   src_app[MAX_DISTNAME_PREFIX_SZ + MAX_TANAME_SZ + 1] = {0};
    char                   full_distname[MAX_DISTNAME_PREFIX_SZ + MAX_TANAME_SZ + 1] = {0}; // skm + RoT
    char                   *target_tz_name = getNameFromUuid(target_uuid);

    if ( blob_ptr == NULL || dest_ptr == NULL || dest_len == NULL) {
        KNOXAI_LOG("ptr must be NOT NULL");
        ret = KNOXAI_FAILURE;
        goto exit;
    }
    if ( TEE_MemCompare((void*)target_tz_name, (void*)SELF_NAME, sizeof(SELF_NAME)) == 0 ) {
        KNOXAI_LOG("knoxai_unwrap - local exchange");
        return knoxai_q_unwrap(target_tz_name, blob_ptr, blob_len, dest_ptr, dest_len);
    }
    if ( blob_len < SO_LOCAL_HEADERSIZE ) {
        KNOXAI_LOG("blob len error %d", blob_len);
        ret = KNOXAI_FAILURE;
        goto exit;
    }
    get_alt_rot_distname(target_tz_name, full_distname);
    unwrap_ret = qsee_decapsulate_inter_app_message(src_app, blob_ptr, blob_len, dest_ptr, dest_len);
    KNOXAI_DEBUG_LOG("knoxai_unwrap len blob[%d]dest[%d]", blob_len, *dest_len);
    if (unwrap_ret != 0) {
        KNOXAI_LOG("knoxai_unwrap decapsulate failed with ret=0x%08X, exit", unwrap_ret);
        ret = KNOXAI_FAILURE;
        goto exit;
    } else if (strncmp(src_app, full_distname, strlen(full_distname) + 1)) {
        KNOXAI_LOG("unwrap failed: src_app[%s] was not [%s]", src_app, full_distname);
        ret = KNOXAI_FAILURE;
        goto exit;
    }
    /*
    if (*dest_len < KNOXAI_FAC_KEY_LEN || *dest_len > MAX_WRAPPED_KEY_LEN) {
        KNOXAI_LOG("knoxai_unwrap: Key Bigger than space");
        ret = KNOXAI_FAILURE;
        goto exit;
    }*/
    ret = KNOXAI_SUCCESS;
exit:
    return ret;
}

uint32_t knoxai_wrap(TEE_UUID target_uuid, uint8_t *blob_ptr, uint32_t blob_len, uint8_t *dest_ptr, uint32_t* dest_len){
    tz_knoxai_return_type  ret = KNOXAI_FAILURE;
    TEE_Result             wrap_ret;
    char                   full_distname[MAX_DISTNAME_PREFIX_SZ + MAX_TANAME_SZ + 1] = {0}; // skm + RoT
    char                   *target_tz_name = getNameFromUuid(target_uuid);

    if ( blob_ptr == NULL || dest_ptr == NULL || blob_len == 0 || blob_len > *dest_len) {
        KNOXAI_LOG("ptr must be NOT NULL");
        ret = KNOXAI_FAILURE;
        goto exit;
    }
    if ( TEE_MemCompare((void*)target_tz_name, (void*)SELF_NAME, sizeof(SELF_NAME)) == 0 ) {
        KNOXAI_LOG("knoxai_wrap - local exchange");
        return knoxai_q_wrap(target_tz_name, blob_ptr, blob_len, dest_ptr, dest_len);
    }

    add_rot_alt_name(getNameFromUuid(target_uuid), full_distname);
    KNOXAI_DEBUG_LOG("knoxai_wrap full:%s", full_distname);
    wrap_ret = qsee_encapsulate_inter_app_message(full_distname, blob_ptr, blob_len, dest_ptr, dest_len);
    KNOXAI_DEBUG_LOG("knoxai_wrap len blob[%d]dest[%d]", blob_len, *dest_len);
    if (wrap_ret != 0) {
        KNOXAI_LOG("knoxai_wrap wrap failed with ret=0x%08X, exit", wrap_ret);
        ret = wrap_ret;
        goto exit;
    }
    /*if (*dest_len < KNOXAI_FAC_KEY_LEN || *dest_len > MAX_WRAPPED_KEY_LEN) {
        KNOXAI_LOG("knoxai_wrap: Key Bigger than space");
        ret = KNOXAI_FAILURE;
        goto exit;
    }*/
    // KNOXAI_DBG_DUMP("dest", (uint8_t*)(dest_ptr + (*dest_len) - 128), 128);
    ret = KNOXAI_SUCCESS;
exit:    
    return ret;
}
