/*
 * app_main.c
 */
#include <stdio.h>
#include <stdint.h>

char TZ_APP_NAME[] = {"tz_knoxai"};
int  _TBASE_API_LEVEL = 7;
#define EXIT_ERROR ((uint32_t)(-1))

#ifndef ENABLE_STACK_PROTECTION
DECLARE_STACK_PROTECTOR;
#endif

/* Trusted Application stack definition */
DECLARE_TRUSTLET_MAIN_STACK(65536)

static tciMessage_t sendMsgCopy;
static tciMessage_t rspMsgCopy;

#ifndef ENABLE_STACK_PROTECTION
  uint32_t __stack_chk_guard = 0;
  size_t len;

  /* Initialize the stackprotector canary value. */
  static inline void init_stack_canary(void) {
      len = sizeof(__stack_chk_guard);
      tlApiRandomGenerateData(TLAPI_ALG_SECURE_RANDOM, (uint8_t *)&__stack_chk_guard, &len);
  }
  void __stack_chk_fail(void) {
      KNOXAI_LOG("Stack smashing detected");
      abort();
  }
#endif

int isBufferInRange(const void *buffer, const uint32_t bufLen, const void *minAddr, const void *maxAddr) {
    uint64_t buf64=(uint64_t)buffer;
    uint64_t bufLen64=(uint64_t)bufLen;
    uint64_t min64=(uint64_t)minAddr;
    uint64_t max64=(uint64_t)maxAddr;

    if(buffer==NULL) //null buffer is not allowed
        return 0;
    if(minAddr>maxAddr) //minAddr must <=maxAddr
        return 0;
    if(buf64<min64||buf64>max64)
        return 0;
    if(bufLen64>max64)
        return 0;
    if(buf64>max64-bufLen64)
        return 0;
    return 1;
}

/* ----------------------------------------------------------------------------
 *   Service Entry Points
 * ---------------------------------------------------------------------------- */
#ifdef GP_ENTRY
#include <tee_internal_api.h>
#include <tee_internal_api_ext.h>

