/**
* \file keyManager.c
* \brief Key management functions.
* \author Dmytro Podgornyi (d.podgornyi@samsung.com)
* \version 0.1
* \date Created May 28, 2013
* \par In Samsung Ukraine R&D Center (SURC) under a contract between
* \par LLC "Samsung Electronics Ukraine Company" (Kiev, Ukraine) and
* \par "Samsung Elecrtronics Co", Ltd (Seoul, Republic of Korea)
* \par Copyright: (c) Samsung Electronics Co, Ltd 2012. All rights reserved.
**/

#include <stdint.h>
#include "CommLayerData.h"
#include "commonConfig.h"
#include "cryptoPlatform.h"
#include "secMemoryManager.h"
#include "teeCryptoApi.h"
#include "log.h"
#include "keyManager.h"
#include "asn1gen.h"
#include "sha/sha.h"
#include "rsa/rsa.h"
#include "asn1rsa.h"
#include "asn1ec.h"
#include "asn1build_ec.h"
#include "asn1build_rsa.h"
#include "objects/obj_mac.h"

typedef int32_t (*asn1_build_keypair_func)(uint8_t*, uint32_t*, const crypto_t);
typedef int32_t (*asn1_build_pub_func)(const crypto_t, uint8_t*, uint32_t*);

typedef int32_t (*asn1_parse_pub_func)(const uint8_t *, size_t, uint8_t **, uint32_t*, 
                                                                uint8_t **, uint32_t*);
typedef int32_t (*asn1_parse_keypair)(const uint8_t *, size_t, uint8_t **, uint32_t*, 
                                        uint8_t **, uint32_t*, uint8_t **, uint32_t*);


typedef struct __asn1_handlers_t {
    asn1_build_keypair_func     build_keypair;
    asn1_build_pub_func         build_public;
    asn1_parse_pub_func         parse_public;
    asn1_parse_keypair          parse_keypair;
} ASN1_API;


struct key_t {
    uint32_t type;
    crypto_t key;
    CRYPTO_API crypto;
    ASN1_API asn1;
};


static ASN1_API asn1_rsa = {
    asn1_build_keypair_rsa,
    asn1_build_pub_rsa,
    rsa_public_key_parse,
    rsa_keypair_parse
};

static ASN1_API asn1_ec = {
    asn1_build_pri_ec,
    asn1_build_pub_ec,
    ec_public_key_parse,
    ec_keypair_parse
};

ASN1_API ASN1_get_api(uint32_t type)
{
    if (type == RSA_KEY)
        return asn1_rsa;
    else if (type == ECC_KEY)
        return asn1_ec;
    else 
        LOGE("Unknown key type: %d", type);
    return asn1_rsa; // default
}

KEY* KEY_new(uint32_t type)
{
    KEY* ret = (KEY*)secMemoryManagerMalloc(sizeof(KEY));
    if (!ret)
    {
        LOGE("KEY_new: bad pointer allocated for key");
        return NULL;
    }
    memset(ret, 0, sizeof(KEY));
    ret->type = type;
    ret->crypto = CRYPTO_get_api(type);
    ret->asn1 = ASN1_get_api(type);

    if (!ret->crypto.new)
    {
        LOGE("KEY_new: bad pointer to key allocator function");
        secMemoryManagerFree(ret);
        return NULL;
    }
    ret->key = ret->crypto.new();
    if (ret->key == NULL)
    {
        LOGE("KEY_new: bad pointer allocated for key data structure");
        secMemoryManagerFree(ret);
        return NULL;
    }
    return ret;
}

void KEY_free(KEY* key)
{
    if (!key || !key->crypto.free)
        return;
    key->crypto.free(key->key);
    secMemoryManagerFree(key);
}

int32_t KEY_signature_size(const KEY* key)
{
    if (!key || !key->crypto.size)
        return 0;
    return key->crypto.size(key->key);
}

int32_t KEY_generate_key(KEY* key, const struct KeyGenInfo *info)
{
    if (!key || !key->crypto.generate || !info)
        return WRONG_DATA;
    return key->crypto.generate(key->key, info) ? NOT_ERROR : RSA_GEN_ERROR;
}

static int32_t KEY_parse_public(const KEY *key, const uint8_t *buf, size_t len,
                          uint8_t **part1, uint32_t *part1len, 
                          uint8_t **part2, uint32_t *part2len)
{
    if (!key || !key->asn1.parse_public)
        return 0;
    return key->asn1.parse_public(buf, len, part1, part1len, part2, part2len);
}

