#include <ctype.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <time.h>

#include "mbedtls/aes.h"
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/entropy.h"
#include "mbedtls/error.h"
#include "mbedtls/md.h"
#include "mbedtls/oid.h"
#include "mbedtls/pk.h"
#include "mbedtls/sha256.h"
#include "mbedtls/x509.h"
#include "mbedtls/x509_crt.h"

#include "em_common.h"
#include "em_crypto.h"
#include "em_crypto_cert_lite.h"

static int entropy_dummy_source(void *data, unsigned char *output, size_t len, size_t *olen)
{
	((void)data);
	int i;

	srand((unsigned int)time(NULL));

	for (i = 0; i < len; i++)
		output[i] = rand() % 256;
	*olen = len;

	return 0;
}

static int entropy_dummy_func(void *data, unsigned char *output, size_t len)
{
	int i;

	srand((unsigned int)time(NULL));

	for (i = 0; i < len; i++)
		output[i] = rand() % 256;

	return 0;
}

size_t em_crypto_find_oid_value_in_name(const mbedtls_x509_name *name, const char *target_short_name, char *value,
					size_t value_length)
{
	const char *short_name = NULL;
	int found = 0;
	size_t retval = 0;

	while ((name != NULL) && !found) {
		// if there is no data for this name go to the next one
		if (!name->oid.p) {
			name = name->next;
			continue;
		}

		int ret = mbedtls_oid_get_attr_short_name(&name->oid, &short_name);
		if (ret == 0 && memcmp(short_name, target_short_name, strlen(target_short_name)) == 0)
			found = 1;

		short_name = NULL;
		if (found) {
			size_t bytes_to_write = (name->val.len >= value_length) ? value_length - 1 : name->val.len;

			for (size_t i = 0; i < bytes_to_write; i++) {
				char c = name->val.p[i];
				if (c < 32 || c == 127 || (c > 128 && c < 160)) {
					value[i] = '?';
				} else {
					value[i] = c;
				}
			}

			// null terminate
			value[bytes_to_write] = 0;
			retval = name->val.len;
		}

		name = name->next;
	}

	return retval;
}

int em_crypto_get_subject_from_cert(const unsigned char *cert, const int len_cert, const char *type, char *out,
				    int out_len)
{
	int ret;
	const char *converted_type = NULL;

	mbedtls_x509_crt crt;
	mbedtls_x509_crt_init(&crt);

	EM_CHECK_NULL(__func__, EM_ERR_EM_CRYPTO_GET_SUBJECT_FROM_CERT_LITE, cert, type, out);

	ret = mbedtls_x509_crt_parse(&crt, cert, len_cert);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to parse certificate(0x%08x)\n", ret);
		ret = EM_ERR_EM_CRYPTO_GET_SUBJECT_FROM_CERT_LITE_PARSE;
		goto out;
	}

	if (memcmp(type, EM_CERT_SUBJECT_UID, strlen(type)) == 0)
		converted_type = EM_CERT_SUBJECT_UID_FOR_LITE;
	else
		converted_type = EM_CERT_SUBJECT_CN_FOR_LITE;

	ret = em_crypto_find_oid_value_in_name(&crt.subject, converted_type, out, out_len);
	if (ret <= 0) {
		LOGE("Failed to get %s in subject of cert(%d)\n", converted_type, ret);
		ret = EM_ERR_EM_CRYPTO_GET_SUBJECT_FROM_CERT_LITE_LEN;
		goto out;
	}

	ret = EM_SUCCESS;
out:
	mbedtls_x509_crt_free(&crt);
	return ret;
}