TEE_Result TA_EXPORT TA_CreateEntryPoint(void)
{
    KNOXAI_LOG("TA_CreateEntry Point");
#ifndef ENABLE_STACK_PROTECTION
    KNOXAI_LOG("Trustlet init_stack_canary");
    init_stack_canary();
#endif
    return TEE_SUCCESS;
}
void TA_EXPORT TA_DestroyEntryPoint(void)
{
    KNOXAI_LOG("TA_DestroyEntryPoint");
}
TEE_Result TA_EXPORT TA_OpenSessionEntryPoint(  uint32_t nParamTypes,
                                                IN OUT TEE_Param pParams[4],
                                                OUT void** ppSessionContext
)
{
    S_VAR_NOT_USED(nParamTypes);
    S_VAR_NOT_USED(pParams);
    S_VAR_NOT_USED(ppSessionContext);

    KNOXAI_LOG("TA_OpenSessionEntryPoint");
    return TEE_SUCCESS;
}
void TA_EXPORT TA_CloseSessionEntryPoint(IN OUT void* pSessionContext)
{
    KNOXAI_LOG("TA_CloseSessionEntryPoint");
}
static tciMessage_t sendMsgCopy;
static tciMessage_t rspMsgCopy;
TEE_Result TA_EXPORT TA_InvokeCommandEntryPoint(IN OUT void* pSessionContext,
                                                uint32_t commandID,
                                                uint32_t paramTypes,
                                                TEE_Param pParams[4])
{
    TEE_Result ret;
    uint8_t* pInput;
    uint32_t nInputSize;
    uint8_t* pOutput;
    uint32_t nOutputSize;
    tciMessage_t* sendmsg = NULL;
    tciMessage_t* respmsg = NULL;

    KNOXAI_LOG("TA_InvokeCommandEntryPoint");

    if (paramTypes != (uint32_t)TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INOUT, TEE_PARAM_TYPE_MEMREF_OUTPUT, TEE_PARAM_TYPE_NONE, TEE_PARAM_TYPE_NONE)) {
        KNOXAI_LOG("Bad Parameters in TA_InvokeCommandEntryPoint \n");
        return TEE_ERROR_BAD_PARAMETERS;
    }
    pInput      = pParams[0].memref.buffer;
    nInputSize  = pParams[0].memref.size;
    pOutput     = pParams[1].memref.buffer;
    nOutputSize = pParams[1].memref.size;

    KNOXAI_LOG("%s pInput     =0x%p", __func__, pInput);
    KNOXAI_LOG("%s nInputSize =%d",   __func__, nInputSize);
    KNOXAI_LOG("%s pOutput    =0x%p", __func__, pOutput);
    KNOXAI_LOG("%s nOutputSize=%d",   __func__, nOutputSize);

    if ( pInput == NULL || nInputSize < sizeof(tciMessage_t) || 
         pOutput == NULL || nOutputSize < sizeof(tciMessage_t)) {
        KNOXAI_LOG("Shared memory buffer size check error \n");
        return TEE_ERROR_BAD_PARAMETERS;
    }
    ret = TEE_CheckMemoryAccessRights(TEE_MEMORY_ACCESS_READ | TEE_MEMORY_ACCESS_ANY_OWNER, pInput, nInputSize);
    if (ret != TEE_SUCCESS) {
        KNOXAI_LOG("wrong input access rights!");
        return ret;
    }
    ret = TEE_CheckMemoryAccessRights(TEE_MEMORY_ACCESS_WRITE | TEE_MEMORY_ACCESS_ANY_OWNER, pOutput, nOutputSize);
    if (ret != TEE_SUCCESS) {
        KNOXAI_LOG("wrong output access rights!");
        return ret;
    }

    if ( checkIccc() < 1 ) {
        return TEE_ERROR_SECURITY;
    }

    /* Local buffer to prevent Race Condition */
    TEE_MemFill(&sendMsgCopy, 0, sizeof(tciMessage_t));
    TEE_MemFill(&rspMsgCopy, 0, sizeof(tciMessage_t));
    TEE_MemMove(&sendMsgCopy, pInput, sizeof(tciMessage_t));

    sendmsg = &sendMsgCopy;
    respmsg = &rspMsgCopy;

    ret = process_cmd(commandID, sendmsg, respmsg);

    respmsg->header.id = RSP_ID(commandID);
    respmsg->header.status = ret;

    TEE_MemMove(pOutput, &rspMsgCopy, sizeof(tciMessage_t));
    TEE_MemFill(&sendMsgCopy, 0, sizeof(tciMessage_t));
    TEE_MemFill(&rspMsgCopy, 0, sizeof(tciMessage_t));
    return TEE_SUCCESS;
}
#else
/* Trustlet entry. */
#include "taStd.h"
#include <TlApi/TlApiLogging.h>
_TLAPI_ENTRY void tlMain(const addr_t tciBuffer, const uint32_t tciBufferLen)
{
    /* Initialization sequence */
    uint32_t ret;
    uint32_t commandId;
    tciMessage_t *sendmsg = NULL;
    tciMessage_t *respmsg = NULL;

    /* Check if the size of the given TCI is sufficient */
    if ((NULL == tciBuffer) || (sizeof(tciMessage_t) > tciBufferLen)) {
        /* TCI too small -> end Trustlet */
        KNOXAI_LOG("tlMain buffer error");
        tlApiExit(EXIT_ERROR);
    }
    KNOXAI_DEBUG_LOG("tlMain loaded");
#ifndef ENABLE_STACK_PROTECTION
    KNOXAI_LOG("Trustlet init_stack_canary");
    init_stack_canary();
#endif
    if ( checkIccc() < 1 ) {
        KNOXAI_LOG("tlMain security error");
        tlApiExit(EXIT_ERROR);
    }
    /* The Trustlet main loop running infinitely */
    for (;;) {
        /* Wait for a notification to arrive (INFINITE timeout is recommended -> not polling!) */
        tlApiWaitNotification(TLAPI_INFINITE_TIMEOUT);

        if(!isBufferInRange(tciBuffer, sizeof(tciMessage_t), tciBuffer, (uint8_t *) tciBuffer+tciBufferLen)) {
            KNOXAI_LOG("invalid tciBuffer or tciBufferLen!");
            tlApiNotify();
            continue;
        }
        TEE_MemMove(&sendMsgCopy, tciBuffer, sizeof(tciMessage_t));

        if(sendMsgCopy.header.len > tciBufferLen - sizeof(tciMessage_t)) {
            KNOXAI_LOG("invalid header len!");
            tlApiNotify();
            continue;
        }

        if(!isBufferInRange((uint8_t *) tciBuffer + sendMsgCopy.header.len,
                sizeof(tciMessage_t),
                (uint8_t *) tciBuffer + sendMsgCopy.header.len,
                (uint8_t *) tciBuffer + tciBufferLen)) {
            KNOXAI_LOG("invalid tciBuffer or tciBufferLen! with sendMsgCopy.header.len");
            tlApiNotify();
            continue;
        }
        TEE_MemMove(&rspMsgCopy, (uint8_t *) tciBuffer + sendMsgCopy.header.len, sizeof(tciMessage_t));

        sendmsg= &sendMsgCopy;
        respmsg= &rspMsgCopy;

        /* Derefernece commandId once for further usage */
        commandId = sendmsg->header.id;

        /* Check if the message received is (still) a response */
        if (!IS_CMD(commandId)) {
            /* Tell the NWd a response is still pending (optional) */
            tlApiNotify();
            continue;
        }
        /* Process Command. */
        ret = process_cmd(commandId, sendmsg, respmsg);

        /* Set up response header -> mask response ID and set return code */
        respmsg->header.id = RSP_ID(commandId);
        respmsg->header.status = ret;

        TEE_MemMove((uint8_t *) tciBuffer + sendMsgCopy.header.len, &rspMsgCopy, sizeof(tciMessage_t));
        TEE_MemFill(&sendMsgCopy, 0, sizeof(tciMessage_t));
        TEE_MemFill(&rspMsgCopy, 0, sizeof(tciMessage_t));

        /*
        KNOXAI_LOG("Trustlet sendmsg->header.id = 0x%08x", sendmsg->header.id);
        KNOXAI_LOG("Trustlet respmsg->header.id = 0x%08x", respmsg->header.id);
        KNOXAI_LOG("Trustlet respmsg->header.status = 0x%08x",respmsg->header.status);

        KNOXAI_LOG("Trustlet Returning = 0x%08x", ret);
        */

        /* Notify back the TLC */
        tlApiNotify();
    }
}
#endif
