/**
 * Copyright (C) 2011 Samsung Electronics Co., Ltd. All rights reserved.
 *
 * Mobile Communication Division,
 * Digital Media & Communications Business, Samsung Electronics Co., Ltd.
 *
 * This software and its documentation are confidential and proprietary
 * information of Samsung Electronics Co., Ltd.  No part of the software and
 * documents may be copied, reproduced, transmitted, translated, or reduced to
 * any electronic medium or machine-readable form without the prior written
 * consent of Samsung Electronics.
 *
 * Samsung Electronics makes no representations with respect to the contents,
 * and assumes no responsibility for any errors that might appear in the
 * software and documents. This publication and the contents hereof are subject
 * to change without notice.
 *
 */

/**
 * @file tz_hdcp2_crypto.c
 * @author
 * @date
 * @brief This file contains all the encryption-decryption function definitions used in HDCP authentication protocol.
 */

#include "tz_hdcp2.h"
#include "tlapi_secdrv.h"

#ifdef USE_MTK
#include "tee_internal_api.h"
#include "tlAsyncExampleDriverApi.h"
#endif /* USE_MTK */

#ifndef WITHOUT_KEYMANDRV
#include "tlapi_secdrv_keyman.h"
#endif /* WITHOUT_KEYMANDRV */

#ifdef USE_TOOLCHAIN_GNU
#define sprintf TEE_LogPrintf
#define vsprintf TEE_LogPrintf
#endif /* USE_TOOLCHAIN_GNU */

/**
 * @fn int TZ_RSA_PKCS1_OAEP_SHA256(unsigned char *to, int tlen,const unsigned char *from, int flen, unsigned char *param, int plen);
 * @brief - This function is used for RSAES-OAEP encryption.
 * @param to - pointer to the encrypted output data.
 * @param tlen - length of output.
 * @param from - pointer to input data to be encrypted.
 * @param flen - length of input.
 * @param param - pointer to the parameters, if any.
 * @param plen - length of parameter.
 * @return int returns HDCP2_OK in case of success, else error code corresponding to the error
*/
int TZ_RSA_PKCS1_OAEP_SHA256(unsigned char *to, int tlen,
			const unsigned char *from, int flen, unsigned char *param, int plen);

/**
 * @fn int TZ_RSA_VERIFY_PKCS1_OAEP_SHA256(unsigned char *to, int tlen, const unsigned char *from, int flen, int num, const unsigned char *param, int plen)
 * @brief This function is used to unhash the decrypted data. It is called by TZ_RSA_OAEP_decrypt.
 * @param to - the pointer to the decrypted and unhashed data.
 * @param tlen - length of the data.
 * @param from - the pointer to the decrypted but hashed data
 * @param flen - length of from data.
 * @param num - length of HDCP2_SIZE_RECEIVER_PUBKEY.
 * @param param - pointer to parameter, if any.
 * @param plen - length of parameter.
 * @return int returns HDCP2_OK in case of success, else error code corresponding to the error
 */
int TZ_RSA_VERIFY_PKCS1_OAEP_SHA256(unsigned char *to, int tlen,
			const unsigned char *from, int flen, int num,
			const unsigned char *param, int plen);

void TZ_LOG_HEX(const char *title, const unsigned char *data, int length)
{
#ifdef DEBUG
	// Print title
	LOGD("%s[%d] = ", title, length);

	// print binary data
	for (int i = 0; i < length; i++) {
#ifndef USE_MTK
		if (i % 16 == 0)
			tlApiLogPrintf("\nHDCP:");
#endif /* !USE_MTK */

		tlApiLogPrintf("0x%02x ", data[i]);
	}

	tlApiLogPrintf("\n");
#endif /* DEBUG */
}

/**
 * @fn int TZ_AES_ECB_encrypt(uint8_t *key, uint32_t key_len, uint8_t *pt, uint32_t pt_len, uint8_t *ct, uint32_t *ct_len)
 * @brief This function is used for AES ecryption while pairing HDCP Transmitter and HDCP Receiver.
 * @param key - pointer to the key used in AES encryption
 * @param key_len - length of the key
 * @param pt - pointer to the data/information to be encrypted
 * @param pt_len - length of the data
 * @param ct - pointer to the encrypted output
 * @param ct_len - length of the encrypted output
 * @return int returns HDCP2_ERR_CRYPTO in case of failure else returns TLAPI_OK.
 */
int TZ_AES_ECB_encrypt(uint8_t *key, uint32_t key_len, uint8_t *pt,
			uint32_t pt_len, uint8_t *ct, uint32_t *ct_len)
{
	tlApiCrSession_t session_handle;
	tlApiResult_t ret;

	tlApiSymKey_t sym_key;
	tlApiKey_t aes_key;

	sym_key.key = key;
	sym_key.len = key_len;
	aes_key.symKey = &sym_key;

	if (key == NULL || key_len != 16) {
		LOGE("Invalid key length\n");
		return HDCP2_ERR_CRYPTO;
	}

	if (pt == NULL || pt_len == 0) {
		LOGE("Invalid input data length\n");
		return HDCP2_ERR_CRYPTO;
	}

	if (ct == NULL) {
		LOGE("Invalid output data length\n");
		return HDCP2_ERR_CRYPTO;
	}

	ret = tlApiCipherInit(&session_handle, TLAPI_ALG_AES_128_ECB_NOPAD,
			TLAPI_MODE_ENCRYPT, &aes_key);
	if (ret != TLAPI_OK) {
		return ret;
	}

	ret = tlApiCipherDoFinal(session_handle, pt, pt_len, ct, ct_len);
	if (ret != TLAPI_OK) {
		return ret;
	}

	memset(&aes_key, 0x00, sizeof(aes_key));

	return ret;
}

