/*
 * =====================================================================================
 *
 *       Filename:  certParser.c
 *
 *    Description:  X.509 certificate parser.
 *
 *        Version:  1.0
 *        Created:  12/05/2017 10:35:26 AM
 *       Compiler:  armcc
 *
 *         Author:  Dongwook Shim (), dw.shim@samsung.com
 *        Company:  Samsung Electronics
 *
 *        Copyright (c) 2017 by Samsung Electronics, All rights reserved.
 *
 * =====================================================================================
 */

#include <stdint.h>
#include "commonConfig.h"
#include "certParser.h"
#include "keyManager.h"
#include "log.h"
#include "teeCryptoApi.h"
#include "x509v3.h"
#include "objects/obj_mac.h"
#include "sha/sha.h"

#include "caPubKey.h"
#include "hsmPubKey.h"

#define DRK_ISSUER              "C=KR, L=Suwon city, OU=Samsung Mobile, CN=Samsung corporation"

#define DRK_V1_MAX_LEN          34
#define DRK_V1_KEYCLASS_P       "PHN-P"
#define DRK_V1_KEYCLASS_D       "PHN-D"
#define DRK_V1_DATE_LEN         8
#define DRK_V1_SERVER_LEN       2
#define DRK_V1_CLIENT_LEN       2
#define DRK_V1_SERIAL_LEN       8
#define DRK_V1_SUFFIX           "ROOT"
#define DRK_V1_DELIMETER        ":"
#define DRK_V1_EXTS_PRESENT     (X509_EXT_BASIC_CONSTRAINTS | X509_EXT_PATH_LEN_CONSTRAINT | \
                                X509_EXT_KEY_USAGE | X509_EXT_SUBJECT_KEY_IDENTIFIER | \
                                X509_EXT_AUTHORITY_KEY_IDENTIFIER)
#define DRK_V1_EXTS_KEY_USAGE   (X509_KEY_USAGE_CRL_SIGN | X509_KEY_USAGE_CRL_SIGN)

static char *_strtok(char *str, char *deli)
{
    static char *p = NULL;
    char *s, *d, *ret;

    if(deli == NULL)
        return NULL;

    if(str)
        p = str;
    else if(p == NULL || *p == '\0')
        return NULL;

    ret = p;

    for(s = p; *s != '\0'; s++)
    {
        for(d = deli; *d != '\0'; d++)
        {
            if(*s == *d)
            {
                *s = '\0';
                p = ++s;

                return ret;
            }
        }
    }

    if(*s == '\0')
        p = s;

    return ret;
}
static int getCertSignHash(const struct x509_certificate *cert)
{
    const unsigned long sha1WithRSAEncryptionOID[] = {OBJ_sha1WithRSAEncryption};
    const unsigned long sha256WithRSAEncryptionOID[] = {OBJ_sha256WithRSAEncryption};
    const unsigned long ecdsaWithSHA256EncryptionOID[] = {OBJ_ecdsa_with_SHA256};
    const unsigned long *oid = NULL;

    if(cert == NULL)
    {
        LOGE("%s : Invalid argument.", __func__);
        return ERR_TA_INVALID_ARGUMENT;
    }

    oid = cert->signature.oid.oid;

    if(ARRAY_SIZE(sha1WithRSAEncryptionOID) > ASN1_MAX_OID_LEN ||
            ARRAY_SIZE(sha256WithRSAEncryptionOID) > ASN1_MAX_OID_LEN ||
            ARRAY_SIZE(ecdsaWithSHA256EncryptionOID) > ASN1_MAX_OID_LEN)
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_ALGORITHM_IDENTIFIER;

    if(memcmp(sha256WithRSAEncryptionOID, oid, sizeof(sha256WithRSAEncryptionOID)) == 0 ||
            memcmp(ecdsaWithSHA256EncryptionOID, oid, sizeof(ecdsaWithSHA256EncryptionOID)) == 0)
        return NID_sha256;

    if(memcmp(sha1WithRSAEncryptionOID, oid, sizeof(sha1WithRSAEncryptionOID)) == 0)
        return NID_sha1;

    LOGE("Invalid Oid : %ld", *oid);

    return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_ALGORITHM_IDENTIFIER;
}

