
/*
 * =====================================================================================
 *
 *       Filename:  pebble_drk.c
 *
 *    Description:  PEBBLE DRK manipulation
 *
 *        Version:  1.0
 *        Created:  06/03/2020
 *       Revision:  none
 *       Compiler:  gcc
 *
 *        Company:  Samsung Electronics
 *        Copyright (c) 2020 by Samsung Electronics, All rights reserved.
 *
 * =====================================================================================
 */

/** Includes */
#include "pebble_drk.h"

/**
 * Static functions prototypes
 */
static pebble_return_code_t parseX509_find_DRKV2(uint8_t *p_certificate, int certificate_len, char *b64_IMEI_SERIAL);

static pebble_return_code_t parse_key_int(uint8_t *key, uint32_t key_buffer_len,
                                       uint32_t *index, uint8_t **output, uint32_t *len);

static pebble_return_code_t parse_private_key(uint8_t *private_key, uint32_t private_key_len, uint32_t index,
                                           uint8_t **modulus, uint32_t *modulus_len, uint8_t **priv_expo,
                                           uint32_t *priv_expo_len, uint8_t **pub_expo,
                                           uint32_t *pub_expo_len);


/**
 * @brief
 * get_b64_hash
 * Parse Unwrap object(DRK) and get b64 from DRK (b64( H( H(IMEI1|IMEI2) | H(SERIAL))))
 *
 * @param[in]      *unwrap_object   - pointer to DRK unwrap object
 * @param[in]       unwrap_len      - length of unwrap object
 * @param[in/out]   b64_IMEI_SERIAL - pointer to b64_IMEI_SERIAL / b64_IMEI_SERIAL value
 *
 * @return PEBBLE status code
 */
pebble_return_code_t get_b64_hash(uint8_t *unwrap_object, int unwrap_len, char *b64_IMEI_SERIAL) {
        PEBBLE_LOG_DEBUG("get_b64_hash()");

        pebble_return_code_t ret = PEBBLE_STATUS_SUCCESS;
        int i = 0;

        while (i < unwrap_len) {
                if (i + 3 >= unwrap_len) {
                        if (i + 1 == unwrap_len) {
                                // end of key_file reached
                                return PEBBLE_STATUS_SUCCESS;
                        }

                        PEBBLE_LOG_DEBUG("parse_cert failed, unwrap_len exceeded i = %u unwrap_len = %u", i, unwrap_len);
                        return PEBBLE_KEY_ERROR;
                }

                if (unwrap_object[i] == DRK_CERT_KEY_TYPE_RSA_CERT) {
                        int cert_len = unwrap_object[++i];

                        if (!CHECK_UINT_BEFORE_ADD(cert_len, unwrap_object[i + 1] << 8)) {
                                PEBBLE_LOG_DEBUG("parse_key_file failed, uint exceeded");
                                return PEBBLE_KEY_ERROR;
                        }

                        cert_len += unwrap_object[++i] << 8;
                        i++;

                        if (b64_IMEI_SERIAL[0] == 0) {
                                parseX509_find_DRKV2(unwrap_object + i, cert_len, b64_IMEI_SERIAL);
                        }

                        i += cert_len;

                } else if (unwrap_object[i] == DRK_CERT_KEY_TYPE_RSA_PRIVATE) {
                        // private key
                        int pvtkey_len = unwrap_object[++i];
                        if (!CHECK_UINT_BEFORE_ADD(pvtkey_len, unwrap_object[i + 1] << 8)) {
                                PEBBLE_LOG_DEBUG("parse_cert failed, uint exceeded");
                                return PEBBLE_KEY_ERROR;
                        }

                        pvtkey_len += unwrap_object[++i] << 8;
                        i++;
                        i += pvtkey_len;

                } else if (unwrap_object[i] == DRK_CERT_KEY_TYPE_TL_NAME) {
                        int tlname_len = unwrap_object[++i];
                        tlname_len += unwrap_object[++i] << 8;
                        i += tlname_len;

                } else {
                        PEBBLE_LOG_DEBUG("Unknown type in key file: %u", unwrap_object[i]);
                        PEBBLE_LOG_DEBUG("Unknown type in key file index: %d", i);
                        int tag_unknow_len = unwrap_object[i + 1];
                        tag_unknow_len += unwrap_object[i + 2] << 8;
                        i += 2 + tag_unknow_len;
                }
        }

        return ret;
}