/**
 * @fn int TZ_AES_ECB_decrypt(uint8_t *key, uint32_t key_len, uint8_t *ct, uint32_t ct_len, uint8_t *dt, uint32_t *dt_len)
 * @brief This function is used for AES deryption.
 * @param key - pointer to the key used in AES decryption
 * @param key_len - length of the key
 * @param pt - pointer to the encrypted data
 * @param pt_len - length of the data
 * @param ct - pointer to the decrypted output
 * @param ct_len - length of the decrypted output
 * @return int returns TLAPI_OK in case of success else returns corresponding error code.
 */
int TZ_AES_ECB_decrypt(uint8_t *key, uint32_t key_len, uint8_t *ct,
			uint32_t ct_len, uint8_t *dt, uint32_t *dt_len)
{
	tlApiCrSession_t session_handle;
	tlApiResult_t ret;

	tlApiSymKey_t sym_key;
	tlApiKey_t aes_key;

	sym_key.key = key;
	sym_key.len = key_len;
	aes_key.symKey = &sym_key;

	if (key == NULL || key_len != 16) {
		LOGE("Invalid key length\n");
		return HDCP2_ERR_CRYPTO;
	}

	if (ct == NULL || ct_len == 0) {
		LOGE("Invalid input data length\n");
		return HDCP2_ERR_CRYPTO;
	}

	if (dt == NULL) {
		LOGE("Invalid output data length\n");
		return HDCP2_ERR_CRYPTO;
	}

	ret = tlApiCipherInit(&session_handle, TLAPI_ALG_AES_128_ECB_NOPAD,
				TLAPI_MODE_DECRYPT, &aes_key);
	if (ret != TLAPI_OK) {
		LOGE("tlApiCipherInit error %d \n", ret);
		return ret;
	}

	ret = tlApiCipherDoFinal(session_handle, ct, ct_len, dt, dt_len);
	if (ret != TLAPI_OK) {
		return ret;
	}

	memset(&aes_key, 0x00, sizeof(aes_key));

	return ret;
}

/**
 * @fn rsa_encrypt(HDCP2_KEY *hdcp2_key, uint8_t *pt, uint32_t pt_len, uint8_t *ct)
 * @brief This function is used for RSA ecryption.
 * @param hdcp_key - pointer to the key used in RSA encryption
 * @param pt - pointer to the data/information to be encrypted
 * @param pt_len - length of the data
 * @param ct - pointer to the encrypted output
 * @return int returns ct_len in case of success else returns corresponding error code.
 */
int rsa_encrypt(HDCP2_KEY *hdcp2_key, uint8_t *pt, uint32_t pt_len, uint8_t *ct)
{
	tlApiCrSession_t session_handle;
	tlApiResult_t ret;
	uint32_t ct_len = HDCP2_SIZE_RECEIVER_PUBKEY;

	tlApiRsaKey_t tl_rsa_key;
	tlApiKey_t tl_key;

	if (!pt || pt_len <= 0 || !ct) {
		LOGE("rsa_encrypt: invalid input...\n");
		return HDCP2_ERR_CRYPTO;
	}

	memset(&tl_rsa_key, 0x00, sizeof(tl_rsa_key));
	memset(&tl_key, 0x00, sizeof(tl_key));

	tl_rsa_key.modulus.value = hdcp2_key->cert.publickey_n;
	tl_rsa_key.modulus.len = sizeof(hdcp2_key->cert.publickey_n);
	tl_rsa_key.exponent.value = hdcp2_key->cert.publickey_e;
	tl_rsa_key.exponent.len = sizeof(hdcp2_key->cert.publickey_e);

	tl_key.rsaKey = &tl_rsa_key;

	ret = tlApiCipherInit(&session_handle, TLAPI_ALG_RSA_NOPAD,
				TLAPI_MODE_ENCRYPT, &tl_key);
	if (ret != TLAPI_OK) {
		LOGE("tlApiCipherInit error %d \n", ret);
		return ret;
	}

	ret = tlApiCipherDoFinal(session_handle, pt, pt_len, ct, &ct_len);
	if (ret != TLAPI_OK) {
		LOGE("tlApiCipherDoFinal error %d \n", ret);
		return ret;
	}

	memset(&tl_rsa_key, 0x00, sizeof(tl_rsa_key));
	memset(&tl_key, 0x00, sizeof(tl_key));

	return ct_len;
}

/**
 * @fn rsa_decrypt(HDCP2_KEY *hdcp2_key, uint8_t *ct, uint32_t ct_len, uint8_t *pt, uint32_t *pt_len)
 * @brief This function is used for RSA decryption.
 * @param hdcp_key - pointer to the key used in RSA encryption
 * @param ct - pointer to the data/information to be encrypted
 * @param ct_len - length of the data
 * @param pt - pointer to the decrypted output
 * @param pt_len - length of output
 * @return int returns ct_len in case of success else returns corresponding error code.
 */
