#include <rpmb.h>
#include <string.h>

#include "em_ta.h"

#ifdef USE_MTK_RPMB
#include "tlRpmbDriverApi.h"
#endif /* USE_MTK_RPMB */

static int em_teegris_rpmb_read(int block_index, uint8_t *data, uint32_t data_len)
{
	int ret;
	TEE_Result lret;
#ifdef USE_MTK_RPMB
	uint32_t crSession;
#endif /* USE_MTK_RPMB */

	EM_CHECK_NULL(__func__, EM_ERR_EM_TEEGRIS_RPMB_READ, data);

	if (block_index < 0) {
		LOGE("block index isn't normal(%u)\n", block_index);
		ret = EM_ERR_EM_TEEGRIS_RPMB_READ_BLOCK_INDEX;
		goto out;
	}

#ifdef USE_MTK_RPMB
	crSession = TEE_RpmbOpenSession(EM_RPMB_MTK_PARTITION_ID);
	if (crSession == 0xFFFFFFFF) {
		LOGE("No session permitted or wrong user ID(%d)\n", EM_RPMB_MTK_PARTITION_ID);
		ret = EM_ERR_EM_TEEGRIS_RPMB_READ_OPEN_FAIL;
		goto out;
	}

	lret = TEE_RpmbReadDatabyOffset(crSession, block_index * EM_RPMB_SINGLE_BLOCK_SIZE, data, EM_RPMB_SINGLE_BLOCK_SIZE * data_len, &ret);
	if (lret != TEE_SUCCESS) {
		LOGE("Failed TEES_RPMBRead(index %u)(0x%08x/%d) - try again\n", block_index, lret, ret);
		lret = TEE_RpmbReadDatabyOffset(crSession, block_index * EM_RPMB_SINGLE_BLOCK_SIZE, data, EM_RPMB_SINGLE_BLOCK_SIZE * data_len, &ret);
		if (lret != TEE_SUCCESS) {
			LOGE("Failed TEES_RPMBRead(index %u)(0x%08x/%d) twice\n", block_index, lret, ret);
			TEE_RpmbCloseSession(crSession);
			ret = EM_ERR_EM_TEEGRIS_RPMB_READ_SECTOR;
			goto out;
		}
	}

	TEE_RpmbCloseSession(crSession);
#else
	lret = TEES_RPMBRead(EM_RPMB_EXYNOS_PARTITION_ID, block_index, data, data_len, RPMB_TYPE_BLOCK);
	if (lret != TEE_SUCCESS) {
		LOGE("Failed TEES_RPMBRead(index %u)(0x%08x) - try again\n", block_index, lret);
		lret = TEES_RPMBRead(EM_RPMB_EXYNOS_PARTITION_ID, block_index, data, data_len, RPMB_TYPE_BLOCK);
		if (lret != TEE_SUCCESS) {
			LOGE("Failed TEES_RPMBRead(index %u)(0x%08x) twice\n", block_index, lret);
			ret = EM_ERR_EM_TEEGRIS_RPMB_READ_SECTOR;
			goto out;
		}
	}
#endif /* USE_MTK_RPMB */

	ret = EM_SUCCESS;
out:

	return ret;
}

