/**
* \file CommandHandler.h
* \author Roman Pasechnik (r.pasechnik@samsung.com)
* \version 0.1
* \date Created Feb 11, 2014
* \par In Samsung Ukraine R&D Center (SURC) under a contract between
* \par LLC "Samsung Electronics Ukraine Company" (Kiev, Ukraine) and
* \par "Samsung Elecrtronics Co", Ltd (Seoul, Republic of Korea)
* \par Copyright: (c) Samsung Electronics Co, Ltd 2014. All rights reserved.
**/

#include "CommandHandler.h"
#include "CommLayerData.h"
#include "TLV.h"
#include "CryptoPlatform.h"
#include "log.h"
#ifdef USE_MOBICORE
#include "mobicore_utils.h"
#endif
#ifdef USE_QSEE
#include "SfsFileOperations.h"
#ifdef USE_ENCAPSULATED_TID
#include "qsee_message.h"
#endif
#endif
#ifdef USE_BLOWFISH
#include "blowfish_utils.h"
#endif

#ifdef USE_QSEE
const char pcr_file[] = "/efs/prov_data/pcr/pcr.dat";
char SKM_NAME[] = {"skm"};
char TZ_APP_NAME[] = {"mldap"};
#endif

int32_t readPcr(uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int ret;

	if (*outDataLen < SHA1_SIZE) {
		LOGE("Not enough space to initialize PCR\n");
		return WRONG_DATA;
	}

#if !defined(USE_QSEE_WRAP_WITH_SFS)
#ifdef USE_BLOWFISH
	/*
	When inDataLen is 0, blowfish sets sahredBuffer to NULL, and results in crash. So hardcoding the length to SHA1_SIZE(0)
	and handling it in SWD side to sopy dummy zeroes of length SHA1_SIZE to response buffer.
	*/
	if (inDataLen == SHA1_SIZE) {
		LOGD("blowfish PCR\n");
		*outDataLen = SHA1_SIZE;
		memset(outData, 0, SHA1_SIZE);
		ret = NO_ERROR;

		return ret;
	}
#endif /* USE_BLOWFISH */

	if (inDataLen == 0) {
		// There is no PCR data yet,
		// initialize data with 0
		if (*outDataLen < SHA1_SIZE) {
			LOGE("Not enough space to initialize PCR\n");
			return WRONG_DATA;
		}
		*outDataLen = SHA1_SIZE;
		memset(outData, 0, SHA1_SIZE);
		ret = NO_ERROR;
	} else {
#if USE_BLOWFISH
		ret = unwrap(inData, inDataLen, outData, outDataLen);
#elif defined(USE_QSEE)   /* Use QSEE without SFS */
		ret = unwrapWithoutSFS(inData, inDataLen, outData, outDataLen);
#else
		ret = loadPcr(inData, inDataLen, outData, outDataLen);
#endif
		if (*outDataLen > SHA1_SIZE) {
			LOGE("Size outDataLen of PCR is invalid\n");
			memset(outData, 0, *outDataLen);
			*outDataLen = 0;
			return WRONG_DATA;
		}
		if (ret != NO_ERROR) {
			*outDataLen = 0;
			memset(outData, 0, SHA1_SIZE);
		} else {
			*outDataLen = SHA1_SIZE;
		}
	}
#else /* USE_QSEE use SFS*/
	(void)inData;
	(void)inDataLen;

	ret = readKeyFromSFS(pcr_file, outData, SHA1_SIZE, 0);
	if (ret != NO_ERROR) {
		// There is no file yet, create it
		memset(outData, 0, SHA1_SIZE);
		ret = saveKeyToSFS(pcr_file, outData, SHA1_SIZE, NULL, 0);
		if (ret != NO_ERROR) {
			*outDataLen = 0;
		} else {
			*outDataLen = SHA1_SIZE;
		}
	} else {
		*outDataLen = SHA1_SIZE;
	}
#endif

	return ret;
}