/**
 * @brief
 * get_cert_chain_rsakey
 * Parse Unwrap object(DRK) and get certificate chain and pvt key from DRK
 *
 * @param[in]      *unwrap_object     - pointer to DRK unwrap object
 * @param[in]       unwrap_len        - length of unwrap object
 * @param[in/out]   drk_parsed_object - pointer to drk_parsed_object / drk_parsed_object value
 *
 * @return PEBBLE status code
 */
pebble_return_code_t get_cert_chain_rsakey(uint8_t *unwrap_object, int unwrap_len,
                                                drk_parsed_object_t *drk_parsed_object) {
        PEBBLE_LOG_DEBUG("get_cert_chain_rsakey()");
        
        pebble_return_code_t ret = PEBBLE_STATUS_SUCCESS;
        int i = 0;

        drk_parsed_object->drk_cert_chain.num_certs = 0;

        while (i < unwrap_len) {
                if (i + 3 >= unwrap_len) {
                        if (i + 1 == unwrap_len) {
                                // end of key_file reached
                                return PEBBLE_STATUS_SUCCESS;
                        }

                        PEBBLE_LOG_DEBUG("parse_cert failed, unwrap_len exceeded i = %u unwrap_len = %u", i, unwrap_len);
                        return PEBBLE_KEY_ERROR;
                }

                if (unwrap_object[i] == DRK_CERT_KEY_TYPE_RSA_CERT) {
                        int cert_len = unwrap_object[++i];

                        if (!CHECK_UINT_BEFORE_ADD(cert_len, unwrap_object[i + 1] << 8)) {
                                PEBBLE_LOG_DEBUG("parse_key_file failed, uint exceeded");
                                return PEBBLE_KEY_ERROR;
                        }

                        cert_len += unwrap_object[++i] << 8;
                        i++;

                        if (drk_parsed_object->drk_cert_chain.num_certs >= CERTS_MAX_NUM) {
                                PEBBLE_LOG_DEBUG("get_cert_chain_rsakey failed, num_certificates exceeded");
                                return PEBBLE_KEY_ERROR;
                        }

                        if (cert_len > CERT_MAX_LEN) {
                                PEBBLE_LOG_DEBUG("get_cert_chain_rsakey failed, cert_len exceeded");
                                return PEBBLE_KEY_ERROR;
                        }

                        //TODO: check stack memory, maybe just to use a pointer to wrap object, move value is not necessary
                        TEE_MemMove(drk_parsed_object->drk_cert_chain.cert[drk_parsed_object->drk_cert_chain.num_certs].blob, 
                                        unwrap_object + i, cert_len);
                        drk_parsed_object->drk_cert_chain.cert[drk_parsed_object->drk_cert_chain.num_certs++].len = cert_len;

                        i += cert_len;

                } else if (unwrap_object[i] == DRK_CERT_KEY_TYPE_RSA_PRIVATE) {
                        // private key
                        int pvtkey_len = unwrap_object[++i];

                        if (!CHECK_UINT_BEFORE_ADD(pvtkey_len, unwrap_object[i + 1] << 8)) {
                                PEBBLE_LOG_DEBUG("parse_cert failed, uint exceeded");
                                return PEBBLE_KEY_ERROR;
                        }

                        pvtkey_len += unwrap_object[++i] << 8;
                        i++;

                        ret = parse_private_key(unwrap_object + i, unwrap_len, i,
                                                &drk_parsed_object->drk_rsa_private_key.modulus,
                                                &drk_parsed_object->drk_rsa_private_key.modulus_len,
                                                &drk_parsed_object->drk_rsa_private_key.priv_expo,
                                                &drk_parsed_object->drk_rsa_private_key.priv_expo_len,
                                                &drk_parsed_object->drk_rsa_private_key.pub_expo,
                                                &drk_parsed_object->drk_rsa_private_key.pub_expo_len);
                        if (ret) {
                                PEBBLE_LOG_DEBUG("Error parsing private key");
                                return PEBBLE_KEY_ERROR;
                        }

                        i += pvtkey_len;

                } else if (unwrap_object[i] == DRK_CERT_KEY_TYPE_TL_NAME) {
                        int tlname_len = unwrap_object[++i];
                        tlname_len += unwrap_object[++i] << 8;
                        i += tlname_len;

                } else {
                        PEBBLE_LOG_DEBUG("Unknown type in key file: %u", unwrap_object[i]);
                        PEBBLE_LOG_DEBUG("Unknown type in key file index: %d", i);
                        int tag_unknow_len = unwrap_object[i + 1];
                        tag_unknow_len += unwrap_object[i + 2] << 8;
                        i += 2 + tag_unknow_len;
                }
        }
        return ret;
}

