#include "knoxai_prov.h"
#include "knoxai_drk.h"
#include "knoxai_rsa.h"
#include "knoxai_logger.h"
#include "knoxai_vendor_utils.h"

#include <openssl/evp.h>
#include <openssl/digest.h>

#include <tee_internal_api.h>


static uint8_t so_buffer[sizeof(drk_parsed_object_t)+SO_LOCAL_HEADERSIZE+16];
static uint32_t so_buffer_len = sizeof(drk_parsed_object_t)+SO_LOCAL_HEADERSIZE+16;

tz_knoxai_return_type knoxai_get_device_certificates(knoxai_get_device_certificates_payload_t *sendmsg, knoxai_get_device_certificates_payload_t *respmsg) {
    uint32_t ret                   = KNOXAI_FAILURE;
    uint32_t blob_len              = sendmsg->content.knoxai_req.wrapped_CERT_DRK.wrapped_key_len;
    uint8_t* blob_ptr              = sendmsg->content.knoxai_req.wrapped_CERT_DRK.wrapped_key_buf;
    uint8_t* blob_buffer_ptr       = NULL;
    uint32_t blob_buffer_len       = 0;
    uint32_t unwrapleaf_object_len = respmsg->content.knoxai_rsp.plain_CERT_LEAF.plain_key_len;
    uint8_t* unwrapleaf_object     = respmsg->content.knoxai_rsp.plain_CERT_LEAF.plain_key_buf;
    uint32_t unwraproot_object_len = respmsg->content.knoxai_rsp.plain_CERT_ROOT.plain_key_len;
    uint8_t* unwraproot_object     = respmsg->content.knoxai_rsp.plain_CERT_ROOT.plain_key_buf;
    uint32_t wrapsvc_object_len = respmsg->content.knoxai_rsp.wrapped_SVC_key.wrapped_key_len;
    uint8_t* wrapsvc_object     = respmsg->content.knoxai_rsp.wrapped_SVC_key.wrapped_key_buf;
    drk_parsed_object_t drk_data;
    drk_parsed_object_t *pDrk_data = &drk_data;
    
    if ( blob_ptr == NULL || unwrapleaf_object == NULL || unwraproot_object == NULL || wrapsvc_object == NULL
        || blob_len > MAX_WRAPPED_KEY_LEN || unwrapleaf_object_len > MAX_WRAPPED_KEY_LEN
        || unwraproot_object_len > MAX_WRAPPED_KEY_LEN || wrapsvc_object_len > MAX_WRAPPED_KEY_LEN ) {
        KNOXAI_LOG("rewrapProvData() input value is null or error\n");
        return ret;
    }
    blob_buffer_ptr = TEE_Malloc(MAX_WRAPPED_KEY_LEN, 0);
    if(blob_buffer_ptr == NULL) {
        ret = KNOXAI_FAILURE;
        KNOXAI_LOG("Error to alloc memory");
        goto exit;
    }
    blob_buffer_len = MAX_WRAPPED_KEY_LEN;
    ret = knoxai_unwrap((TEE_UUID)PROV_UUID, blob_ptr, blob_len, blob_buffer_ptr, &blob_buffer_len);
    if (ret != TEE_SUCCESS) {
        KNOXAI_LOG("rewrapProvData unwrap failed with ret=0x%08X, exit", ret);
        goto exit;
    }
    // KNOXAI_DBG_DUMP("unwrap", blob_buffer_ptr, blob_buffer_len);
    TEE_MemFill(pDrk_data, 0, sizeof(drk_parsed_object_t));
    ret = get_cert_chain_rsakey(blob_buffer_ptr, blob_buffer_len, pDrk_data);
    // b64 send?
    ret = convert_der_to_b64(pDrk_data->drk_cert_chain, pDrk_data->num_certificates);
    if (ret != TEE_SUCCESS) {
        KNOXAI_LOG("convert_der_to_b64 FAIL ret = %d", ret);
        goto exit;
    }
    //backup in memory
    ret = knoxai_wrap((TEE_UUID)SELF_UUID, (uint8_t*)pDrk_data, sizeof(drk_parsed_object_t), so_buffer, &so_buffer_len);
    if (ret != TEE_SUCCESS) {
        KNOXAI_LOG("rewrapProvData unwrap failed with ret=0x%08X, exit", ret);
        goto exit;
    }
    // size check
    KNOXAI_DEBUG_LOG("rewrap %lu, %lu, %d", sizeof(so_buffer), sizeof(drk_parsed_object_t), so_buffer_len);
    unwrapleaf_object_len = pDrk_data->drk_cert_chain[1].certificate_len;
    unwraproot_object_len = pDrk_data->drk_cert_chain[0].certificate_len;
    respmsg->content.knoxai_rsp.plain_CERT_LEAF.plain_key_len = unwrapleaf_object_len;
    TEE_MemMove(unwrapleaf_object, pDrk_data->drk_cert_chain[1].certificate, unwrapleaf_object_len);
    respmsg->content.knoxai_rsp.plain_CERT_ROOT.plain_key_len = unwraproot_object_len;
    TEE_MemMove(unwraproot_object, pDrk_data->drk_cert_chain[0].certificate, unwraproot_object_len);
    respmsg->content.knoxai_rsp.wrapped_SVC_key.wrapped_key_len = so_buffer_len;
    TEE_MemMove(wrapsvc_object, so_buffer, so_buffer_len);

    ret = KNOXAI_SUCCESS;
    respmsg->content.knoxai_rsp.result_code = ret;
exit:
    if ( blob_buffer_ptr != NULL) {
        TEE_MemFill(blob_buffer_ptr, 0, MAX_WRAPPED_KEY_LEN);
        TEE_Free(blob_buffer_ptr);
        blob_buffer_ptr = NULL;
    }
    TEE_MemFill(&drk_data, 0, sizeof(drk_parsed_object_t));
    pDrk_data = NULL;
    so_buffer_len = sizeof(drk_parsed_object_t)+SO_LOCAL_HEADERSIZE+16;
    TEE_MemFill(so_buffer, 0, so_buffer_len);
    return ret;
}

