#include "kg_dh.h"
#include "kg_state.h"

static unsigned char dhp_2048[] = {
        0xFF,0xFF,0xFF,0xFF,0xFF,0xFF,0xFF,0xFF,
        0xC9,0x0F,0xDA,0xA2,0x21,0x68,0xC2,0x34,
        0xC4,0xC6,0x62,0x8B,0x80,0xDC,0x1C,0xD1,
        0x29,0x02,0x4E,0x08,0x8A,0x67,0xCC,0x74,
        0x02,0x0B,0xBE,0xA6,0x3B,0x13,0x9B,0x22,
        0x51,0x4A,0x08,0x79,0x8E,0x34,0x04,0xDD,
        0xEF,0x95,0x19,0xB3,0xCD,0x3A,0x43,0x1B,
        0x30,0x2B,0x0A,0x6D,0xF2,0x5F,0x14,0x37,
        0x4F,0xE1,0x35,0x6D,0x6D,0x51,0xC2,0x45,
        0xE4,0x85,0xB5,0x76,0x62,0x5E,0x7E,0xC6,
        0xF4,0x4C,0x42,0xE9,0xA6,0x37,0xED,0x6B,
        0x0B,0xFF,0x5C,0xB6,0xF4,0x06,0xB7,0xED,
        0xEE,0x38,0x6B,0xFB,0x5A,0x89,0x9F,0xA5,
        0xAE,0x9F,0x24,0x11,0x7C,0x4B,0x1F,0xE6,
        0x49,0x28,0x66,0x51,0xEC,0xE4,0x5B,0x3D,
        0xC2,0x00,0x7C,0xB8,0xA1,0x63,0xBF,0x05,
        0x98,0xDA,0x48,0x36,0x1C,0x55,0xD3,0x9A,
        0x69,0x16,0x3F,0xA8,0xFD,0x24,0xCF,0x5F,
        0x83,0x65,0x5D,0x23,0xDC,0xA3,0xAD,0x96,
        0x1C,0x62,0xF3,0x56,0x20,0x85,0x52,0xBB,
        0x9E,0xD5,0x29,0x07,0x70,0x96,0x96,0x6D,
        0x67,0x0C,0x35,0x4E,0x4A,0xBC,0x98,0x04,
        0xF1,0x74,0x6C,0x08,0xCA,0x18,0x21,0x7C,
        0x32,0x90,0x5E,0x46,0x2E,0x36,0xCE,0x3B,
        0xE3,0x9E,0x77,0x2C,0x18,0x0E,0x86,0x03,
        0x9B,0x27,0x83,0xA2,0xEC,0x07,0xA2,0x8F,
        0xB5,0xC5,0x5D,0xF0,0x6F,0x4C,0x52,0xC9,
        0xDE,0x2B,0xCB,0xF6,0x95,0x58,0x17,0x18,
        0x39,0x95,0x49,0x7C,0xEA,0x95,0x6A,0xE5,
        0x15,0xD2,0x26,0x18,0x98,0xFA,0x05,0x10,
        0x15,0x72,0x8E,0x5A,0x8A,0xAC,0xAA,0x68,
        0xFF,0xFF,0xFF,0xFF,0xFF,0xFF,0xFF,0xFF
};

static unsigned char dhg_2048[] = {
		0x02
};
#if 0
void DH_get0_pqg(const DH *dh,
                 const BIGNUM **p, const BIGNUM **q, const BIGNUM **g)
 {
    if (p != NULL)
        *p = dh->p;
    if (q != NULL)
        *q = dh->q;
    if (g != NULL)
        *g = dh->g;
 }


int DH_set0_pqg(DH *dh, BIGNUM *p, BIGNUM *q, BIGNUM *g)
{
   /* If the fields p and g in d are NULL, the corresponding input
    * parameters MUST be non-NULL.  q may remain NULL.
    */
   if ((dh->p == NULL && p == NULL)
       || (dh->g == NULL && g == NULL))
       return 0;

   if (p != NULL) {
       BN_free(dh->p);
       dh->p = p;
   }
   if (q != NULL) {
       BN_free(dh->q);
       dh->q = q;
   }
   if (g != NULL) {
       BN_free(dh->g);
       dh->g = g;
   }

   if (q != NULL) {
       dh->length = BN_num_bits(q);
   }

   return 1;
}


