/*
 * app_drk.c
 */

#include <comdef.h>
#include <stdio.h>
#include <string.h>

#include <secmath.h> // secmath_BIGINT_read_unsigned_bin
#include <secrsa_err.h> // E_SECMATH_SUCCESS, E_SECMATH_FAILURE
#include <secrsa_padding.h> // S_BIGINT_POS
#include <qsee_heap.h> // qsee_malloc, qsee_free
#include <qsee_rsa.h> // QSEE_RSA_KEY, QSEE_RSA_KEY_PRIVATE, QSEE_S_BIGINT

#include "app_main.h"
#include "app_drk.h"

#include "tz_iccc_comdef.h"

uint8_t key_file_buffer[ICCC_MAX_KEY_BUF];
uint32_t key_file_buffer_len;

// ICCC certs
uint8_t *iccc_certs[MAX_CERTS];
uint32_t iccc_certs_len[MAX_CERTS];
uint8_t iccc_certs_num = 0;

/*
 * Parse private key and retrieve one integer
 * In:
 *  key: the key buffer
 *  index: point to the current index to be parsed.
 * Outt:
 *  index: the new index to the next integer
 *  output: the output buffer
 *  len: the length of output buffer
 *
 */
iccc_error_code_t parse_key_int(uint8_t *key, uint32_t key_buffer_len, uint32_t *index, uint8_t **output, uint32_t *len)
{
    uint32_t i, len_size;
    i = *index;
    if (i + 2 >= key_buffer_len) {
        ICCC_LOG("TZ_ICCC: parse_key_int failed, key_buffer_len exceeded");
        return ICCC_KEY_ERROR;
    }

    if (key[i++] != 0x02) {
        // 0x02 means integer in ASN1
        ICCC_LOG("TZ_ICCC: Not 0x02: i=%d", i);
        return ICCC_KEY_ERROR;
    }
    // Get the length of integer
    if (key[i] <= 127) {
        *len = key[i++];
    } else {
        if (!CHECK_UINT_BEFORE_SUB(key[i], 128)) {
            ICCC_LOG("TZ_ICCC: parse_key_int failed, key_buffer_len under exceeded");
            return ICCC_KEY_ERROR;
        }
        len_size = key[i++] - 128;
        *len = 0;
        //ICCC_LOG("TZ_ICCC: len_size=%d", len_size);
        while (len_size > 0) {
            //ICCC_LOG("TZ_ICCC: i=%d, key[i]=%d", i, key[i]);
            if (i + 1 >= key_buffer_len) {
                ICCC_LOG("TZ_ICCC: parse_key_int failed, key_buffer_len exceeded");
                return ICCC_KEY_ERROR;
            }
            if (!CHECK_UINT_BEFORE_ADD((*len << 8), key[i])) {
                ICCC_LOG("TZ_ICCC: parse_key_int failed, uint exceeded");
                return ICCC_KEY_ERROR;
            }
            *len = (*len << 8) + key[i++];
            len_size--;
        }
    }
    // Skip the leading 0 if any. Leading zero means positive
    if (key[i] == 0x0) {
        if (i + 1 >= key_buffer_len) {
            ICCC_LOG("TZ_ICCC: parse_key_int failed, key_buffer_len exceeded");
            return ICCC_KEY_ERROR;
        }
        i++;
        (*len)--;
    }

    //ICCC_LOG("TZ_ICCC: parse int len=%d, i=%d", *len, i);

    // does not need to check this sum as i < key_buffer_len
    *output = key + i;
    if (!CHECK_UINT_BEFORE_ADD(i, *len)) {
        ICCC_LOG("TZ_ICCC: parse_key_int failed, uint exceeded");
        return ICCC_KEY_ERROR;
    }
    *index = i + *len;
    return ICCC_SUCCESS;
}

/*
 * Retrieve the private key modulus and exponent
 * NOTE: here we are not writing a complete DER parser. Since we know
 * the public key length, we simply find the private key (n, e, d) part and return.
 *
 * RSAPrivateKey ::= SEQUENCE {
 *  version           Version,
 *  modulus           INTEGER,  -- n
 *  publicExponent    INTEGER,  -- e
 *  privateExponent   INTEGER,  -- d
 *  prime1            INTEGER,  -- p
 *  prime2            INTEGER,  -- q
 *  exponent1         INTEGER,  -- d mod (p1)
 *  exponent2         INTEGER,  -- d mod (q-1)
 *  coefficient       INTEGER,  -- (inverse of q) mod p
 *  otherPrimeInfos   OtherPrimeInfos OPTIONAL
 */
