#include "knoxai_rsa.h"
#include "knoxai_logger.h"

#include <stdint.h>
#include <tee_internal_api.h>

uint32_t knoxai_rsa_encrypt(
    uint8_t *keyMod,
    uint32_t keyModLen,
    uint8_t *keyPubExp,
    uint32_t keyPubExpLen,
    uint8_t * in,
    uint32_t in_len,
    uint8_t * out,
    uint32_t * pOut_len
)
{
    TEE_OperationHandle opHandle=NULL;
    TEE_ObjectHandle keyHandle=NULL;
    TEE_Attribute *params = NULL;
    uint32_t paramCount=2;
    uint32_t ret= TZ_API_OK;
    uint8_t *pPubKeyExp=NULL;

    KNOXAI_DEBUG_LOG("GLOBAL_PLAT: RSA key encrypt");

    KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s keyModLen=%d, keyPubExpLen=%d", __FUNCTION__, keyModLen, keyPubExpLen);

    if (keyMod == NULL || keyModLen == 0) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid input key modulus", __FUNCTION__);
        return TZ_API_ERROR;
    }

    if (keyPubExp == NULL || keyPubExpLen == 0 || keyPubExpLen > keyModLen) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid input key public exponent", __FUNCTION__);
        return TZ_API_ERROR;
    }

    if (in == NULL || in_len == 0) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid input data", __FUNCTION__);
        return TZ_API_ERROR;
    }

    if (out == NULL || *pOut_len == 0) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid output data", __FUNCTION__);
        return TZ_API_ERROR;
    }

    params = TEE_Malloc(paramCount* sizeof(TEE_Attribute), 0);
    if (NULL == params) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_Malloc failed for TEE_Malloc(paramCount* sizeof(TEE_Attribute))", __FUNCTION__);
        ret = TZ_API_ERROR;
        goto EXIT;
    }

    pPubKeyExp = TEE_Malloc(keyModLen, 0);
    if (NULL == pPubKeyExp) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_Malloc failed for TEE_Malloc(keyModLen, 0)", __FUNCTION__);
        ret = TZ_API_ERROR;
        goto EXIT;
    }
    TEE_MemFill(pPubKeyExp, 0, keyModLen);
    TEE_MemMove(pPubKeyExp+keyModLen-keyPubExpLen, keyPubExp, keyPubExpLen);
//    TEE_MemMove(pPubKeyExp, keyPubExp, keyPubExpLen);

    params[0].attributeID = TEE_ATTR_RSA_MODULUS;
    params[0].content.ref.buffer = keyMod;
    params[0].content.ref.length = keyModLen;
    params[1].attributeID = TEE_ATTR_RSA_PUBLIC_EXPONENT;
    params[1].content.ref.buffer = pPubKeyExp;
    params[1].content.ref.length = keyModLen;

    ret = TEE_AllocateTransientObject(TEE_TYPE_RSA_PUBLIC_KEY, 8*keyModLen, &keyHandle);
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_AllocateTransientObject failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    ret=TEE_PopulateTransientObject(keyHandle, params, paramCount);
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_PopulateTransientObject failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    ret = TEE_AllocateOperation(&opHandle, TEE_ALG_RSAES_PKCS1_V1_5, TEE_MODE_ENCRYPT, 8*keyModLen);
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_AllocateOperation failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s to call TEE_SetOperationKey()", __FUNCTION__);
    ret = TEE_SetOperationKey(opHandle, keyHandle);
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_SetOperationKey failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s to call TEE_AsymmetricEncrypt()", __FUNCTION__);
    
    size_t t_out_len = *pOut_len;
    ret = TEE_AsymmetricEncrypt(opHandle, NULL, 0, in, in_len, out, &t_out_len);
    *pOut_len = (uint32_t)t_out_len;
    
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_AsymmetricEncrypt failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    ret= TZ_API_OK;

EXIT:
    if(params!=NULL) {
        TEE_Free(params);
    }
    if(pPubKeyExp!=NULL) {
        TEE_Free(pPubKeyExp);
    }
    if(keyHandle!=NULL) {
        TEE_FreeTransientObject(keyHandle);
    }
    if(opHandle!=NULL) {
        TEE_FreeOperation(opHandle);
    }
    return ret;
}