int32_t getX509NameAttr(const struct x509_name *name, enum x509_name_attr_type type, char *data, size_t dataLen)
{
    size_t index, len = 0;

    if(name == NULL || data == NULL)
    {
        LOGE("%s : Invalid argument.", __func__);
        return ERR_TA_INVALID_ARGUMENT;
    }

    for(index = 0; index < name->num_attr; index++)
    {
        if(name->attr[index].type == type)
            break;
    }

    if(index > name->num_attr)
    {
        LOGE("Failed to find type %d in name.", type);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_INVALID_ARGUMENTS;
    }

    len = name->attr[index].value_size;

    if(dataLen < len)
    {
        LOGE("Data buffer is too small - %zu %zu", dataLen, len);
        return ERR_TA_BUFFER_OVERFLOW;
    }

    memset(data, 0, dataLen);
    memcpy(data, name->attr[index].value, len);

    return NOT_ERROR;
}

static bool isValidateLeapYear(struct x509_parser_time *time)
{
    if(time == NULL)
    {
        LOGE("%s : Invalid argument.", __func__);
        return false;
    }

    if(time->year % 4 != 0 || (time->year % 100 == 0 && time->year % 400 != 0))
    {
        if(time->mon == 2 && time->mday > 28)
            return false;
    }

    return true;
}

static int32_t getRsaPubKey(const uint8_t *pDerPubKey, size_t derPubKeyLen, uint8_t **pPubKey)
{
    struct asn1_hdr hdr;
    const uint8_t *pos = NULL, *end = NULL;

    if(pDerPubKey == NULL || pPubKey == NULL)
    {
        LOGE("%s : Invalid argument.", __func__);
        return ERR_TA_INVALID_ARGUMENT;
    }

    memset(&hdr, 0, sizeof(hdr));

    if(asn1_get_next(pDerPubKey, derPubKeyLen, &hdr) < 0 ||
            hdr.class != ASN1_CLASS_UNIVERSAL ||
            hdr.tag != ASN1_TAG_SEQUENCE)
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_INVALID_SEQUENCE;

    end = hdr.payload + hdr.length;

    if(asn1_get_next(hdr.payload, hdr.length, &hdr) < 0 ||
            hdr.class != ASN1_CLASS_UNIVERSAL ||
            hdr.tag != ASN1_TAG_SEQUENCE)
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_INVALID_SEQUENCE;

    pos = hdr.payload + hdr.length;

    if(asn1_get_next(hdr.payload, hdr.length, &hdr) < 0 ||
            hdr.class != ASN1_CLASS_UNIVERSAL ||
            hdr.tag != ASN1_TAG_OID)
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_INVALID_OID;

    if(asn1_get_next(pos, end - pos, &hdr) < 0 ||
            hdr.class != ASN1_CLASS_UNIVERSAL ||
            hdr.tag != ASN1_TAG_BITSTRING)
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_EXPECTED_BITSTRING;

    *pPubKey = (uint8_t *)(hdr.payload + 1);
    return hdr.length -1;
}

static int32_t verifyCertSign(const struct x509_certificate* cert, KEY* ca)
{
    int32_t ret = NOT_ERROR, algo = 0;
    uint8_t digest[SHA256_DIGEST_LENGTH] = {0};
    uint32_t digestLen = sizeof(digest);

    if(cert == NULL || ca == NULL)
    {
        LOGE("%s : Invalid argument.", __func__);
        return ERR_TA_INVALID_ARGUMENT;
    }

    if((ret = getCertHash(cert, &algo, digest, &digestLen)) != NOT_ERROR)
    {
        LOGE("Hash digest calculating error with error %d.", ret);
        return ret;
    }

    //ca->len = cer->sign_value_len;
    if((ret = KEY_verify(ca, algo, digest, digestLen, (uint8_t*)cert->sign_value, cert->sign_value_len)) != NOT_ERROR)
        LOGE("Failed to certificate signature with error %d.", ret);

    return ret;
}