int em_crypto_hmac(unsigned char *hmac_buf, unsigned char *data, unsigned int data_len, unsigned char *key,
		   unsigned int key_len)
{
	int ret;

	EM_CHECK_NULL(__func__, EM_ERR_EM_CRYPTO_HMAC_MBEDTLS, hmac_buf, data, key);

	ret = mbedtls_md_hmac(mbedtls_md_info_from_type(MBEDTLS_MD_SHA256), key, key_len, data, data_len, hmac_buf);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to calculate hmac(0x%08x)", ret);
		ret = EM_ERR_EM_CRYPTO_HMAC_MBEDTLS_MD_HMAC;
		goto out;
	}

	ret = EM_SUCCESS;
out:
	return ret;
}

int em_crypto_verify_cert(unsigned char *cert, unsigned int cert_size)
{
	int ret;
	uint32_t flags = 0;

	mbedtls_x509_crt server_cert;
	mbedtls_x509_crt cacert;

	mbedtls_x509_crt_init(&server_cert);
	mbedtls_x509_crt_init(&cacert);

	EM_CHECK_NULL(__func__, EM_ERR_EM_CRYPTO_VERIFY_CERT_LITE, cert);

	ret = mbedtls_x509_crt_parse(&server_cert, cert, cert_size);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to parse server cert(0x%08x)\n", ret);
		ret = EM_ERR_EM_CRYPTO_VERIFY_CERT_LITE_SERVER_CERT_PARSE;
		goto out;
	}

	ret = mbedtls_x509_crt_parse(&cacert, root_ca, sizeof(root_ca));
	if (ret != EM_SUCCESS) {
		LOGE("Failed to parse ca cert(0x%08x)\n", ret);
		ret = EM_ERR_EM_CRYPTO_VERIFY_CERT_LITE_CA_CERT_PARSE;
		goto out;
	}

	ret = mbedtls_x509_crt_verify(&server_cert, &cacert, NULL, EM_CERT_VALID_CN_VALUE, &flags, NULL, NULL);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to verify cert(0x%08x)\n", ret);
		ret = EM_ERR_EM_CRYPTO_VERIFY_CERT_LITE_VERIFY;
		goto out;
	}

	ret = EM_SUCCESS;
out:
	mbedtls_x509_crt_free(&server_cert);
	mbedtls_x509_crt_free(&cacert);

	return ret;
}

int em_crypto_verify_rsa_signature(unsigned char *cert, unsigned int cert_size, unsigned char *sig,
				   unsigned int sig_size, unsigned char *data, unsigned int data_size)
{
	int ret;

	uint8_t decrypt[EM_LEN_SHA256 + 1] = {};
	uint8_t digest[EM_LEN_SHA256 + 1] = {};
	uint32_t decrypt_size = 0;

	mbedtls_x509_crt crt;
	mbedtls_rsa_context *rsa_context = NULL;
	mbedtls_entropy_context entropy;
	mbedtls_ctr_drbg_context ctr_drbg;

	const char *personalization = "mbedtls rsa_encrypt for EM lite";
	char mbedtls_error[512] = {};

	EM_CHECK_NULL(__func__, EM_ERR_EM_CRYPTO_VERIFY_RSA_SIGNATURE_LITE, cert, sig, data);

	mbedtls_x509_crt_init(&crt);
	mbedtls_ctr_drbg_init(&ctr_drbg);
	mbedtls_entropy_init(&entropy);

	ret = mbedtls_entropy_add_source(&entropy, entropy_dummy_source, NULL, 16, 1);
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to add soruce for entropy(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_CRYPTO_VERIFY_RSA_SIGNATURE_LITE_ET_ADD;
		goto out;
	}

	ret = mbedtls_ctr_drbg_seed(&ctr_drbg, entropy_dummy_func, &entropy, (const unsigned char *)personalization,
				    strlen(personalization));
	if (ret != 0) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to set seed(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_CRYPTO_VERIFY_RSA_SIGNATURE_LITE_DRBG;
		goto out;
	}

	ret = mbedtls_x509_crt_parse(&crt, cert, cert_size);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to parse cert(0x%08x)\n", ret);
		ret = EM_ERR_EM_CRYPTO_VERIFY_RSA_SIGNATURE_LITE_CERT;
		goto out;
	}

	ret = em_crypto_sha256(data, data_size, digest);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to make digest(0x%08x)\n", ret);
		ret = EM_ERR_EM_CRYPTO_VERIFY_RSA_SIGNATURE_LITE_SHA256;
		goto out;
	}

	rsa_context = mbedtls_pk_rsa(crt.pk);
	if (rsa_context == NULL) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to get rsa context from pk(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_CRYPTO_VERIFY_RSA_SIGNATURE_LITE_PK;
		goto out;
	}

	//	rsa_context->padding = 0; // PKCS1 V21
	//	rsa_context->hash_id = MBEDTLS_MD_SHA256;
	ret = mbedtls_rsa_pkcs1_decrypt(rsa_context, mbedtls_ctr_drbg_random, &ctr_drbg, MBEDTLS_RSA_PUBLIC,
					&decrypt_size, sig, decrypt, EM_LEN_SHA256);
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("Failed to verify signature(%d/%s)", ret, mbedtls_error);
		ret = EM_ERR_EM_CRYPTO_VERIFY_RSA_SIGNATURE_LITE_DECRYPT;
		goto out;
	}

	if (memcmp(digest, decrypt, EM_LEN_SHA256) != 0) {
		LOGE("Failed to compare signature\n");
		ret = EM_ERR_EM_CRYPTO_VERIFY_RSA_SIGNATURE_LITE_FAIL;
		goto out;
	}

	ret = EM_SUCCESS;