uint32_t knoxai_rsa_decrypt(
    uint8_t *keyMod,
    uint32_t keyModLen,
    uint8_t *keyPubExp,
    uint32_t keyPubExpLen,
    uint8_t *keyPriExp,
    uint32_t keyPriExpLen,
    uint8_t * in,
    uint32_t in_len,
    uint8_t * out,
    uint32_t * pOut_len
)
{
    TEE_OperationHandle opHandle=NULL;
    TEE_ObjectHandle keyHandle=NULL;
    uint32_t paramCount=3;
    uint32_t ret= TZ_API_OK;
    uint8_t *pPubKeyExp=NULL;
    TEE_Attribute attrs[3];
    
    KNOXAI_DEBUG_LOG("GLOBAL_PLAT: RSA key decrypt");

    if (keyMod == NULL || keyModLen == 0) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid input key modulus", __FUNCTION__);
        return TZ_API_ERROR;
    }

    if (keyPubExp == NULL || keyPubExpLen == 0) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid input key public exponent", __FUNCTION__);
        return TZ_API_ERROR;
    }

    if (keyPriExp == NULL || keyPriExpLen == 0) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid input key private exponent", __FUNCTION__);
        return TZ_API_ERROR;
    }

    if (in == NULL || in_len == 0) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid input data", __FUNCTION__);
        return TZ_API_ERROR;
    }

    if (out == NULL || pOut_len==NULL|| *pOut_len == 0) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s invalid output data", __FUNCTION__);
        return TZ_API_ERROR;
    }

    pPubKeyExp = TEE_Malloc(keyModLen, 0);
    if (NULL == pPubKeyExp) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_Malloc failed for TEE_Malloc(keyModLen, 0)", __FUNCTION__);
        ret = TZ_API_ERROR;
        goto EXIT;
    }
    TEE_MemFill(pPubKeyExp, 0, keyModLen);
    TEE_MemMove(pPubKeyExp+keyModLen-keyPubExpLen, keyPubExp, keyPubExpLen);

    TEE_InitRefAttribute(&attrs[0], TEE_ATTR_RSA_MODULUS, keyMod, keyModLen);
    TEE_InitRefAttribute(&attrs[1], TEE_ATTR_RSA_PUBLIC_EXPONENT, pPubKeyExp, keyModLen);
    TEE_InitRefAttribute(&attrs[2], TEE_ATTR_RSA_PRIVATE_EXPONENT, keyPriExp, keyPriExpLen);
    

    ret = TEE_AllocateTransientObject(TEE_TYPE_RSA_KEYPAIR, 8*keyModLen, &keyHandle);
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_AllocateTransientObject failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    ret=TEE_PopulateTransientObject(keyHandle, attrs, paramCount);
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_PopulateTransientObject failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }    

    ret = TEE_AllocateOperation(&opHandle, TEE_ALG_RSAES_PKCS1_V1_5, TEE_MODE_DECRYPT, 8*keyModLen);
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_AllocateOperation failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    ret = TEE_SetOperationKey(opHandle, keyHandle);
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_SetOperationKey failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    /* Adding for 0xffff0016 error TEE_AsymmetricDecrypt */
    *pOut_len = keyModLen;

    size_t t_out_len = *pOut_len;
    ret = TEE_AsymmetricDecrypt(opHandle, attrs, 0, in, in_len, out, &t_out_len);
    *pOut_len = (uint32_t)t_out_len;
    
    if (TEE_SUCCESS != ret) {
        KNOXAI_DEBUG_LOG("GLOBAL_PLAT: %s TEE_AsymmetricEncrypt failed with ret=%d", __FUNCTION__, ret);
        goto EXIT;
    }

    ret= TZ_API_OK;

EXIT:
    if(pPubKeyExp!=NULL) {
        TEE_Free(pPubKeyExp);
    }
    if(keyHandle!=NULL) {
        TEE_FreeTransientObject(keyHandle);
    }
    if(opHandle!=NULL) {
        TEE_FreeOperation(opHandle);
    }
    return ret;
}