static int em_teegris_rpmb_write(int block_index, uint8_t *data, uint32_t data_len)
{
	int ret = EM_SUCCESS;
	TEE_Result lret;
#ifdef USE_MTK_RPMB
	uint32_t crSession;
#endif /* USE_MTK_RPMB */

	EM_CHECK_NULL(__func__, EM_ERR_EM_TEEGRIS_RPMB_WRITE, data);

	if (block_index < 0) {
		LOGE("block index isn't normal(%u)\n", block_index);
		ret = EM_ERR_EM_TEEGRIS_RPMB_WRITE_BLOCK_INDEX;
		goto out;
	}

#ifdef USE_MTK_RPMB
	crSession = TEE_RpmbOpenSession(EM_RPMB_MTK_PARTITION_ID);
	if (crSession == 0xFFFFFFFF) {
		LOGE("No session permitted or wrong user ID(%d)\n", EM_RPMB_MTK_PARTITION_ID);
		ret = EM_ERR_EM_TEEGRIS_RPMB_WRITE_OPEN_FAIL;
		goto out;
	}

	lret = TEE_RpmbWriteDatabyOffset(crSession, block_index * EM_RPMB_SINGLE_BLOCK_SIZE, data, EM_RPMB_SINGLE_BLOCK_SIZE * data_len, &ret);
	if (lret != TEE_SUCCESS) {
		LOGE("Failed to write front cvault (index %u)(0x%08x/%d) - try again\n", block_index, lret, ret);
		lret = TEE_RpmbWriteDatabyOffset(crSession, block_index * EM_RPMB_SINGLE_BLOCK_SIZE, data, EM_RPMB_SINGLE_BLOCK_SIZE * data_len, &ret);
		if (lret != TEE_SUCCESS) {
			LOGE("Failed to write front cvault (index %u)(0x%08x/%d) twice\n", block_index, lret, ret);
			TEE_RpmbCloseSession(crSession);
			ret = EM_ERR_EM_TEEGRIS_RPMB_WRITE_SECTOR;
			goto out;
		}
	}

	lret = TEE_RpmbWriteDatabyOffset(crSession, (block_index + 1) * EM_RPMB_SINGLE_BLOCK_SIZE, data, EM_RPMB_SINGLE_BLOCK_SIZE * data_len, &ret);
	if (lret != TEE_SUCCESS) {
		LOGE("Failed to write front cvault (index %u)(0x%08x/%d) - try again\n", block_index + 1, lret, ret);
		lret = TEE_RpmbWriteDatabyOffset(crSession, (block_index + 1) * EM_RPMB_SINGLE_BLOCK_SIZE, data, EM_RPMB_SINGLE_BLOCK_SIZE * data_len, &ret);
		if (lret != TEE_SUCCESS) {
			LOGE("Failed to write front cvault (index %u)(0x%08x/%d) twice\n", block_index + 1, lret, ret);
			TEE_RpmbCloseSession(crSession);
			ret = EM_ERR_EM_TEEGRIS_RPMB_WRITE_SECTOR_BACKUP;
			goto out;
		}
	}

	TEE_RpmbCloseSession(crSession);
#else
	lret = TEES_RPMBWrite(EM_RPMB_EXYNOS_PARTITION_ID, block_index, data, data_len, RPMB_TYPE_BLOCK);
	if (lret != TEE_SUCCESS) {
		LOGE("Failed to write front cvault (index %u)(0x%08x) - try again\n", block_index, lret);
		lret = TEES_RPMBWrite(EM_RPMB_EXYNOS_PARTITION_ID, block_index, data, data_len, RPMB_TYPE_BLOCK);
		if (lret != TEE_SUCCESS) {
			LOGE("Failed to write front cvault (index %u)(0x%08x) twice\n", block_index,
			     lret);
			ret = EM_ERR_EM_TEEGRIS_RPMB_WRITE_SECTOR;
			goto out;
		}
	}

	lret = TEES_RPMBWrite(EM_RPMB_EXYNOS_PARTITION_ID, block_index + 1, data, data_len, RPMB_TYPE_BLOCK);
	if (lret != TEE_SUCCESS) {
		LOGE("Failed to write back cvault (index %u)(0x%08x) - try again\n", block_index + 1,
		     lret);
		lret = TEES_RPMBWrite(EM_RPMB_EXYNOS_PARTITION_ID, block_index + 1, data, data_len, RPMB_TYPE_BLOCK);
		if (lret != TEE_SUCCESS) {
			LOGE("Failed to write back cvault (index %u)(0x%08x) twice\n", block_index + 1,
			     lret);
			ret = EM_ERR_EM_TEEGRIS_RPMB_WRITE_SECTOR_BACKUP;
			goto out;
		}
	}
#endif /* USE_MTK_RPMB */

	ret = EM_SUCCESS;
out:
	return ret;
}

