/*
 * (c) Copyright 2016 Samsung Research America, Inc.
 *                  All rights reserved
 *
 *  MPS Lab
 *
 * File         : aes_hsha2_aead.c
 * Author       : r.kothari@samsung.com
 * Creation Date: Apr 15, 2016
 * Description  : Implementation of AES_CBC_HMAC_SHA2 Algorithms as described
 *                in RFC 7518 JSON Web Algorithm (JWA) Section 5.2
 *                https://tools.ietf.org/pdf/rfc7518.pdf
 *
 */
/*
 * Copyright (C) 2016 Samsung Electronics Co., Ltd. All rights reserved.
 *
 * Mobile Communication Division,
 * Digital Media & Communications Business, Samsung Electronics Co., Ltd.
 *
 * This software and its documentation are confidential and proprietary
 * information of Samsung Electronics Co., Ltd.  No part of the software and
 * documents may be copied, reproduced, transmitted, translated, or reduced to
 * any electronic medium or machine-readable form without the prior written
 * consent of Samsung Electronics.
 *
 * Samsung Electronics makes no representations with respect to the contents,
 * and assumes no responsibility for any errors that might appear in the
 * software and documents. This publication and the contents hereof are subject
 * to change without notice.
 */
#include "pebble_platform_interface.h"
#include "pebble_defs.h"
#include "aes_hsha2_aead.h"

static uint32_t
validate_enc_params (aead_algo_type algo,
                     aead_enc_in_params  *in,
                     aead_enc_out_params *out) {
    if (algo >  A256CBC_HS512 || !in || !in->key || !in->key_len ||
            !in->iv || !in->iv_len || in->iv_len != AES_HSHA2_IV_SIZE ||
            !in->aad || !in->aad_len || !in->plain_text || !in->plain_text_len ||
            !out || !out->cipher_text || !out->cipher_text_len || !out->tag || !out->tag_len) {
        PEBBLE_LOG("Invalid Parameters");
        return AES_GEN_ERROR_INVALID_INPUT_PARAM;
    }

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("key, len = %d", in->key_len);
    DBG_DUMP(in->key, in->key_len);
    DBG_LOG("iv, len = %d", in->iv_len);
    DBG_DUMP(in->iv, in->iv_len);
    DBG_LOG("aad, len = %d", in->aad_len);
    DBG_DUMP(in->aad, in->aad_len);
    DBG_LOG("plain text, len = %d", in->plain_text_len);
    DBG_DUMP(in->plain_text, in->plain_text_len);
    DBG_LOG("cipher text len = %d", *out->cipher_text_len);
    DBG_LOG("auth tag len = %d", *out->tag_len);
#endif


    if ((A128CBC_HS256 == algo &&
            (in->key_len != A128CBC_HS256_KEY_SIZE || *out->tag_len < A128CBC_HS256_TAG_SIZE)) ||
            (A192CBC_HS384 == algo &&
             (in->key_len != A192CBC_HS384_KEY_SIZE || *out->tag_len < A192CBC_HS384_TAG_SIZE)) ||
            (A256CBC_HS512 == algo &&
             (in->key_len != A256CBC_HS512_KEY_SIZE || *out->tag_len < A256CBC_HS512_TAG_SIZE))) {
        PEBBLE_LOG("Insufficient key length or output auth tag len");
        return AES_GEN_ERROR_INVALID_INPUT_SIZE;
    }

    return AES_OK;
}

uint32_t
get_keys_tag_len (aead_algo_type algo, uint32_t * mac_key_len,
                  uint32_t * enc_key_len, uint32_t * tag_len) {
    uint32_t ret = AES_GEN_ERROR_INTERNAL;

    if (!mac_key_len || !enc_key_len || !tag_len) {
        PEBBLE_LOG("Invalid Parameters");
        return AES_GEN_ERROR_INVALID_INPUT_PARAM;
    }

    switch (algo) {
    case A128CBC_HS256: {
        *mac_key_len  = *enc_key_len = (A128CBC_HS256_KEY_SIZE / 2);
        *tag_len =  A128CBC_HS256_TAG_SIZE;
    }
    break;

    case A192CBC_HS384: {
        *mac_key_len = *enc_key_len = (A192CBC_HS384_KEY_SIZE / 2);
        *tag_len = A192CBC_HS384_TAG_SIZE;
    }
    break;

    case A256CBC_HS512: {
        *mac_key_len = *enc_key_len = (A256CBC_HS512_KEY_SIZE / 2);
        *tag_len = A256CBC_HS512_TAG_SIZE;
    }
    break;

    default :
        PEBBLE_LOG("Unhandled algo_type");
        ret = AES_GEN_ERROR_INVALID_INPUT_PARAM;
        goto error;
    }

    ret = AES_OK;

error :

    return ret ;
}