/**
 * @brief
 * parse_key_int
 * Parse private key and retrieve one integer.
 *
 * @param[in]      *key            - key buffer
 * @param[in]       key_buffer_len - key buffer length
 * @param[in/out]   index          - point to the current index to be parsed / the new index to the next integer
 * @param[out]    **output         - output buffer
 * @param[out]      len            - output buffer length
 *
 * @return PEBBLE status code
 */
static pebble_return_code_t parse_key_int(uint8_t *key, uint32_t key_buffer_len,
                                uint32_t *index, uint8_t **output, uint32_t *len) {

        PEBBLE_LOG_DEBUG("parse_key_int()");

        pebble_return_code_t ret = PEBBLE_STATUS_SUCCESS;
        uint32_t i;
        uint32_t len_size;

        i = *index;
        if (i + 2 >= key_buffer_len) {
                PEBBLE_LOG_DEBUG("parse_key_int failed, key_buffer_len exceeded");
                ret = PEBBLE_KEY_ERROR;
                goto exit;
        }

        if (key[i++] != 0x02) {
                // 0x02 means integer in ASN1
                PEBBLE_LOG_DEBUG("Not 0x02: i=%d", i);
                ret = PEBBLE_KEY_ERROR;
                goto exit;
        }

        // Get the length of integer
        if (key[i] <= 127) {
                *len = key[i++];

        } else {
                if (!CHECK_UINT_BEFORE_SUB(key[i], 128)) {
                        PEBBLE_LOG_DEBUG("parse_key_int failed, key_buffer_len under exceeded");
                        ret = PEBBLE_KEY_ERROR;
                        goto exit;
                }

                len_size = key[i++] - 128;
                *len = 0;

                while (len_size > 0) {

                        if (i + 1 >= key_buffer_len) {
                                PEBBLE_LOG_DEBUG("parse_key_int failed, key_buffer_len exceeded");
                                ret = PEBBLE_KEY_ERROR;
                                goto exit;
                        }

                        if (!CHECK_UINT_BEFORE_ADD((*len << 8), key[i])) {
                                PEBBLE_LOG_DEBUG("parse_key_int failed, uint exceeded");
                                ret = PEBBLE_KEY_ERROR;
                                goto exit;
                        }

                        *len = (*len << 8) + key[i++];
                        len_size--;
                }
        }

        // Skip the leading 0 if any. Leading zero means positive
        if (key[i] == 0x0) {

                if (i + 1 >= key_buffer_len) {
                        PEBBLE_LOG_DEBUG("parse_key_int failed, key_buffer_len exceeded");
                        ret = PEBBLE_KEY_ERROR;
                        goto exit;
                }

                i++;
                (*len)--;
        }

        // does not need to check this sum as i < key_buffer_len
        *output = key + i;
        if (!CHECK_UINT_BEFORE_ADD(i, *len)) {
                PEBBLE_LOG_DEBUG("parse_key_int failed, uint exceeded");
                ret = PEBBLE_KEY_ERROR;
                goto exit;
        }

        *index = i + *len;
        ret = PEBBLE_STATUS_SUCCESS;

exit:
        return ret;
}