int32_t checkCertificate(const uint8_t *buf, uint32_t len, KEY* ca)
{
    struct x509_certificate cert;
    x509_parser_error_t x509Ret = X509_PARSE_ERROR_UNDEFINED;
    int32_t ret = NOT_ERROR;

    if((x509Ret = x509_certificate_parse(buf, len, &cert)) != X509_PARSE_OK)
    {
        ret = ERR_TA_X509_PARSE_BASE + x509Ret;
        LOGE("Cannot parse certificate. Error code: %d", ret);
        return ret;
    }

    if((ret = verifyCertSign(&cert, ca)) != NOT_ERROR)
        LOGE("Failed to certificate verification with error %d.", ret);

    return ret;
}

int32_t getCAPublicKey(KEY** ca, uint32_t keyType)
{
    int32_t ret = NOT_ERROR;
    const uint8_t *pCaKey = NULL;
    uint32_t caKeyLen = 0;

    if(ca == NULL)
    {
        LOGE("%s : Invalid argument.", __func__);
        return ERR_TA_INVALID_ARGUMENT;
    }

    if(*ca != NULL)
    {
        LOGI("CA crypto context have been already initialized...");
        return NOT_ERROR;
    }

    switch(keyType)
    {
        case RSA_KEY :
            pCaKey = CaRsaPubKey;
            caKeyLen = sizeof(CaRsaPubKey);
            break;

        case ECC_KEY :
            pCaKey = CaEcPubKey;
            caKeyLen = sizeof(CaEcPubKey);
            break;

        case RSA_ENC_KEY :
            pCaKey = hsmRsaPubKey;
            caKeyLen = sizeof(hsmRsaPubKey);
            keyType = RSA_KEY;
            break;

        default :
            LOGE("Unknown key type: 0x%x", keyType);
            return ERR_TA_INVALID_ARGUMENT;
    }

    if((*ca = KEY_new(keyType)) == NULL)
    {
        LOGE("Failed to allocat new key.");
        return ERR_TA_NOT_ENOUGH_MEMORY;
    }

    if((ret = KEY_populate_keys(*ca, pCaKey, caKeyLen, NULL, 0)) != NOT_ERROR)
    {
        KEY_free(*ca);
        *ca = NULL;
        LOGE("Failed to key popluation with errot %d", ret);
    }

    return ret;
}

int32_t getCertHash(const struct x509_certificate* cert, int32_t* algo, uint8_t* digest, uint32_t* digestLen)
{
    int32_t requiredLen = 0;
    DigestAlgo_t digestAlgo;

    if (!cert || !digest || !digestLen || !algo)
    {
        LOGE("%s : Invalid argument.", __func__);
        return ERR_TA_INVALID_ARGUMENT;
    }

    *algo = getCertSignHash(cert);

    switch (*algo)
    {
        case NID_sha256:
            requiredLen = SHA256_DIGEST_LENGTH;
            digestAlgo = ALGO_SHA256;
            break;

        case NID_sha1:
            requiredLen = SHA_DIGEST_LENGTH;
            digestAlgo = ALGO_SHA1;
            break;

        default:
            LOGE("getCertHash: unknown signature algorithm");
            return ERR_TA_INVALID_ARGUMENT;
    }

    if (*digestLen < requiredLen)
    {
        LOGE("getCertHash: output buffer is too small");
        return ERR_TA_BUFFER_OVERFLOW;
    }

    *digestLen = requiredLen;

    return getShaDigest(cert->tbs_cert_start, cert->tbs_cert_len, digest, requiredLen, digestAlgo);
}

uint8_t getCertPublicKeyType(const struct x509_certificate* cert)
{
    const unsigned long ecPublicKeyOID[] = {OBJ_X9_62_id_ecPublicKey};
    const unsigned long rsaPublicKeyOID[] = {OBJ_rsaEncryption};

    if(cert == NULL)
    {
        LOGE("%s : Invalid argument. Return default.", __func__);
        return RSA_KEY;
    }

    if(!memcmp(cert->public_key_alg.oid.oid, &rsaPublicKeyOID, sizeof(rsaPublicKeyOID)))
        return RSA_KEY;
    else if(!memcmp(cert->public_key_alg.oid.oid, &ecPublicKeyOID, sizeof(ecPublicKeyOID)))
        return ECC_KEY;

    LOGE("Unknown public key algorithm. Default RSA");

    return RSA_KEY;
}

