#include "ta_logger.h"

#include "tzWrappers/TzwMemory.h"

#include "ifaa_tlv_parser.h"
#include "ifaa_mem_utils.h"
#include "ifaa_tlv_common.h"

/*#include "ifaa_log_utils.h"*/
#include <string.h>

tlv_node_t* free_tlv_tree(tlv_node_t* node)
{
    if (node == NULL) return NULL;

    tlv_node_t* ptr = node->child;
    while (ptr != NULL) {
        ptr = free_tlv_tree(ptr);
    }

    tlv_node_t* ret = node->sibling;
    tzwFree(node);

    return ret;
}

tlv_node_t* find_node_by_tag(tlv_node_t* node, tag_t tag)
{
    if (node->tag == tag) {
        return node;
    }

    tlv_node_t* ptr = node->child;
    while (ptr != NULL) {
        tlv_node_t* ret = find_node_by_tag(ptr, tag);
        if (ret != NULL) return ret;
        ptr = ptr->sibling;
    }

    return NULL;
}

bool is_directory(tag_t tag)
{
    return !(tag & TAG_LEAF_NODE_BASE);
}


int32_t parse_request_internal(tlv_node_t* node, uint8_t* buf, uint32_t len) {
    if (len < TLSIZE || len < GET_NODE_SIZE(buf)) {
        return IFAA_ERR_BAD_PARAM;
    }

    node->tag = read16(buf);
    node->length = GET_NODE_DATA_SIZE(buf);
    node->value = GET_NODE_DATA_PTR(buf);

    if (!is_directory(node->tag)) return GET_NODE_SIZE(buf);

    buf = node->value;
    len = node->length;

    tlv_node_t* ptr = (tlv_node_t*)tzwMalloc(sizeof(tlv_node_t));
    if (ptr == NULL) {
        LOG_E("ptr is NULL");
        return IFAA_ERR_MALLOC_FAILED;
    }
    memset(ptr, 0, sizeof(tlv_node_t));
    
    if (ptr == NULL) return IFAA_ERR_OUT_OF_MEM;
    node->child = ptr;
    while (true) {
        int32_t ret = parse_request_internal(ptr, buf, len);
        // propagate error
        if (ret < 0) return ret;
        if (len < ret) return IFAA_ERR_BAD_PARAM;
        if (len == ret) break;

        buf += ret;
        len -= ret;
        ptr->sibling = (tlv_node_t*)tzwMalloc(sizeof(tlv_node_t));
        if (ptr->sibling == NULL){
            LOG_E("ptr->sibling is NULL");
            return IFAA_ERR_MALLOC_FAILED;
        }
        memset(ptr->sibling, 0, sizeof(tlv_node_t));
        
        if (ptr->sibling == NULL) return IFAA_ERR_OUT_OF_MEM;
        ptr = ptr->sibling;
    }

    return (node->length + TLSIZE);
}

IFAA_Result parse_request(vlb_t* req, request_type type, tlv_node_t** node, IFAA_Validator validator)
{
    *node = (tlv_node_t*)tzwMalloc(sizeof(tlv_node_t));
    
    if (*node == NULL) return IFAA_ERR_OUT_OF_MEM;

    int32_t parse_ret = parse_request_internal(*node, req->buf, req->len);
    if (parse_ret < 0) {
        LOG_E("parse_request_internal failed: 0x%08x", parse_ret);
        free_tlv_tree(*node);
        return IFAA_ERR_BAD_PARAM;
    }

    return validator(*node);
}

uint32_t cal_node_length(tlv_node_t* node) {
    if (node == NULL) return 0;
    if (node->length != 0 || node->child == NULL) {
        return node->length;
    }
    
    tlv_node_t* ptr = node->child;
    uint32_t len = 0;
    while (ptr != NULL) {
        len += (TLSIZE + cal_node_length(ptr));
        ptr = ptr->sibling;
    }
    return (node->length = len);
}

uint8_t *construct_node_buf(tlv_node_t *node, uint8_t *buf, uint32_t buf_len) {
    if (node == NULL) return buf;

    if (buf_len < TLSIZE) {
        LOG_E("construct_node_buf ,can't use memcpy operation because of overflow.");
        return NULL;
    }
    write16(buf, node->tag);
    write16(buf + 2, node->length);

    uint8_t *buf_ptr = buf + TLSIZE;
    buf_len = buf_len - TLSIZE;
    if (node->value != NULL) {
        if (buf_len < (uint32_t) node->length) {
            LOG_E(
                    "construct_node_buf node->value,can't use memcpy operation because of overflow.");
            return NULL;
        }
        memcpy((void *) buf_ptr, (void *) node->value, (uint32_t) node->length);
        buf_len = buf_len - (uint32_t) node->length;
        return buf_ptr + node->length;
    }
    
    node->value = buf_ptr;
    
    tlv_node_t* node_ptr = node->child;
    while (node_ptr != NULL) {
        buf_ptr = construct_node_buf(node_ptr, buf_ptr, buf_len);
        node_ptr = node_ptr->sibling;
    }
    
    return buf_ptr;
}



