/**
 * \file crypto_kdf.c
 * \brief
 * \author Vladislav Kovalenko v.kovalenko@samsung.com
 * \author Vladyslav Figol v.figol@samsung.com
 * \author Andrey Neyvanov a.neyvanov@samsung.com
 * \version 1.0
 * \date Created Jan 22, 2014 12:15
 * \par In Samsung Ukraine R&D Center (SURC) under a contract between
 * \par LLC "Samsung Electronics Ukraine Company" (Kyiv, Ukraine) and "Samsung Electronics Co", Ltd (Seoul, Republic of Korea)
 * \par Copyright: (c) Samsung Electronics Co, Ltd 2013. All rights reserved.
 **/

#include "custom_so.h"
#include "utilities.h"
#include "wsm_definitions_v3.h"
#include "wsm_log.h"
#include "wsm_rand.h"
#include "wsm_v1_crypto.h"
#include "wsm_v2_crypto.h"

#if defined(MOBICORE) && defined(TA_BUILD)
    #include "tlStd.h"
    #include "TlApi/TlApi.h"
#elif defined(QSEE) && defined(TA_BUILD)
    #include "qsee_kdf.h"
#elif defined(GP_API) && defined(TA_BUILD)
    #include <tee_internal_api.h>
    #include <tees_kdf.h>
#endif /* if defined(MOBICORE) && defined(TA_BUILD) */

#define BITS(value) ((value) * 8)

static return_t custom_so_derive_key2(uint8_t *key, uint32_t key_len, uint8_t *salt,
                                      uint32_t    salt_len);

static return_t custom_so_derive_key(uint8_t     *key, uint32_t    key_len, uint8_t     *salt,
                                     uint32_t    salt_len)
{
    (void)salt_len;

    if ((key == 0) ||
        (salt == 0))
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_BAD_PARAM, "[%s] input params are bad \n", __func__);
        return WSM_RET_E_BAD_PARAM;
    }

    WSM_LOG(err_level_info, LOG_TAG, "[%s] Entry key_len : %x, salt_len %x \n",
            __func__,
            key_len,
            salt_len);
    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Salt for key:", (char *)salt, salt_len);

#if defined(MOBICORE) && defined(TA_BUILD)

    tlApiResult_t   res = tlApiDeriveKey(salt,
                                         salt_len,
                                         key,
                                         key_len,
                                         MC_SO_CONTEXT_TLT,
                                         MC_SO_LIFETIME_PERMANENT);
    if (res != TLAPI_OK)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_PBKDF, "[%s] key derive error %x \n", __func__, res);
        return WSM_RET_E_PBKDF;
    }
    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Derived key :", (const char *)key, key_len);

#elif defined(QSEE) && defined(TA_BUILD)

    const uint32_t QSEE_KDF_KEY_LEN = 32;
    const size_t label_len = 32;
    const size_t context_len = 32;

    // NIST vectors
    uint8_t label[] = { 0x24, 0xe0, 0x9c, 0x85, 0x1f, 0x0b, 0xab, 0x38, 0xac, 0xf4,
                        0x51, 0xcb, 0x74, 0x98, 0x4b, 0x17, 0xb8, 0x29, 0x56, 0x3a,
                        0x44, 0x9e, 0xfd, 0x32, 0x12, 0x8d, 0x46, 0x27, 0x6c, 0x4f,
                        0x2f, 0x3a };
    uint8_t context[] = { 0x1a, 0x20, 0xc4, 0xbe, 0xe1, 0x74, 0xc0, 0x05, 0x11, 0x78,
                          0x5a, 0x79, 0xdf, 0x33, 0xbf, 0x89, 0x1a, 0xbe, 0x3a, 0x1c,
                          0x09, 0x81, 0x59, 0xc5, 0xcc, 0xf4, 0xb1, 0xd6, 0xaf, 0xd0,
                          0x59, 0x86 };

    for (unsigned int i = 0; ((i < label_len) && (i < salt_len)); ++i)
        label[i] ^= salt[i];

    for (unsigned int i = 0; ((i < context_len) && (i < salt_len)); ++i)
        context[i] ^= salt[i];

    int32_t     res = qsee_kdf(NULL,
                               QSEE_KDF_KEY_LEN,
                               label,
                               label_len,
                               context,
                               context_len,
                               key,
                               key_len);
    if (res < 0)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_PBKDF, "[%s] key derive error %x <<\n", __func__, res);
        return WSM_RET_E_PBKDF;
    }

    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Derived key :", (const char *)key, key_len);