void DH_get0_key(const DH *dh, const BIGNUM **pub_key, const BIGNUM **priv_key)
 {
    if (pub_key != NULL)
        *pub_key = dh->pub_key;
    if (priv_key != NULL)
        *priv_key = dh->priv_key;
 }

 int DH_set0_key(DH *dh, BIGNUM *pub_key, BIGNUM *priv_key)
 {
    /* If the field pub_key in dh is NULL, the corresponding input
     * parameters MUST be non-NULL.  The priv_key field may
     * be left NULL.
     */
    if (dh->pub_key == NULL && pub_key == NULL)
        return 0;

    if (pub_key != NULL) {
        BN_free(dh->pub_key);
        dh->pub_key = pub_key;
    }
    if (priv_key != NULL) {
        BN_free(dh->priv_key);
        dh->priv_key = priv_key;
    }

    return 1;
 }
#endif 

// TODO: change the log function
static DH *get_RFC_dh2048(){
    DH *dh = DH_new();
    BIGNUM *dhp_bn, *dhg_bn;
    if (dh == NULL){
        return NULL;
    }
    dhp_bn = BN_bin2bn(dhp_2048, sizeof (dhp_2048), NULL);
    dhg_bn = BN_bin2bn(dhg_2048, sizeof (dhg_2048), NULL);

    if (dhp_bn == NULL || dhg_bn == NULL
            || !DH_set0_pqg(dh, dhp_bn, NULL, dhg_bn)) {
        DH_free(dh);
        BN_free(dhp_bn);
        BN_free(dhg_bn);
        return NULL;
    }
    return dh;
}


#if 0 

void DBG_dump_dh_keys(DH *dh){
    const BIGNUM *pub_key = NULL, *priv_key = NULL;

    if(!dh) return; 
    
    DH_get0_key(dh, &pub_key, &priv_key);
    if(NULL != priv_key){
        char * priv_key_str  = BN_bn2hex(priv_key);
        KG_LOG_DBG("/**************Private Key********************/\n");
        KG_LOG_DBG("%s\n", priv_key_str);
        KG_LOG_DBG("\n");

        OPENSSL_free(priv_key_str);
    }

    if(NULL != pub_key){
        char * pub_key_str  = BN_bn2hex(pub_key);
        KG_LOG_DBG("/**************Public Key********************/\n");
        KG_LOG_DBG("%s\n", pub_key_str);
        KG_LOG_DBG("\n");

        OPENSSL_free(pub_key_str);
    }
    return;
}



BIGNUM* DBG_KG_get_server_pub(){
    DH *dh;
    dh = get_RFC_dh2048(); 
    if(!dh){
        KG_LOG_DBG("error generate dh\n");
        return NULL; 
    }
    if(!DH_generate_key(dh)){
        DH_free(dh); 
        return NULL; 
    }

    const BIGNUM *pub_key, *priv_key;
    DH_get0_key(dh, &pub_key, &priv_key);

    if(NULL != pub_key){
        char * pub_key_str  = BN_bn2hex(pub_key);
        KG_LOG_DBG("/**************Server Public Key********************/\n");
        KG_LOG_DBG("%s\n", pub_key_str);
        KG_LOG_DBG("\n");

        OPENSSL_free(pub_key_str);
    }
    BIGNUM *key = BN_dup(pub_key);

    DH_free(dh);
    return key; 
}

#endif 

/* dh take over the ownership of priv_key, do not free it outside */
DH* KG_reload_dh(BIGNUM *priv_key){

    DH * dh = get_RFC_dh2048(); 
    if(!dh){
        KG_LOG_DBG("Error in generating dh\n");
        return NULL; 
    }

    /*1. read saved priv_key from RPMB */ 
    
    if(!DH_set0_key(dh, NULL, priv_key)){
        /* for backward compatibility, setting the dummy public key */
        BIGNUM *tmp_bn = BN_new(); 
        if(!BN_set_word(tmp_bn, 0)){
            if(!DH_set0_key(dh, tmp_bn, priv_key)){
                /*should not reach here, something serious wrong */
                KG_LOG_DBG("setting the dummy pub key failed!\n"); 
                DH_free(dh); 
                BN_free(tmp_bn); 
                BN_free(priv_key); 
                return NULL; 
            }
        }
        KG_LOG("set the dummpy pub key!\n");
    } 
    return dh;
}

