#include "storage.h"

#include "stdlib.h"

#ifndef BORING_SSL
#include <openssl/opensslconf.h>
#else
#include <openssl/dh.h>
#endif

#include <tee_internal_api.h>

#include "openssl/rsa.h"
#ifndef OPENSSL_NO_EC
#include "openssl/ec.h"
#endif
#ifndef OPENSSL_NO_DSA
#include "openssl/dsa.h"
#endif

#include "openssl/bn.h"
#include "openssl/err.h"

#include <misc_defs.h>
#include <gpapi_log.h>
#include <crypto_rand.h>

/* Taken from SCrypto */
#ifndef DSS_prime_checks
  #define DSS_prime_checks 50
#endif /* !DSS_prime_checks */

int ec_determine_nid_by_tee_id(uint32_t ec_tee_curve_id, int *out_nid);

#define ROUND_UP_BITS_TO_BYTE(a) ((a + 7) >> 3)

TEE_Result TEE_GenerateKey(TEE_ObjectHandle object, uint32_t keySize, const TEE_Attribute* params, uint32_t paramCount)
{
    (void)paramCount;
    (void)params;
#ifdef STORAGE_HANDLES_VALIDATION
    if (!in_objects_list(object)) {
        MB_LOGE("Panic Reason: can't find object in list\n");
        TEE_Panic(ID_TEE_GenerateKey);
    }
#endif

    struct TransientObject* tr = &object->tr;

    if (tr->info.handleFlags & TEE_HANDLE_FLAG_INITIALIZED) {
        MB_LOGE("Panic Reason: object already initialized\n");
        TEE_Panic(ID_TEE_GenerateKey);
    }

    if (keySize > tr->info.maxKeySize) {
        MB_LOGE("Panic Reason: key size exceeds max key size\n");
        TEE_Panic(ID_TEE_GenerateKey);
    }

    tr->info.keySize = keySize;

    switch(tr->info.objectType)
    {
    case TEE_TYPE_AES:
    case TEE_TYPE_DES:
    case TEE_TYPE_DES3:
    case TEE_TYPE_HMAC_MD5:
    case TEE_TYPE_HMAC_SHA1:
    case TEE_TYPE_HMAC_SHA224:
    case TEE_TYPE_HMAC_SHA256:
    case TEE_TYPE_HMAC_SHA384:
    case TEE_TYPE_HMAC_SHA512:
    case TEE_TYPE_GENERIC_SECRET:
        // generate 1 random key
        TEE_GenerateRandom(tr->attr.buffer, ROUND_UP_BITS_TO_BYTE(keySize));
        tr->attr.attr_array[0].content.ref.buffer = tr->attr.buffer;
        tr->attr.attr_array[0].content.ref.length = ROUND_UP_BITS_TO_BYTE(keySize);
        tr->attr.attr_array[0].attributeID = TEE_ATTR_SECRET_VALUE;
        tr->attr.attr_number = 1;
        tr->info.keySize = ROUND_UP_BITS_TO_BYTE(keySize) << 3;
        tr->info.handleFlags |= TEE_HANDLE_FLAG_INITIALIZED;
        break;

    case TEE_TYPE_RSA_KEYPAIR:
    {
        // generate random key for 8 attributes
        int rc;

        RSA *rsa = (RSA*)tr->attr.buffer;
#ifdef USE_SCRYPTO_VER2_4
        rc = RSA_generate_key_fips(rsa, keySize, NULL);
#else
        BIGNUM *e = BN_new();
        BN_set_word(e, 65537);

        rc = RSA_generate_key_ex(rsa, keySize, e, NULL);
        BN_free(e);
#endif
        
        if (!rc) return TEE_ERROR_BAD_PARAMETERS;

        // don't use TEE_PopulateTransientObject as our attributes pointers to internal RSA structure !
        tr->attr.attr_array[0].content.ref.buffer = rsa->n;
        tr->attr.attr_array[0].content.ref.length = BN_num_bytes(rsa->n);
        tr->attr.attr_array[0].attributeID = TEE_ATTR_RSA_MODULUS;

        tr->attr.attr_array[1].content.ref.buffer = rsa->e;
        tr->attr.attr_array[1].content.ref.length = BN_num_bytes(rsa->e);
        tr->attr.attr_array[1].attributeID = TEE_ATTR_RSA_PUBLIC_EXPONENT;

        tr->attr.attr_array[2].content.ref.buffer = rsa->d;
        tr->attr.attr_array[2].content.ref.length = BN_num_bytes(rsa->d);
        tr->attr.attr_array[2].attributeID = TEE_ATTR_RSA_PRIVATE_EXPONENT;

        tr->attr.attr_array[3].content.ref.buffer = rsa->p;
        tr->attr.attr_array[3].content.ref.length = BN_num_bytes(rsa->p);
        tr->attr.attr_array[3].attributeID = TEE_ATTR_RSA_PRIME1;

        tr->attr.attr_array[4].content.ref.buffer = rsa->q;
        tr->attr.attr_array[4].content.ref.length = BN_num_bytes(rsa->q);
        tr->attr.attr_array[4].attributeID = TEE_ATTR_RSA_PRIME2;

        tr->attr.attr_array[5].content.ref.buffer = rsa->dmp1;
        tr->attr.attr_array[5].content.ref.length = BN_num_bytes(rsa->dmp1);
        tr->attr.attr_array[5].attributeID = TEE_ATTR_RSA_EXPONENT1;

        tr->attr.attr_array[6].content.ref.buffer = rsa->dmq1;
        tr->attr.attr_array[6].content.ref.length = BN_num_bytes(rsa->dmq1);
        tr->attr.attr_array[6].attributeID = TEE_ATTR_RSA_EXPONENT2;

        tr->attr.attr_array[7].content.ref.buffer = rsa->iqmp;
        tr->attr.attr_array[7].content.ref.length = BN_num_bytes(rsa->iqmp);
        tr->attr.attr_array[7].attributeID = TEE_ATTR_RSA_COEFFICIENT;

        tr->attr.attr_number = 8;
        tr->info.handleFlags |= TEE_HANDLE_FLAG_INITIALIZED;
    }
    break;

#ifndef OPENSSL_NO_EC
    case TEE_TYPE_ECDSA_KEYPAIR:
    case TEE_TYPE_ECDH_KEYPAIR:
    {
        EC_KEY *ec_key = (EC_KEY *)tr->attr.buffer;
        EC_GROUP *ec_group;
        uint32_t curve_id = 0;
        int curve_id_is_set = 0;
        uint32_t i;
        int nid;

        if (keySize < 192 || keySize > 528) {
            MB_LOGE("Panic Reason: key size is less than 192 "
                    "or bigger than 528\n");
            TEE_Panic(ID_TEE_GenerateKey);
        }

        if (paramCount > 0) {
            if (!params) {
                MB_LOGE("Panic Reason: Parameter TEE_Attribute is NULL\n");
                TEE_Panic(ID_TEE_GenerateKey);
            }
        }

        for (i = 0; i < paramCount; i++) {
            if (params[i].attributeID == TEE_ATTR_ECC_CURVE) {
                curve_id = params[i].content.value.a;
                curve_id_is_set = 1;
                /* Only one attribute is mandatory here - no need for further search */
                break;
            }
        }

        /* Check mandatory parameter */
        if (!curve_id_is_set) {
            MB_LOGE("Panic Reason: curve id isn't set\n");
            TEE_Panic(ID_TEE_GenerateKey);
        }

        if (!ec_determine_nid_by_tee_id(curve_id, &nid))
            return TEE_ERROR_BAD_PARAMETERS; /* incorrect or inconsistent attribute is detected */

        ec_group = EC_GROUP_new_by_curve_name(nid);
        if (!ec_group) {
            MB_LOGE("Panic reason: EC group is NULL\n");
            TEE_Panic(ID_TEE_GenerateKey);
        }

        int rv;
        rv = EC_KEY_set_group(ec_key, ec_group);
        EC_GROUP_free(ec_group);
        ec_group = NULL;

        if (!rv) {
            MB_LOGE("Panic reason: failed to set EC group\n");
            TEE_Panic(ID_TEE_GenerateKey);
        }

        /* Generate the private and public key */
#ifdef USE_SCRYPTO_VER2_4
        if (!EC_KEY_generate_key_fips(ec_key)) {
#else
        if (!EC_KEY_generate_key(ec_key)) {
#endif
            MB_LOGE("Panic reason: failed to generate EC key\n");
            TEE_Panic(ID_TEE_GenerateKey);
        }

        if (!EC_KEY_check_key(ec_key)) {
            MB_LOGE("Panic reason: EC key check failed\n");
            TEE_Panic(ID_TEE_GenerateKey);
        }


        const EC_POINT *Q = EC_KEY_get0_public_key(ec_key);
        BIGNUM *Qx = NULL, *Qy = NULL, *d = NULL;

        d = BN_dup(EC_KEY_get0_private_key(ec_key));
        if (!d) goto err_free;

        Qx = BN_new();
        if (!Qx) goto err_free;

        Qy = BN_new();
        if (!Qy) goto err_free;


        if (!EC_POINT_get_affine_coordinates_GFp(EC_KEY_get0_group(ec_key), Q, Qx, Qy, NULL))
            goto err_free;

        tr->attr.attr_array[0].content.ref.buffer = d;
        tr->attr.attr_array[0].content.ref.length = BN_num_bytes(d);
        tr->attr.attr_array[0].attributeID = TEE_ATTR_ECC_PRIVATE_VALUE;

        tr->attr.attr_array[1].content.ref.buffer = Qx;
        tr->attr.attr_array[1].content.ref.length = BN_num_bytes(Qx);
        tr->attr.attr_array[1].attributeID = TEE_ATTR_ECC_PUBLIC_VALUE_X;

        tr->attr.attr_array[2].content.ref.buffer = Qy;
        tr->attr.attr_array[2].content.ref.length = BN_num_bytes(Qy);
        tr->attr.attr_array[2].attributeID = TEE_ATTR_ECC_PUBLIC_VALUE_Y;

        tr->attr.attr_array[3].content.value.a = curve_id;
        tr->attr.attr_array[3].content.value.b = 0;
        tr->attr.attr_array[3].attributeID = TEE_ATTR_ECC_CURVE;

        tr->attr.attr_number = 4;
        tr->info.handleFlags |= TEE_HANDLE_FLAG_INITIALIZED;
        break;

err_free:
        if (d) BN_clear_free(d);
        if (Qx) BN_free(Qx);
        if (Qy) BN_free(Qy);
        PRINT_OSSL_ERROR_AND_PANIC(ID_TEE_GenerateKey);
    }
    break;
#endif /* OPENSSL_NO_EC */
#ifndef OPENSSL_NO_DSA
    case TEE_TYPE_DSA_KEYPAIR:
    {
        DSA *ctx = (DSA *)tr->attr.buffer;
        BIGNUM *bn_p = NULL, *bn_g = NULL, *bn_q = NULL;
        BIGNUM *bn_p_orig, *bn_g_orig, *bn_q_orig;
        uint32_t check = 0;
        unsigned int i;
        int panic = 1;

        ERR_clear_error(); // clear Open SSL errors;

        if ((keySize < 512)
         || (keySize > 1024 && keySize != 2048 && keySize != 3072)
         || (keySize % 64)) {
            MB_LOGE("Panic reason: keysize less than 512 or bigger than 1024"
                      " or isn't multiple of 64 (for DSA_SHA1); or != 2048 "
                      "(for DSA_SHA224 and DSA_SHA256); or != 3072 (for "
                      "DSA_SHA256)\n");
            TEE_Panic(ID_TEE_GenerateKey);
        }

        bn_p_orig = ctx->p;
        bn_g_orig = ctx->g;
        bn_q_orig = ctx->q;

        if (ctx->p) {
            bn_p = ctx->p;
        } else if (!(bn_p = BN_new())) {
            goto dsa_err;
        }

        if (ctx->g) {
            bn_g = ctx->g;
        } else if (!(bn_g = BN_new())) {
            goto dsa_err;
        }

        if (ctx->q) {
            bn_q = ctx->q;
        } else if (!(bn_q = BN_new())) {
            goto dsa_err;
        }


        for(i = 0; i < paramCount; i++) {
            size_t len;
            const unsigned char *buf;
            len = params[i].content.ref.length;
            buf = params[i].content.ref.buffer;
            if (params[i].attributeID == TEE_ATTR_DSA_PRIME) {
                check |= 0x01;
                if (!BN_bin2bn(buf, len, bn_p))
                    goto dsa_err;
                if (params[i].content.ref.length*8 != keySize) goto dsa_err;
            } else if (params[i].attributeID == TEE_ATTR_DSA_BASE) {
                check |= 0x02;
                if (!BN_bin2bn(buf, len, bn_g))
                    goto dsa_err;
            } else if (params[i].attributeID == TEE_ATTR_DSA_SUBPRIME) {
                check |= 0x04;
                if (!BN_bin2bn(buf, len, bn_q))
                    goto dsa_err;
            }
        }

        if (check != 0x07) {
            goto dsa_err;
        }

        /* Check DSA parameters */
        if (!BN_is_prime_fasttest_ex(bn_q, DSS_prime_checks, NULL, 1, NULL) ||
            !BN_is_prime_fasttest_ex(bn_p, DSS_prime_checks, NULL, 1, NULL)) {
            panic = 0;
            goto dsa_err;
        }

        ctx->p = bn_p;
        ctx->g = bn_g;
        ctx->q = bn_q;

        if (!DSA_generate_key(ctx)) {
            goto dsa_err;
        }

        tr->attr.attr_array[0].content.ref.buffer = ctx->p;
        tr->attr.attr_array[0].content.ref.length = BN_num_bytes(ctx->p);
        tr->attr.attr_array[0].attributeID = TEE_ATTR_DSA_PRIME;

        tr->attr.attr_array[1].content.ref.buffer = ctx->q;
        tr->attr.attr_array[1].content.ref.length = BN_num_bytes(ctx->q);
        tr->attr.attr_array[1].attributeID = TEE_ATTR_DSA_SUBPRIME;

        tr->attr.attr_array[2].content.ref.buffer = ctx->g;
        tr->attr.attr_array[2].content.ref.length = BN_num_bytes(ctx->g);
        tr->attr.attr_array[2].attributeID = TEE_ATTR_DSA_BASE;

        tr->attr.attr_array[3].content.ref.buffer = ctx->priv_key;
        tr->attr.attr_array[3].content.ref.length = BN_num_bytes(ctx->priv_key);
        tr->attr.attr_array[3].attributeID = TEE_ATTR_DSA_PRIVATE_VALUE;

        tr->attr.attr_array[4].content.ref.buffer = ctx->pub_key;
        tr->attr.attr_array[4].content.ref.length = BN_num_bytes(ctx->pub_key);
        tr->attr.attr_array[4].attributeID = TEE_ATTR_DSA_PUBLIC_VALUE;

        tr->attr.attr_number = 5;
        tr->info.handleFlags |= TEE_HANDLE_FLAG_INITIALIZED;

        break;

    dsa_err:
        ctx->p = bn_p_orig;
        ctx->g = bn_g_orig;
        ctx->q = bn_q_orig;

        if (bn_p && !ctx->p) BN_free(bn_p);
        if (bn_g && !ctx->g) BN_free(bn_g);
        if (bn_q && !ctx->q) BN_free(bn_q);

        if (CHECK_OSSL_MALLOC_FAILURE) {
            PRINT_OSSL_ERROR();
            return TEE_ERROR_OUT_OF_MEMORY;
        }
        if (panic) {
            TEE_Panic(ID_TEE_GenerateKey);
        }
        return TEE_ERROR_BAD_PARAMETERS;
    }
    break;
#endif /* OPENSSL_NO_DSA */

#ifndef OPENSSL_NO_DH
    case TEE_TYPE_DH_KEYPAIR:
    {
        DH *dh = (DH *)tr->attr.buffer;
        BIGNUM *bn_p = NULL, *bn_g = NULL, *bn_q = NULL;
        BIGNUM *bn_p_orig, *bn_g_orig, *bn_q_orig;
        uint32_t check = 0;
        int check_dh_params;
        uint32_t dh_x_bits = 0;
        int panic = 1;
        uint32_t i;

        ERR_clear_error(); // clear Open SSL errors;

        if (keySize < 256 || keySize > 2048) {
            MB_LOGE("Panic reason: keysize less than 256 "
                    "or bigger than 2048\n");
            TEE_Panic(ID_TEE_GenerateKey);
        }

        bn_p_orig = dh->p;
        bn_g_orig = dh->g;
        bn_q_orig = dh->q;

        if (dh->p) {
            bn_p = dh->p;
        } else if (!(bn_p = BN_new())) {
            goto dh_err;
        }

        if (dh->g) {
            bn_g = dh->g;
        } else if (!(bn_g = BN_new())) {
            goto dh_err;
        }

        for(i = 0; i < paramCount; i++) {
            if (params[i].attributeID == TEE_ATTR_DH_X_BITS) {
                check |= 0x08;
                dh_x_bits = params[i].content.value.a;
                continue;
            }

            size_t len;
            const unsigned char *buf;
            len = params[i].content.ref.length;
            buf = params[i].content.ref.buffer;
            if (params[i].attributeID == TEE_ATTR_DH_PRIME) {
                check |= 0x01;
                if (!BN_bin2bn(buf, len, bn_p))
                    goto dh_err;
                if (keySize > tr->info.maxKeySize) goto dh_err;
            } else if (params[i].attributeID == TEE_ATTR_DH_BASE) {
                check |= 0x02;
                if (!BN_bin2bn(buf, len, bn_g))
                    goto dh_err;
            } else if (params[i].attributeID == TEE_ATTR_DH_SUBPRIME) {
                check |= 0x04;
                if (dh->q) bn_q = dh->q;
                else if (!(bn_q = BN_new())) goto dh_err;
                if (!BN_bin2bn(buf, len, bn_q))
                    goto dh_err;
            }
        }

        /* Check mandatory parameters */
        if (!(check & 0x01) || !(check & 0x02)) {
            goto dh_err;
        }

        dh->p = bn_p;
        dh->g = bn_g;

        /* Check that dh->p is a safe prime and dh->g is a suitable generator.
         * Note: FIPS Cryptocore's check supports only 2 and 5 generators. */
        if (!DH_check(dh, &check_dh_params)) {
            goto dh_err;
        }
        if (check_dh_params) {
            panic = 0;
            goto dh_err;
        }

        if (check & 0x04) {
            dh->q = bn_q;
        } else if (dh->q) {
            BN_free(dh->q);
            bn_q_orig = dh->q = NULL;
        }

        if (check & 0x08) {
#ifndef BORING_SSL
            dh->length = dh_x_bits;
#else
            /* Suppress "unused variable" compiler error in release buid
             * when SCrypto is used */
            (void)dh_x_bits;
            MB_LOGD("Skipped setting x bits %u!\n", dh_x_bits);
#endif
        }

        panic = 1;
        if (!DH_generate_key(dh)) {
            goto dh_err;
        }

        tr->attr.attr_array[0].content.ref.buffer = dh->p;
        tr->attr.attr_array[0].content.ref.length = BN_num_bytes(dh->p);
        tr->attr.attr_array[0].attributeID = TEE_ATTR_DH_PRIME;

        tr->attr.attr_array[1].content.ref.buffer = dh->g;
        tr->attr.attr_array[1].content.ref.length = BN_num_bytes(dh->g);
        tr->attr.attr_array[1].attributeID = TEE_ATTR_DH_BASE;

        tr->attr.attr_array[2].content.ref.buffer = dh->priv_key;
        tr->attr.attr_array[2].content.ref.length = BN_num_bytes(dh->priv_key);
        tr->attr.attr_array[2].attributeID = TEE_ATTR_DH_PRIVATE_VALUE;

        tr->attr.attr_array[3].content.ref.buffer = dh->pub_key;
        tr->attr.attr_array[3].content.ref.length = BN_num_bytes(dh->pub_key);
        tr->attr.attr_array[3].attributeID = TEE_ATTR_DH_PUBLIC_VALUE;

        tr->attr.attr_array[4].content.value.a = BN_num_bits(dh->priv_key);
        tr->attr.attr_array[4].content.value.b = 0;
        tr->attr.attr_array[4].attributeID = TEE_ATTR_DH_X_BITS;

        tr->attr.attr_number = 5;

        if (dh->q) {
            tr->attr.attr_array[5].content.ref.buffer = dh->q;
            tr->attr.attr_array[5].content.ref.length = BN_num_bytes(dh->q);
            tr->attr.attr_array[5].attributeID = TEE_ATTR_DH_SUBPRIME;
            tr->attr.attr_number = 6;
        }
        tr->info.handleFlags |= TEE_HANDLE_FLAG_INITIALIZED;
        break;
    dh_err:
        dh->p = bn_p_orig;
        dh->g = bn_g_orig;
        dh->q = bn_q_orig;

        if (bn_p && !dh->p) BN_free(bn_p);
        if (bn_g && !dh->g) BN_free(bn_g);
        if (bn_q && !dh->q) BN_free(bn_q);

        if (CHECK_OSSL_MALLOC_FAILURE) {
            PRINT_OSSL_ERROR();
            return TEE_ERROR_OUT_OF_MEMORY;
        }

        if (panic) {
            PRINT_OSSL_ERROR_AND_PANIC(ID_TEE_GenerateKey);
        }
        return TEE_ERROR_BAD_PARAMETERS;
        break;
    }
#endif /* OPENSSL_NO_DH */
    }

    return TEE_SUCCESS;
}