tz_knoxai_return_type knoxai_set_prov_data(knoxai_set_provision_payload_t *sendmsg, knoxai_set_provision_payload_t *respmsg) {
    uint32_t ret                   = KNOXAI_FAILURE;
    uint32_t blob_len              = sendmsg->content.knoxai_req.encrypted_PROV_key.wrapped_key_len;
    uint8_t* blob_ptr              = sendmsg->content.knoxai_req.encrypted_PROV_key.wrapped_key_buf;
    uint32_t wrapsvc_object_len    = sendmsg->content.knoxai_req.wrapped_SVC_key.wrapped_key_len;
    uint8_t* wrapsvc_object        = sendmsg->content.knoxai_req.wrapped_SVC_key.wrapped_key_buf;
    uint8_t  wrap_object[REAL_WRAPPED_KEY_LEN];
    uint32_t wrap_object_len       = REAL_WRAPPED_KEY_LEN;
    uint32_t decrypted_object_len  = 0;
    uint8_t* decrypted_object      = NULL;
    drk_parsed_object_t drk_data;
    drk_parsed_object_t *pDrk_data = &drk_data;
    uint32_t sk_len                = 0;
    uint8_t  sk[KNOXAI_FAC_KEY_LEN];
    uint32_t kek_len               = 0;
    uint8_t  kek[KNOXAI_FAC_KEY_LEN];
    uint32_t salt_len              = 0;
    uint8_t  salt[SALT_32_BYTES_LEN];
    int      ref_next              = 0;
    uint32_t local_so_len          = sizeof(drk_parsed_object_t);
    
    if ( blob_ptr == NULL || wrapsvc_object == NULL
         || blob_len > MAX_WRAPPED_KEY_LEN || wrapsvc_object_len > MAX_WRAPPED_KEY_LEN ) {
        KNOXAI_LOG("getProvData() input value is null or error");
        return ret;
    }
    //restore in memory
    // KNOXAI_DEBUG_LOG("size[%d], [%d]", local_so_len, so_buffer_len);
    ret = knoxai_unwrap((TEE_UUID)SELF_UUID, wrapsvc_object, wrapsvc_object_len, (uint8_t*)pDrk_data, &local_so_len);    
    if (ret != TEE_SUCCESS || local_so_len != sizeof(drk_parsed_object_t)) {
        KNOXAI_DEBUG_LOG("rewrapProvData unwrap failed with ret=0x%08X, exit", ret);
        goto exit;
    }
    //KNOXAI_DBG_DUMP("modulus", pDrk_data->drk_rsa_private_key.modulus, pDrk_data->drk_rsa_private_key.modulus_len);
    //KNOXAI_DBG_DUMP("exppub", pDrk_data->drk_rsa_private_key.pub_expo, pDrk_data->drk_rsa_private_key.pub_expo_len);
    // decrypt blob
    decrypted_object = TEE_Malloc(MAX_WRAPPED_KEY_LEN, 0);
    if ( decrypted_object == NULL ) {
        KNOXAI_LOG("malloc error");
        goto exit;
    }
    decrypted_object_len = MAX_WRAPPED_KEY_LEN;
    ret = knoxai_rsa_decrypt(pDrk_data->drk_rsa_private_key.modulus, pDrk_data->drk_rsa_private_key.modulus_len,
         pDrk_data->drk_rsa_private_key.pub_expo, pDrk_data->drk_rsa_private_key.pub_expo_len,
         pDrk_data->drk_rsa_private_key.priv_expo, pDrk_data->drk_rsa_private_key.priv_expo_len,
         blob_ptr, blob_len, decrypted_object, &decrypted_object_len);
    if ( ret != KNOXAI_SUCCESS ) {
        KNOXAI_LOG("dec error, %d", ret);
        goto exit;
    }
    // key parse
    // format: kek(32 bytes) | salt(32 bytes)
    if ( decrypted_object_len < KNOXAI_FAC_KEY_LEN*2 ) {
        KNOXAI_LOG("message is wrong %d", decrypted_object_len);
        goto exit;
    }
    ref_next = 0;
    kek_len = KNOXAI_FAC_KEY_LEN;
    TEE_MemMove(kek, decrypted_object + ref_next, kek_len);
    ref_next += kek_len;
    salt_len = KNOXAI_FAC_KEY_LEN;
    TEE_MemMove(salt, decrypted_object + ref_next, salt_len);
    // key gen
    sk_len = KNOXAI_FAC_KEY_LEN;
    if(!PKCS5_PBKDF2_HMAC((const char *)kek, kek_len, (const unsigned char *)salt, salt_len,
                            2, EVP_sha256(), sk_len, sk)) {
        KNOXAI_LOG("PKCS5_PBKDF2_HMAC(pwd, salt, out) failed");
        ret = KNOXAI_FAILURE;
        goto exit;
    }
    // key length check 
    if ( sk_len != KNOXAI_FAC_KEY_LEN) {
        KNOXAI_LOG("kdf(key length wrong) failed, %d", sk_len);
        ret = KNOXAI_FAILURE;
        goto exit;
    }
#if DEBUG_KNOXAI
    { // printout public x,y
        uint32_t getPublicFromPrivate(uint8_t *pri, uint32_t key_len, uint8_t *out_x, uint8_t *out_y);
        uint8_t             pubPairX[KNOXAI_FAC_KEY_LEN] = {0};
        uint8_t             pubPairY[KNOXAI_FAC_KEY_LEN] = {0};
        getPublicFromPrivate(sk, sk_len, pubPairX, pubPairY);
        KNOXAI_DBG_DUMP("public x", pubPairX, sk_len);
        KNOXAI_DBG_DUMP("public y", pubPairY, sk_len);
    }
#endif
    // KNOXAI_DBG_DUMP("sk", sk, sk_len);
    ret = knoxai_wrap((TEE_UUID)SELF_UUID, sk, sk_len, wrap_object, &wrap_object_len);
    if ( ret != KNOXAI_SUCCESS ) {
        KNOXAI_LOG("wrap error, %d", ret);
        goto exit;
    }
    KNOXAI_DBG_DUMP("wrap", wrap_object, (wrap_object_len>384? 384: wrap_object_len));
    TEE_MemMove(respmsg->content.knoxai_rsp.wrapped_PROV_key.wrapped_key_buf, wrap_object, wrap_object_len);
    respmsg->content.knoxai_rsp.wrapped_PROV_key.wrapped_key_len = wrap_object_len;
    /*{
        uint8_t  plainCekKey[REAL_WRAPPED_KEY_LEN] = {0};
        uint32_t plainCekKeyLen = REAL_WRAPPED_KEY_LEN;
        ret = knoxai_unwrap((TEE_UUID)SELF_UUID, wrap_object, wrap_object_len, plainCekKey, &plainCekKeyLen);
        KNOXAI_DEBUG_LOG("unwrap ret, %d", ret);
        KNOXAI_DBG_DUMP("unwrap", plainCekKey, plainCekKeyLen);
    }*/
exit:
    TEE_MemFill(sk, 0, KNOXAI_FAC_KEY_LEN);
    TEE_MemFill(kek, 0, KNOXAI_FAC_KEY_LEN);
    TEE_MemFill(salt, 0, SALT_32_BYTES_LEN);
    if ( decrypted_object != NULL) {
        TEE_MemFill(decrypted_object, 0, MAX_WRAPPED_KEY_LEN);
        TEE_Free(decrypted_object);
        decrypted_object = NULL;
    }
    TEE_MemFill(&drk_data, 0, sizeof(drk_parsed_object_t));
    pDrk_data = NULL;
    return ret;
}