int rsa_decrypt(HDCP2_KEY *hdcp2_key, uint8_t *ct, uint32_t ct_len,
			uint8_t *pt, uint32_t *pt_len)
{
	tlApiCrSession_t session_handle;
	tlApiResult_t ret;
	tlApiRsaKey_t tl_rsa_crt_key;
	tlApiKey_t tl_key;

	if (!pt || !ct || !pt_len) {
		LOGE("rsa_encrypt: invalid input...\n");
		return HDCP2_ERR_CRYPTO;
	}

	*pt_len = HDCP2_SIZE_RECEIVER_PUBKEY;

	memset(&tl_rsa_crt_key, 0x00, sizeof(tl_rsa_crt_key));
	memset(&tl_key, 0x00, sizeof(tl_key));

	tl_rsa_crt_key.modulus.value = hdcp2_key->cert.publickey_n;
	tl_rsa_crt_key.modulus.len = sizeof(hdcp2_key->cert.publickey_n);
	tl_rsa_crt_key.exponent.value = hdcp2_key->cert.publickey_e;
	tl_rsa_crt_key.exponent.len = sizeof(hdcp2_key->cert.publickey_e);

	tl_rsa_crt_key.privateCrtKey.P.len = sizeof(hdcp2_key->private_key.p);
	tl_rsa_crt_key.privateCrtKey.Q.len = sizeof(hdcp2_key->private_key.q);
	tl_rsa_crt_key.privateCrtKey.DP.len = sizeof(hdcp2_key->private_key.dmp1);
	tl_rsa_crt_key.privateCrtKey.DQ.len = sizeof(hdcp2_key->private_key.dmq1);
	tl_rsa_crt_key.privateCrtKey.Qinv.len = sizeof(hdcp2_key->private_key.iqmp);
	tl_rsa_crt_key.privateCrtKey.P.value = hdcp2_key->private_key.p;
	tl_rsa_crt_key.privateCrtKey.Q.value = hdcp2_key->private_key.q;
	tl_rsa_crt_key.privateCrtKey.DP.value = hdcp2_key->private_key.dmp1;
	tl_rsa_crt_key.privateCrtKey.DQ.value = hdcp2_key->private_key.dmq1;
	tl_rsa_crt_key.privateCrtKey.Qinv.value = hdcp2_key->private_key.iqmp;

	tl_key.rsaKey = &tl_rsa_crt_key;

	ret = tlApiCipherInit(&session_handle, TLAPI_ALG_RSACRT_NOPAD,
				TLAPI_MODE_DECRYPT, &tl_key);
	if (ret != TLAPI_OK) {
		LOGE("tlApiCipherInit error %d \n", ret);
		return HDCP2_ERR_TRUSTZONE_BASE - (int)ret;
	}

	ret = tlApiCipherDoFinal(session_handle, ct, ct_len, pt, pt_len);
	if (ret != TLAPI_OK) {
		LOGE("tlApiCipherDoFinal error %d \n", ret);
		return HDCP2_ERR_TRUSTZONE_BASE - (int)ret;
	}

	memset(&tl_rsa_crt_key, 0x00, sizeof(tl_rsa_crt_key));
	memset(&tl_key, 0x00, sizeof(tl_key));

	return HDCP2_OK;
}

/**
 * @fn int TZ_Get_ContentKey(uint8_t *ks, uint8_t *lc128, uint8_t* pKey)
 * @brief This function is used to compute the Key used in AES module.
 * @param ks - session key.
 * @param lc128 - constant global key.
 * @param pKey - resultant output key
 * @return int returns HDCP2_OK in case of success.
 */
int TZ_Get_ContentKey(uint8_t *ks, uint8_t *lc128, uint8_t* pKey)
{
	int ret = HDCP2_OK;
	int i = 0;

	if (pKey == NULL || ks == NULL || lc128 == NULL) {
		LOGE("TZ_Get_ContentKey: HDCP2_ERR_NULL_POINTER\n");
		ret = HDCP2_ERR_INVALID_INPUT;
		goto err;
	}

	for (i = 0; i < 16; i++)
		pKey[i] = ks[i] ^ lc128[i];

err:
	return ret;
}


/**
 * @fn int TZ_rand(uint8_t *data, const int length)
 * @brief This function generates the pseudo - random number of length given as parameter.
 * @param data - type of pseudo random to be generated.
 * @param length - length of pseudo random to be generated.
 * @return int - returns length of generated pseudo-random number if successful else returns HDCP2_ERR_CRYPTO.
 */
int TZ_rand(uint8_t *data, const int length)
{
	tlApiResult_t tlRet;

	tlRet = tlApiRandomGenerateData(TLAPI_ALG_PSEUDO_RANDOM, data,
					(size_t *) &length);

	if (tlRet != TLAPI_OK) {
#ifdef TZ_DEBUG
		LOGE("tlApiRandomGenerateData error %d \n", tlRet);
#endif /* TZ_DEBUG */
		return HDCP2_ERR_CRYPTO;
	}

	return length;
}

/**
 * @fn void TZ_IncreaseCtr(TZ_HDCP2_CTX *hdcp_ctx)
 * @brief This function is used to increase the counter in AES counter module.
 * @param hdcp_ctx - pointer to the HDCP context.
 * @return void
 */
