#include <stdio.h>
#include <string.h>

#include "vk_constants.h"
#include "vk_interface.h"
#include "vk_data_struct.h"
#include "vk_error.h"
#include "vk_log.h"
#include "crypto/vk_crypto.h"
#include "crypto/vk_crypto_cert.h"
#include "crypto/vk_crypto_aes.h"

#include "openssl/evp.h"
#include "openssl/err.h"
#include "openssl/hmac.h"
#include "openssl/x509.h"
#include "openssl/mem.h"

int vk_crypto_sha256(unsigned char* data, unsigned int data_len, unsigned char* digest)
{
	int ret = VK_ERR_GENERAL;
	unsigned int digest_len = SHA256_DIGEST_LEN;
	EVP_MD_CTX* pMdCtx = NULL;

	if (data == NULL) {
		LOGE("%s: data is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if (digest == NULL) {
		LOGE("%s: digest is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if ((pMdCtx = EVP_MD_CTX_create()) == NULL) {
		LOGE("%s: Failed EVP_MD_CTX_create\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (!EVP_DigestInit_ex(pMdCtx, EVP_sha256(), NULL)) {
		LOGE("%s: Failed EVP_DigestInit_ex\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (!EVP_DigestUpdate(pMdCtx, data, data_len)) {
		LOGE("%s: Failed EVP_DigestUpdate\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	if (!EVP_DigestFinal_ex(pMdCtx, digest, &digest_len)) {
		LOGE("%s: Failed EVP_DigestFinal_ex\n", __func__);
		ret = VK_ERR_TZ_API_CRYPTO;
		goto out;
	}

	ret = VK_SUCCESS;

out:
	if (pMdCtx) {
		EVP_MD_CTX_destroy(pMdCtx);
		pMdCtx = NULL;
	}

	return ret;
}

int vk_crypto_hmac_sha256(unsigned char* hmac, unsigned char* data, unsigned int data_len,
							unsigned char* key, unsigned int key_len)
{
	unsigned int hmac_len = HMAC_SHA256_LEN;

	if (hmac == NULL) {
		LOGE("%s: hmac is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if (data == NULL) {
		LOGE("%s: data is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if (key == NULL) {
		LOGE("%s: key is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if (!HMAC(EVP_sha256(), key, key_len, data, data_len, hmac, &hmac_len)) {
		LOGE("%s: HMAC error\n", __func__);
		return VK_ERR_TZ_API_CRYPTO;
	}

	return VK_SUCCESS;
}

int vk_crypto_pbkdf2(char* pass, int pass_len, unsigned char* salt, int salt_len,
						unsigned char* key, unsigned int key_len)
{
	int ret = VK_ERR_GENERAL;

	if (pass == NULL) {
		LOGE("%s: pass is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if (salt == NULL) {
		LOGE("%s: salt is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if (key == NULL) {
		LOGE("%s: key is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	// Iterations
	// ref. https://pages.nist.gov/800-63-3/sp800-63b.html#sec5
	if (!PKCS5_PBKDF2_HMAC_SHA1(pass, pass_len, salt, salt_len, 10000, key_len, key)) {
		LOGE("%s: PBKDF2 error\n", __func__);
		return VK_ERR_TZ_API_CRYPTO;
	}

	return VK_SUCCESS;
}

int vk_crypto_check_integrity(unsigned char* data, int len, unsigned char* hash)
{
	int ret = VK_ERR_GENERAL;
	unsigned char hash_tmp[SHA256_DIGEST_LEN] = {0,};

	if (data == NULL) {
		LOGE("%s: data is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if (hash == NULL) {
		LOGE("%s: hash is null\n", __func__);
		return VK_ERR_INVALID_ARGUMENT;
	}

	if (len <= 0) {
		LOGE("%s: wrong len(%d)\n", __func__, len);
		return VK_ERR_INVALID_ARGUMENT;
	}

	ret = vk_crypto_sha256(data, len, hash_tmp);
	if (ret != VK_SUCCESS) {
		LOGE("%s: Failed sha256(%d)\n", __func__, ret);
		return ret;
	}

	if (CRYPTO_memcmp(hash_tmp, hash, SHA256_DIGEST_LEN)) {
		LOGE("%s: Integrity is broken\n", __func__);
		return VK_ERR_INTEGRITY_FAILED;
	}

	return VK_SUCCESS;
}