out:
	return ret;
}

int em_get_random(unsigned char *buf, int required_len)
{
	int ret;

	mbedtls_entropy_context entropy;
	mbedtls_ctr_drbg_context ctr_drbg;

	const char *personalization = "mbedtls for LSI bootloader";
	char mbedtls_error[512] = {};

	EM_CHECK_NULL(__func__, EM_ERR_EM_GET_RANDOM_LITE, buf);

	if (required_len < 0) {
		LOGE("%s : Unexpected length(%d)\n", __func__, required_len);
		ret = EM_ERR_EM_GET_RANDOM_LITE_LEN;
		goto out;
	}

	mbedtls_ctr_drbg_init(&ctr_drbg);
	mbedtls_entropy_init(&entropy);

	ret = mbedtls_entropy_add_source(&entropy, entropy_dummy_source, NULL, 16, 1);
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to add soruce for entropy(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_GET_RANDOM_LITE_ET_ADD;
		goto out;
	}

	ret = mbedtls_ctr_drbg_seed(&ctr_drbg, entropy_dummy_func, &entropy, (const unsigned char *)personalization,
				    strlen(personalization));
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to set seed(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_GET_RANDOM_LITE_DRBG;
		goto out;
	}

	ret = mbedtls_ctr_drbg_random(&ctr_drbg, buf, required_len);
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to generate random bytes(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_GET_RANDOM_LITE_DRBG_RANDOM;
		goto out;
	}

	ret = required_len;
out:
	mbedtls_entropy_free(&entropy);
	mbedtls_ctr_drbg_free(&ctr_drbg);
	return ret;
}

int em_crypto_sha256(unsigned char *data, unsigned int data_len, unsigned char *digest)
{
	mbedtls_sha256(data, data_len, digest, 0);
	return EM_SUCCESS;
}

