#include "five_app.h"
#include "handlers.h"
#include "version.h"

#include <tees_log.h>
#include <tci/tci.h>

#include <tee_internal_api.h>
#include <tees_kdf.h>
#ifdef TEEGRIS30
#include <tees_extension.h>
#endif

#define MSGBUF_SIZE 4096

enum operation {
  OPERATION_SIGN,
  OPERATION_VERIFY,
};

static int sign_verify(struct tci_msg *msg, enum operation operation) {
  const void *data = msg->hash;
  uint32_t data_len;
  void *signature = msg->signature;
  uint32_t signature_len;
  uint32_t key_len;
  uint16_t label_len = msg->label_len;
  void *label = label_len ? msg->label : NULL;
  TEE_OperationHandle h_operation = TEE_HANDLE_NULL;
  TEE_ObjectHandle h_object = TEE_HANDLE_NULL;
  uint32_t obj_type;
  uint32_t alg;
  TEE_Result rc;

  if (operation != OPERATION_VERIFY && operation != OPERATION_SIGN)
    return TEE_ERROR_BAD_PARAMETERS;

  MB_LOGD("Hash type: %u\n", msg->hash_type);

  switch (msg->hash_type) {
    case HASH_SHA1: {
      data_len = HASH_SHA1_LEN;
      signature_len = HASH_SHA1_LEN;
      obj_type = TEE_TYPE_HMAC_SHA1;
      alg = TEE_ALG_HMAC_SHA1;
      break;
    }
    case HASH_SHA256: {
      data_len = HASH_SHA256_LEN;
      signature_len = HASH_SHA256_LEN;
      obj_type = TEE_TYPE_HMAC_SHA256;
      alg = TEE_ALG_HMAC_SHA256;
      break;
    }
    case HASH_SHA512: {
      data_len = HASH_SHA512_LEN;
      signature_len = HASH_SHA512_LEN;
      obj_type = TEE_TYPE_HMAC_SHA512;
      alg = TEE_ALG_HMAC_SHA512;
      break;
    }
    default: {
      MB_LOGE("Unknown hash type %u\n", msg->hash_type);
      return TEE_ERROR_NOT_SUPPORTED;
    }
  }

  key_len = signature_len;

  rc = TEE_AllocateOperation(&h_operation, alg, TEE_MODE_MAC, key_len * 8);
  if (rc)
    return rc;

  rc = TEE_AllocateTransientObject(obj_type, key_len * 8, &h_object);
  if (rc)
    goto end;

  rc = TEES_DeriveKeyKDF(label, label_len, NULL, 0, key_len, h_object);
  if (rc)
    goto end;

  rc = TEE_SetOperationKey(h_operation, h_object);
  if (rc)
    goto end;

  MB_LOGD("Hash:\n");
  MB_LOGMEMD(data, data_len);
  if (label_len) {
    MB_LOGD("Label:\n");
    MB_LOGMEMD(label, label_len);
  }

  // TEE_MACInit ignores IV value therefore there is no sense to use it.
  TEE_MACInit(h_operation, NULL, 0);
  if (operation == OPERATION_VERIFY) {
    MB_LOGD("Signature:\n");
    MB_LOGMEMD(signature, signature_len);

    rc = TEE_MACCompareFinal(h_operation, data, data_len, signature, signature_len);
    if (rc)
      goto end;
  } else {
    rc = TEE_MACComputeFinal(h_operation, data, data_len, signature, &signature_len);
    if (rc)
      goto end;

    MB_LOGD("Signature:\n");
    MB_LOGMEMD(signature, signature_len);
  }

  rc = TEE_SUCCESS;

end:
  TEE_CloseObject(h_object);
  TEE_FreeOperation(h_operation);
  return rc;
}

TEE_Result TA_CreateEntryPoint(void) {
  TEE_Result res;
  res = FiveRegisterKernelHandler();
  MB_LOGD("%s = %d\n", __func__, res);
  return res;
}

void TA_DestroyEntryPoint(void) {
}

