/*
 * app_drk.c
 */

#include <tee_internal_api.h>
#include <tees_secure_object.h>

#include <string.h>

#include "app_main.h"
#include "app_core.h"
#include "app_drk.h"

#include "icccOperations_v4.h"

tz_iccc_rsakey_t rsakey;
uint8_t srctmpData[ICCC_MAX_KEY_BUF];
uint8_t dsttmpData[ICCC_MAX_KEY_BUF];

// ICCC certs
uint8_t *iccc_certs[MAX_CERTS];
uint32_t iccc_certs_len[MAX_CERTS];
uint8_t iccc_certs_num = 0;

/*
 * Parse private key and retrieve one integer
 * In:
 *  key: the key buffer
 *  index: point to the current index to be parsed.
 * Outt:
 *  index: the new index to the next integer
 *  output: the output buffer
 *  len: the length of output buffer
 *
 */
iccc_error_code_t parse_key_int(uint8_t *key, uint32_t key_buffer_len, uint32_t *index, uint8_t **output, uint32_t *len)
{
    uint32_t i, len_size;
    i = *index;
    if (i + 2 >= key_buffer_len) {
        ICCC_LOG("TZ_ICCC: parse_key_int failed, key_buffer_len exceeded");
        return ICCC_KEY_ERROR;
    }

    if (key[i++] != 0x02) {
        // 0x02 means integer in ASN1
        ICCC_LOG("TZ_ICCC: Not 0x02: i=%d", i);
        return ICCC_KEY_ERROR;
    }
    // Get the length of integer
    if (key[i] <= 127) {
        *len = key[i++];
    } else {
        if (!CHECK_UINT_BEFORE_SUB(key[i], 128)) {
            ICCC_LOG("TZ_ICCC: parse_key_int failed, key_buffer_len under exceeded");
            return ICCC_KEY_ERROR;
        }
        len_size = key[i++] - 128;
        *len = 0;
        //ICCC_LOG("TZ_ICCC: len_size=%d", len_size);
        while (len_size > 0) {
            //ICCC_LOG("TZ_ICCC: i=%d, key[i]=%d", i, key[i]);
            if (i + 1 >= key_buffer_len) {
                ICCC_LOG("TZ_ICCC: parse_key_int failed, key_buffer_len exceeded");
                return ICCC_KEY_ERROR;
            }
            if (!CHECK_UINT_BEFORE_ADD((*len << 8), key[i])) {
                ICCC_LOG("TZ_ICCC: parse_key_int failed, uint exceeded");
                return ICCC_KEY_ERROR;
            }
            *len = (*len << 8) + key[i++];
            len_size--;
        }
    }
    // Skip the leading 0 if any. Leading zero means positive
    if (key[i] == 0x0) {
        if (i + 1 >= key_buffer_len) {
            ICCC_LOG("TZ_ICCC: parse_key_int failed, key_buffer_len exceeded");
            return ICCC_KEY_ERROR;
        }
        i++;
        (*len)--;
    }

    //ICCC_LOG("TZ_ICCC: parse int len=%d, i=%d", *len, i);

    // does not need to check this sum as i < key_buffer_len
    *output = key + i;
    if (!CHECK_UINT_BEFORE_ADD(i, *len)) {
        ICCC_LOG("TZ_ICCC: parse_key_int failed, uint exceeded");
        return ICCC_KEY_ERROR;
    }
    *index = i + *len;
    return ICCC_SUCCESS;
}

/*
 * Retrieve the private key modulus and exponent
 * NOTE: here we are not writing a complete DER parser. Since we know
 * the public key length, we simply find the private key (n, e, d) part and return.
 *
 * RSAPrivateKey ::= SEQUENCE {
 *  version           Version,
 *  modulus           INTEGER,  -- n
 *  publicExponent    INTEGER,  -- e
 *  privateExponent   INTEGER,  -- d
 *  prime1            INTEGER,  -- p
 *  prime2            INTEGER,  -- q
 *  exponent1         INTEGER,  -- d mod (p1)
 *  exponent2         INTEGER,  -- d mod (q-1)
 *  coefficient       INTEGER,  -- (inverse of q) mod p
 *  otherPrimeInfos   OtherPrimeInfos OPTIONAL
 */