void TZ_IncreaseCtr(TZ_HDCP2_CTX *hdcp_ctx)
{
	int i = 0;

	for (i = 7; i >= 0; i--) {
		hdcp_ctx->ctr[i] = hdcp_ctx->ctr[i] == MAX_INCREASE_COUNTER ? 0 : hdcp_ctx->ctr[i] + 1;
		if (hdcp_ctx->ctr[i] != 0)
			break;
	}
}

/**
 * @fn int TZ_Derivate_dkey(TZ_HDCP2_CTX *hdcp_ctx)
 * @brief This function is used for Key derivation.
 * @param hdcp_ctx - pointer to the HDCP context.
 * @return int returns HDCP2_OK in case of success else returns (HDCP2_ERR_TRUSTZONE_BASE - 61).
 */
int TZ_Derivate_dkey(TZ_HDCP2_CTX *hdcp_ctx)
{
	int i = 0;
	uint8_t key[16];
	uint8_t input[16] = {0, };
	size_t ct_len = 16;

	if (hdcp_ctx->version >= HDCP2_VERSION_2_2 && hdcp_ctx->transmitter_info.VERSION >= 0x02 && hdcp_ctx->receiver_info.VERSION >= 0x02) {
		//IV = rtx || (rrx XOR ctr)
		u8 temp[8] = {0};
		for(i = 7; i >= 0; i--)
			temp[i] = hdcp_ctx->rrx[i] ^ hdcp_ctx->ctr[i];

		TZ_LOG_HEX("hdcp_ctx->rrx", hdcp_ctx->rrx, sizeof(hdcp_ctx->rrx));
		TZ_LOG_HEX("hdcp_ctx->ctr", hdcp_ctx->ctr, sizeof(hdcp_ctx->ctr));
		TZ_LOG_HEX("temp", temp, sizeof(temp));
		memcpy(input, hdcp_ctx->rtx, sizeof(hdcp_ctx->rtx));
		memcpy(input + sizeof(hdcp_ctx->rtx), temp, sizeof(temp));
	} else {
		// input = rtx || ctr
		memcpy(input, hdcp_ctx->rtx, sizeof(hdcp_ctx->rtx));
		memcpy(input + 8, hdcp_ctx->ctr, sizeof(hdcp_ctx->ctr));
	}

	// key = km XOR rn
	memcpy(key, hdcp_ctx->pairing_info.km, sizeof(key));
	for (i = 8; i < 16; i++)
		key[i] ^= hdcp_ctx->rn[i - 8];

	TZ_AES_ECB_encrypt(key, 16, input, 16, hdcp_ctx->dkey, &ct_len);
	if (ct_len != 16)
		return HDCP2_ERR_TRUSTZONE_BASE - 61;

	TZ_IncreaseCtr(hdcp_ctx);

	return HDCP2_OK;
}

/**
 * @fn int TZ_SHA256(uint8_t *md, uint8_t *in, const int inlen)
 * @brief This is the hashing function used in cryptographic signature calculated on receiver certificate.
 * @param md - pointer to the key used in hash function.
 * @param in - pointer to the data to be hashed.
 * @param inlen - length of the data
 * @return int - returns 32 if success else returns HDCP2_ERR_CRYPTO.
 */
int TZ_SHA256(uint8_t *md, uint8_t *in, const int inlen)
{
	tlApiCrSession_t crSessionHandle = 0;
	tlApiResult_t tlRet;
	int ret = HDCP2_ERR_CRYPTO;
	size_t outlen = 32;
	unsigned char emptyDigest[] = {0xE3, 0xB0, 0xC4, 0x42, 0x98, 0xFC, 0x1C, 0x14, 0x9A, 0xFB, 0xF4, 0xC8,
									0x99, 0x6F, 0xB9, 0x24, 0x27, 0xAE, 0x41, 0xE4, 0x64, 0x9B, 0x93,
									0x4C, 0xA4, 0x95, 0x99, 0x1B, 0x78, 0x52, 0xB8, 0x55};

	if (md == NULL) {
		return ret;
	}

	if (in == NULL || inlen == 0) {
		memcpy(md, emptyDigest, sizeof(emptyDigest));
		return 32;
	}

	tlRet = tlApiMessageDigestInit(&crSessionHandle, TLAPI_ALG_SHA256);
	if (TLAPI_OK != tlRet) {
		LOGE("TZ_SHA256 : error, tlApiMessageDigestInit ret=0x%08X, exit\n", tlRet);
		return HDCP2_ERR_CRYPTO;
	}

	tlRet = tlApiMessageDigestDoFinal(crSessionHandle, in, inlen, md, &outlen);
	if (TLAPI_OK != tlRet) {
		LOGE("TZ_SHA256 : tlApiMessageDigestDoFinal ret=0x%08X, exit\n", tlRet);
		return HDCP2_ERR_CRYPTO;
	}

	return 32;
}

/**
 * @fn int TZ_HMAC_SHA256(uint8_t *md, uint8_t *key, const int keylen, uint8_t *in, const int inlen)
 * @brief This is hashing function used to calculate the hash value of the data given, *in, using hashing key *md.
 * @param md - output pointer
 * @param key - pointer to the key used in hash function.
 * @param keylen - length of the key
 * @param in - data to be hashed
 * @param inlen - length of the data
 * @return int returns HDCP2_ERR_CRYPTO
 */