TEE_Result TA_OpenSessionEntryPoint(uint32_t paramTypes, TEE_Param params[4],
                                    void **sessionContext) {
  (void) paramTypes;
  (void) params;
  (void) sessionContext;

  static bool session_exists = false;

  if (!session_exists) {
    session_exists = true;
    print_version();
    return TEE_SUCCESS;
  } else {
    MB_LOGD("Session already exists\n");
    return TEE_ERROR_ACCESS_DENIED;
  }
}

void TA_CloseSessionEntryPoint(void *sessionContext) {
  (void) sessionContext;
}

TEE_Result TA_InvokeCommandEntryPoint(void *sessionContext, uint32_t commandID,
                    uint32_t paramTypes, TEE_Param params[4]) {
  (void) sessionContext;

  return FiveEntryPointHandler(commandID, paramTypes, params);
}

TEE_Result FiveCommandHandler(uint32_t commandID,
                    uint32_t paramTypes, TEE_Param params[4]) {
  uint8_t msg_buf[MSGBUF_SIZE] = {};
  const size_t msgbuf_size = sizeof(msg_buf);
  struct tci_msg *msg = (struct tci_msg *)msg_buf;
  TEE_Result rc = TEE_ERROR_BAD_PARAMETERS;

  if (TEE_PARAM_TYPE_GET(paramTypes, 0) == TEE_PARAM_TYPE_MEMREF_INPUT
                                       && commandID == COMMAND_VERIFY) {
    size_t in_size = params[0].memref.size;
    rc = TEE_ERROR_OVERFLOW;

    if (!params[0].memref.buffer) {
      MB_LOGE("Buffer is NULL\n");
      return TEE_ERROR_BAD_PARAMETERS;
    }

#ifdef TEEGRIS30
    if (TEES_IsREESharedMemory(
        TEE_MEMORY_ACCESS_READ,
        params[0].memref.buffer, params[0].memref.size) != TEE_SUCCESS) {
      MB_LOGE("Memory access check error\n");
      return TEE_ERROR_ACCESS_DENIED;
    }
#endif

    MB_LOGD("START COMMAND_VERIFY\n");

    if (in_size <= msgbuf_size) {
      TEE_MemMove(msg_buf, params[0].memref.buffer, in_size);

      if (msgbuf_size - sizeof(*msg) > msg->label_len) {
        rc = sign_verify(msg, OPERATION_VERIFY);
      }
    }
    MB_LOGD("END COMMAND_VERIFY rc=0x%x\n", rc);
  } else if (TEE_PARAM_TYPE_GET(paramTypes, 0) == TEE_PARAM_TYPE_MEMREF_INOUT
                                                && commandID == COMMAND_SIGN) {
    size_t in_size = params[0].memref.size;
    rc = TEE_ERROR_OVERFLOW;

    if (!params[0].memref.buffer) {
      MB_LOGE("Buffer is NULL\n");
      return TEE_ERROR_BAD_PARAMETERS;
    }

#ifdef TEEGRIS30
    if (TEES_IsREESharedMemory(
        TEE_MEMORY_ACCESS_READ | TEE_MEMORY_ACCESS_WRITE,
        params[0].memref.buffer, params[0].memref.size) != TEE_SUCCESS) {
      MB_LOGE("Memory access check error\n");
      return TEE_ERROR_ACCESS_DENIED;
    }
#endif

    MB_LOGD("START COMMAND_SIGN\n");

    if (in_size <= msgbuf_size) {
      TEE_MemMove(msg_buf, params[0].memref.buffer, in_size);
      if (msgbuf_size - sizeof(*msg) > msg->label_len) {
        rc = sign_verify(msg, OPERATION_SIGN);
        if (rc == TEE_SUCCESS)
          TEE_MemMove(params[0].memref.buffer, msg_buf, in_size);
      }
    }

    MB_LOGD("END COMMAND_SIGN rc=0x%x\n", rc);
  } else {
    MB_LOGE("Unknown command: cmd=0x%x type=%u\n",
            commandID, TEE_PARAM_TYPE_GET(paramTypes, 0));
  }

  return rc;
}