int32_t getCertUID(uint8_t* cert, int16_t certLen, uint8_t* uid, uint32_t uidLen)
{
    struct x509_certificate x509Cert;
    x509_parser_error_t ret = X509_PARSE_ERROR_UNDEFINED;

    memset(&x509Cert, 0, sizeof(x509Cert));

    if((ret = x509_certificate_parse(cert, certLen, &x509Cert)) != X509_PARSE_OK)
    {
        LOGE( "Cannot parse certificate. Error code: %d", ERR_TA_X509_PARSE_BASE + ret);
        return ERR_TA_X509_PARSE_BASE + ret;
    }

    return getX509NameAttr(&(x509Cert.subject), X509_NAME_ATTR_UID, (char *)uid, uidLen);
}

int32_t verifyCertificateWithCA(const uint8_t *derCert, uint16_t derCertLen, struct x509_certificate *x509Cert)
{
    int32_t ret = NOT_ERROR;
    struct x509_certificate x509;
    KEY* caKey = NULL;

    if(derCert == NULL || derCertLen == 0)
    {
        LOGE("%s : Invalid agrument.", __func__);
        return ERR_TA_INVALID_ARGUMENT;
    }

    memset(&x509, 0, sizeof(x509));

    if((ret = x509_certificate_parse(derCert, derCertLen, &x509)) != X509_PARSE_OK)
    {
        ret += ERR_TA_X509_PARSE_BASE;
        LOGE("Cannot parse certificate. Error code: %d", ret);
        return ret;
    }

    if((ret = getCAPublicKey(&caKey, getCertPublicKeyType(&x509))) != NOT_ERROR)
    {
        LOGE("Failed to get CA public key with error %d.", ret);
        return ret;
    }

    ret = verifyCertSign(&x509, caKey);
    KEY_free(caKey);

    if(ret != NOT_ERROR)
    {
        LOGE("Verification of certification is failed with error %d.", ret);
        return ERR_TA_VERIFICATION_FAILED;
    }

    if(x509Cert != NULL)
        memcpy(x509Cert, &x509, sizeof(x509));

    return NOT_ERROR;
}