int TZ_HMAC_SHA256(uint8_t *md, uint8_t *key, const int keylen, uint8_t *in,
			const int inlen)
{
	int ret = HDCP2_ERR_CRYPTO;
	size_t outlen = 32;

	tlApiCrSession_t crSessionHandle;
	tlApiResult_t tlRet;

	tlApiSymKey_t symKey;
	tlApiKey_t tlKey;

	symKey.key = key;
	symKey.len = keylen;
	tlKey.symKey = &symKey;

	tlRet = tlApiSignatureInit(&crSessionHandle, &tlKey, TLAPI_MODE_SIGN,
					TLAPI_ALG_HMAC_SHA_256);
	if (tlRet != TLAPI_OK) {
		LOGE("Error : tlApiSignatureInit ret = 0x%x \n", tlRet);
		return ret;
	}

	tlRet = tlApiSignatureSign(crSessionHandle, in, inlen, md, &outlen);
	if (tlRet != TLAPI_OK) {
		LOGE("Error : tlApiSignaureSign ret = 0x%x \n", tlRet);
		return ret;
	}

	memset(&tlKey, 0x00, sizeof(tlKey));
	return ret;
}

/**
 * @fn int TZ_RSA_OAEP_encrypt(TZ_HDCP2_CTX *hdcp_ctx, HDCP2_KEY *hdcp2_key, uint8_t *in, uint8_t *out)
 * @brief This function is used for the encryption of data using RSAES-OAEP encryption scheme.
 * @param hdcp_ctx - pointer to the HDCP context.
 * @param hdcp2_key - pointer to the key used in the data encryption.
 * @param in - pointer to the data to be encrypted.
 * @param out - pointer to the output encrypted data.
 * @return int - returns length of the encrypted output.
 */
int TZ_RSA_OAEP_encrypt(TZ_HDCP2_CTX *hdcp_ctx, HDCP2_KEY *hdcp2_key,
			uint8_t *in, uint8_t *out)
{
	int length = 0;
	uint8_t padded[HDCP2_SIZE_RECEIVER_PUBKEY] = {0, };

	memset(out, 0, HDCP2_SIZE_RECEIVER_PUBKEY);
	if ((length = TZ_RSA_PKCS1_OAEP_SHA256(padded, HDCP2_SIZE_RECEIVER_PUBKEY, in, 16, NULL, 0)) < 0)
		return length;

	length = rsa_encrypt(hdcp2_key, padded, HDCP2_SIZE_RECEIVER_PUBKEY, out);

	return length;
}


/**
 * @fn int TZ_RSA_OAEP_decrypt(TZ_HDCP2_CTX *hdcp_ctx, HDCP2_KEY *hdcp2_key, uint8_t *in, uint8_t *out)
 * @brief This function is used for the decryption of data using RSAES-OAEP scheme.
 * @param hdcp_ctx - pointer to the HDCP context.
 * @param hdcp2_key - pointer to the key used in the data decryption.
 * @param in - pointer to the data to be decrypted.
 * @param out - pointer to the decrypted data.
 * @return int returns HDCP2_OK in success else returns corresponding error code.
 */
int TZ_RSA_OAEP_decrypt(TZ_HDCP2_CTX *hdcp_ctx, HDCP2_KEY *hdcp2_key,
			uint8_t *in, uint8_t *out)
{
	int ret = HDCP2_OK;
	uint32_t length = 0;
	uint8_t padded[HDCP2_SIZE_RECEIVER_PUBKEY] = {0, };

	if (rsa_decrypt(hdcp2_key, in, HDCP2_SIZE_RECEIVER_PUBKEY, padded, &length) != HDCP2_OK)
		return HDCP2_ERR_TRUSTZONE_BASE - 51;

	ret = TZ_RSA_VERIFY_PKCS1_OAEP_SHA256(out, HDCP2_SIZE_RECEIVER_PUBKEY,
					padded, HDCP2_SIZE_RECEIVER_PUBKEY, HDCP2_SIZE_RECEIVER_PUBKEY,
					NULL, 0);

	return ret;
}

/**
 * @fn int TZ_PKCS1_MGF1(unsigned char *mask, long len, const unsigned char *seed, long seedlen)
 * @brief This function is used for mask generation which uses TZ_SHA256 internally.
 * @param mask - pointer to the bits used for masking the data.
 * @param len - length of the mask.
 * @param seed - pointer to the extra bits used for seeding the input data.
 * @param seedlen - length of the seed number.
 * @return int returns length of the output
 */
int TZ_PKCS1_MGF1(unsigned char *mask, long len, const unsigned char *seed,
			long seedlen)
{
	long i, outlen = 0;
	unsigned char md[32];
	unsigned char buffer[128] = {0, };
	int mdlen = 32;

	for (i = 0; outlen < len; i++) {
		memcpy(buffer, seed, seedlen);
		buffer[seedlen] = (unsigned char) ((i >> 24) & 255);
		buffer[seedlen + 1] = (unsigned char) ((i >> 16) & 255);
		buffer[seedlen + 2] = (unsigned char) ((i >> 8)) & 255;
		buffer[seedlen + 3] = (unsigned char) (i & 255);

		if (outlen + mdlen <= len) {
			if (TZ_SHA256(mask + outlen, buffer, seedlen + 4) < 0) {
				return HDCP2_ERR_CRYPTO;
			}
			outlen += mdlen;
		} else {
			if (TZ_SHA256(md, buffer, seedlen + 4) < 0) {
				return HDCP2_ERR_CRYPTO;
			}
			memcpy(mask + outlen, md, len - outlen);
			outlen = len;
		}
	}

	return len;
}