int32_t extendPcr(uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int32_t ret;
	uint8_t tmp[256];
	uint32_t tmpLen = sizeof(tmp);

#if !defined(USE_QSEE_WRAP_WITH_SFS)
	uint32_t wrappedLen;
	uint8_t* wrappedData;
	uint32_t extendDataLen;
	uint8_t* extendData;
	uint32_t TIDLen;
	uint8_t* TID;

	ret = tlvGet(inData, inDataLen, TLV_WRAPPED_PCR, &wrappedLen, &wrappedData);

	if (ret != NO_ERROR) {
		LOGE("Error getting TLV_WRAPPED_PCR\n");
		return ret;
	}

	ret = tlvGet(inData, inDataLen, TLV_EXTEND_PCR_DATA, &extendDataLen, &extendData);

	if (ret != NO_ERROR) {
		LOGE("Error getting TLV_EXTEND_PCR_DATA\n");
		return ret;
	}

	ret = tlvGet(inData, inDataLen, TLV_TID, &TIDLen, &TID);

	if (ret != NO_ERROR) {
		LOGE("Error getting TLV_TID\n");
		return ret;
	}

	ret = readPcr(wrappedData, wrappedLen, tmp, &tmpLen);

	if (ret != NO_ERROR) {
		LOGE("readPcr(wrappedData, ...)\n");
		return ret;
	}

	ret = getSHA1Digest(extendData, extendDataLen, tmp + tmpLen);
	if (ret != NO_ERROR) {
		LOGE("getSHA1Digest 1 Failed ...\n");
		return ret;
	}

	ret = getSHA1Digest(tmp, SHA1_SIZE * 2, tmp);
	if (ret != NO_ERROR) {
		LOGE("getSHA1Digest 2 Failed ...\n");
		return ret;
	}

#ifdef USE_QSEE   /* Use QSEE without SFS */
	ret = wrapWithoutSFS(tmp, SHA1_SIZE, outData, outDataLen);
#else
	ret = wrap(tmp, SHA1_SIZE, outData, outDataLen, TID, TIDLen);
#endif
	if (ret != NO_ERROR) {
		LOGE("wrap(tmp, ...\n");
		return ret;
	}
#else /* USE_QSEE use SFS*/
	ret = readPcr(NULL, 0, tmp, &tmpLen);

	if (ret != NO_ERROR) {
		return ret;
	}

	ret = getSHA1Digest(inData, inDataLen, tmp + tmpLen);

	if (ret != NO_ERROR) {
		return ret;
	}

	ret = getSHA1Digest(tmp, SHA1_SIZE * 2, tmp);

	if (ret != NO_ERROR) {
		return ret;
	}

	ret = saveKeyToSFS(pcr_file, tmp, SHA1_SIZE, NULL, 0);
	*outDataLen = 0;
#endif

	return ret;
}

int32_t handleGetMLDAPCert(certType_t type, uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int ret;
	uint8_t* wrappedKey = NULL;
	uint32_t wrappedKeyLen = 0;

#if !defined(USE_QSEE_WRAP_WITH_SFS)
	ret = tlvGet(inData, inDataLen, TLV_WRAPPED_KEY, &wrappedKeyLen, &wrappedKey);

	if (ret != NO_ERROR || wrappedKey == NULL || wrappedKeyLen == 0) {
		LOGE("Error getting TLV_WRAPPED_KEY\n");
		return ret;
	}
#endif

	ret = getMLDAPCert(type, outData, outDataLen, wrappedKey, wrappedKeyLen);
	return ret;
}

int32_t handleGetSKMCert(certType_t certType, uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int ret;
	uint8_t* wrappedKey = NULL;
	uint32_t wrappedKeyLen = 0;

#if !defined(USE_QSEE_WRAP_WITH_SFS)
	ret = tlvGet(inData, inDataLen, TLV_WRAPPED_KEY, &wrappedKeyLen, &wrappedKey);

	if (ret != NO_ERROR || wrappedKey == NULL || wrappedKeyLen == 0) {
		LOGE("Error getting TLV_WRAPPED_KEY\n");
		return ret;
	}
#endif

	ret = getSKMCert(certType, outData, outDataLen, wrappedKey, wrappedKeyLen);
	return ret;
}