/**
 * @brief
 * parse_private_key
 *
 * Retrieve the private key modulus and exponent
 * NOTE: here we are not writing a complete DER parser. Since we know
 * the public key length, we simply find the private key (n, e, d) part and return.
 *
 * RSAPrivateKey ::= SEQUENCE {
 *  version           Version,
 *  modulus           INTEGER,  -- n
 *  publicExponent    INTEGER,  -- e
 *  privateExponent   INTEGER,  -- d
 *  prime1            INTEGER,  -- p
 *  prime2            INTEGER,  -- q
 *  exponent1         INTEGER,  -- d mod (p1)
 *  exponent2         INTEGER,  -- d mod (q-1)
 *  coefficient       INTEGER,  -- (inverse of q) mod p
 *  otherPrimeInfos   OtherPrimeInfos OPTIONAL
 *
 * @return PEBBLE status code
 */
static pebble_return_code_t parse_private_key(uint8_t *private_key, uint32_t private_key_len, uint32_t index,
                                           uint8_t **modulus, uint32_t *modulus_len, uint8_t **priv_expo,
                                           uint32_t *priv_expo_len, uint8_t **pub_expo,
                                           uint32_t *pub_expo_len) {
        
        PEBBLE_LOG_DEBUG("parse_private_key()");

        pebble_return_code_t ret = PEBBLE_STATUS_SUCCESS;
        uint32_t i = 0;

        if (i + index + 7 >= private_key_len) {
                PEBBLE_LOG_DEBUG("parse_private_key failed, private_key_len exceeded");
                ret = PEBBLE_KEY_ERROR;
                goto exit;
        }

        if (private_key[i] != 0x30 || private_key[i + 1] != 0x82 || private_key[i + 4] != 0x02 ||
                    private_key[i + 5] != 0x01) {
                PEBBLE_LOG_DEBUG("Private key must start with {0x30,0x82,xx,xx,0x02,0x01}");
                PEBBLE_LOG_DEBUG("actually start with {0x%02x,0x%02x,0x%02x,0x%02x,0x%02x,0x%02x}", private_key[i],
                              private_key[i + 1], private_key[i + 2], private_key[i + 3], private_key[i + 4],
                              private_key[i + 5]);

                ret = PEBBLE_KEY_ERROR;
                goto exit;
        }

        i += 7; // point to modulus

        if (parse_key_int(private_key, private_key_len, &i, modulus, modulus_len)) {
                PEBBLE_LOG_DEBUG("Modulus error");
                ret = PEBBLE_KEY_ERROR;
                goto exit;
        }

        // i now point to publicExponent, most likely 65535
        if (parse_key_int(private_key, private_key_len, &i, pub_expo, pub_expo_len)) {
                PEBBLE_LOG_DEBUG("Public key exponent error, i = %d, len=%d", i, *pub_expo_len);
                ret = PEBBLE_KEY_ERROR;
                goto exit;
        }

        // i now point to private exponent
        if (parse_key_int(private_key, private_key_len, &i, priv_expo, priv_expo_len)) {
                PEBBLE_LOG_DEBUG("Private key exponent error, i = %d, len=%d", i, *priv_expo_len);
                ret = PEBBLE_KEY_ERROR;
                goto exit;
        }

        ret = PEBBLE_STATUS_SUCCESS;

exit:
        return ret;
}

/**
 * @brief
 * parseX509_find_DRKV2
 * Parse x509 certificate and search by DRK_V2 to get b64 and populate pointer b64_IMEI_SERIAL
 *
 * @param[in]     *p_certificate   - certificate (DER format)
 * @param[in]      certificate_len - length of certificate
 * @param[in/out] *b64_IMEI_SERIAL - pointer to b64_IMEI_SERIAL / b64( H( H(IMEI1|IMEI2) | H(SERIAL)))
 *
 * @return PEBBLE status code
 */