/**
 * @fn int TZ_RSA_PKCS1_OAEP_SHA256(unsigned char *to, int tlen,const unsigned char *from, int flen, unsigned char *param, int plen);
 * @brief - This function is used for RSAES-OAEP encryption.
 * @param to - pointer to the encrypted output data.
 * @param tlen - length of output.
 * @param from - pointer to input data to be encrypted.
 * @param flen - length of input.
 * @param param - pointer to the parameters, if any.
 * @param plen - length of parameter.
 * @return int returns HDCP2_OK in case of success, else error code corresponding to the error
*/
int TZ_RSA_PKCS1_OAEP_SHA256(unsigned char *to, int tlen,
			const unsigned char *from, int flen, unsigned char *param, int plen)
{
	int i, digestlen = 32, emlen = tlen - 1;
	unsigned char *db, *seed;
	unsigned char dbmask[256] = {0, }, seedmask[32] = {0, };

	if (flen > emlen - 2 * digestlen - 1) {
		return HDCP2_ERR_CRYPTO;
	}

	if (emlen < 2 * digestlen + 1) {
		return HDCP2_ERR_CRYPTO;
	}

	to[0] = 0;
	seed = to + 1;
	db = to + digestlen + 1;

	if (TZ_SHA256(db, param, plen) < 0) {
		return HDCP2_ERR_CRYPTO;
	}

	memset(db + digestlen, 0, emlen - flen - 2 * digestlen - 1);
	db[emlen - flen - digestlen - 1] = 0x01;
	memcpy(db + emlen - flen - digestlen, from, (unsigned int) flen);

	if (TZ_rand(seed, digestlen) <= 0) {
		return HDCP2_ERR_CRYPTO;
	}

	if (TZ_PKCS1_MGF1(dbmask, emlen - digestlen, seed, digestlen) < 0) {
		return HDCP2_ERR_CRYPTO;
	}

	for (i = 0; i < emlen - digestlen; i++)
		db[i] ^= dbmask[i];

	if (TZ_PKCS1_MGF1(seedmask, digestlen, db, emlen - digestlen) < 0) {
		return HDCP2_ERR_CRYPTO;
	}

	for (i = 0; i < digestlen; i++)
		seed[i] ^= seedmask[i];

	return tlen;
}

/**
 * @fn int TZ_RSA_VERIFY_PKCS1_OAEP_SHA256(unsigned char *to, int tlen, const unsigned char *from, int flen, int num, const unsigned char *param, int plen)
 * @brief This function is used to unhash the decrypted data. It is called by TZ_RSA_OAEP_decrypt.
 * @param to - the pointer to the decrypted and unhashed data.
 * @param tlen - length of the data.
 * @param from - the pointer to the decrypted but hashed data
 * @param flen - length of from data.
 * @param num - length of HDCP2_SIZE_RECEIVER_PUBKEY.
 * @param param - pointer to parameter, if any.
 * @param plen - length of parameter.
 * @return int returns HDCP2_OK in case of success, else error code corresponding to the error
 */
int TZ_RSA_VERIFY_PKCS1_OAEP_SHA256(unsigned char *to, int tlen,
			const unsigned char *from, int flen, int num,
			const unsigned char *param, int plen)
{
	int i, digestlen = 32, dblen, mlen = -1;
	const unsigned char *maskeddb;
	int lzero;
	unsigned char db[256] = {0, }, seed[32], phash[32];
	unsigned char *padded_from;
	int bad = 0;

	if (--num < 2 * digestlen + 1)
		return HDCP2_ERR_TRUSTZONE_BASE - 1;

	lzero = num - flen;
	if (lzero < 0) {
		bad = 1;
		lzero = 0;
		flen = num;
	}

	dblen = num - digestlen;
	if (dblen > MAX_HASH_BUFFER - 1) {
		return HDCP2_ERR_TRUSTZONE_BASE-1;
	}
	padded_from = db + dblen;
	if (lzero > MAX_HASH_BUFFER - dblen || flen > MAX_HASH_BUFFER - dblen) {
		return HDCP2_ERR_TRUSTZONE_BASE-1;
	}
	memset(padded_from, 0, lzero);

	bad = 0;
	memcpy(padded_from, from + 1, flen);

	maskeddb = padded_from + digestlen;
	if (TZ_PKCS1_MGF1(seed, digestlen, maskeddb, dblen) < 0)
		return HDCP2_ERR_TRUSTZONE_BASE - 2;

	for (i = 0; i < digestlen; i++)
		seed[i] ^= padded_from[i];

	if (TZ_PKCS1_MGF1(db, dblen, seed, digestlen) < 0)
		return HDCP2_ERR_TRUSTZONE_BASE - 3;

	for (i = 0; i < dblen; i++)
		db[i] ^= maskeddb[i];

	if (TZ_SHA256(phash, (u8 *) param, plen) < 0)
		return HDCP2_ERR_TRUSTZONE_BASE - 4;

	if (memcmp(db, phash, digestlen) != 0 || bad) {
		return HDCP2_ERR_TRUSTZONE_BASE - 5;
	} else {
		for (i = digestlen; i < dblen; i++)
			if (db[i] != 0x00)
				break;

		if (i == dblen || db[i] != 0x01) {
			return HDCP2_ERR_TRUSTZONE_BASE - 6;
		} else {
			//everything looks OK
			mlen = dblen - ++i;
			if (tlen < mlen || mlen > HDCP2_SIZE_RECEIVER_PUBKEY) {
				return HDCP2_ERR_TRUSTZONE_BASE - 7;
			} else {
				memcpy(to, db + i, mlen);
			}
		}
	}

	return HDCP2_OK;
}