int32_t validateDrkCert(struct x509_certificate *cert)
{
    int32_t ret = NOT_ERROR, pubKeyLen = 0;
    int i, lengthField[] = {DRK_V1_DATE_LEN, DRK_V1_SERVER_LEN, DRK_V1_CLIENT_LEN, DRK_V1_SERIAL_LEN};
    const unsigned long sha256WithRsaEncryptionOID[] = {OBJ_sha256WithRSAEncryption};
    char issuer[MAX_UID_SIZE], uid[MAX_UID_SIZE], *p = NULL;
    uint8_t md[SHA_DIGEST_LENGTH], *pPubKey = NULL;

    if(cert == NULL)
    {
        LOGE("%s : Invalid agrument.", __func__);
        return ERR_TA_INVALID_ARGUMENT;
    }

    // Check version.
    if(cert->version != X509_CERT_V3)
    {
        LOGE("Invalid X.509 version : %d", cert->version);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_VERSION_TAG;
    }

    // Check signature / signature_alg OID and parameter.
    if(memcmp(cert->signature.oid.oid, sha256WithRsaEncryptionOID, sizeof(sha256WithRsaEncryptionOID)) ||
            memcmp(cert->signature_alg.oid.oid, sha256WithRsaEncryptionOID, sizeof(sha256WithRsaEncryptionOID)))
    {
        LOGE("Invalid signature oid.");
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_INVALID_OID;
    }

    if(cert->signature.param.length != 0 || cert->signature.param.data != NULL ||
            cert->signature_alg.param.length != 0 || cert->signature_alg.param.data != NULL)

    {
        LOGE("Signature Identifier parameter value is invalid.");
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_ALGORITHM_IDENTIFIER;
    }

    // Check validity - leap year.
    if(!isValidateLeapYear(&(cert->not_before)) || !isValidateLeapYear(&(cert->not_after)))
    {
        LOGE("Failed to verify leap year.");
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_VALIDITY;
    }

    // Check issuer.
    memset(issuer, 0, sizeof(issuer));
    x509_name_string(&(cert->issuer), issuer, sizeof(issuer));

    if(strlen(DRK_ISSUER) != strlen(issuer) || memcmp(issuer, DRK_ISSUER, strlen(DRK_ISSUER)))
    {
        LOGE("Invalid issuer : %zu %s", strlen(issuer), issuer);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_NAME;
    }

    // Check UID.
    memset(uid, 0, sizeof(uid));
    if((ret = getX509NameAttr(&(cert->subject), X509_NAME_ATTR_UID, uid, sizeof(uid))) != NOT_ERROR)
        return ret;

    p = _strtok(uid, DRK_V1_DELIMETER);

    if(strncmp(p, DRK_V1_KEYCLASS_P, strlen(DRK_V1_KEYCLASS_P)) != 0 &&
            strncmp(p, DRK_V1_KEYCLASS_D, strlen(DRK_V1_KEYCLASS_D)) != 0)
    {
        LOGE("DRK V1 Keyclass is invalid - %s.", p);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_NAME;
    }

    for(i = 0; i < sizeof(lengthField) / sizeof(int); i++)
    {
        if((p = _strtok(NULL, DRK_V1_DELIMETER)) == NULL || strlen(p) != lengthField[i])
        {
            LOGE("Field(%d) length is invalid - %d %d.", i, lengthField[i], (p == NULL) ? 0 : strlen(p));
            return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_NAME;
        }
    }

    if((p = _strtok(NULL, DRK_V1_DELIMETER)) == NULL || strncmp(p, DRK_V1_SUFFIX, strlen(DRK_V1_SUFFIX)) != 0)
    {
        LOGE("Suffix field is invalid - %s.", p);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_NAME;
    }

    // Check extension fields.
    if(!(cert->exts.extensions_present & DRK_V1_EXTS_PRESENT))
    {
        LOGE("Extension fields are invalid - %X", cert->exts.extensions_present);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_EXTENSIONS;
    }

    // Check extension - basic constraints.
    if(cert->exts.ca != 0xFF)
    {
        LOGE("Extension field CA is invalid. - %X", cert->exts.ca);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_EXTENSIONS;
    }

    if(cert->exts.path_len_constraint != 0)
    {
        LOGE("Extension field path_len is invalid. - %X", cert->exts.ca);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_EXTENSIONS;
    }

    // Check extension - key usage.
    if(!(cert->exts.key_usage & DRK_V1_EXTS_KEY_USAGE))
    {
        LOGE("Extension fields key usage is invalid - %X", cert->exts.key_usage);
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_EXTENSIONS;
    }

    // Check extinsion - subject key identifier.
    if((ret = getShaDigest(cert->public_key, cert->public_key_len, md, sizeof(md), ALGO_SHA1)) != NOT_ERROR)
    {
        LOGE("Failed to get digest for SKI with error %d.", ret);
        return ret;
    }

    if(cert->exts.subjectKeyIdentifier.length != sizeof(md) ||
            memcmp(cert->exts.subjectKeyIdentifier.data, md, sizeof(md)))
    {
        LOGE("Extension subject key identifier is invalid.");
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_EXTENSIONS;
    }

    // Check extension - authority key identifier.
    if((pubKeyLen = getRsaPubKey(CaRsaPubKey, sizeof(CaRsaPubKey), &pPubKey)) < 0)
    {
        LOGE("Failed to get CA public key with error %d.", pubKeyLen);
        return pubKeyLen;
    }

    if((ret = getShaDigest(pPubKey, pubKeyLen, md, sizeof(md), ALGO_SHA1)) != NOT_ERROR)
    {
        LOGE("Failed to get digest with error %d.", ret);
        return ret;
    }

    if(cert->exts.authorityKeyIdentifier.length != sizeof(md) ||
            memcmp(cert->exts.authorityKeyIdentifier.data, md, sizeof(md)))
    {
        LOGE("Extension authority key identifier is invalid.");
        return ERR_TA_X509_PARSE_BASE + X509_PARSE_ERROR_EXTENSIONS;
    }

    return NOT_ERROR;
}