#elif defined(GP_API) && defined(TA_BUILD)

    TEE_Result res = TEE_ERROR_GENERIC;

    TEE_ObjectHandle obj_hndl = TEE_HANDLE_NULL;

    res = TEE_AllocateTransientObject(TEE_TYPE_GENERIC_SECRET, // object type
                                      BITS(key_len), // max object size
                                      &obj_hndl); // ptr to obj hndl
    if (TEE_SUCCESS != res)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_PBKDF, "[%s] TEE_AllocateTransientObject error %x <<\n",
                  __func__, res);
        return WSM_RET_E_PBKDF;
    }

    res = TEES_DeriveKeyKDF(salt, // label
                            salt_len, // label_len
                            key, // context
                            key_len, // context len in bytes
                            key_len, // KDF key leng required in bytes
                            obj_hndl); // obj hndl
    if (TEE_SUCCESS != res)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_TA_DERIVE_KEY, "[%s] key derive error %x <<\n", __func__, res);
        goto free_obj_hndl;
    }

    res = TEE_GetObjectBufferAttribute(obj_hndl, // obj hndl
                                       TEE_ATTR_SECRET_VALUE, // attribute ID
                                       key, // out buff
                                       &key_len); // out buff len
    if (TEE_SUCCESS != res)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_TA_GET_OBJECT_ATTRIBUTE,
                  "[%s] TEE_GetObjectBufferAttribute error %x <<\n", __func__, res);
        goto free_obj_hndl;
    }

free_obj_hndl:

    TEE_ResetTransientObject(obj_hndl);
    TEE_CloseObject(obj_hndl);

    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Derived key :", (const char *)key, key_len);

    if (TEE_SUCCESS != res)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_PBKDF, "[%s] key derive error %x <<\n", __func__, res);
        return WSM_RET_E_PBKDF;
    }
#else

    _secure_memset(key, key_len);

    #ifndef WSM_FUZZING
    WSM_LOG_E(LOG_TAG, WSM_RET_E_FAKED_DERIVE_KEY, "[%s] key derive is faked \n", __func__);

    #else
    WSM_LOG(err_level_warning, LOG_TAG, "[%s] key derive is faked \n", __func__);
    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Derived key :", (char *)key, key_len);

    #endif /* ifdef FUZZING */

#endif /* if defined(MOBICORE) && defined(TA_BUILD) */

    WSM_LOG(err_level_info, LOG_TAG, "[%s] Exit >>\n", __func__);

    return WSM_RET_SUC;
}