int32_t handleVerifyMLCert(uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int ret;
	uint8_t* wrappedKey = NULL;
	uint32_t wrappedKeyLen = 0;
	uint8_t* timeData = NULL;
	uint32_t timeDataLen = 0;

#if !defined(USE_QSEE_WRAP_WITH_SFS)
	ret = tlvGet(inData, inDataLen, TLV_WRAPPED_KEY, &wrappedKeyLen, &wrappedKey);

	if (ret != NO_ERROR || wrappedKey == NULL || wrappedKeyLen == 0) {
		LOGE("Error getting TLV_WRAPPED_KEY\n");
		return ret;
	}
#endif
	ret = tlvGet(inData, inDataLen, TLV_TIMESTAMP, &timeDataLen, &timeData);

	if (ret != NO_ERROR || timeData == NULL || timeDataLen == 0) {
		LOGE("Error getting TLV_TIMESTAMP\n");
		return ret;
	}

	ret = verifySDKey(wrappedKey, wrappedKeyLen, timeData, timeDataLen);
	return ret;
}

#ifdef USE_QSEE
int32_t handleStoreServiceKey(uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int ret = NO_ERROR;
	uint8_t* wrappedKey = NULL;
	uint32_t wrappedKeyLen = 0;

	ret = tlvGet(inData, inDataLen, TLV_WRAPPED_KEY, &wrappedKeyLen, &wrappedKey);

	if (ret != NO_ERROR || wrappedKey == NULL || wrappedKeyLen == 0) {
		LOGE("Error getting TLV_WRAPPED_KEY\n");
		return ret;
	}

	ret = storeServiceKey(wrappedKey, wrappedKeyLen, outData, outDataLen);
	if (ret != NO_ERROR) {
		LOGE("Cannot store Service key\n");
		return ret;
	}
	return ret;
}
#endif

int32_t handleVerifyServiceKey(uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int ret = NO_ERROR;
	uint8_t* timeData = NULL;
	uint32_t timeDataLen = 0;
	uint8_t* wrappedKey = NULL;
	uint32_t wrappedKeyLen = 0;

#if !defined(USE_QSEE_WRAP_WITH_SFS)
	ret = tlvGet(inData, inDataLen, TLV_WRAPPED_KEY, &wrappedKeyLen, &wrappedKey);

	if (ret != NO_ERROR || wrappedKey == NULL || wrappedKeyLen == 0) {
		LOGE("Error getting TLV_WRAPPED_KEY\n");
		return ret;
	}
#endif

	ret = tlvGet(inData, inDataLen, TLV_TIMESTAMP, &timeDataLen, &timeData);

	if (ret != NO_ERROR || timeData == NULL || timeDataLen == 0) {
		LOGE("Error getting TLV_TIMESTAMP\n");
		return ret;
	}
	ret = verifyServiceKey(wrappedKey, wrappedKeyLen, timeData, timeDataLen);
	return ret;
}

int32_t handleComplexCommand(genSKCallback_t func, uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int ret;

	uint8_t* keyInfo = NULL;
	uint32_t keyInfoLen = 0;

	uint8_t* wrappedKey = NULL;
	uint32_t wrappedKeyLen = 0;

	uint8_t* TID = NULL;
	uint32_t TIDLen = 0;

	uint8_t* attrs = NULL;
	uint32_t attrsLen = 0;

	ret = tlvGet(inData, inDataLen, TLV_KEY_INFO, &keyInfoLen, &keyInfo);

	if (ret != NO_ERROR) {
		LOGE("Error getting TLV_KEY_INFO\n");
		return ret;
	}

	if (keyInfoLen != sizeof(KeyInfo_t)) {
		LOGE("keyInfoLen != sizeof(KeyInfo_t)\n");
		return ret;
	}

	ret = tlvGet(inData, inDataLen, TLV_WRAPPED_KEY, &wrappedKeyLen, &wrappedKey);

	if (ret == WRONG_DATA) {
		LOGD("Error getting TLV_WRAPPED_KEY, wrong data!\n");
		return ret;
	}

	ret = tlvGet(inData, inDataLen, TLV_TID, &TIDLen, &TID);

	if (ret != NO_ERROR) {
		LOGE("Error getting TLV_TID\n");
		return ret;
	}

	ret = tlvGet(inData, inDataLen, TLV_ATTRS, &attrsLen, &attrs);

	if (ret == WRONG_DATA) {
		LOGD("Error getting TLV_ATTRS, wrong data!\n");
		return ret;
	}

	return func((KeyInfo_t*)keyInfo, outData, outDataLen, wrappedKey, wrappedKeyLen, TID, TIDLen, attrs, attrsLen);
}