// TODO CHANGE
int em_read_core(uint8_t *buf, uint32_t buf_len)
{
	uint32_t ret;
	uint8_t core[EM_RPMB_SINGLE_BLOCK_SIZE] = {0,};
	uint8_t key[EM_LEN_KEY_CORE_V20] = {0,};
	uint8_t iv[EM_LEN_IV] = {0,};
	uint32_t plaintext_len = 0;
	uint32_t encrypt_len = 0;
	uint32_t unencrypt_len = 0;

	EM_CHECK_NULL(__func__, EM_ERR_EM_TEEGRIS_READ_CORE, buf);

	if (buf_len < EM_RPMB_SINGLE_BLOCK_SIZE) {
		LOGE("buffer length isn't normal(%u)\n", buf_len);
		ret = EM_ERR_EM_TEEGRIS_READ_CORE_LEN_BUF;
		goto out;
	}

	ret = em_teegris_rpmb_read(0, core, 1);
	if (ret != EM_SUCCESS) {
		LOGI("Failed to read core, try read backup(0x%08x)\n", ret);
		ret = em_teegris_rpmb_read(1, core, 1);
		if (ret != EM_SUCCESS) {
			LOGE("Failed to read backup core(0x%08x)\n", ret);
			goto out;
		}
	}

	if (em_is_all_zero((uint8_t *)core, EM_RPMB_SINGLE_BLOCK_SIZE) == EM_SUCCESS) {
		LOGE("%s: The block is all zero\n", __func__);
		ret = EM_ERR_EM_READ_CORE_ALL_ZERO;
		goto out;
	}

	ret = em_crypto_kdf(key, EM_LEN_KEY_CORE_V20, iv, EM_LEN_IV);
	if (ret != EM_SUCCESS) {
		LOGI("Failed to derive key and iv(0x%08x)\n", ret);
		goto out;
	}

	memcpy(buf, core, sizeof(em_core_v20));
	unencrypt_len = strlen(EM_MAGIC_EM_CORE) + EM_LEN_GCM_TAG_CORE_V20;
	encrypt_len = EM_RPMB_SINGLE_BLOCK_SIZE - unencrypt_len;

	ret = em_crypto_aes_256_gcm_decrypt(core + unencrypt_len, encrypt_len, buf + unencrypt_len, &plaintext_len, key,
					    EM_LEN_KEY_CORE_V20, iv, EM_LEN_IV, core + strlen(EM_MAGIC_EM_CORE),
					    EM_LEN_GCM_TAG_CORE_V20);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to decrypt core(0x%08x)\n", ret);
		goto out;
	}

	if (plaintext_len != encrypt_len) {
		LOGE("Core is decrypted, but the output length is unexpected(%u/%u)", plaintext_len,
		     encrypt_len);
		ret = EM_ERR_EM_TEEGRIS_READ_CORE_DECRYPT;
		goto out;
	}

	if (em_is_all_zero((uint8_t *)buf + unencrypt_len, EM_RPMB_SINGLE_BLOCK_SIZE - unencrypt_len) == EM_SUCCESS) {
		LOGE("%s: The block is all zero\n", __func__);
		ret = EM_ERR_EM_READ_CORE_ALL_ZERO;
		goto out;
	}
out:
	memset(core, 0, sizeof(em_core_v20));
	memset(key, 0, EM_LEN_KEY_CORE_V20);
	memset(iv, 0, EM_LEN_IV);

	return ret;
}