/**
 * @fn int TZ_SW_AES_CTR(u8 *in, u32 inlen, u8 *out, u32 *outlen, uint8_t *key, uint8_t *pP)
 * @brief This function is used to encrypt data using AES ctr scheme.
 * @param in - pointer to the input frame.
 * @param inlen - length of input data.
 * @param out - pointer to the output
 * @param outlen - length of encrypted output.
 * @param key - key used in the encryption.
 * @param pP - pointer to the (riv XORed with streamCtr) concatenated with inputCtr.
 * @return int
 */
int TZ_SW_AES_CTR(u8 *in, u32 inlen, u8 *out, u32 *outlen, uint8_t *key, uint8_t *pP)
{
	int carry = 0;
	int ret, i, blkidx;
	u8 ct[16] = {0};
	int blknum = (inlen+15)/16;
	u32 blklen = 16;

	*outlen = 0;

	for (blkidx=0; blkidx<blknum; blkidx++) {
		if ((ret=TZ_AES_ECB_encrypt(key, 16, pP, 16, ct, &blklen)) != HDCP2_OK)
			return HDCP2_ERR_TRUSTZONE_BASE-ret;

		if (blkidx == blknum - 1 && inlen % 16 != 0)
			blklen = inlen % 16;
		for (i=0; i<blklen; i++)
			out[blkidx*16 + i] = ct[i] ^ in[blkidx*16 + i];

		*outlen += blklen;

		// increase counter
		carry = 0;
		for (i = 0; i < 8; i++) {
			if (pP[15 - i] + carry < 255) {
				pP[15 - i] += 1;
				break;
			}
			pP[15 - i] = 0;
			carry = 1;
		}
	}

	return HDCP2_OK;
}

#ifdef USE_MTK
//int data_index = 0;

int TZ_Cipher_AES_CTR_Encrypt(int64_t in, u32 inlen, uint8_t *out, u32 *outlen,
			uint8_t *key, uint8_t *pP)
{
	tlApiResult_t ret = 0;

	LOGW("tlApi_Encrypt_AES_CTR(s)\n");
	*outlen = inlen;

	ret = tlApi_Encrypt_AES_CTR(in, out, pP, key, inlen, outlen);
	if (ret != 0) {
		LOGE("tlApi_Encrypt_AES_CTR - error (%d)\n", ret);
		goto err;
	}

//	TZ_HDCP_DEBUG("[%d] output buffer dump(s)\n", data_index);
//	for (int i = 0 ;i < 32 ;i=i+8) {
//		TZ_HDCP_DEBUG("[%d][%02zx/%02zx/%02zx/%02zx/%02zx/%02zx/%02zx/%02zx]\n", data_index,
//					request[i], request[i+1], request[i+2], request[i+3],
//					request[i+4], request[i+5], request[i+6], request[i+7]);
//	}
//
//	TZ_HDCP_DEBUG("[%d] output buffer dump(e) (%d)\n", data_index, outlen);
//
//	data_index++;

	LOGW("tlApi_Encrypt_AES_CTR(e)\n");

	return HDCP2_OK;

err:
	return HDCP2_ERR_CRYPTO;
}

int TZ_Cipher_AES_CTR_Decrypt(u8* in, u32 inlen, int64_t out, u32 *outlen, uint8_t *key,
			uint8_t *pP)
{
	tlApiResult_t ret = 0;

	LOGW("tlApi_Decrypt_AES_CTR(s)\n");
	*outlen = inlen;

//	TZ_HDCP_DEBUG("[%d] output buffer dump(s)\n", data_index);
//	for (int i = 0 ;i < 32 ;i=i+8) {
//		TZ_HDCP_DEBUG("[%d][%02zx/%02zx/%02zx/%02zx/%02zx/%02zx/%02zx/%02zx]\n", data_index,
//					in[i], in[i+1], in[i+2], in[i+3], in[i+4], in[i+5], in[i+6], in[i+7]);
//	}
//
//	TZ_HDCP_DEBUG("[%d] output buffer dump(e) (%d)\n", data_index, inlen);
//
//	data_index++;

	ret = tlApi_Decrypt_AES_CTR(in, out, pP, key, inlen, outlen);
	if (ret != 0) {
		LOGE("tlApi_Decrypt_AES_CTR - error (%d)\n", ret);
		goto err;
	}

	LOGW("tlApi_Decrypt_AES_CTR(e)\n");

	return HDCP2_OK;

err:

	return HDCP2_ERR_CRYPTO;
}
#else
/**
 * @fn int TZ_Cipher_AES_CTR_Encrypt(u8 *in, u32 inlen, u8 *out, u32 *outlen, uint8_t *key, uint8_t *pP)
 * @brief This function encrypts the data using AES encryption in CTR mode
 * @param in - pointer to input data to be encrypted
 * @param inlen - length of input data to be encrypted
 * @param out- pointer to output encrypted data stream is to be stored
 * @param outlen - length of output data
 * @param key - pointer to the key
 * @param pP - pointer to the (riv XORed with streamCtr) concatenated with inputCtr.
 * @return int - HDCP2_OK in case of success, else error code corresponding to the error
 */