static pebble_return_code_t parseX509_find_DRKV2(uint8_t *p_certificate, int certificate_len, char *b64_IMEI_SERIAL) {
        PEBBLE_LOG_DEBUG("parseX509_find_DRKV2()");

        pebble_return_code_t ret = PEBBLE_STATUS_SUCCESS;
        BIO *cert = NULL;
        X509 *x509Cert = NULL;

        OpenSSL_add_all_algorithms();
        ERR_load_BIO_strings();

        cert = BIO_new_mem_buf(p_certificate, certificate_len);
        if (cert == NULL) {
                PEBBLE_LOG_DEBUG("ERROR read bio\n");
                ret = PEBBLE_CERT_DRK_FAIL;
                goto exit;
        }

        x509Cert = d2i_X509_bio(cert, NULL);
        if (x509Cert == NULL) {
                PEBBLE_LOG_DEBUG("Error to parse x509!\n");
                ret = PEBBLE_CERT_DRK_FAIL;
                goto exit;
        }

        X509_NAME *subject_tag = X509_get_subject_name(x509Cert);

        for (int subject_field = 0; subject_field < X509_NAME_entry_count(subject_tag); subject_field++) {

                //The value returned is an internal pointer which must not be freed
                X509_NAME_ENTRY *e = X509_NAME_get_entry(subject_tag, subject_field);
                //The value returned is an internal pointer which must not be freed
                ASN1_STRING *d = X509_NAME_ENTRY_get_data(e);
                uint8_t *p_subject_field = (uint8_t *) ASN1_STRING_get0_data(d);

                if (TEE_MemCompare((void *) p_subject_field, DRK_V2_CERTIFICATE_UID_DRK_TAG, strlen(DRK_V2_CERTIFICATE_UID_DRK_TAG)) == 0) {
                        int ini = 0;

                        /*
                         * Splitting on DRK_V2_CERTIFICATE_UID_SEPARATOR
                         */
                        for (int j = 0; j < ASN1_STRING_length(d); ++j) {
                                if (p_subject_field[j] == DRK_V2_CERTIFICATE_UID_SEPARATOR)
                                        ini = j + 1;
                        }

                        //Copy b64 to struct
                        if((strlen((char *) p_subject_field) - ini)  < DEVICE_ID_B64_LEN){
                                ret = PEBBLE_CERT_DRK_FAIL;
                                goto exit;
                        }

                        TEE_MemMove(b64_IMEI_SERIAL, &p_subject_field[ini], DEVICE_ID_B64_LEN);
                        b64_IMEI_SERIAL[DEVICE_ID_B64_LEN] = '\0';
                }
        }
        if (b64_IMEI_SERIAL[0] == 0) {
                PEBBLE_LOG_DEBUG("DRK_V2 was not found on Certificate");
                ret = PEBBLE_CERT_DRK_FAIL;
        }

exit:
        if (x509Cert != NULL) X509_free(x509Cert);
        if (cert != NULL) BIO_free(cert);

        return ret;
}

/**
 * @brief
 * validate_drk_device_id
 *
 * @param[in]      unwrapped_drk_id - unwrapped DRK device id
 * @param[in]      nwd_info         - info from the NWd
 *
 * @return PEBBLE status code
 */
pebble_return_code_t validate_drk_device_id(uint8_t *unwrapped_drk_id, pebble_nwd_info_t nwd_info) {
        uint8_t computed_device_id[DEVICE_ID_B64_LEN+1] = {0,};
        uint32_t len = DEVICE_ID_B64_LEN+1;
        uint32_t ret = PEBBLE_DEVICE_ID_CHECK_FAIL;

        PEBBLE_LOG_DEBUG("validate_drk_device_id()");
        ret = gen_drk_device_id(nwd_info.hash_imei, nwd_info.serial_number, computed_device_id, &len);

        if (ret != PEBBLE_STATUS_SUCCESS || len != DEVICE_ID_B64_LEN) {
                PEBBLE_LOG("Fail to generate device id");
                goto exit;
        }

        if (TEE_MemCompare(unwrapped_drk_id, computed_device_id, DEVICE_ID_B64_LEN) != 0) {
                ret = PEBBLE_DEVICE_ID_CHECK_FAIL;
                goto exit;
        }

        ret = PEBBLE_STATUS_SUCCESS;
exit:
        return ret;
}