int em_write_core(uint8_t *buf, uint32_t buf_len)
{
	int ret;
	uint8_t core[EM_RPMB_SINGLE_BLOCK_SIZE] = {};
	uint8_t key[EM_LEN_KEY_CORE_V20] = {};
	uint8_t iv[EM_LEN_IV] = {};
	uint32_t ciphertext_len = 0;
	uint32_t encrypt_len = 0;
	uint32_t unencrypt_len = 0;

	EM_CHECK_NULL(__func__, EM_ERR_EM_TEEGRIS_WRITE_CORE, buf);

	if (buf_len < EM_RPMB_SINGLE_BLOCK_SIZE) {
		LOGE("buf_len is not enough(%u)\n", buf_len);
		ret = EM_ERR_EM_TEEGRIS_WRITE_CORE_LEN_BUF;
		goto out;
	}

	ret = em_crypto_kdf(key, EM_LEN_KEY_CORE_V20, iv, EM_LEN_IV);
	if (ret != EM_SUCCESS) {
		LOGI("Failed to derive key and iv(0x%08x)\n", ret);
		goto out;
	}

	memcpy(core, buf, EM_RPMB_SINGLE_BLOCK_SIZE);
	unencrypt_len = strlen(EM_MAGIC_EM_CORE) + EM_LEN_GCM_TAG_CORE_V20;
	encrypt_len = EM_RPMB_SINGLE_BLOCK_SIZE - unencrypt_len;

	ret = em_crypto_aes_256_gcm_encrypt(buf + unencrypt_len, encrypt_len, core + unencrypt_len, &ciphertext_len,
					    key, EM_LEN_KEY_CORE_V20, iv, EM_LEN_IV, core + strlen(EM_MAGIC_EM_CORE),
					    EM_LEN_GCM_TAG_CORE_V20);
	if (ret != EM_SUCCESS) {
		LOGE("Failed to encrypt core(0x%08x)\n", ret);
		goto out;
	}

	if (ciphertext_len != encrypt_len) {
		LOGE("Core is encrypted, but the output length is unexpected(%u/%u)", ciphertext_len,
		     encrypt_len);
		ret = EM_ERR_EM_TEEGRIS_WRITE_CORE_ENCRYPT;
		goto out;
	}

	ret = em_teegris_rpmb_write(0, core, 1);
	if (ret != EM_SUCCESS) {
		LOGI("Failed to write core(0x%08x)\n", ret);
		goto out;
	}

	ret = EM_SUCCESS;
out:
	memset(core, 0, EM_RPMB_SINGLE_BLOCK_SIZE);
	memset(key, 0, EM_LEN_KEY_CORE_V20);
	memset(iv, 0, EM_LEN_IV);

	return ret;
}

int em_check_provision(void)
{
	int ret;
	TEE_Result lret = 0;

	LOGW("Checking provision start...\n");
	
#ifdef USE_MTK_RPMB
	uint32_t crSession;
	uint8_t data[EM_RPMB_SINGLE_BLOCK_SIZE] = {0,};

	crSession = TEE_RpmbOpenSession(EM_RPMB_MTK_PARTITION_ID);
	if (crSession == 0xFFFFFFFF) {
		LOGE("No session permitted or wrong user ID(%d)\n", EM_RPMB_MTK_PARTITION_ID);
		ret = EM_ERR_EM_CHECK_PROVISION_UNKNOWN_ERROR;
		goto out;
	}

	lret = TEE_RpmbReadDatabyOffset(crSession, 0, data, EM_RPMB_SINGLE_BLOCK_SIZE, &ret);
	TEE_RpmbCloseSession(crSession);

	if (lret != TEE_SUCCESS) {
		LOGI("rpmb is not provisioned(0x%08x)\n", lret);
		ret = EM_ERR_EM_CHECK_PROVISION_NOT_PROVISION;
		goto out;
	}
#else
	lret = TEES_RPMBCheckEnable();
	if (lret != TEE_SUCCESS) {
		if (lret == TEE_ERROR_NOT_IMPLEMENTED || lret == TEE_ERROR_ITEM_NOT_FOUND) {
			LOGI("check provision api not implemented(0x%08x)\n", lret);
			ret = EM_ERR_EM_CHECK_PROVISION_NOT_PROVISION;
			goto out;
		} else if (lret == TEE_ERROR_NOT_SUPPORTED) {
			LOGI("rpmb is not provisioned(0x%08x)\n", lret);
			ret = EM_ERR_EM_CHECK_PROVISION_NOT_PROVISION;
			goto out;
		} else {
			LOGE("Failed to check provision(0x%08x)\n", lret);
			ret = EM_ERR_EM_CHECK_PROVISION_UNKNOWN_ERROR;
			goto out;
		}
	}
#endif /* USE_MTK_RPMB */

	ret = EM_SUCCESS;
out:
	LOGW("Checking provision done...(0x%08x)\n", ret);
	return ret;
}