iccc_error_code_t parse_private_key(uint8_t *private_key, uint32_t private_key_len, uint32_t index,
                                    uint8_t **modulus, uint32_t *modulus_len, uint8_t **priv_expo,
                                    uint32_t *priv_expo_len, uint8_t **pub_expo,
                                    uint32_t *pub_expo_len)
{
    uint32_t i = 0;
    if (7 >= private_key_len) {
        ICCC_LOG("TZ_ICCC: parse_private_key failed, private_key_len exceeded");
        return ICCC_KEY_ERROR;
    }
    if (private_key[i] != 0x30 || private_key[i + 1] != 0x82 || private_key[i + 4] != 0x02 ||
        private_key[i + 5] != 0x01) {
        ICCC_LOG("TZ_ICCC: Private key must start with {0x30,0x82,xx,xx,0x02,0x01}");
        ICCC_LOG("TZ_ICCC: actually start with {0x%02x,0x%02x,0x%02x,0x%02x,0x%02x,0x%02x}", private_key[i], private_key[i + 1], private_key[i + 2], private_key[i + 3], private_key[i + 4], private_key[i + 5]);
        return ICCC_KEY_ERROR;
    }
    i += 7; // point to modulus

    if (parse_key_int(private_key, private_key_len, &i, modulus, modulus_len)) {
        ICCC_LOG("TZ_ICCC: Modulus error");
        return ICCC_KEY_ERROR;
    }
    // i now point to publicExponent, most likely 65535
    if (parse_key_int(private_key, private_key_len, &i, pub_expo, pub_expo_len)) {
        ICCC_LOG("TZ_ICCC: Public key exponent error, i = %d, len = %d", i, *pub_expo_len);
        return ICCC_KEY_ERROR;
    }
    // i now point to private exponent
    if (parse_key_int(private_key, private_key_len, &i, priv_expo, priv_expo_len)) {
        ICCC_LOG("TZ_ICCC: Private key exponent error, i = %d, len = %d", i, *priv_expo_len);
        return ICCC_KEY_ERROR;
    }

    return ICCC_SUCCESS;
}

/*
 * Parse the key file, get private key and public key certificate
 * key file format:
 *
 * item_type (1 byte) | length (2 bytes big endian) | data (length bytes)
 * item_type:
 * 	0x01: RSA cert
 * 	0x03: RSA private key
 */
iccc_error_code_t parse_iccc_key_file(uint8_t *key_file, uint32_t key_file_len, uint8_t **modulus,
                                      uint32_t *modulus_len, uint8_t **priv_expo,
                                      uint32_t *priv_expo_len, uint8_t **pub_expo,
                                      uint32_t *pub_expo_len)
{
    uint32_t i = 0;
    uint32_t len;
    iccc_error_code_t ret;
    uint8_t tl_name[TL_NAME_MAX_LEN] = {0};

    iccc_certs_num = 0;
    // FIXME : All RSA operations assum Big Endian - Switchng endianness is needed
    while (i < key_file_len) {
        if (i + 3 >= key_file_len) {
            if (i + 1 == key_file_len) {
                // end of key_file reached
                return ICCC_SUCCESS;
            }
            ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, key_file_len exceeded i = %u key_file_len = %u", i, key_file_len);
            return ICCC_KEY_ERROR;
        }
        if (key_file[i] == ICCC_KEY_TYPE_RSA_CERT) {
            // public key certs
            len = key_file[++i];
            if (!CHECK_UINT_BEFORE_ADD(len, key_file[i+1] << 8)) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, uint exceeded");
                return ICCC_KEY_ERROR;
            }
            len += key_file[++i] << 8;
            i++;
            if (iccc_certs_num >= MAX_CERTS) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, iccc_certs_num exceeded");
                return ICCC_KEY_ERROR;
            }

            // does not need to check this sum as i < key_file_len
            iccc_certs[iccc_certs_num] = key_file + i;
            iccc_certs_len[iccc_certs_num++] = len;

            //ICCC_LOG("TZ_ICCC: Cert len = %d", len);
            i += len;
        } else if (key_file[i] == ICCC_KEY_TYPE_RSA_PRIVATE) {
            // private key
            len = key_file[++i];
            if (!CHECK_UINT_BEFORE_ADD(len, key_file[i+1] << 8)) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, uint exceeded");
                return ICCC_KEY_ERROR;
            }
            len += key_file[++i] << 8;
            i++;

            ret = parse_private_key(key_file + i, key_file_len - i, i,
                                    modulus, modulus_len, priv_expo, priv_expo_len, pub_expo, pub_expo_len);
            if (ret) {
                ICCC_LOG("TZ_ICCC: Error parsing private key");
                return ICCC_KEY_ERROR;
            }
            //ICCC_LOG("TZ_ICCC: Private key len = %d", len);

            i += len;
        } else if (key_file[i] == ICCC_KEY_TYPE_TL_NAME) {
            len = key_file[++i];
            len += key_file[++i] << 8;
            if (len > TL_NAME_MAX_LEN) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, len exceeded");
                return ICCC_KEY_ERROR;
            }
            i++;
            memcpy(tl_name, key_file + i, len);
            ICCC_LOG("TZ_ICCC: TL NAME : %s, length : %d", (char *)tl_name, len);
            i += len;
        } else {
            ICCC_LOG("TZ_ICCC: Unknown type in key file: %u", key_file[i]);
            ICCC_LOG("TZ_ICCC: Unknown type in key file index: %d", i);
            len = key_file[i + 1];
            if (!CHECK_UINT_BEFORE_ADD(len, key_file[i+2] << 8)) {
                ICCC_LOG("TZ_ICCC: parse_iccc_key_file failed, uint exceeded");
                return ICCC_KEY_ERROR;
            }
            len += key_file[i + 2] << 8;
            i += 3 + len;
            // Security team added new type. Skip
            // return ICCC_KEY_ERROR;
        }
    }
    return ICCC_SUCCESS;
}