iccc_error_code_t parse_private_key(uint8_t *private_key, uint32_t private_key_len, uint32_t index,
                                    uint8_t **modulus, uint32_t *modulus_len, uint8_t **priv_expo,
                                    uint32_t *priv_expo_len, uint8_t **pub_expo,
                                    uint32_t *pub_expo_len)
{
    uint32_t i = 0;
    if (7 >= private_key_len) {
        ICCC_LOG("TZ_ICCC: parse_private_key failed, private_key_len exceeded");
        return ICCC_KEY_ERROR;
    }
    if (private_key[i] != 0x30 || private_key[i + 1] != 0x82 || private_key[i + 4] != 0x02 ||
        private_key[i + 5] != 0x01) {
        ICCC_LOG("TZ_ICCC: Private key must start with {0x30,0x82,xx,xx,0x02,0x01}");
        ICCC_LOG("TZ_ICCC: actually start with {0x%02x,0x%02x,0x%02x,0x%02x,0x%02x,0x%02x}", private_key[i], private_key[i + 1], private_key[i + 2], private_key[i + 3], private_key[i + 4], private_key[i + 5]);
        return ICCC_KEY_ERROR;
    }
    i += 7; // point to modulus

    if (parse_key_int(private_key, private_key_len, &i, modulus, modulus_len)) {
        ICCC_LOG("TZ_ICCC: Modulus error");
        return ICCC_KEY_ERROR;
    }
    // i now point to publicExponent, most likely 65535
    if (parse_key_int(private_key, private_key_len, &i, pub_expo, pub_expo_len)) {
        ICCC_LOG("TZ_ICCC: Public key exponent error, i = %d, len = %d", i, *pub_expo_len);
        return ICCC_KEY_ERROR;
    }
    // i now point to private exponent
    if (parse_key_int(private_key, private_key_len, &i, priv_expo, priv_expo_len)) {
        ICCC_LOG("TZ_ICCC: Private key exponent error, i = %d, len = %d", i, *priv_expo_len);
        return ICCC_KEY_ERROR;
    }

    return ICCC_SUCCESS;
}

/*
 * Parse the key file, get private key and public key certificate
 * key file format:
 *
 * item_type (1 byte) | length (2 bytes big endian) | data (length bytes)
 * item_type:
 * 	0x01: RSA cert
 * 	0x03: RSA private key
 */
iccc_error_code_t parse_iccc_key_file(uint8_t *key_file, uint32_t key_file_len, uint8_t **modulus,
                                      uint32_t *modulus_len, uint8_t **priv_expo,
                                      uint32_t *priv_expo_len, uint8_t **pub_expo,
                                      uint32_t *pub_expo_len)
{
    uint32_t i = 0;
    uint32_t len;
    iccc_error_code_t ret;
    uint8_t tl_name[TL_NAME_MAX_LEN] = {0};

    iccc_certs_num = 0;
    // FIXME : All RSA operations assum Big Endian - Switchng endianness is needed
    while (i < key_file_len) {
        if (i + 3 >= key_file_len) {
            if (i + 1 == key_file_len) {
                // end of key_file reached
                return ICCC_SUCCESS;
            }
            ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, key_file_len exceeded i = %u key_file_len = %u", i, key_file_len);
            return ICCC_KEY_ERROR;
        }
        if (key_file[i] == ICCC_KEY_TYPE_RSA_CERT) {
            // public key certs
            len = key_file[++i];
            if (!CHECK_UINT_BEFORE_ADD(len, key_file[i+1] << 8)) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, uint exceeded");
                return ICCC_KEY_ERROR;
            }
            len += key_file[++i] << 8;
            i++;
            if (iccc_certs_num >= MAX_CERTS) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, iccc_certs_num exceeded");
                return ICCC_KEY_ERROR;
            }

            // does not need to check this sum as i < key_file_len
            iccc_certs[iccc_certs_num] = key_file + i;
            iccc_certs_len[iccc_certs_num++] = len;

            //ICCC_LOG("TZ_ICCC: Cert len = %d", len);
            i += len;
        } else if (key_file[i] == ICCC_KEY_TYPE_RSA_PRIVATE) {
            // private key
            len = key_file[++i];
            if (!CHECK_UINT_BEFORE_ADD(len, key_file[i+1] << 8)) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, uint exceeded");
                return ICCC_KEY_ERROR;
            }
            len += key_file[++i] << 8;
            i++;

            ret = parse_private_key(key_file + i, key_file_len - i, i,
                                    modulus, modulus_len, priv_expo, priv_expo_len, pub_expo, pub_expo_len);
            if (ret) {
                ICCC_LOG("TZ_ICCC: Error parsing private key");
                return ICCC_KEY_ERROR;
            }
            //ICCC_LOG("TZ_ICCC: Private key len = %d", len);

            i += len;
        } else if (key_file[i] == ICCC_KEY_TYPE_TL_NAME) {
            len = key_file[++i];
            len += key_file[++i] << 8;
            if (len > TL_NAME_MAX_LEN) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, len exceeded");
                return ICCC_KEY_ERROR;
            }
            i++;
            TEE_MemMove(tl_name, key_file + i, len);
            ICCC_LOG("TZ_ICCC: TL NAME : %s, length : %d", (char *)tl_name, len);
            i += len;
        } else {
            ICCC_LOG("TZ_ICCC: Unknown type in key file: %u", key_file[i]);
            ICCC_LOG("TZ_ICCC: Unknown type in key file index: %d", i);
            len = key_file[i + 1];
            if (!CHECK_UINT_BEFORE_ADD(len, key_file[i+2] << 8)) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, uint exceeded");
                return ICCC_KEY_ERROR;
            }
            len += key_file[i + 2] << 8;
            i += 3 + len;
            // Security team added new type. Skip
            // return ICCC_KEY_ERROR;
        }
    }
    return ICCC_SUCCESS;
}