static uint32_t
compute_tag (aead_algo_type algo, uint8_t * aad, uint32_t aad_len, uint8_t * iv, uint32_t iv_len,
             uint8_t * cipher_text, uint32_t cipher_text_len, uint8_t * key, uint32_t key_len,
             uint8_t * tag, uint32_t * out_tag_len) {
    uint32_t ret = AES_GEN_ERROR_INTERNAL;
    uint8_t al[AES_HSHA2_AL_SIZE] = {0};
    uint64_t al_bits = 0;
    int i = 0;
    uint8_t  message[AES_HSHA2_MAX_BUF_SIZE];
    uint32_t msg_len = 0;
    uint8_t mac[SHA256_DIGEST_LENGTH] = {0};
    uint32_t mac_len = sizeof(mac);
    uint32_t mac_key_len = 0;
    uint32_t enc_key_len = 0;
    uint32_t tag_len = 0;

    if (!aad || !aad_len || !iv || !iv_len || !cipher_text || !cipher_text_len ||
            !key || !key_len || !tag || !out_tag_len  || *out_tag_len < AES_HSHA2_MIN_TAG_SIZE ) {
        PEBBLE_LOG("Invalid Parameters");
        return AES_GEN_ERROR_INVALID_INPUT_PARAM;
    }

    /*AL is # of bits in aad, as a 64 bit unsigned big -endian integer
     *If aad_len = 51, al_bits = 51*8 = 408 bits,= 0x0198
     *    then, in hex, al = {0x00,0x00,0x00,0x00,0x00,0x00,0x01,0x98}
     *    and, in decimals al = {0,0,0,0,0,0,1,152} - See RFC example
    */
    al_bits = aad_len * 8;
    for (i = AES_HSHA2_AL_SIZE - 1; i >= 0; i--)
        al[AES_HSHA2_AL_SIZE - 1 - i] = (al_bits >> 8*i) & 0xff ;

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("compute_hmac: AL, len = %d", sizeof(al));
    DBG_DUMP(al, sizeof(al));
#endif

    if (aad_len + iv_len + cipher_text_len + sizeof(al) > sizeof(message)) {
        PEBBLE_LOG("Insufficient buffer for message ");
        ret = AES_GEN_ERROR_INSUFFICIENT_BUFFER;
        goto error;
    }

    msg_len = 0;
    memcpy(message, aad, aad_len);
    msg_len += aad_len;
    memcpy (message + msg_len, iv, iv_len);
    msg_len += iv_len;
    memcpy (message + msg_len, cipher_text, cipher_text_len);
    msg_len += cipher_text_len;
    memcpy (message + msg_len, al, sizeof(al));
    msg_len += sizeof(al);

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("compute_hmac: Input to HMAC, len = %d", msg_len);
    DBG_DUMP(message, msg_len);
#endif

    //ret = TZ_hmac_generate (message, msg_len, key, key_len, mac, &mac_len, ALGO_HMAC_SHA_256); //pebble's TZ_Vendor_tl.h does not have this
    ret = hmac_sha256_sign(key, key_len, message, msg_len, mac, &mac_len);
    if (ret != TZ_API_OK) {
        PEBBLE_LOG("TZ_hmac_generate failed ret = %d", ret);
        ret = AES_CRYPT_ERROR_HMAC_ERROR;
        goto error;
    }

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("compute_tag: HMAC, len = %d", mac_len);
    DBG_DUMP(mac, mac_len);
#endif

    ret = get_keys_tag_len (algo, &mac_key_len, &enc_key_len, &tag_len);

    if (ret != AES_OK) {
        PEBBLE_LOG("get_keys_tag_len failed, ret = %d", ret);
        goto error;
    }

    /* First tag_len of mac is the tag */
    memcpy (tag, mac, tag_len);
    *out_tag_len = tag_len;

    ret =  AES_OK;

error :
    memset(message, 0, sizeof(message));
    memset(al, 0, sizeof(al));
    memset(mac, 0, sizeof(mac));

    return ret;
}