static int32_t KEY_parse_keypair(const KEY *key, const uint8_t *buf, size_t len,
                          uint8_t ** privkey, uint32_t * privkeylen, 
                          uint8_t ** pubkey1, uint32_t * pubkey1len, 
                          uint8_t ** pubkey2, uint32_t * pubkey2len)
{
    if (!key || !key->asn1.parse_keypair)
        return 0;
    return key->asn1.parse_keypair(buf, len, privkey, privkeylen, pubkey1, pubkey1len, pubkey2, pubkey2len);
}

int32_t KEY_populate_keys(KEY* key, const uint8_t *pubkey, size_t pubkeylen,
                          const uint8_t *prkey, size_t prkeylen)
{
    int32_t res = NOT_ERROR;
    uint8_t *pubkey1 = NULL;
    uint32_t pubkey1len = 0;
    uint8_t *pubkey2 = NULL;
    uint32_t pubkey2len = 0;
    uint8_t* privkey = NULL;
    uint32_t privkeylen = 0;

    struct asn1_hdr hdr = {0};
    const uint8_t *key_end = NULL;
    const uint8_t *pos = NULL;

    if (!key || !key->crypto.populate)
        return WRONG_DATA;
    if (NULL != pubkey)
    {
        // Parse algorithm identifier
        if (asn1_get_next(pubkey, pubkeylen, &hdr) || hdr.class != ASN1_CLASS_UNIVERSAL ||
            hdr.tag != ASN1_TAG_SEQUENCE)
        {
            return WRONG_RSA_CERT;
        }
        key_end = hdr.payload + hdr.length;
        if (asn1_get_next(hdr.payload, hdr.length, &hdr) < 0 || hdr.class != ASN1_CLASS_UNIVERSAL || 
            hdr.tag != ASN1_TAG_SEQUENCE)
        {
            return WRONG_RSA_CERT;
        }
        pos = hdr.payload + hdr.length;
        if (asn1_get_next(hdr.payload, hdr.length, &hdr) < 0 || hdr.class != ASN1_CLASS_UNIVERSAL || 
            hdr.tag != ASN1_TAG_OID)
        {
            return WRONG_RSA_CERT;
        }
        if (asn1_get_next(pos, key_end - pos, &hdr) < 0 || hdr.class != ASN1_CLASS_UNIVERSAL || 
            hdr.tag != ASN1_TAG_BITSTRING)
        {
            return WRONG_RSA_CERT;
        }
        pos = hdr.payload + 1;

        res = KEY_parse_public(key, pos, key_end - pos, &pubkey1, &pubkey1len, 
                                                       &pubkey2, &pubkey2len);
        if (res != NOT_ERROR)
        {
            LOGE("Wrong public key");
            return WRONG_RSA_CERT;
        }
    }
    if (NULL != prkey)
    {
        res = KEY_parse_keypair(key, prkey, prkeylen, &privkey, &privkeylen, NULL, NULL, NULL, NULL);
        if (res != NOT_ERROR)
        {
            LOGE("Wrong private key");
            return WRONG_PRIV_KEY;
        }
    }

    res = key->crypto.populate(&key->key, pubkey1, pubkey1len, pubkey2, pubkey2len,
                               privkey, privkeylen);
    if (res != NOT_ERROR)
    {
        LOGE("Can't populate keys: %d", res);
    }
    return res;
}

int32_t KEY_populate_keypair(KEY* key, const uint8_t *keypair, size_t keypairLen)
{
    int32_t res = NOT_ERROR;
    uint8_t *pubkey1 = NULL;
    uint32_t pubkey1len = 0;
    uint8_t *pubkey2 = NULL;
    uint32_t pubkey2len = 0;
    uint8_t* privkey = NULL;
    uint32_t privkeylen = 0;

    if (!key || !key->crypto.populate)
    {
        LOGE("Bad populate function ptr registered in KEY data");
        return WRONG_DATA;
    }
    if (!keypair || !keypairLen)
    {
        LOGE("Bad ptr to keypair");
        return WRONG_DATA;
    }

    res = KEY_parse_keypair(key, keypair, keypairLen, &privkey, &privkeylen, 
                             &pubkey1, &pubkey1len, &pubkey2, &pubkey2len);
    if (res != NOT_ERROR)
    {
        LOGE("Wrong public key");
        return WRONG_RSA_CERT;
    }

    res = key->crypto.populate(&key->key, pubkey1, pubkey1len, pubkey2, pubkey2len,
                               privkey, privkeylen);
    if (res != NOT_ERROR)
    {
        LOGE("Can't populate keys: %d", res);
    }
    return res;
}