IFAA_Result generate_reg_response(const vlb_t *user_token,
                                  const vlb_t *pub_alg_encode, const vlb_t *pubkey,
                                  const vlb_t *key_type, const vlb_t *challenge,
                                  const vlb_t *device_id, const vlb_t *reg_type,
                                  const vlb_t *reg_info, const vlb_t *idlist, const vlb_t *level,
                                  const vlb_t *ext_info, const vlb_t *sign_algorithm,
                                  IFAA_RegRespSigner signer, vlb_t *response)
{


    CREATE_TLV_NODE(node_token, TAG_TOKEN, user_token->len, user_token->buf);
    CREATE_TLV_NODE(node_pub_alg_encode, TAG_PUB_ALG_ENCODE, pub_alg_encode->len, pub_alg_encode->buf);
    CREATE_TLV_NODE(node_pubkey, TAG_PUBKEY, pubkey->len, pubkey->buf);
    CREATE_TLV_NODE(node_key_type, TAG_KEY_TYPE, key_type->len, key_type->buf);
    CREATE_TLV_NODE(node_challenge, TAG_CHALLENGE, challenge->len, challenge->buf);
    CREATE_TLV_NODE(node_deviceid, TAG_DEVICE_ID, device_id->len, device_id->buf);
    CREATE_TLV_NODE(node_reg_type, TAG_REG_TYPE, reg_type->len, reg_type->buf);
    CREATE_TLV_NODE(node_level, TAG_LEVEL, level->len, level->buf);
    
    CREATE_TLV_NODE(node_krd, TAG_KRD, 0, NULL);
    
    TLV_ADD_CHILD(node_krd, node_token);
    TLV_ADD_CHILD(node_krd, node_pub_alg_encode);
    TLV_ADD_CHILD(node_krd, node_pubkey);
    TLV_ADD_CHILD(node_krd, node_key_type);
    TLV_ADD_CHILD(node_krd, node_challenge);
    TLV_ADD_CHILD(node_krd, node_deviceid);
    TLV_ADD_CHILD(node_krd, node_reg_type);
    TLV_ADD_CHILD(node_krd, node_level);
    
    CREATE_TLV_NODE(node_reg_info, TAG_REGINFO, 0, NULL);
    CREATE_TLV_NODE(node_ext_info, TAG_EXTINFO, 0, NULL);
    CREATE_TLV_NODE(node_idlist, TAG_TMPIDLIST, 0, NULL);
    
    if (reg_info != NULL) {
        SET_TLV_VALUE(node_reg_info, TAG_REGINFO, reg_info->len, reg_info->buf);
        TLV_ADD_CHILD(node_krd, node_reg_info);
    }
    if (ext_info != NULL) {
        SET_TLV_VALUE(node_ext_info, TAG_EXTINFO, ext_info->len, ext_info->buf);
        TLV_ADD_CHILD(node_krd, node_ext_info)
    }
    if (idlist != NULL) {
        SET_TLV_VALUE(node_idlist, TAG_TMPIDLIST, idlist->len, idlist->buf);
        TLV_ADD_CHILD(node_krd, node_idlist);
    }
   
    
    uint32_t buf_len = cal_node_length(&node_krd) + TLSIZE;
    uint8_t* buf = (uint8_t*)tzwMalloc(buf_len);

    if (buf == NULL) {
        LOG_E("tzwMalloc for node_krd failed");
        return IFAA_ERR_OUT_OF_MEM;
    }

    construct_node_buf(&node_krd, buf, buf_len);

    uint8_t sig[256] = {0};
    uint32_t sig_len = sizeof(sig);

    IFAA_Result ret = signer(node_krd.value, node_krd.length, sig, &sig_len);


    do {
        if (ret != IFAA_ERR_SUCCESS) {
            LOG_E("TEE_AuthenticatorSignDigest failed: 0x%08x", ret);
            ret = IFAA_ERR_SIGN;
            break;
        } else {
            LOG_D("skpm sign result =0x%08x", ret);
        }
        
        CREATE_TLV_NODE(node_root, TAG_REG_RESPONSE, 0, NULL);
        CREATE_TLV_NODE(node_algorithm, TAG_SIGN_ALGORITHM, sign_algorithm->len, sign_algorithm->buf);
        CREATE_TLV_NODE(node_signature, TAG_SIGNATURE, sig_len, sig);
        TLV_ADD_CHILD(node_root, node_signature);
        TLV_ADD_CHILD(node_root, node_algorithm);
        TLV_ADD_CHILD(node_root, node_krd);
        
        response->len = cal_node_length(&node_root) + TLSIZE;
        response->buf = (uint8_t*)tzwMalloc(response->len);
        
        if (response->buf == NULL) {
            LOG_E("tzwMalloc for node_root failed");
            ret = IFAA_ERR_OUT_OF_MEM;
            break;
        }
        construct_node_buf(&node_root, response->buf,response->len);
    } while (false);
    
    tzwFree(buf);

    return ret;
}