/* Read ICCC key, parse the file and populate the RSA key data */
iccc_error_code_t load_iccc_key(uint8_t *key_file_buffer, uint32_t len)
{
    iccc_error_code_t ret = ICCC_SUCCESS;

    // parse the key file and extract data
    memset((void*)&rsakey, 0, sizeof(tz_iccc_rsakey_t));
    ret = parse_iccc_key_file(key_file_buffer, len,
                              &rsakey.modulus, &rsakey.modulus_len,
                              &rsakey.priv_expo, &rsakey.priv_expo_len,
                              &rsakey.pub_expo, &rsakey.pub_expo_len);

    return ret;
}

iccc_error_code_t unwrap(uint8_t *blob_ptr, uint32_t blob_len, uint32_t *dest_len)
{
    iccc_error_code_t ret = ICCC_KEY_ERROR;
    TEE_Result unwrap_ret;

    if (blob_len == 0 || blob_len > sizeof(dsttmpData)) {
        ret = ICCC_KEY_ERROR;
        goto exit;
    }

    TEE_MemMove(srctmpData, blob_ptr, blob_len);

    SO_AccessControlInfoType ac_info;

    TEEC_UUID creator_uuid = (TEEC_UUID)TA_SKM_UUID; 
    TEE_MemFill(&ac_info, 0, sizeof ac_info);

    TEE_MemMove(&ac_info.ta_id, &creator_uuid, sizeof(TEEC_UUID));
    TEE_MemMove(&ac_info.auth_id, TA_SKM_TA_AUTH_CRYPTOSUITE, strlen(TA_SKM_TA_AUTH_CRYPTOSUITE));
    ac_info.access_flags = DELEGATED_TA_ID_AC;

    unwrap_ret = TEES_CheckSecureObjectCreator(srctmpData, blob_len, &ac_info); 
    if (unwrap_ret != TEE_SUCCESS) {
        ICCC_LOG_DEBUG("TZ_ICCC: TEES_CheckSecureObjectCreator failed with ret = 0x%08X, exit", unwrap_ret);
        ret = ICCC_KEY_ERROR;
        goto exit;
    }

    unwrap_ret = TEES_UnwrapSecureObject(srctmpData, blob_len, dsttmpData, dest_len);
    if (unwrap_ret != TEE_SUCCESS) {
        ICCC_LOG_DEBUG("TZ_ICCC: TEES_UnwrapSecureObject failed with ret = 0x%08X, exit", unwrap_ret);
        ret = ICCC_KEY_ERROR;
        goto exit;
    }

    if (*dest_len > ICCC_MAX_KEY_BUF) {
        ICCC_LOG_DEBUG("TZ_ICCC: Key Bigger than space.");
        ret = ICCC_KEY_ERROR;
        goto exit;
    }

    ret = ICCC_SUCCESS;

exit:
    return ret;
}