int em_crypto_aes_256_ctr_encrypt(unsigned char *plaintext, int plaintext_len, unsigned char *key, unsigned char *iv,
				  unsigned char *ciphertext, int *ciphertext_len)
{
	int ret;
	size_t nc_off = 0;

	unsigned char stream_block[16] = {};
	char mbedtls_error[512] = {};

	mbedtls_aes_context aes_ctx;
	mbedtls_aes_init(&aes_ctx);

	EM_CHECK_NULL(__func__, EM_ERR_EM_CRYPTO_AES_256_CTR_ENCRYPT_LITE, plaintext, key, iv, ciphertext,
		      ciphertext_len);

	ret = mbedtls_aes_setkey_enc(&aes_ctx, key, 256);
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to set aes key(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_CRYPTO_AES_256_CTR_ENCRYPT_LITE_SET_ENC;
		goto out;
	}

	ret = mbedtls_aes_crypt_ctr(&aes_ctx, plaintext_len, &nc_off, iv, stream_block, plaintext, ciphertext);
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to aes encrypt(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_CRYPTO_AES_256_CTR_ENCRYPT_LITE_CRYPT;
		goto out;
	}

	*ciphertext_len = plaintext_len;

	ret = EM_SUCCESS;
out:
	mbedtls_aes_free(&aes_ctx);
	return ret;
}

int em_crypto_aes_256_ctr_decrypt(unsigned char *ciphertext, int ciphertext_len, unsigned char *key, unsigned char *iv,
				  unsigned char *plaintext, int *plaintext_len)
{
	return em_crypto_aes_256_ctr_encrypt(ciphertext, ciphertext_len, key, iv, plaintext, plaintext_len);
}

int em_crypto_rsa_encrypt(uint8_t *cert, int cert_size, uint8_t *in, int in_len, uint8_t *out, int *out_len)
{
	int ret;

	mbedtls_x509_crt crt;
	mbedtls_rsa_context *rsa_context = NULL;
	mbedtls_entropy_context entropy;
	mbedtls_ctr_drbg_context ctr_drbg;

	const char *personalization = "mbedtls rsa_encrypt for EM lite";

	char mbedtls_error[512] = {};

	EM_CHECK_NULL(__func__, EM_ERR_EM_CRYPTO_RSA_ENCRYPT_LITE, cert, in, out, out_len);

	mbedtls_x509_crt_init(&crt);
	mbedtls_ctr_drbg_init(&ctr_drbg);
	mbedtls_entropy_init(&entropy);

	ret = em_crypto_verify_cert(cert, cert_size);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to verify cert(0x%08x)\n", ret);
		goto out;
	}

	ret = mbedtls_ctr_drbg_seed(&ctr_drbg, entropy_dummy_func, &entropy, (const unsigned char *)personalization,
				    strlen(personalization));
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to set seed(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_CRYPTO_RSA_ENCRYPT_LITE_DRBG;
		goto out;
	}

	ret = mbedtls_x509_crt_parse(&crt, cert, cert_size);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to parse cert(0x%08x)\n", ret);
		ret = EM_ERR_EM_CRYPTO_RSA_ENCRYPT_LITE_CRT;
		goto out;
	}

	rsa_context = mbedtls_pk_rsa(crt.pk);
	if (rsa_context == NULL) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to get rsa context from pk(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = -2;
		goto out;
	}

	rsa_context->padding = MBEDTLS_RSA_PKCS_V21;
	rsa_context->hash_id = MBEDTLS_MD_SHA1;
	ret = mbedtls_rsa_pkcs1_encrypt(rsa_context, mbedtls_ctr_drbg_random, &ctr_drbg, MBEDTLS_RSA_PUBLIC, in_len, in,
					out);
	if (ret != EM_SUCCESS) {
		mbedtls_strerror(ret, mbedtls_error, sizeof(mbedtls_error));
		LOGE("%s : Failed to rsa encrypt(%08x)%s\n", __func__, ret, mbedtls_error);
		ret = EM_ERR_EM_CRYPTO_RSA_ENCRYPT_LITE_PKCS_ENCRYPT;
		goto out;
	}
	*out_len = rsa_context->len;

	ret = EM_SUCCESS;
out:
	mbedtls_ctr_drbg_free(&ctr_drbg);
	mbedtls_entropy_free(&entropy);
	mbedtls_rsa_free(rsa_context);
	mbedtls_x509_crt_free(&crt);

	return ret;
}