// -----------------------------------------------------------------------------------------------------
return_t wrap_so(const uint8_t *src, uint32_t srcSize, uint8_t *dst, uint32_t *dstSize,
                 const uint8_t *id, uint32_t id_len)
{
    uint8_t         key[IV_KEY_SIZE] = { 0 };
    return_t        ret = WSM_RET_SUC;
    uint8_t         buff[WSM_MAX_ID_LENGTH + IV_KEY_SIZE] = { 0 }; // IV + ID
    uint32_t        buff_len = 0;

    WSM_LOG(err_level_info, LOG_TAG, "[%s] data  %d\n", __func__, srcSize);

    if ((src == 0) ||
        (dst == 0) ||
        (dstSize == 0) ||
        (id == 0) ||
        (id_len > WSM_MAX_ID_LENGTH))
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_BAD_PARAM, "[%s] Input params are bad %p %p %p %p %d\n",
                  __func__,
                  src,
                  dst,
                  dstSize,
                  id,
                  id_len);
        return WSM_RET_E_BAD_PARAM;
    }

    // Generate IV for Secure Object encryption.
    ret = wsm_GenerateRandombuffer(IV_KEY_SIZE, buff);
    if (ret != WSM_RET_SUC)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_RANDOM, "[%s] wsm_GenerateRandombuffer, failed \n",
                  __func__);
        return ret;
    }
    buff_len += IV_KEY_SIZE;

    // Concatenate id and IV for more secure key generation.
    memcpy(&buff[buff_len], id, id_len);
    buff_len += id_len;

    buff_len = ((buff_len + (AES_BLOCK_SIZE - 1)) / AES_BLOCK_SIZE) * AES_BLOCK_SIZE;

    ret = custom_so_derive_key(key,
                               sizeof(key),
                               buff,
                               buff_len);
    if (ret != WSM_RET_SUC)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_TA_DERIVE_KEY, "[%s] key derive error %d \n", __func__, ret);
        return ret;
    }

    size_t dstSize_local = srcSize;
    ret = WSMv2_CRYPTO_aes128_gcm_encrypt(dst, src, &dstSize_local, buff, key, AES_KEY_SIZE_128,
                                          NULL, 0);
    *dstSize = (__typeof__(*dstSize))dstSize_local;
    if (ret != WSM_RET_SUC)
    {
        _secure_memset(key, IV_KEY_SIZE);
        _secure_memset(buff, buff_len);
        return ret;
    }

    // Store aesGcmIV in to Secure Object end.
    memcpy(dst + *dstSize, buff, IV_KEY_SIZE);
    *dstSize += IV_KEY_SIZE;
    _secure_memset(key, IV_KEY_SIZE);
    _secure_memset(buff, buff_len);

    WSM_LOG(err_level_info, LOG_TAG, "[%s] out data  %d\n", __func__, (uint32_t)*dstSize);

    return ret;
}

// -----------------------------------------------------------------------------------------------------
return_t wrap_so2(const uint8_t          *src, uint32_t        srcSize, uint8_t         *dst,
                  uint32_t        *dstSize, const uint8_t   *id, uint32_t        id_len)
{
    uint8_t         key[IV_KEY_SIZE] = { 0 };
    return_t        ret = WSM_RET_SUC;
    uint8_t         buff[WSM_MAX_ID_LENGTH + IV_KEY_SIZE] = { 0 }; // IV + ID
    uint32_t        buff_len = 0;

    WSM_LOG(err_level_info, LOG_TAG, "[%s] data  %d\n", __func__, srcSize);

    if ((src == 0) ||
        (dst == 0) ||
        (dstSize == 0) ||
        (id == 0) ||
        (id_len > WSM_MAX_ID_LENGTH))
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_BAD_PARAM, "[%s] Input params are bad %p %p %p %p %d\n",
                  __func__,
                  src,
                  dst,
                  dstSize,
                  id,
                  id_len);
        return WSM_RET_E_BAD_PARAM;
    }

    // Generate IV for Secure Object encryption.
    ret = wsm_GenerateRandombuffer(IV_KEY_SIZE, buff);
    if (ret != WSM_RET_SUC)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_RANDOM,
                  "[%s] SVOIP_RETURN_GENERAL_FAILURE when Generate aesGcmIV for Secure Object encryption \n",
                  __func__);
        return ret;
    }
    buff_len += IV_KEY_SIZE;

    // Concatenate id and IV for more secure key generation.
    memcpy(&buff[buff_len], id, id_len);
    buff_len += id_len;

    buff_len = ((buff_len + (AES_BLOCK_SIZE - 1)) / AES_BLOCK_SIZE) * AES_BLOCK_SIZE;

    ret = custom_so_derive_key2(key,
                                sizeof(key),
                                buff,
                                buff_len);
    if (ret != WSM_RET_SUC)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_TA_DERIVE_KEY, "[%s] key derive error %d \n", __func__, ret);
        return ret;
    }

    size_t dstSize_local = srcSize;
    ret = WSMv2_CRYPTO_aes128_gcm_encrypt(dst, src, &dstSize_local, buff, key, AES_KEY_SIZE_128,
                                          NULL, 0);
    *dstSize = (__typeof__(*dstSize))dstSize_local;
    if (ret != WSM_RET_SUC)
    {
        _secure_memset(key, IV_KEY_SIZE);
        _secure_memset(buff, buff_len);
        return ret;
    }

    // Store aesGcmIV in to Secure Object end.
    memcpy(dst + *dstSize, buff, IV_KEY_SIZE);
    *dstSize += IV_KEY_SIZE;
    _secure_memset(key, IV_KEY_SIZE);
    _secure_memset(buff, buff_len);
    WSM_LOG(err_level_info, LOG_TAG, "[%s] out data  %d\n", __func__, (uint32_t)*dstSize);

    return ret;
}