/*server_pub is 257 bits Big Endian encodin with leading 0 padding*/
int KG_dh_gen_shared_key(uint8_t *server_pub, int len, uint8_t *dh_secret, uint32_t *dh_secret_len, kg_dh_data_t* dh_data){
    DH *dh = NULL; 
    BIGNUM *server_pub_bn = NULL;
    BIGNUM *priv_bn = NULL; 
    unsigned char *secret = NULL;
    int dh_len = 0;

    if(server_pub == NULL || KG_DH_PUB_KEY_LEN != len){
        KG_LOG("Error allocate BIGNUM\n"); 
        return KG_DH_GEN_SECRET_FAIL;
    }

    /*convert the bin to bn */ 
    if(NULL == (server_pub_bn = BN_bin2bn(server_pub, len, NULL))){
        KG_LOG("Error convert server_pub\n");
        goto err;
    }

    //TODO : get private key from RPMB and convert to BIGNUM 

    if(DH_STATE_ERROR == dh_data->dh_state || DH_STATE_PRE == dh_data->dh_state){
        KG_LOG_DBG("Wrong dh states\n");
        goto err;
    }
    KG_DUMP_DBG("Retrived priv key byte: \n", (uint8_t *)&(dh_data->dh_key), KG_DH_PRIV_KEY_LEN);

    if(NULL == (priv_bn = BN_bin2bn(dh_data->dh_key, KG_DH_PRIV_KEY_LEN, NULL))){
        KG_LOG("Error convert priv byte\n");
        goto err;
    }
    dh = KG_reload_dh(priv_bn);
    if(NULL == dh){
        KG_LOG_DBG("Error reloading the DH!\n");
        goto err;
    }
    
    if (BN_is_zero(server_pub_bn) ) {
        KG_LOG("Server pub is zero\n");
        goto err;
    }

    KG_LOG_DBG("Generating DH secret key \n");
    
    
    dh_len = DH_size(dh);
    KG_LOG_DBG("DH size: %d\n", dh_len); 
    secret = TEE_Malloc(sizeof(uint8_t) * (dh_len), 0);
    if(NULL == secret){
        goto err;
    }
    
    int secret_size = DH_compute_key(secret, server_pub_bn, dh); 
    if(secret_size <= 0){
        KG_LOG("secret_size %d\n", secret_size); 
        goto err; 
    }

    KG_LOG_DBG("secret_size %d\n", secret_size);

    if (secret_size > (int)*dh_secret_len) {
        KG_LOG_DBG("dh_secret Buffer length check failed\n");
        goto err;
    }

    *dh_secret_len = secret_size;
    TEE_MemMove(dh_secret, (void *)secret, secret_size);
    
    return KG_SUCCESS;
    
err: 
    /* private key should be part of DH, will be freed by DH */
    
    if(NULL != server_pub_bn){
        BN_free(server_pub_bn); 
        server_pub_bn = NULL; 
    }
    if(NULL != priv_bn){
        BN_free(priv_bn); 
        priv_bn = NULL; 
    }
    if(NULL != secret){
        TEE_MemFill(secret, 0, sizeof(uint8_t) * (dh_len));
        TEE_Free(secret); 
        secret = NULL; 
    }
    if(NULL != dh){
        DH_free(dh); 
        dh = NULL; 
    }
    return KG_DH_GEN_SECRET_FAIL;
    
}

/* out_pub and out_priv should be unallocated space
 *return KG_SUCCESS = 0 upon success, otherwise return error code
 */
