#include "kg_x509.h"

X509 *parse_pem_cert(char *pem_string, BIO * cert_bio)
{
    KG_LOG("parse_pem_cert()\n");

    size_t certLen = strlen(pem_string);
    BIO_write(cert_bio, pem_string, certLen);
    X509 *certX509 = PEM_read_bio_X509(cert_bio, NULL, NULL, NULL);

    if (NULL == certX509) {
        KG_LOG("parse_pem_cert failed!\n");
    } else {
        KG_LOG("parse_pem_cert success!\n");
    }

    return certX509;
}

RSA* PEM_read_RSA_pub_key(const char* pem)
{
    RSA *rsa = NULL;
    BIO *certbio = NULL;
//    BIO *keybio = NULL;
    X509 *cert = NULL; 
    EVP_PKEY *pkey = NULL;
    
    certbio = BIO_new_mem_buf(pem, -1);
    if (NULL == certbio)
    {
        KG_LOG("Failed to create cert BIO\n");
        goto err;
    }
    
    if (NULL == (cert = PEM_read_bio_X509(certbio, NULL, 0, NULL))) {
        KG_LOG("Error loading x509 cert into memory\n");
        goto err;
    }


    if((pkey = X509_get_pubkey(cert)) == NULL){
        KG_LOG("Error loading pkey into memory\n");
        goto err;
    }
    
    rsa = EVP_PKEY_get1_RSA(pkey);
    if(NULL == rsa)   
    {   
        KG_LOG("Failed to create RSA key\n");
        goto err;
    }

    X509_free(cert);
    EVP_PKEY_free(pkey); 
    BIO_free(certbio);
//    BIO_free(keybio); 
    return rsa;

err: 
    if(NULL != cert)
        X509_free(cert);

    if(NULL != pkey)
        EVP_PKEY_free(pkey); 

    if(NULL != certbio)
        BIO_free(certbio);

//    if(NULL != keybio)
//        BIO_free(keybio); 
    
    return NULL;
}

EVP_PKEY *get_public_key(const char *pem_cert)
{
    KG_LOG("get_public_key()\n");

    BIO *reqbio = NULL;
    X509 *certreq = NULL;

    reqbio = BIO_new_mem_buf(pem_cert, -1);
    if (!(certreq = PEM_read_bio_X509(reqbio, NULL, NULL, NULL))) {
        KG_LOG("Error read bio\n");
        return NULL;
    }

    EVP_PKEY *pkey = X509_get_pubkey(certreq);
    BIO_free(reqbio);
    X509_free(certreq);
    return pkey;
}

uint32_t extract_public_keybytes(const char *pem_cert, uint8_t *pub_mod, 
    uint32_t *pub_mod_len, uint8_t *pub_exp, uint32_t *pub_exp_len)
{
    uint32_t ret = KG_CRYPTO_PKEY_PARSE_FAIL;
    EVP_PKEY *pkey = NULL;
    RSA *rsakey = NULL;
    const BIGNUM *n;
    const BIGNUM *e;

    if (NULL == pub_mod || NULL == pub_exp || *pub_mod_len < 512 || *pub_exp_len < 3) {
        KG_LOG("Received invalid result buffer\n");
        goto exit;
    }

    pkey = get_public_key(pem_cert);
    if (NULL == pkey) {
        KG_LOG("Invalid pem cert to extract the public key\n");
        goto exit;
    }

    rsakey = EVP_PKEY_get1_RSA(pkey);
    if (NULL == rsakey) {
        KG_LOG("Failed to create RSA key\n");
        goto exit;
    }

    TEE_MemFill(pub_mod, 0x0, *pub_mod_len);
    TEE_MemFill(pub_exp, 0x0, *pub_exp_len);

    RSA_get0_key(rsakey, &n, NULL, NULL);
    BN_bn2bin(n, pub_mod);
    RSA_get0_key(rsakey, NULL, &e, NULL);
    BN_bn2bin(e, pub_exp);

    if (pub_mod[511] == 0) {
        *pub_mod_len = 256;
    } else {
        *pub_mod_len = 512;
    }

    *pub_exp_len = 3;

    ret = KG_SUCCESS;
exit:
    if (pkey != NULL) {
        EVP_PKEY_free(pkey);
    }
    if (rsakey != NULL) {
        RSA_free(rsakey);
    }
    return ret;
}