/* Check if memory allocation for QSEE_BigInt is successfuly or not */
int mem_check(QSEE_S_BIGINT *p)
{
    if (!p) {
        ICCC_LOG("TZ_ICCC: No memory for QSEE_S_BIGINT");
        return E_SECMATH_FAILURE;
    } else {
        (void)memset(p, 0, sizeof(QSEE_S_BIGINT));
        return E_SECMATH_SUCCESS;
    }
}

/* Read ICCC key, parse the file and populate the RSA key data */
iccc_error_code_t load_iccc_key(QSEE_RSA_KEY *iccc_key)
{
    iccc_error_code_t ret = ICCC_SUCCESS;

    SECMATH_ERRNO_ET secmath_ret = E_SECMATH_SUCCESS;
    QSEE_BigInt a_key_N, a_key_d, a_key_e;

    uint8 *modulus;
    unsigned int modulus_len;
    uint8 *priv_expo;
    unsigned int priv_expo_len;
    uint8 *pub_expo;
    unsigned int pub_expo_len;

    ret = parse_iccc_key_file(key_file_buffer, key_file_buffer_len, &modulus, &modulus_len, &priv_expo, &priv_expo_len, &pub_expo, &pub_expo_len);
    if (ret) {
        return ret;
    }

    // Create the private key
    iccc_key->type = QSEE_RSA_KEY_PRIVATE;
    iccc_key->bitLength = 2048;

    iccc_key->p = NULL;
    iccc_key->q = NULL;
    iccc_key->qP = NULL;
    iccc_key->dP = NULL;
    iccc_key->dQ = NULL;

    iccc_key->N = (QSEE_S_BIGINT *)qsee_malloc(sizeof(QSEE_S_BIGINT));
    iccc_key->d = (QSEE_S_BIGINT *)qsee_malloc(sizeof(QSEE_S_BIGINT));
    iccc_key->e = (QSEE_S_BIGINT *)qsee_malloc(sizeof(QSEE_S_BIGINT));

    if (mem_check(iccc_key->N) != E_SECMATH_SUCCESS || mem_check(iccc_key->e) != E_SECMATH_SUCCESS || mem_check(iccc_key->d) != E_SECMATH_SUCCESS) {
        ICCC_LOG("TZ_ICCC: Cannot alloc mem for private key");
        return ICCC_KEY_ERROR;
    }

    if (secmath_BIGINT_read_unsigned_bin((BigInt *)&a_key_N, modulus, modulus_len) != E_SECMATH_SUCCESS) {
        ICCC_LOG("TZ_ICCC: Failed to load modulus: %d", secmath_ret);
        return ICCC_KEY_ERROR;
    }
    if (secmath_BIGINT_read_unsigned_bin((BigInt *)&a_key_d, priv_expo, priv_expo_len) != E_SECMATH_SUCCESS) {
        ICCC_LOG("TZ_ICCC: Failed to load private exponent: %d", secmath_ret);
        return ICCC_KEY_ERROR;
    }
    if (secmath_BIGINT_read_unsigned_bin((BigInt *)&a_key_e, pub_expo, pub_expo_len) != E_SECMATH_SUCCESS) {
        ICCC_LOG("TZ_ICCC: Failed to load public exponent: %d", secmath_ret);
        return ICCC_KEY_ERROR;
    }

    (iccc_key->N)->bi = a_key_N;
    (iccc_key->N)->sign = S_BIGINT_POS;
    (iccc_key->d)->bi = a_key_d;
    (iccc_key->d)->sign = S_BIGINT_POS;
    (iccc_key->e)->bi = a_key_e;
    (iccc_key->e)->sign = S_BIGINT_POS;

    return ICCC_SUCCESS;
}

void free_key(QSEE_RSA_KEY *rsa_key)
{
    if (rsa_key->N)
        qsee_free(rsa_key->N);
    if (rsa_key->d)
        qsee_free(rsa_key->d);
    if (rsa_key->e)
        qsee_free(rsa_key->e);
}