int KG_dh_get_keypair(uint8_t **out_pub, int *out_pub_len, uint8_t **out_priv, int *out_priv_len){

    DH *dh; 
    const BIGNUM *pub_key;
    const BIGNUM *priv_key;

    uint8_t *pub_bin = NULL; 
    uint8_t *priv_bin = NULL; 
    int pub_size;
    int priv_size;

    dh = KG_dh_gen_dh_keys(); 
    if(NULL == dh){
        return KG_DH_GEN_DH_FAIL; 
    }
    
    DH_get0_key(dh, &pub_key, &priv_key);
    
    if (pub_key == NULL || priv_key == NULL
        ||  BN_num_bytes(pub_key) == 0 
        ||  BN_num_bytes(priv_key) == 0)  {
        goto err;
    }

    if(NULL == (priv_bin = (uint8_t*)TEE_Malloc(KG_DH_PRIV_KEY_LEN * sizeof(uint8_t), 0))){
        KG_LOG("Memory alloc failed\n"); 
        goto err; 
    }
    
    if(0 == (BN_bn2bin_padded(priv_bin, KG_DH_PRIV_KEY_LEN,  priv_key))){
        KG_LOG_DBG("Covert private key failed\n"); 
        goto err; 
    }
 
     if(NULL == (pub_bin = (uint8_t*)TEE_Malloc(KG_DH_PUB_KEY_LEN * sizeof(uint8_t), 0))){
        KG_LOG("Memory alloc failed\n"); 
        goto err; 
    }
    
    if(0 == (BN_bn2bin_padded(pub_bin, KG_DH_PUB_KEY_LEN,  pub_key))){
        KG_LOG_DBG("Covert public key failed\n"); 
        goto err; 
    }
        
    *out_pub_len = KG_DH_PUB_KEY_LEN;
    *out_priv_len = KG_DH_PRIV_KEY_LEN; 
    *out_pub = pub_bin; 
    *out_priv = priv_bin; 
    
    DH_free(dh);

    return KG_SUCCESS; 
    
err: 
    
    if(NULL != pub_bin){
        TEE_MemFill(pub_bin, 0, KG_DH_PUB_KEY_LEN);
        TEE_Free(pub_bin); 
        *out_pub = NULL; 
    }
    if(NULL != priv_bin){
        TEE_MemFill(priv_bin, 0, KG_DH_PRIV_KEY_LEN);
        TEE_Free(priv_bin);
        *out_priv = NULL; 
    }
    if(NULL != dh){
        DH_free(dh);
        dh = NULL;
    }

    return KG_DH_GEN_KEYPAIR_FAIL; 
}  


DH* KG_dh_gen_dh_keys(){
    DH *dh;
    int codes;
    
    dh = get_RFC_dh2048(); 
    if(!dh){
        KG_LOG_DBG("Error generate DH\n");
        return NULL; 
    }
    
#if 0 //comment out due the check has slaggish performance issue, should check the key is OK before release! 
    if(!DH_check(dh, &codes)){
        DH_free(dh);
        KG_LOG("Error in DH checking\n");
        return NULL; //dh check error.
    }else{
        const BIGNUM *p, *g; 
        DH_get0_pqg(dh, &p, NULL, &g);

        /*!!!
         * relax the test such that accpet p(mod)24==23 when g=2
         * http://gmssl.org/docs/evp-api.html
         * https://tools.ietf.org/html/rfc2412#page-45
         */
        if (BN_is_word(g, DH_GENERATOR_2)) {
            long residue = BN_mod_word(p, 24);
            if (residue == 11 || residue == 23) {
                codes &= ~DH_NOT_SUITABLE_GENERATOR;
            }
            
            /*add return value and return in the following block */
            if (codes & DH_UNABLE_TO_CHECK_GENERATOR) {
                KG_LOG("DH_check: failed to test generator");
            }else if (codes & DH_NOT_SUITABLE_GENERATOR) {
                KG_LOG("DH_check: not a suitable generator");
            }else if (codes & DH_CHECK_P_NOT_PRIME) {
                KG_LOG("DH_check: not a prime");
            }else if (codes & DH_CHECK_P_NOT_SAFE_PRIME) {
                KG_LOG("DH_check: not a safe prime");
            }  
        }
    }
    if(0 != codes){
        KG_LOG("Error in DH checking with codes %d\n", codes);
        DH_free(dh);
        return  NULL; 
    }
#endif 

    KG_LOG("DH key generation\n"); 
    /* DH_generate_key() expects dh to contain the shared parameters dh->p and dh->g.
    It generates a random private DH value unless dh->priv_key is already set,
    and computes the corresponding public value dh->pub_key,
    which can then be published. */ 
    if(!DH_generate_key(dh)){
        KG_LOG_DBG("Error in DH key generation\n");
        DH_free(dh); 
        return NULL; 
    }

    return dh; 
}