// -----------------------------------------------------------------------------------------------------
return_t unwrap_so(const uint8_t   *src, uint32_t        srcSize, uint8_t         *dst,
                   uint32_t        *dstSize, const uint8_t   *id, uint32_t        id_len)
{
    WSM_LOG(err_level_info, LOG_TAG, "[%s] Entry \n", __func__);

    uint8_t     key[IV_KEY_SIZE] = { 0 };
    uint8_t     buff[WSM_MAX_ID_LENGTH + IV_KEY_SIZE] = { 0 }; // IV + ID
    uint32_t    buff_len = 0;
    return_t    ret = WSM_RET_SUC;

    if ((src == 0) ||
        (dst == 0) ||
        (dstSize == 0) ||
        (id == 0) ||
        (id_len > WSM_MAX_ID_LENGTH) ||
        (srcSize < IV_KEY_SIZE))
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_LENGTH_MISMATCH, "[%s] Input params are bad %p %p %p %p %d\n",
                  __func__,
                  src,
                  dst,
                  dstSize,
                  id,
                  id_len);
        return WSM_RET_E_LENGTH_MISMATCH;
    }

    // Cut aesGcmIV from Secure Object end and put it into aesGcmIV vector.
    memcpy(buff, src + srcSize - IV_KEY_SIZE, IV_KEY_SIZE);
    buff_len += IV_KEY_SIZE;

    // Concatenate id and aesGcmIV for more secure key generation.
    memcpy(&buff[buff_len], id, id_len);
    buff_len += id_len;

    buff_len = ((buff_len + (AES_BLOCK_SIZE - 1)) / AES_BLOCK_SIZE) * AES_BLOCK_SIZE;

    ret = custom_so_derive_key(key,
                               sizeof(key),
                               buff,
                               buff_len);
    if (ret != WSM_RET_SUC)
    {
        _secure_memset(key, IV_KEY_SIZE);
        _secure_memset(buff, buff_len);
        return ret;
    }

    size_t  dstSize_local = srcSize - IV_KEY_SIZE;
    ret = WSMv2_CRYPTO_aes128_gcm_decrypt(src, dst, &dstSize_local, buff, key, AES_KEY_SIZE_128,
                                          NULL, 0);
    *dstSize = (__typeof__(*dstSize))dstSize_local;

    _secure_memset(key, IV_KEY_SIZE);
    _secure_memset(buff, buff_len);

    return ret;
}

//
// Gear vp1 and vp2 backward compatibility
//
static return_t custom_so_derive_key2(uint8_t *key, uint32_t key_len, uint8_t *salt,
                                      uint32_t salt_len)
{
    (void)salt_len;

    WSM_LOG(err_level_info, LOG_TAG, "[%s] Entry key_len : %x, salt_len %x \n",
            __func__, key_len, salt_len);

    if ((key == 0) || (salt == 0))
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_BAD_PARAM, "[%s] input params are bad \n", __func__);
        return WSM_RET_E_BAD_PARAM;
    }

    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Salt for key:", (char *)salt, salt_len);