int TZ_Cipher_AES_CTR_Encrypt(addr_s in, u32 inlen, addr_s out, u32 *outlen, uint8_t *key,
			uint8_t *pP)
{
	tlApiResult_t ret = TLAPI_OK;

#ifdef USE_64BIT_ADDR
	struct secSss64_t sss;
#else
	struct secSss_t sss;
#endif /* USE_64BIT_ADDR */

	*outlen = inlen;

#ifdef WITHOUT_KEYMANDRV
	sss.key = (uint32_t *)key;
	sss.key_len = 16;
	sss.iv = (uint32_t *)pP;
	sss.iv_len = 16;
	sss.input = in;
	sss.input_len = inlen;
	sss.output = out;
	sss.output_len = outlen;
	sss.mode = CTR_MODE;
	sss.cipher_mode = ENCRYPT;
	sss.content_mode = VIDEO;

	ret = tlApiSecSssRun(&sss);
#else
	sss.key = (uint32_t *)key;
	sss.key_len = 16;
	sss.iv = (uint32_t *)pP;
	sss.iv_len = 16;
	sss.block_offset = 0;
	sss.input.addr = in;
	sss.input.len = inlen;
	sss.input.space = PHYSICAL;
	sss.input.cache = NON_CACHEABLE;
	sss.output.addr = out;
	sss.output.len = *outlen;
	sss.output.space = PHYSICAL;
	sss.output.cache = NON_CACHEABLE;
	sss.mode = CTR_MODE;
	sss.cipher_mode = ENCRYPT;
	sss.pad = NO_PADDING;

#ifdef USE_64BIT_ADDR
	sss.ldfw_mode = CM_MODE;
	ret = tlApiRunAES64(&sss);
#else
	ret = tlApiRunAES(&sss);
#endif /* USE_64BIT_ADDR */
#endif /* WITHOUT_KEYMANDRV */

	memset(&sss, 0x00, sizeof(sss));
	return ret;
}

/**
 * @fn int TZ_Cipher_AES_CTR_Decrypt(u8 *in, u32 inlen, u8 *out, u32 *outlen, uint8_t *key, uint8_t *pP)
 * @brief This function decrypts the data using AES decryption in CTR mode
 * @param in - pointer to input data to be decrypted
 * @param inlen - length of input data to be decrypted
 * @param out- pointer to output decrypted data stream is to be stored
 * @param outlen - length of output data
 * @param key - pointer to the key
 * @param pP - pointer to the (riv XORed with streamCtr) concatenated with inputCtr.
 * @return int - HDCP2_OK in case of success, else error code corresponding to the error
 */
int TZ_Cipher_AES_CTR_Decrypt(addr_s in, u32 inlen, addr_s out, u32 *outlen, uint8_t *key,
			uint8_t *pP)
{
	tlApiResult_t ret = TLAPI_OK;

#ifdef USE_64BIT_ADDR
	struct secSss64_t sss;
#else
	struct secSss_t sss;
#endif /* USE_64BIT_ADDR */

	*outlen = inlen;

#ifdef WITHOUT_KEYMANDRV
	sss.key = (uint32_t *)key;
	sss.key_len = 16;
	sss.iv = (uint32_t *)pP;
	sss.iv_len = 16;
	sss.input = in;
	sss.input_len = inlen;
	sss.output = out;
	sss.output_len = outlen;
	sss.mode = CTR_MODE;
	sss.cipher_mode = DECRYPT;
	sss.content_mode = VIDEO;

	ret = tlApiSecSssRun(&sss);
#else
	sss.key = (uint32_t *)key;
	sss.key_len = 16;
	sss.iv = (uint32_t *)pP;
	sss.iv_len = 16;
	sss.block_offset = 0;
	sss.input.addr = in;
	sss.input.len = inlen;
	sss.input.space = PHYSICAL;
	sss.input.cache = NON_CACHEABLE;
	sss.output.addr = out;
	sss.output.len = *outlen;
	sss.output.space = PHYSICAL;
	sss.output.cache = NON_CACHEABLE;
	sss.mode = CTR_MODE;
	sss.cipher_mode = ENCRYPT;
	sss.pad = NO_PADDING;

#ifdef USE_64BIT_ADDR
	sss.ldfw_mode = CM_MODE;
	ret = tlApiRunAES64(&sss);
#else
	ret = tlApiRunAES(&sss);
#endif /* USE_64BIT_ADDR */
#endif /* WITHOUT_KEYMANDRV */

	memset(&sss, 0x00, sizeof(sss));
	return ret;
}
#endif /* USE_MTK */
