#include <string.h>
#include "vk_constants.h"
#include "vk_interface.h"
#include "vk_data_struct.h"
#include "vk_error.h"
#include "vk_log.h"
#include "vk_utils.h"
#include "crypto/vk_crypto_aes.h"
#include "openssl/evp.h"
#include "openssl/err.h"

int vk_crypto_aes_256_gcm_encrypt(unsigned char* plaintext, unsigned int plaintext_len,
									unsigned char* ciphertext, unsigned int* ciphertext_len,
									unsigned char* tag, unsigned int tag_len, unsigned char* shared_key,
									unsigned char* random_iv)
{
	int ret = VK_ERR_GENERAL;
	EVP_CIPHER_CTX* ctx = NULL;
	unsigned char key[AES256_KEY_LEN] = {0,};
	unsigned char iv[AES_IV_LEN] = {0,};
	int len = 0;

	if (random_iv == NULL) {
		LOGE("%s: IV invalid\n", __func__);
		ret = VK_ERR_INVALID_ARGUMENT;
		goto out;
	}

	if (!isAllZero(random_iv, AES_IV_LEN)) {
		memcpy(iv, random_iv, AES_IV_LEN);
	}

	// In case of vault encryption 
	if (shared_key == NULL) {
		ret = VK_GET_SECURE_ITEM(VK_SECURE_ITEM_KEY, key, AES256_KEY_LEN);
		if (ret != VK_SUCCESS) {
			LOGE("%s: Failed VK_GET_SECURE_ITEM(%d/%d)\n", __func__, VK_SECURE_ITEM_KEY, ret);
			goto out;
		}
	// In case of user message encryption with key
	} else {
		if (isAllZero(shared_key, AES256_KEY_LEN)) {
			LOGE("%s: Shared key invalid\n", __func__);
			ret = VK_ERR_INVALID_ARGUMENT;
			goto out;
		} else {
			memcpy(key, shared_key, AES256_KEY_LEN);
		}
	}

	if (!(ctx = EVP_CIPHER_CTX_new())) {
		LOGE("%s: Failed EVP_CIPHER_CTX_new\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (!EVP_EncryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)) {
		LOGE("%s: Failed EVP_EncryptInit_ex, gcm\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, AES_IV_LEN, NULL)) {
		LOGE("%s: Failed EVP_CIPHER_CTX_ctrl, salt\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (!EVP_EncryptInit_ex(ctx, NULL, NULL, key, iv)) {
		LOGE("%s: EVP_EncryptInit_ex, key and iv \n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (!EVP_EncryptUpdate(ctx, ciphertext, &len, plaintext, plaintext_len)) {
		LOGE("%s: Failed EVP_EncryptUpdate plaintext\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	*ciphertext_len = len;

	if (EVP_Cipher(ctx, NULL, NULL, 0) < 0) {
		LOGE("%s: Failed TAG calculation\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) {
		LOGE("%s: Failed EVP_CIPHER_CTX_ctrl, get tag\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	ret = VK_SUCCESS;

out:
	if (ctx != NULL) EVP_CIPHER_CTX_free(ctx);

	memset(key, 0, AES256_KEY_LEN);
	memset(iv, 0, AES_IV_LEN);
	return ret;
}

int vk_crypto_aes_256_gcm_decrypt(unsigned char* ciphertext, unsigned int ciphertext_len,
									unsigned char* plaintext, unsigned int* plaintext_len,
									unsigned char* tag, unsigned int tag_len, unsigned char* shared_key,
									unsigned char* random_iv)
{
	int ret = VK_ERR_GENERAL;
	EVP_CIPHER_CTX *ctx = NULL;
	unsigned char key[AES256_KEY_LEN] = {0,};
	unsigned char iv[AES_IV_LEN] = {0,};
	int len = 0;

	if (random_iv == NULL) {
		LOGE("%s: IV invalid\n", __func__);
		ret = VK_ERR_INVALID_ARGUMENT;
		goto out;
	}

	if (!isAllZero(random_iv, AES_IV_LEN)) {
		memcpy(iv, random_iv, AES_IV_LEN);
	}

	// In case of vault decryption 
	if (shared_key == NULL) {
		ret = VK_GET_SECURE_ITEM(VK_SECURE_ITEM_KEY, key, AES256_KEY_LEN);
		if (ret != VK_SUCCESS) {
			LOGE("%s: Failed VK_GET_SECURE_ITEM(%d/%d)\n", __func__, VK_SECURE_ITEM_KEY, ret);
			goto out;
		}
	// In case of user message decryption with key
	} else {
		if (isAllZero(shared_key, AES256_KEY_LEN)) {
			LOGE("%s: Shared key invalid\n", __func__);
			ret = VK_ERR_CLIENT_NOT_INITIALIZED;
			goto out;
		} else {
			memcpy(key, shared_key, AES256_KEY_LEN);
		}
	}

	/* Create and initialise the context */
	if (!(ctx = EVP_CIPHER_CTX_new())) {
		LOGE("%s: Failed EVP_CIPHER_CTX_new\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	/* Initialise the decryption operation. */
	if (!EVP_DecryptInit_ex(ctx, EVP_aes_256_gcm(), NULL, NULL, NULL)) {
		LOGE("%s: Failed EVP_DecryptInit_ex, gcm\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	/* Set IV length. Not necessary if this is 12 bytes (96 bits) */
	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_IVLEN, AES_IV_LEN, NULL)) {
		LOGE("%s: Failed EVP_CIPHER_CTX_ctrl, salt\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	/* Initialise key and IV */
	if (!EVP_DecryptInit_ex(ctx, NULL, NULL, key, iv)) {
		LOGE("%s: Failed EVP_DecryptInit_ex, key and iv \n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	/* Provide the message to be decrypted, and obtain the plaintext output.
	 * EVP_DecryptUpdate can be called multiple times if necessary
	 */
	if (!EVP_DecryptUpdate(ctx, plaintext, &len, ciphertext, ciphertext_len)) {
		LOGE("%s: Failed EVP_DecryptUpdate, ciphertext\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	*plaintext_len = len;
	/* Set expected tag value. Works in OpenSSL 1.0.1d and later */
	if (!EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, tag_len, tag)) {
		LOGE("%s: Failed EVP_CIPHER_CTX_ctrl, tag\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (EVP_Cipher(ctx, NULL, NULL, 0) < 0) {
		LOGE("%s: Failed to verify auth tag\n", __func__);
		ret = VK_ERR_INTEGRITY_FAILED;
		goto out;
	}

	ret = VK_SUCCESS;

out:
	/* Clean up */
	if (ctx != NULL) EVP_CIPHER_CTX_free(ctx);

	memset(key, 0, AES256_KEY_LEN);
	memset(iv, 0, AES_IV_LEN);

	return ret;
}
