#include "kg_policy_cmd.h"
#include "kg_read_data.h"

uint32_t KG_get_policy(tz_get_policy_payload_t *sendmsg, tz_get_policy_payload_t *respmsg) {
    KG_LOG("KG get policy\n");
    uint32_t ret = KG_SUCCESS;
    kg_rpmb_info_t* info = NULL;
    uint8_t* policy_file = NULL;

#ifdef CONFIG_QSEE
    if (KG_SUCCESS != (ret = kg_rpmb_init())) {
        KG_LOG("KG rpmb is unavailable when verifying policy\n");
        goto exit;
    }
#endif    

    if (KG_SUCCESS != (ret = read_info_object(&info))
        || NULL == info) {
        KG_LOG("failed to read info object\n");
        goto exit;
    }

    if(info->policy_file_len > KG_POLICY_LEN_MAX){
        KG_LOG("Received invalid size for policy_file\n");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

#ifdef CONFIG_QSEE
/* error: taking address of packed member 'policy_file_len' of class or structure 'kg_rpmb_info' may result in an unaligned pointer value [-Werror,-Waddress-of-packed-member]*/    
    uint32_t temp_policy_file_len = info->policy_file_len;
    if (KG_SUCCESS != (ret = read_policy_file(&temp_policy_file_len, &policy_file))
        || NULL == policy_file) {
        KG_LOG("failed to read policy file\n");
        goto exit;
    }
    info->policy_file_len = temp_policy_file_len;
#else
    if (KG_SUCCESS != (ret = read_policy_file(&(info->policy_file_len), &policy_file))
        || NULL == policy_file) {
        KG_LOG("failed to read policy file\n");
        goto exit;
    }
#endif
    // TODO, return policy file ONLY in clear text format
    KG_LOG_DBG("policy content is %s\n", policy_file);
    
    respmsg->payload.resp.policy_len = info->policy_file_len;
    TEE_MemMove(respmsg->payload.resp.policy_buf, policy_file, respmsg->payload.resp.policy_len);

exit:
    if (info != NULL) {
        TEE_Free(info);
        info = NULL;
    }
    if (policy_file != NULL) {
        TEE_Free(policy_file);
        policy_file = NULL;
    }
    return ret;
}

uint32_t KG_verify_policy(tz_verify_policy_payload_t *sendmsg, tz_verify_policy_payload_t *respmsg) {
    KG_LOG("[TRACE] KG_verify_policy start\n");
    uint32_t ret = KG_SUCCESS;
    uint8_t *policy_buffer = NULL;
    uint32_t policy_buffer_len = KG_POLICY_LEN_MAX;
    uint8_t *policy_sig = NULL;
    uint32_t policy_sig_len = KG_POLICY_SIGN_MAX;
    uint8_t *pub_mod = NULL;
    uint32_t pub_mod_len = KG_BUF_LEN;
    uint8_t *pub_exp = NULL;
    uint32_t pub_exp_len = KG_BUF_LEN;

    kg_rpmb_info_t* info = NULL;

    uint8_t* unwrap_data = NULL;
    uint32_t unwrap_data_len = KG_BUF_LEN;
    uint8_t* wrap_data = NULL;
    
    uint8_t* rewrap_data = NULL;
    uint32_t rewrap_data_len = KG_SECURE_DATA_LEN;

    if (sendmsg->payload.cmd.policy_len > KG_POLICY_LEN_MAX) {
        KG_LOG_DBG("Recived invalid policy buffer\n");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

    if (sendmsg->payload.cmd.policy_sign_len > KG_POLICY_SIGN_MAX) {
        KG_LOG_DBG("Recived invalid policy buffer\n");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

    policy_buffer = TEE_Malloc(policy_buffer_len, 0);
    if (NULL == policy_buffer) {
        KG_LOG("Failed to alloc buffer for policy b64\n");
        ret = KG_ALLOC_BUFFER_FAIL;
        goto exit;
    }

    policy_sig = TEE_Malloc(policy_sig_len, 0);
    if (NULL == policy_sig) {
        KG_LOG("Failed to alloc buffer for policy sig b64\n");
        ret = KG_ALLOC_BUFFER_FAIL;
        goto exit;
    }

    pub_mod = TEE_Malloc(pub_mod_len, 0);
    if (NULL == pub_mod) {
        KG_LOG("Failed to alloc buffer for pub_mod\n");
        ret = KG_ALLOC_BUFFER_FAIL;
        goto exit;
    }

    pub_exp = TEE_Malloc(pub_exp_len, 0);
    if (NULL == pub_exp) {
        KG_LOG("Failed to alloc buffer for pub_mod\n");
        ret = KG_ALLOC_BUFFER_FAIL;
        goto exit;
    }

    KG_LOG_DBG("Input policy len %d\n", sendmsg->payload.cmd.policy_len);
    KG_LOG_DBG("Input policy content %s\n", sendmsg->payload.cmd.policy_buf);
    if (BASE64_OK != base64_decode(sendmsg->payload.cmd.policy_buf,
            sendmsg->payload.cmd.policy_len, policy_buffer, &policy_buffer_len)) {
        KG_LOG("Failed to decode base64 encoded input policy buffer\n");
        ret = KG_BASE64_DECODE_FAIL;
        goto exit;
    }
    if (policy_buffer_len > KG_POLICY_LEN_MAX) {
        KG_LOG("Decoded policy buffer length check failed\n");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

    KG_LOG_DBG("Input policy signature len %d\n", sendmsg->payload.cmd.policy_sign_len);
    KG_LOG_DBG("Input policy signature content %s\n", sendmsg->payload.cmd.policy_sign_buf);
    if (BASE64_OK != base64_decode(sendmsg->payload.cmd.policy_sign_buf,
            sendmsg->payload.cmd.policy_sign_len, policy_sig, &policy_sig_len)) {
        KG_LOG("Failed to decode base64 encoded input policy signature\n");
        ret = KG_BASE64_DECODE_FAIL;
        goto exit;
    }
    if (policy_sig_len != KG_SIG_LEN) {
        KG_LOG("Decoded policy signature buffer length check failed\n");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

#ifdef CONFIG_QSEE
    if (KG_SUCCESS != (ret = kg_rpmb_init())) {
        KG_LOG("KG rpmb is unavailable when verifying policy\n");
        goto exit;
    }
#endif    

    if (KG_SUCCESS != (ret = read_info_object(&info))
        || NULL == info) {
        KG_LOG("failed to read info object\n");
        goto exit;
    }

    if(info->kg_wrap_data_len > KG_SECURE_DATA_LEN){
        KG_LOG("Received invalid size for wrap_data\n");
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }

#ifdef CONFIG_QSEE
/* error: taking address of packed member 'kg_wrap_data_len' of class or structure 'kg_rpmb_info' may result in an unaligned pointer value [-Werror,-Waddress-of-packed-member]*/ 
    uint32_t temp_wrap_data_len = info->kg_wrap_data_len;
    if (KG_SUCCESS != (ret = read_wrap_data(&temp_wrap_data_len, &wrap_data))
        || NULL == wrap_data) {
        KG_LOG("failed to read wrap data\n");
        goto exit;
    }
    info->kg_wrap_data_len = temp_wrap_data_len;
#else
    if (KG_SUCCESS != (ret = read_wrap_data(&(info->kg_wrap_data_len), &wrap_data))
        || NULL == wrap_data) {
        KG_LOG("failed to read wrap data\n");
        goto exit;
    }
#endif
    unwrap_data = TEE_Malloc(unwrap_data_len, 0);
    if (NULL == unwrap_data) {
        KG_LOG("failed to alloc unwrap data\n");
        ret = KG_ALLOC_BUFFER_FAIL;
        goto exit;
    }

#ifdef CONFIG_QSEE
    if (TZ_API_OK != TZ_unwrap_persist_data((uint8_t *)KG_NAME, strlen(KG_NAME), 
        wrap_data, info->kg_wrap_data_len, unwrap_data, &unwrap_data_len)) {
        KG_LOG("Failed to unwrap kg metadata structure on QC\n");
        ret = KG_TZ_API_FAIL;
        goto exit;
    }
#else
    if (TZ_API_OK != TZ_unwrap_data_with_derived_key((uint8_t *)KG_NAME, strlen(KG_NAME), 
        wrap_data, info->kg_wrap_data_len, unwrap_data, &unwrap_data_len)) {
        KG_LOG("failed to unwrap data\n")
        ret = KG_TZ_API_FAIL;
        goto exit;
    }
#endif

    if (unwrap_data_len != sizeof(kg_secure_data_t)) {
        KG_LOG("KG TA recovering unwraped secure data size check failed\n");
        ret = KG_RPMB_UNWRAP_FAIL;
        goto exit;
    }

    kg_secure_data_t* secure_data = (kg_secure_data_t*)unwrap_data;
    if (secure_data->kg_metadata.reg_info == REGION_EU) {
        if (KG_SUCCESS != extract_public_keybytes(policy_sign_cert_EU, pub_mod, &pub_mod_len, pub_exp, &pub_exp_len)) {
            KG_LOG_DBG("Failed to extract publick key bytes from policy signing cert EU\n");
            ret = KG_CRYPTO_PKEY_PARSE_FAIL;
            goto exit;
        }
    } else if (secure_data->kg_metadata.reg_info == REGION_US) {
        if (KG_SUCCESS != extract_public_keybytes(policy_sign_cert_US, pub_mod, &pub_mod_len, pub_exp, &pub_exp_len)) {
            KG_LOG_DBG("Failed to extract publick key bytes from policy signing cert US\n");
            ret = KG_CRYPTO_PKEY_PARSE_FAIL;
            goto exit;
        }
    } else {
        KG_LOG("target region is unknown %d\n", secure_data->kg_metadata.reg_info);
        ret = KG_REGION_NOT_CONFIG;
        goto exit;
    }

    KG_DUMP_DBG("dump pubkey n:\n", pub_mod, pub_mod_len);
    KG_DUMP_DBG("dump pubkey e:\n", pub_exp, pub_exp_len);
    bool verified = false;
    if (TZ_API_OK != TZ_verify_CKM_SHA256_RSA_PKCS(pub_mod, pub_mod_len, pub_exp, pub_exp_len,
            policy_buffer, policy_buffer_len, policy_sig, KG_SIG_LEN, &verified)) {
        KG_LOG_DBG("TZ API failed when verifying registration info\n");
        ret = KG_TZ_API_FAIL;
        goto exit;
    }
    if (false == verified) {
        KG_LOG_DBG("Failed to verify policy signature signature\n");
        ret = KG_POLICY_VERIFY_FAIL;
        goto exit;
    }
    KG_LOG("Policy signature has been verified successfully\n");

    // TODO: Update with end-to-end test cases
    if (KG_SUCCESS != parse_policy((char *)policy_buffer, policy_buffer_len, &(secure_data->kg_metadata))) {
        KG_LOG_DBG("Failed to verify the policy\n");
        ret = KG_POLICY_VERIFY_FAIL;
        goto exit;
    }

    // save kg_wrap_data
    rewrap_data = TEE_Malloc(rewrap_data_len, 0);
    if (NULL == rewrap_data) {
        KG_LOG("KG TA failed to alloc buffer to hold wrap data\n");
        ret = KG_ALLOC_BUFFER_FAIL;
        goto exit;
    }

#ifdef CONFIG_QSEE
    if (TZ_API_OK != TZ_wrap_persist_data((uint8_t *)KG_NAME, strlen(KG_NAME), 
        unwrap_data, sizeof(kg_secure_data_t), rewrap_data, &rewrap_data_len)) {
        KG_LOG("Failed to wrap kg secure data structure in QC\n");
        ret = KG_TZ_API_FAIL;
        goto exit;
    }
#else
    if (TZ_API_OK != TZ_wrap_data_with_derived_key((uint8_t *)KG_NAME, strlen(KG_NAME), 
        (uint8_t *)secure_data, sizeof(kg_secure_data_t), rewrap_data, &rewrap_data_len)) {
        KG_LOG("Failed to wrap kg secure data structure in QC\n");
        ret = KG_TZ_API_FAIL;
        goto exit;
    }
#endif

    if (rewrap_data_len > KG_SECURE_DATA_LEN) {
        KG_LOG("Wraped kg secure data size overflow\n");
        ret = KG_RPMB_WRAP_FAIL;
        goto exit;
    }

    info->kg_wrap_data_len = rewrap_data_len;
    if (KG_SUCCESS != (ret = write_wrap_data(info->kg_wrap_data_len, rewrap_data))) {
        KG_LOG("failed to write wrap data\n");
        goto exit;
    }

    info->policy_file_len = policy_buffer_len;
    if (KG_SUCCESS != (ret = write_policy_file(info->policy_file_len, policy_buffer))) {
        KG_LOG("failed to write policy file\n");
        goto exit;
    }

    if (KG_SUCCESS != (ret = write_info_object(info))) {
        KG_LOG("failed to write info object\n");
        goto exit;
    }

    TEE_MemFill(respmsg->payload.resp.policy_buf, 0, KG_POLICY_LEN_MAX);
    
    
    if(policy_buffer_len > KG_POLICY_LEN_MAX){
        ret = KG_BUFFER_SIZE_FAIL;
        goto exit;
    }
    respmsg->payload.resp.policy_len = policy_buffer_len;
    TEE_MemMove(respmsg->payload.resp.policy_buf, policy_buffer, policy_buffer_len);
    ret = KG_SUCCESS;
exit:
    if (policy_sig != NULL) {
        TEE_Free(policy_sig);
        policy_sig = NULL;
    }
    if (policy_buffer != NULL) {
        TEE_Free(policy_buffer);
        policy_buffer = NULL;
    }
    if (pub_exp != NULL) {
        TEE_Free(pub_exp);
        pub_exp = NULL;
    }
    if (pub_mod != NULL) {
        TEE_Free(pub_mod);
        pub_mod = NULL;
    }
    if (info != NULL) {
        TEE_Free(info);
        info = NULL;
    }
    if (wrap_data != NULL) {
        TEE_Free(wrap_data);
        wrap_data = NULL;
    }
    if (unwrap_data != NULL) {
        TEE_MemFill(unwrap_data, 0, unwrap_data_len);
        TEE_Free(unwrap_data);
        unwrap_data = NULL;
    }
    if (rewrap_data != NULL) {
        TEE_Free(rewrap_data);
        rewrap_data = NULL;
    }
    KG_LOG("[TRACE] KG_verify_policy end\n");
    return ret;
}