uint32_t
aes_hsha2_aead_encrypt (aead_algo_type algo,
                        aead_enc_in_params  *in,
                        aead_enc_out_params *out) {
    uint32_t ret = AES_GEN_ERROR_INTERNAL;
    uint8_t *mac_key = NULL;
    uint32_t mac_key_len = 0;
    uint8_t *enc_key = NULL;
    uint32_t enc_key_len = 0;
    uint32_t tag_len = 0;

    ret = validate_enc_params (algo, in, out);

    if (ret != AES_OK) {
        PEBBLE_LOG("Invalid Parameters");
        goto error;
    }

    ret = get_keys_tag_len (algo, &mac_key_len, &enc_key_len, &tag_len);

    if (ret != AES_OK) {
        PEBBLE_LOG("get_keys_tag_len failed, ret = %d", ret);
        goto error;
    }

    if (in->key_len < (mac_key_len + enc_key_len)) {
        PEBBLE_LOG("Insufficient key length or output auth tag len");
        ret =  AES_GEN_ERROR_INVALID_INPUT_SIZE;
        goto error;
    }

    mac_key = in->key;
    enc_key = in->key + mac_key_len;

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("aes_hsha2_aead_encrypt: Mac Key, len = %d", mac_key_len);
    DBG_DUMP(mac_key, mac_key_len);

    DBG_LOG("aes_hsha2_aead_encrypt: Enc Key, len = %d", enc_key_len);
    DBG_DUMP(enc_key, enc_key_len);
#endif

    ret = aes_encrypt_with_params (enc_key, enc_key_len, in->plain_text, in->plain_text_len,
                                   out->cipher_text, out->cipher_text_len,
                                   CIPHER_MODE_CBC, CIPHER_PAD_PKCS7, in->iv, in->iv_len);

    if (ret != TZ_API_OK) {
        PEBBLE_LOG("TZ_aes_encrypt_with_params failed ret = %d", ret);
        ret = AES_CRYPT_ERROR_ENCRYPT_ERROR;
        goto error;
    }

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("aes_hsha2_aead_encrypt: CipherText , len = %d", *out->cipher_text_len);
    DBG_DUMP(out->cipher_text, *out->cipher_text_len);
#endif

    ret = compute_tag (algo, in->aad, in->aad_len, in->iv, in->iv_len, out->cipher_text,
                       *out->cipher_text_len, mac_key, mac_key_len, out->tag, out->tag_len);

    if (ret != AES_OK) {
        PEBBLE_LOG("compute_tag failed, ret = %d", ret);
        goto error;
    }

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("aes_hsha2_aead_encrypt: Auth Tag, length = %d", *out->tag_len);
    DBG_DUMP(out->tag, *out->tag_len);
#endif

    ret =  AES_OK;

error :

    //memset(message, 0, sizeof(message));
    mac_key = enc_key = NULL;

    return ret;
}