IFAA_Result generate_auth_response(const vlb_t *auth_token,
                                   const vlb_t *challenge, const vlb_t *device_id,
                                   const vlb_t *auth_type, const vlb_t *auth_info,
                                   const vlb_t *idlist, 
                                   const vlb_t *level, const vlb_t *ext_info,
                                   const vlb_t *sign_algorithm,
                                   IFAA_AuthRespSigner signer, IFAA_AsymKeyAndType *key, vlb_t *response)
{
    CREATE_TLV_NODE(node_auth_token, TAG_TOKEN, auth_token->len, auth_token->buf);
    CREATE_TLV_NODE(node_challenge, TAG_CHALLENGE, challenge->len, challenge->buf);
    CREATE_TLV_NODE(node_deviceid, TAG_DEVICE_ID, device_id->len, device_id->buf);
    CREATE_TLV_NODE(node_authtype, TAG_AUTH_TYPE, auth_type->len, auth_type->buf);
    CREATE_TLV_NODE(node_level, TAG_LEVEL, level->len, level->buf);
    
    CREATE_TLV_NODE(node_signed_data, TAG_SIGNED_DATA, 0, NULL);
    
    TLV_ADD_CHILD(node_signed_data, node_auth_token);
    TLV_ADD_CHILD(node_signed_data, node_challenge);
    TLV_ADD_CHILD(node_signed_data, node_deviceid);
    TLV_ADD_CHILD(node_signed_data, node_authtype);
    TLV_ADD_CHILD(node_signed_data, node_level);

    CREATE_TLV_NODE(node_authinfo, TAG_AUTHINFO, 0, NULL);
    CREATE_TLV_NODE(node_ext_info, TAG_EXTINFO, 0, NULL);
    CREATE_TLV_NODE(node_idlist, TAG_TMPIDLIST, 0, NULL);

    if (auth_info != NULL) {
        SET_TLV_VALUE(node_authinfo, TAG_AUTHINFO, auth_info->len, auth_info->buf);
        TLV_ADD_CHILD(node_signed_data, node_authinfo);
    }
    if (ext_info != NULL) {
        SET_TLV_VALUE(node_ext_info, TAG_EXTINFO, ext_info->len, ext_info->buf);
        TLV_ADD_CHILD(node_signed_data, node_ext_info);
    }
    if(idlist != NULL) {
        SET_TLV_VALUE(node_idlist, TAG_TMPIDLIST, idlist->len, idlist->buf);
        TLV_ADD_CHILD(node_signed_data, node_idlist);
    }
    
    uint32_t buf_len = cal_node_length(&node_signed_data) + TLSIZE;
    uint8_t* buf = (uint8_t*)tzwMalloc(buf_len);
    if (buf == NULL) {
        LOG_E("tzwMalloc for node_krd failed");
        return IFAA_ERR_OUT_OF_MEM;
    }
    construct_node_buf(&node_signed_data, buf, buf_len);

    uint8_t sig[256] = {0};
    uint32_t sig_len = sizeof(sig);
    IFAA_Result ret = signer((char*)auth_token->buf, auth_token->len, node_signed_data.value, node_signed_data.length, sig, &sig_len, key);
    
    do {
        if (ret != IFAA_ERR_SUCCESS) {
            LOG_E("rsa_sign_digest failed");
            break;
        }

        CREATE_TLV_NODE(node_root, TAG_AUTH_RESPONSE, 0, NULL);
        CREATE_TLV_NODE(node_algorithm, TAG_SIGN_ALGORITHM, sign_algorithm->len, sign_algorithm->buf);
        CREATE_TLV_NODE(node_sig, TAG_SIGNATURE, sig_len, sig);
        TLV_ADD_CHILD(node_root, node_signed_data);
        TLV_ADD_CHILD(node_root, node_sig);
        TLV_ADD_CHILD(node_root, node_algorithm);

        response->len = cal_node_length(&node_root) + TLSIZE;
        response->buf = (uint8_t*)tzwMalloc(response->len);

        if (response->buf == NULL) {
            tzwFree(buf);
            LOG_E("tzwMalloc for node_root failed");
            ret = IFAA_ERR_OUT_OF_MEM;
            break;
        }
        construct_node_buf(&node_root, response->buf, response->len);
    } while (false);

	tzwFree(buf);
    
    return ret;
}