int32_t handleCmd(int32_t cmdId, uint8_t* inData, uint32_t inDataLen, uint8_t* outData, uint32_t* outDataLen)
{
	int ret;
	ret = getOemFlag();
	if (ret != NO_ERROR) {
		LOGE("Integrity check error : %d\n", ret);
		return ret;
	}

	LOGD("======================================\n");
	LOGD("handleCmd id 0x%x, inLen %d, outLen %d\n", cmdId, inDataLen, *outDataLen);
	switch (cmdId)
	{
		case DAP_PCR_READ:
			ret = readPcr(inData, inDataLen, outData, outDataLen);
			break;
		case DAP_PCR_EXTEND:
			ret = extendPcr(inData, inDataLen, outData, outDataLen);
			break;
		case DAP_GET_SD_CERT_CMD:
			ret = handleGetMLDAPCert(CERT_SD, inData, inDataLen, outData, outDataLen);
			break;
		case DAP_GET_SM0_CERT_CMD:
			ret = handleGetMLDAPCert(CERT_SM0, inData, inDataLen, outData, outDataLen);
			break;
		case DAP_GET_SM1_CERT_CMD:
			ret = handleGetMLDAPCert(CERT_SM1, inData, inDataLen, outData, outDataLen);
			break;
		case DAP_GET_OEM_FLAG:
			*outDataLen = 0;
			ret = getOemFlag();
			break;
		case OTA_GET_ML_CERT_CMD:
			ret = handleGetSKMCert(CERT_ML, inData, inDataLen, outData, outDataLen);
			break;
		case OTA_GET_DRK_CERT_CMD:
			ret = handleGetSKMCert(CERT_DRK, inData, inDataLen, outData, outDataLen);
			break;
		case OTA_VERIFY_SD_CERT_CMD:
			ret = handleVerifyMLCert(inData, inDataLen, outData, outDataLen);
			*outDataLen = 0;
			break;
		case VERIFY_SERVICE_KEY_CMD:
			ret = handleVerifyServiceKey(inData, inDataLen, outData, outDataLen);
			*outDataLen = 0;
			break;
#ifdef USE_QSEE
		case STORE_SERVICE_KEY_CMD:
			ret = handleStoreServiceKey(inData, inDataLen, outData, outDataLen);
			break;
#endif
		case OTA_GET_SD_PUB_KEY_CMD:
			ret = handleComplexCommand(getOTA_SD_PK, inData, inDataLen, outData, outDataLen);
			break;
		case OTA_STORE_ML_AND_SD_CERT_CMD:
			ret = handleComplexCommand(storeOTACerts, inData, inDataLen, outData, outDataLen);
			break;
		case SIGN_DATA_CMD:
			ret = handleComplexCommand(signDataWithKey, inData, inDataLen, outData, outDataLen);
			break;
#ifdef USE_QSEE
#ifdef USE_ENCAPSULATED_TID
		case GET_ENCRYPTED_TID:
			ret = qsee_encapsulate_inter_app_message((char*)SKM_NAME, (uint8*)TZ_APP_NAME, sizeof(TZ_APP_NAME), outData, (uint32*)outDataLen);
			if (ret != 0) {
				LOGE("qsee_encapsulate_inter_app_message failed: 0x%x\n", ret);
				ret = PLATFORM_INTERNAL_ERROR;
			} else {
				ret = NO_ERROR;
			}
			break;
#endif
		case OTA_DELETE_SD_FILE:
			*outDataLen = 0;
			ret = removeFile("/efs/prov_data/sd/sd.dat");
			break;
#endif /* #ifdef USE_QSEE */
		default:
			LOGE("Unknown command ID: 0x%x\n", cmdId);
			ret = UNSUPPORTED_CMD;
			break;
	}

	LOGD("handleCmd ret %d outLen %d\n", ret, *outDataLen);
	LOGD("======================================\n");
	return ret;
}
