/*
 * (c) Copyright 2020 Samsung Research America, Inc.
 *                  All rights reserved
 *
 *  MPS Lab
 *
 * File: cipher.c
 * Author: jianwei.qian@samsung.com
 * Update Date: Aug 12, 2020
 *
 */
/*
 * Copyright (C) 2020 Samsung Electronics Co., Ltd. All rights reserved.
 *
 * Mobile Communication Division,
 * Digital Media & Communications Business, Samsung Electronics Co., Ltd.
 *
 * This software and its documentation are confidential and proprietary
 * information of Samsung Electronics Co., Ltd.  No part of the software and
 * documents may be copied, reproduced, transmitted, translated, or reduced to
 * any electronic medium or machine-readable form without the prior written
 * consent of Samsung Electronics.
 *
 * Samsung Electronics makes no representations with respect to the contents,
 * and assumes no responsibility for any errors that might appear in the
 * software and documents. This publication and the contents hereof are subject
 * to change without notice.
 */

// functions here are implemented based on SCrypto / openssl
#include "cipher.h"
#include "TZ_Vendor_tl.h"
#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/bn.h>
#include <openssl/aes.h>

#define NID_sha256              672   // openssl/obj_mac.h

uint32_t pebble_rsa_encrypt(uint8_t *keyMod, uint32_t keyModLen,
		uint8_t *keyPubExp, uint32_t keyPubExpLen, uint8_t *data,
		uint32_t dataLen, uint8_t *out, uint32_t *pOutLen, uint32_t padding) {
	uint32_t ret = TZ_API_OK;
	RSA *rsa = NULL;
	int opslRet = 0;

	TTY_LOG("%s: RSA encryption", __FUNCTION__);

	if (keyMod == NULL || keyModLen == 0) {
		TTY_LOG("%s invalid input key modulus", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (keyPubExp == NULL || keyPubExpLen == 0) {
		TTY_LOG("%s invalid input key public exponent", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (data == NULL || dataLen == 0) {
		TTY_LOG("%s invalid input data", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (out == NULL || pOutLen == NULL || *pOutLen < keyModLen) {
		TTY_LOG("%s invalid output data", __FUNCTION__);
		return TZ_API_ERROR;
	}

	rsa = RSA_new();
	if (NULL == rsa) {
		TTY_LOG("%s RSA_new() returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->e = BN_bin2bn(keyPubExp, keyPubExpLen, rsa->e);
	if (NULL == rsa->e) {
		TTY_LOG("%s BN_bin2bn() for rsa->e returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->n = BN_bin2bn(keyMod, keyModLen, rsa->n);
	if (NULL == rsa->n) {
		TTY_LOG("%s BN_bin2bn() for rsa->n returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	opslRet = RSA_public_encrypt(dataLen, data, out, rsa, padding); //RSA_PKCS1_PADDING
	if (opslRet != keyModLen) {
		TTY_LOG("%s RSA_public_encrypt() returned an error!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}
	*pOutLen = opslRet;

	ret = TZ_API_OK;

	EXIT: if (rsa != NULL) {
		RSA_free(rsa);
	}
	return ret;
}

uint32_t pebble_rsa_decrypt(uint8_t *keyMod, uint32_t keyModLen,
		uint8_t *keyPubExp, uint32_t keyPubExpLen, uint8_t *keyPriExp,
		uint32_t keyPriExpLen, uint8_t *in, uint32_t inLen, uint8_t *out,
		uint32_t *pOutLen, uint32_t padding) {
	uint32_t ret = TZ_API_OK;
	RSA *rsa = NULL;
	int opslRet = 0;

	TTY_LOG("%s: RSA decryption", __FUNCTION__);

	if (keyMod == NULL || keyModLen == 0) {
		TTY_LOG("%s invalid input key modulus", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (keyPubExp == NULL || keyPubExpLen == 0) {
		TTY_LOG("%s invalid input key public exponent", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (keyPriExp == NULL || keyPriExpLen == 0) {
		TTY_LOG("%s invalid input key private exponent", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (in == NULL || inLen == 0) {
		TTY_LOG("%s invalid input data", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (out == NULL || pOutLen == NULL || *pOutLen == 0) {
		TTY_LOG("%s invalid output data", __FUNCTION__);
		return TZ_API_ERROR;
	}

	rsa = RSA_new();
	if (NULL == rsa) {
		TTY_LOG("%s RSA_new() returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->e = BN_bin2bn(keyPubExp, keyPubExpLen, rsa->e);
	if (NULL == rsa->e) {
		TTY_LOG("%s BN_bin2bn() for rsa->e returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->d = BN_bin2bn(keyPriExp, keyPriExpLen, rsa->d);
	if (NULL == rsa->d) {
		TTY_LOG("%s BN_bin2bn() for rsa->d returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->n = BN_bin2bn(keyMod, keyModLen, rsa->n);
	if (NULL == rsa->n) {
		TTY_LOG("%s BN_bin2bn() for rsa->n returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	opslRet = RSA_private_decrypt(inLen, in, out, rsa, padding); //RSA_PKCS1_PADDING
	if (-1 == opslRet) {
		TTY_LOG("%s RSA_private_decrypt() returned an error!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}
	*pOutLen = opslRet;

	ret = TZ_API_OK;

	EXIT: if (rsa != NULL) {
		RSA_free(rsa);
	}
	return ret;
}

uint32_t pebble_aes_gcm_encrypt(const uint8_t *key, uint32_t key_size,
		const uint8_t *iv, uint32_t iv_size, const uint8_t *aad,
		uint32_t aad_size, const uint8_t *data, uint32_t input_size,
		uint8_t *out, uint8_t *tag) {
	EVP_CIPHER_CTX ctx;
	int tmp_len;
	//int tmp_ciphertext_len;
	//int gcm_ct_len;
	uint32_t ret = FAILED, r;

	/* initialise the context */
	EVP_CIPHER_CTX_init(&ctx);

	/* Initialise the encryption operation. */
	if (key_size == AES256_KEY_SIZE) {
		r = EVP_EncryptInit_ex(&ctx, EVP_aes_256_gcm(), NULL, NULL, NULL);
	} else if (key_size == AES128_KEY_SIZE) {
		r = EVP_EncryptInit_ex(&ctx, EVP_aes_128_gcm(), NULL, NULL, NULL);
	} else {
		PEBBLE_LOG("%s: unsupported aes_gcm key length %d", __FUNCTION__,
				key_size);
		goto exit;
	}

	if (r != 1) {
		PEBBLE_LOG("%s: EVP_EncryptInit_ex failed", __FUNCTION__);
		goto exit;
	}

	/* Set IV length. IV other than 12 bytes (96 bits) is not appropriate */
	if (1 != EVP_CIPHER_CTX_ctrl(&ctx, EVP_CTRL_GCM_SET_IVLEN, iv_size, NULL)) {
		PEBBLE_LOG("%s: EVP_CIPHER_CTX_ctrl(EVP_CTRL_GCM_SET_IVLEN) failed",
				__FUNCTION__);
		goto exit;
	}

	/* Initialise key and IV */
	if (1 != EVP_EncryptInit_ex(&ctx, NULL, NULL, key, iv)) {
		PEBBLE_LOG("%s: EVP_EncryptInit_ex(key,iv) failed", __FUNCTION__);
		goto exit;
	}

	EVP_CIPHER_CTX_set_padding(&ctx, 0);

	/* Provide any AAD data. This can be called zero or more times as
	 * required
	 */
	if (1 != EVP_EncryptUpdate(&ctx, NULL, &tmp_len, aad, aad_size)) {
		PEBBLE_LOG("%s: EVP_EncryptUpdate(aad) failed", __FUNCTION__);
		goto exit;
	}

	/* Provide the message to be encrypted, and obtain the encrypted output.
	 * EVP_EncryptUpdate can be called multiple times if necessary
	 */
	if (1 != EVP_EncryptUpdate(&ctx, out, &tmp_len, data, input_size)) {
		PEBBLE_LOG("%s: EVP_EncryptUpdate(ciphertext, plaintext) failed",
				__FUNCTION__);
		goto exit;
	}

	//tmp_ciphertext_len = tmp_len;

	/* Finalise the encryption. Normally ciphertext bytes may be written at
	 * this stage, but this does not occur in GCM mode
	 */
	if (1 != EVP_EncryptFinal_ex(&ctx, out + tmp_len, &tmp_len)) {
		PEBBLE_LOG("%s: EVP_EncryptFinal_ex(ciphertext, plaintext) failed",
				__FUNCTION__);
		goto exit;
	}

	//tmp_ciphertext_len += tmp_len;

	/* Get the tag */
	if (1
			!= EVP_CIPHER_CTX_ctrl(&ctx, EVP_CTRL_GCM_GET_TAG, AES_GCM_TAG_SIZE,
					tag)) {
		PEBBLE_LOG("%s: EVP_CIPHER_CTX_ctrl(EVP_CTRL_GCM_GET_TAG, tag) failed",
				__FUNCTION__);
		goto exit;
	}

	//gcm_ct_len = tmp_ciphertext_len;

	ret = SUCCESS;

	exit: EVP_CIPHER_CTX_cleanup(&ctx);
	return ret;
}

uint32_t pebble_aes_gcm_decrypt(const uint8_t *key, uint32_t key_size,
		const uint8_t *iv, uint32_t iv_size, const uint8_t *aad,
		uint32_t aad_size, const uint8_t *data, uint32_t input_size, // cipher text
		const uint8_t *tag, uint32_t tag_size, uint8_t *out) //plaintext
{
	EVP_CIPHER_CTX ctx;
	int tmp_len;
	uint32_t ret = FAILED, r;

	/* initialise the context */
	EVP_CIPHER_CTX_init(&ctx);

	/* Initialise the encryption operation. */
	if (key_size == AES256_KEY_SIZE) {
		r = EVP_DecryptInit_ex(&ctx, EVP_aes_256_gcm(), NULL, NULL, NULL);
	} else if (key_size == AES128_KEY_SIZE) {
		r = EVP_DecryptInit_ex(&ctx, EVP_aes_128_gcm(), NULL, NULL, NULL);
	} else {
		PEBBLE_LOG("%s: unsupported aes_gcm key length %d", __FUNCTION__,
				key_size);
		goto exit;
	}

	if (r != 1) {
		PEBBLE_LOG("%s: EVP_EncryptInit_ex failed", __FUNCTION__);
		goto exit;
	}

	/* Set IV length. IV other than 12 bytes (96 bits) is not appropriate */
	if (1 != EVP_CIPHER_CTX_ctrl(&ctx, EVP_CTRL_GCM_SET_IVLEN, iv_size, NULL)) {
		PEBBLE_LOG("%s: EVP_CIPHER_CTX_ctrl(EVP_CTRL_GCM_SET_IVLEN) failed",
				__FUNCTION__);
		goto exit;
	}

	/* Initialise key and IV */
	if (1 != EVP_DecryptInit_ex(&ctx, NULL, NULL, key, iv)) {
		PEBBLE_LOG("%s: EVP_EncryptInit_ex(key,iv) failed", __FUNCTION__);
		goto exit;
	}

	/* Provide tag */
	if (1
			!= EVP_CIPHER_CTX_ctrl(&ctx, EVP_CTRL_GCM_SET_TAG, tag_size,
					(void*) tag)) {
		PEBBLE_LOG("%s: EVP_CIPHER_CTX_ctrl(EVP_CTRL_GCM_SET_TAG) failed",
				__FUNCTION__);
		goto exit;
	}

	EVP_CIPHER_CTX_set_padding(&ctx, 0);

	/* Provide any AAD data. */
	if (1 != EVP_DecryptUpdate(&ctx, NULL, &tmp_len, aad, aad_size)) {
		PEBBLE_LOG("%s: EVP_DecryptUpdate(aad) failed", __FUNCTION__);
		goto exit;
	}

	/* Provide the ciphertext to be decrypted, and obtain the plaintext output.
	 * EVP_EncryptUpdate can be called multiple times if necessary
	 */
	if (1 != EVP_DecryptUpdate(&ctx, out, &tmp_len, data, input_size)) {
		PEBBLE_LOG("%s: EVP_DecryptUpdate(plaintext, ciphertext) failed",
				__FUNCTION__);
		goto exit;
	}

	/* Finalise the decryption.*/
	if (1 != EVP_DecryptFinal_ex(&ctx, out + tmp_len, &tmp_len)) {
		PEBBLE_LOG("%s: EVP_DecryptFinal_ex(plaintext, ciphertext) failed",
				__FUNCTION__);
		goto exit;
	}

	ret = SUCCESS;

	exit: EVP_CIPHER_CTX_cleanup(&ctx);
	return ret;
}

/**
 * Input: data
 * Output: checksum
 * Return: 1 on success, 0 on failure
 */
int gen_checksum(unsigned char *data, size_t len,
		unsigned char checksum[CHECKSUM_LEN]) {
	int ret;
	SHA256_CTX sha256;
	ret = SHA256_Init(&sha256);
	if (ret != 1) { // 1 on success, 0 on failure
		return ret;
	}
	ret = SHA256_Update(&sha256, data, len);
	if (ret != 1) { // 1 on success, 0 on failure
		return ret;
	}
	ret = SHA256_Final(checksum, &sha256); // 1 on success, 0 on failure
	return ret;
}

/**
 * Input: data and checksum
 * Return: 1 on valid checksum, 0 on invalid checksum, -1 on unable to verify (gen_checksum failed)
 */
int verify_checksum(unsigned char *data, size_t len,
		unsigned char checksums[CHECKSUM_LEN]) {
	unsigned char checksums1[CHECKSUM_LEN];
	int ret = gen_checksum(data, len, checksums1);
	if (ret != 1) {
		return -1;
	}
	for (int i = 0; i < CHECKSUM_LEN; i++) {
		if (checksums1[i] != checksums[i])
			return 0; // invalid
	}
	return 1;
}

#if 0// not used
uint32_t pebble_sign_CKM_SHA256_RSA_PKCS(
    uint8_t *keyMod,
    uint32_t keyModLen,
    uint8_t *keyPubExp,
    uint32_t keyPubExpLen,
    uint8_t *keyPriExp,
    uint32_t keyPriExpLen,
    uint8_t * messageData,
    uint32_t messageLen,
    uint8_t * signature,
    uint32_t * pSigLen
)
{
	uint32_t ret= TZ_API_OK;
	uint8_t digest[32]= {0}; // SHA256_LEN = 32
	uint32_t digestLen=sizeof(digest);
	RSA *rsa = NULL;
	int opslRet=0;

	TTY_LOG("%s: sign using RSA key", __FUNCTION__);

	if (keyMod == NULL || keyModLen == 0) {
		TTY_LOG("%s invalid input key modulus", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (keyModLen > MAX_SIGNATURE_SIZE) {
		TTY_LOG("%s signature length exceeds MAX_SIGNATURE_SIZE", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (keyPubExp == NULL || keyPubExpLen == 0) {
		TTY_LOG("%s invalid input key public exponent", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (keyPriExp == NULL || keyPriExpLen == 0) {
		TTY_LOG("%s invalid input key private exponent", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (messageData == NULL || messageLen == 0) {
		TTY_LOG("%s invalid input signing data", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (signature == NULL || pSigLen == NULL|| *pSigLen < keyModLen) {
		TTY_LOG("%s invalid output signature data", __FUNCTION__);
		return TZ_API_ERROR;
	}

	ret= TZ_digest_SHA256(messageData, messageLen, digest, &digestLen);
	if (TZ_API_OK != ret) {
		TTY_LOG("%s: TZ_digest_SHA256() returns an error code 0x%x", __FUNCTION__, ret);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa = RSA_new();
	if(NULL==rsa)
	{
		TTY_LOG("%s RSA_new() returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->e=BN_bin2bn(keyPubExp, keyPubExpLen,rsa->e);
	if(NULL==rsa->e)
	{
		TTY_LOG("%s BN_bin2bn() for rsa->e returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->d=BN_bin2bn(keyPriExp, keyPriExpLen,rsa->d);
	if(NULL==rsa->d)
	{
		TTY_LOG("%s BN_bin2bn() for rsa->d returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->n=BN_bin2bn(keyMod, keyModLen,rsa->n);
	if(NULL==rsa->n)
	{
		TTY_LOG("%s BN_bin2bn() for rsa->n returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	opslRet = RSA_sign(NID_sha256, digest, digestLen, signature, pSigLen, rsa);
	if (opslRet != 1)
	{
		TTY_LOG("%s RSA_sign() returned an error", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	ret= TZ_API_OK;

EXIT:
	if(rsa!=NULL) {
		RSA_free(rsa);
	}
	return ret;
}

// not used
uint32_t pebble_verify_CKM_SHA256_RSA_PKCS(
    uint8_t *keyMod,
    uint32_t keyModLen,
    uint8_t *keyPubExp,
    uint32_t keyPubExpLen,
    uint8_t * messageData,
    uint32_t messageLen,
    uint8_t * signature,
    uint32_t sigLen,
    bool * isValidSig
)
{
	uint32_t ret= TZ_API_OK;
	uint8_t digest[32]= {0}; //SHA256_LENGTH = 32
	uint32_t digestLen=sizeof(digest);
	RSA *rsa = NULL;
	int opslRet=0;

	TTY_LOG("%s: verify RSA signature", __FUNCTION__);

	if (keyMod == NULL || keyModLen == 0) {
		TTY_LOG("%s: invalid input key modulus", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (keyPubExp == NULL || keyPubExpLen == 0) {
		TTY_LOG("%s: invalid input key public exponent", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (messageData == NULL || messageLen == 0) {
		TTY_LOG("%s: invalid input signing data", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (signature == NULL || sigLen == 0) {
		TTY_LOG("%s: invalid output signature data", __FUNCTION__);
		return TZ_API_ERROR;
	}

	if (isValidSig == NULL ) {
		TTY_LOG("%s: invalid isValidSig pointer", __FUNCTION__);
		return TZ_API_ERROR;
	}
	*isValidSig = false;

	ret= TZ_digest_SHA256(messageData, messageLen, digest, &digestLen);
	if (TZ_API_OK != ret) {
		TTY_LOG("%s: TZ_digest_SHA256() returns an error code 0x%x", __FUNCTION__, ret);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa = RSA_new();
	if(NULL==rsa)
	{
		TTY_LOG("%s RSA_new() returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}


	rsa->e=BN_bin2bn(keyPubExp, keyPubExpLen,rsa->e);
	if(NULL==rsa->e)
	{
		TTY_LOG("%s BN_bin2bn() for rsa->e returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	rsa->n=BN_bin2bn(keyMod, keyModLen,rsa->n);
	if(NULL==rsa->n)
	{
		TTY_LOG("%s BN_bin2bn() for rsa->n returned NULL!", __FUNCTION__);
		ret = TZ_API_ERROR;
		goto EXIT;
	}

	opslRet = RSA_verify(NID_sha256, digest, digestLen, signature, sigLen, rsa);
	if(opslRet==1) {
		*isValidSig=true;
	} else {
		*isValidSig=false;
		TTY_LOG("%s: RSA_verify() invalid signature", __FUNCTION__);
	}

	ret= TZ_API_OK;

EXIT:
	if(rsa!=NULL) {
		RSA_free(rsa);
	}
	return ret;
}
#endif