static uint32_t
validate_dec_params (aead_algo_type algo,
                     aead_dec_in_params *in,
                     aead_dec_out_params *out) {

    if (!in || !in->key || !in->key_len || !in->iv || !in->iv_len || !in->aad || !in->aad_len ||
            !in->tag || !in->tag_len || !in->cipher_text || !in->cipher_text_len ||
            !out || !out->plain_text || !out->plain_text_len || !*out->plain_text_len ||
            in->iv_len < AES_HSHA2_IV_SIZE || in->tag_len < AES_HSHA2_MIN_TAG_SIZE ||
            algo > A256CBC_HS512) {
        PEBBLE_LOG("Invalid Parameters");
        return AES_GEN_ERROR_INVALID_INPUT_PARAM;
    }

    if ((A128CBC_HS256 == algo &&
            (in->key_len != A128CBC_HS256_KEY_SIZE || in->tag_len != A128CBC_HS256_TAG_SIZE)) ||
            (A192CBC_HS384 == algo &&
             (in->key_len != A192CBC_HS384_KEY_SIZE || in->tag_len != A192CBC_HS384_TAG_SIZE)) ||
            (A256CBC_HS512 == algo &&
             (in->key_len != A256CBC_HS512_KEY_SIZE || in->tag_len != A256CBC_HS512_TAG_SIZE))) {
        PEBBLE_LOG("Insufficient key length or tag len");
        return AES_GEN_ERROR_INVALID_INPUT_SIZE;
    }

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("key, len = %d", in->key_len);
    DBG_DUMP(in->key, in->key_len);
    DBG_LOG("iv, len = %d", in->iv_len);
    DBG_DUMP(in->iv, in->iv_len);
    DBG_LOG("aad, len = %d", in->aad_len);
    DBG_DUMP(in->aad, in->aad_len);
    DBG_LOG("cipher text len = %d", in->cipher_text_len);
    DBG_DUMP(in->cipher_text, in->cipher_text_len);
    DBG_LOG("tag len = %d", in->tag_len);
    DBG_DUMP(in->tag, in->tag_len);
    DBG_LOG("plain text, len = %d", *out->plain_text_len);
#endif

    return AES_OK;
}


uint32_t
aes_hsha2_aead_decrypt (aead_algo_type algo,
                        aead_dec_in_params  *in,
                        aead_dec_out_params *out) {
    uint32_t ret = AES_GEN_ERROR_INTERNAL;
    uint8_t  expected_tag[AES_HSHA2_MAX_TAG_SIZE] = {0};
    uint32_t expected_tag_len = sizeof(expected_tag);
    uint8_t *mac_key = NULL;
    uint32_t mac_key_len = 0;
    uint8_t *enc_key = NULL;
    uint32_t enc_key_len = 0;
    uint32_t tag_len = 0;

    ret = validate_dec_params (algo, in, out);

    if (ret != AES_OK) {
        PEBBLE_LOG("validate_dec_params failed, ret = %d", ret);
        goto error;
    }

    ret = get_keys_tag_len (algo, &mac_key_len, &enc_key_len, &tag_len);

    if (ret != AES_OK) {
        PEBBLE_LOG("get_keys_tag_len failed, ret = %d", ret);
        goto error;
    }

    if (in->key_len < (mac_key_len + enc_key_len)) {
        PEBBLE_LOG("Insufficient key length or output auth tag len");
        ret =  AES_GEN_ERROR_INVALID_INPUT_SIZE;
        goto error;
    }

    mac_key = in->key;
    enc_key = in->key + mac_key_len;

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("aes_hsha2_aead_decrypt: Mac Key, len = %d", mac_key_len);
    DBG_DUMP(mac_key, mac_key_len);

    DBG_LOG("aes_hsha2_aead_decrypt: Enc Key, len = %d", enc_key_len);
    DBG_DUMP(enc_key, enc_key_len);
#endif

    ret = compute_tag (algo, in->aad, in->aad_len, in->iv, in->iv_len, in->cipher_text,
                       in->cipher_text_len, mac_key, mac_key_len, expected_tag, &expected_tag_len);

    if (ret != AES_OK) {
        PEBBLE_LOG("compute_tag failed, ret = %d", ret);
        goto error;
    }

    if (expected_tag_len != in->tag_len || memcmp (expected_tag, in->tag, expected_tag_len)) {
        PEBBLE_LOG("auth tag doesn't match");
        ret = AES_CRYPT_ERROR_HMAC_ERROR;
        goto error;
    }

#ifdef DEBUG_AES_HSHA2_AEAD
    DBG_LOG("aes_hsha2_aead_decrypt :  auth tag verified");
#endif
    ret = aes_decrypt_with_params (enc_key, enc_key_len, in->cipher_text, in->cipher_text_len,
                                   out->plain_text, out->plain_text_len,
                                   CIPHER_MODE_CBC, CIPHER_PAD_PKCS7, in->iv, in->iv_len);
    if (ret != TZ_API_OK) {
        PEBBLE_LOG("TZ_aes_decrypt_with_params failed, ret = %d", ret);
        goto error;
    }

    ret = AES_OK;

error:

    memset(expected_tag, 0, sizeof(expected_tag));
    mac_key = NULL;
    enc_key = NULL;

    return ret;
}