#if defined(MOBICORE) && defined(TA_BUILD)

    tlApiResult_t   res = tlApiDeriveKey(salt,
                                         salt_len,
                                         key,
                                         key_len,
                                         MC_SO_CONTEXT_TLT,
                                         MC_SO_LIFETIME_PERMANENT);
    if (res != TLAPI_OK)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_PBKDF, "[%s] key derive error %x \n", __func__, res);
        return WSM_RET_E_PBKDF;
    }

    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Derived key :", (const char *)key, key_len);

#elif defined(QSEE) && defined(TA_BUILD)

    size_t len = 0; // < input key length for qsee_kdf
    size_t label_len = 32;
    size_t context_len = 32;

    // NIST vectors
    uint8_t label[] = { 0x24, 0xe0, 0x9c, 0x85, 0x1f, 0x0b, 0xab, 0x38, 0xac, 0xf4,
                        0x51, 0xcb, 0x74, 0x98, 0x4b, 0x17, 0xb8, 0x29, 0x56, 0x3a,
                        0x44, 0x9e, 0xfd, 0x32, 0x12, 0x8d, 0x46, 0x27, 0x6c, 0x4f,
                        0x2f, 0x3a };
    uint8_t context[] = { 0x1a, 0x20, 0xc4, 0xbe, 0xe1, 0x74, 0xc0, 0x05, 0x11, 0x78,
                          0x5a, 0x79, 0xdf, 0x33, 0xbf, 0x89, 0x1a, 0xbe, 0x3a, 0x1c,
                          0x09, 0x81, 0x59, 0xc5, 0xcc, 0xf4, 0xb1, 0xd6, 0xaf, 0xd0,
                          0x59, 0x86 };

    for (unsigned int i = 0; ((i < label_len) && (i < salt_len)); ++i)
        label[i] ^= salt[i];

    for (unsigned int i = 0; ((i < context_len) && (i < salt_len)); ++i)
        context[i] ^= salt[i];

    // Qualcomm undocumented limitation of key_length (second paremeter).
    // key_length in qsee_kdf on SM-R730T_NA_TMB device must be not more then 32 bytes.
    if (salt_len > 32)
        len = 32;
    else
        len = (size_t)salt_len;

    int32_t     res = qsee_kdf(salt,
                               len,
                               label,
                               label_len,
                               context,
                               context_len,
                               key,
                               key_len);
    if (res < 0)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_PBKDF, "[%s] key derive error %x <<\n", __func__, res);
        return WSM_RET_E_PBKDF;
    }

    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Derived key :", (const char *)key, key_len);

#elif defined(GP_API) && defined(TA_BUILD)
    TEE_Result res = TEE_ERROR_GENERIC;

    TEE_ObjectHandle obj_hndl = TEE_HANDLE_NULL;

    res = TEE_AllocateTransientObject(TEE_TYPE_GENERIC_SECRET, // object type
                                      BITS(key_len), // max object size
                                      &obj_hndl); // ptr to obj hndl
    if (TEE_SUCCESS != res)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_PBKDF, "[%s] TEE_AllocateTransientObject error %x <<\n",
                  __func__, res);
        return WSM_RET_E_PBKDF;
    }

    res = TEES_DeriveKeyKDF(salt, // label
                            salt_len, // label_len
                            key, // context
                            key_len, // context len in bytes
                            key_len, // KDF key leng required in bytes
                            obj_hndl); // obj hndl
    if (TEE_SUCCESS != res)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_TA_DERIVE_KEY, "[%s] key derive error %x <<\n", __func__, res);
        goto free_obj_hndl;
    }

    res = TEE_GetObjectBufferAttribute(obj_hndl, // obj hndl
                                       TEE_ATTR_SECRET_VALUE, // attribute ID
                                       key, // out buff
                                       &key_len); // out buff len
    if (TEE_SUCCESS != res)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_TA_GET_OBJECT_ATTRIBUTE,
                  "[%s] TEE_GetObjectBufferAttribute error %x <<\n", __func__, res);
        goto free_obj_hndl;
    }