int32_t KEY_check_keypair(KEY* key)
{
    /* SHA256 digest to check */
    uint8_t digest[SHA256_DIGEST_LENGTH] = {0};
    uint8_t* sign = NULL;
    uint32_t signlen = KEY_signature_size(key);
    int32_t ret = NOT_ERROR;

    if (getRandBlock(digest, sizeof(digest)) != sizeof(digest))
    {
        /* use hardcoded string if random fails */
        strncpy((char *)digest, "1234567890qwertyuiop[]asdfghjkl", sizeof(digest));
    }
    sign = secMemoryManagerMalloc(signlen);
    if ( NULL == sign)
    {
        LOGD("Signature size (%d) allocation error", signlen);
        return SEC_ALLOC_ERROR;
    }

    memset(sign, 0, signlen);
    
    ret = KEY_sign(key, NID_sha256, digest, SHA256_DIGEST_LENGTH, sign, &signlen);
    if (ret != NOT_ERROR)
    {
        LOGD("KEY_sign() failed, %d", ret);
        goto clean;
    }

    ret = KEY_verify(key, NID_sha256, digest, SHA256_DIGEST_LENGTH, sign, signlen);
    if (ret != NOT_ERROR)
    {
        LOGD("KEY_verify() failed");
        goto clean;
    }
#if defined RUN_FUNC_TESTS // check RSA encrypt/decrypt
    if (key->type == RSA_KEY) 
    {
        const uint8_t text[] = "Text to encrypt & verify";
        uint8_t cipherText[RSA_BIT_SIZE_DEFAULT / 8] = {0};
        uint8_t plainText[RSA_BIT_SIZE_DEFAULT / 8] = {0};

        ret = KEY_public_encrypt(key, sizeof(text), text, cipherText, RSA_PKCS1_PADDING);
        if (ret != NOT_ERROR)
        {
            LOGE("KEY_public_encrypt failed: %d", ret);
            return ret;
        }
        ret = KEY_private_decrypt(key, sizeof(cipherText), cipherText, plainText, RSA_PKCS1_PADDING);
        if (ret != NOT_ERROR)
        {
            LOGE("KEY_private_decrypt failed: %d", ret);
            return ret;
        }
        if (memcmp(plainText, text, sizeof(text)))
        {
            LOGE("RSA Encrypt/decrypt error");
            return WRONG_PRIV_KEY;
        }
        LOGD("RSA encrypt/decrypt OK...");
    }
#endif  // End of RUN_FUNC_TESTS

    ret = NOT_ERROR;
clean:
    if (NULL != sign)
        secMemoryManagerFree(sign);
    return ret;
}

int32_t KEY_build_keypair(const KEY *key, uint8_t *out, uint32_t *outLen)
{
    if (!key || !key->asn1.build_keypair)
        return WRONG_DATA;
    return key->asn1.build_keypair(out, outLen, key->key);
}

int32_t KEY_build_public(const KEY *key, uint8_t *out, uint32_t *outLen)
{
    if (!key || !key->asn1.build_public)
        return WRONG_DATA;
    return key->asn1.build_public(key->key, out, outLen);
}

int32_t KEY_sign(KEY *key, int32_t digestId, uint8_t *digest, uint32_t digestLen, 
                 uint8_t *signature, uint32_t *signatureLen)
{
    if (!key || !key->crypto.sign)
        return WRONG_DATA;
    return key->crypto.sign(digestId, digest, digestLen,
                           signature, signatureLen, 
                           key->key) ? NOT_ERROR : CERT_SIGN_ERROR;
}

int32_t KEY_verify(KEY* key, int32_t digestId, uint8_t *digest, uint32_t digestLen,
                uint8_t *signature, uint32_t signatureLen)
{
    int32_t res = 0;
    if (!key || !key->crypto.verify)
        return WRONG_DATA;
    res = key->crypto.verify(digestId, digest, digestLen,
                            signature, signatureLen, key->key);
    switch (res) {
        case -1:
            LOGD("ECDSA_verify error");
            return EC_GEN_ERROR;
        case 0:
            LOGD("KEY_verify: signature invalid");
            return SIGNATURE_INVALID_ERROR;
        case 1: 
            return NOT_ERROR;
        default:
            LOGE("Unknown verify result");
    };
    return EC_GEN_ERROR;
}

int32_t KEY_public_encrypt(const KEY* key, int32_t len, const uint8_t *from, uint8_t *to, int32_t padding)
{
    if (!key || !key->crypto.public_encrypt)
        return WRONG_DATA;
    return key->crypto.public_encrypt(key->key, len, from, to, padding) != -1 ? NOT_ERROR : RSA_GEN_ERROR;
}
int32_t KEY_private_decrypt(const KEY* key, int32_t len, const uint8_t *from, uint8_t *to, int32_t padding)
{
    if (!key || !key->crypto.private_decrypt)
        return WRONG_DATA;
    return key->crypto.private_decrypt(key->key, len, from, to, padding) != -1 ? NOT_ERROR : RSA_GEN_ERROR;
}