free_obj_hndl:

    TEE_ResetTransientObject(obj_hndl);
    TEE_CloseObject(obj_hndl);

    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Derived key :", (const char *)key, key_len);

    if (TEE_SUCCESS != res)
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_PBKDF, "[%s] key derive error %x <<\n", __func__, res);
        return WSM_RET_E_PBKDF;
    }
#else

    _secure_memset(key, key_len);

    #ifndef WSM_FUZZING
    WSM_LOG_E(LOG_TAG, WSM_RET_E_FAKED_DERIVE_KEY, "[%s] key derive is faked \n", __func__);

    #else
    WSM_LOG(err_level_warning, LOG_TAG, "[%s] key derive is faked \n", __func__);
    WSM_LOG_HEX(err_level_debug, LOG_TAG, "Derived key :", (char *)key, key_len);

    #endif /* ifdef FUZZING */

#endif /* if defined(MOBICORE) && defined(TA_BUILD) */

    WSM_LOG(err_level_info, LOG_TAG, "[%s] Exit >>\n", __func__);

    return WSM_RET_SUC;
}

return_t unwrap_so2(const uint8_t   *src, uint32_t        srcSize, uint8_t         *dst,
                    uint32_t        *dstSize, const uint8_t   *id, uint32_t        id_len)
{
    uint8_t     key[IV_KEY_SIZE] = { 0 };
    uint8_t     buff[WSM_MAX_ID_LENGTH + IV_KEY_SIZE] = { 0 }; // IV + ID
    uint32_t    buff_len = 0;
    return_t    ret = WSM_RET_SUC;

    WSM_LOG(err_level_info, LOG_TAG, "[%s] Entry \n", __func__);

    if ((src == 0) ||
        (dst == 0) ||
        (dstSize == 0) ||
        (id == 0) ||
        (id_len > WSM_MAX_ID_LENGTH) ||
        (srcSize < IV_KEY_SIZE))
    {
        WSM_LOG_E(LOG_TAG, WSM_RET_E_LENGTH_MISMATCH, "[%s] Input params are bad %p %p %p %p %d\n",
                  __func__,
                  src,
                  dst,
                  dstSize,
                  id,
                  id_len);
        return WSM_RET_E_LENGTH_MISMATCH;
    }

    // Cut aesGcmIV from Secure Object end and put it into aesGcmIV vector.
    memcpy(buff, src + srcSize - IV_KEY_SIZE, IV_KEY_SIZE);
    buff_len += IV_KEY_SIZE;

    // Concatenate id and aesGcmIV for more secure key generation.
    memcpy(&buff[buff_len], id, id_len);
    buff_len += id_len;

    buff_len = ((buff_len + (AES_BLOCK_SIZE - 1)) / AES_BLOCK_SIZE) * AES_BLOCK_SIZE;

    ret = custom_so_derive_key2(key,
                                sizeof(key),
                                buff,
                                buff_len);
    if (ret != WSM_RET_SUC)
    {
        _secure_memset(key, IV_KEY_SIZE);
        _secure_memset(buff, buff_len);
        return ret;
    }

    size_t dstSize_local = srcSize - IV_KEY_SIZE;
    ret = WSMv2_CRYPTO_aes128_gcm_decrypt(src, dst, &dstSize_local, buff, key, AES_KEY_SIZE_128,
                                          NULL, 0);
    *dstSize = (__typeof__(*dstSize))dstSize_local;

    _secure_memset(key, IV_KEY_SIZE);
    _secure_memset(buff, buff_len);

    return ret;
}
