#include "authentication.h"

#include "config.h"
#include "driver_log.h"
#include "memory.h"
#include "task.h"
#include "task_parser.h"
#include "tlb.h"
#include "crypto.h"
#include "pa_certificate.h"
#include <scl/string.h>

#include <tee_internal_api.h>

static const uint16_t kGetFiveStatusAttempts = 100;

/**
 * @brief Compare process name in task and with received name list
 * @param [in] task Task info
 * @param [in] names Array with NULL-terminated application names
 * @param [in] names_size Size of all process names
 * @return ::PA_TZ_SUCCESS, ::PA_TZ_AUTHENTICATION_FAILED, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult CheckProcessNames(const TaskInfo *task, const char *names,
                                    const size_t names_size);

/**
 * @brief Check owning of provided memory
 * @param [in] task Task info
 * @param [in] memory Memory range description
 * @return ::PA_TZ_SUCCESS, ::PA_TZ_AUTHENTICATION_FAILED
 */
static PaTzResult CheckProcessMemory(const TaskInfo *task, const PaTzMemoryRange *memory);

/**
 * @brief Check that provided range of task virtual addresses is valid
 * @param [in] task Task info
 * @param [in] address,size User space task buffer
 * @return ::PA_TZ_SUCCESS, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult IsTaskMemoryRangeValid(const TaskInfo *task, ProcessAddress address, size_t size);

PaTzResult CheckIntegritySigningRights(const TaskInfo *task) {
  switch (task->integrity) {
    case kIntegrityPreloadWithSign:
    case kIntegrityMixedWithSign:
    case kIntegrityDmVerityWithSign: {
      return PA_TZ_SUCCESS;
    }
    default: {
      return PA_TZ_GENERAL_ERROR;
    }
  }
}

PaTzResult CheckIntegrityWeak(const TaskInfo *task) {
  switch (task->integrity) {
    case kIntegrityMixed:
    case kIntegrityMixedWithSign:{
      return PA_TZ_SUCCESS;
    }
    default: {
      return PA_TZ_GENERAL_ERROR;
    }
  }
}

PaTzResult CheckIntegrityValue(const TaskInfo *task) {
  for (uint16_t attempt = 1; attempt <= kGetFiveStatusAttempts; attempt++) {
    switch (task->integrity) {
      case kIntegrityNone: {
        LOG_E("Task is not authenticated by FIVE.\n");
        LOG_D("Task pid: %d.\n", task->pid);
        return PA_TZ_AF_INTEGRITY_IS_NONE;
      }
      case kIntegrityPreload:
      case kIntegrityPreloadWithSign:
      case kIntegrityMixed:
      case kIntegrityMixedWithSign:
      case kIntegrityDmVerity:
      case kIntegrityDmVerityWithSign: {
        LOG_V("Task %d is authenticated (%x) by FIVE.\n",
            task->pid, task->integrity);
        return PA_TZ_SUCCESS;
      }
      case kIntegrityPending: {
        TEE_Wait(1);

        PaTzResult res = TaskParseIntegrity((TaskInfo *)task);
        if (PA_TZ_SUCCESS != res) {
          LOG_E("Could not parse extra info.\n");
          LOG_D("Task pid: %d.\n", task->pid);
          return PA_TZ_AF_TASK_IS_NOT_FOUND;
        }
        break;
      }
      default: {
        LOG_E("Unknown integrity flag for task!\n");
        LOG_D("Task pid: %d.\n", task->pid);
        return PA_TZ_AF_INTEGRITY_IS_NONE;
      }
    }
  }

  return PA_TZ_AF_INTEGRITY_IS_NOT_READY;
}

PaTzResult ProcessAuthentication(const TaskInfo *task, const char *process_names,
                                 const size_t process_names_size,
                                 const PaTzMemoryRange *memory) {
  if (!task) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  PaTzResult result = PA_TZ_AUTHENTICATION_FAILED;

  do {
    result = PaCertificateValidate(task->certificate);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("PaSignature is invalid.\n");
      result = PA_TZ_AF_CERTIFICATE_IS_INCORRECT;
      break;
    }

    result = CheckIntegrityValue(task);
    if (result != PA_TZ_SUCCESS) {
#ifndef SKIP_FIVE_AUTH
      LOG_E("Task integrity check failed.\n");
      break;
#else
      LOG_D("Five integrity check is disabled.\n");
#endif
    }

    // Use process name as one of authentication rule
    if (process_names && process_names_size) {
      result = CheckProcessNames(task, process_names, process_names_size);
      if (result == PA_TZ_SUCCESS) {
        LOG_D("Task %d has correct process name.\n", task->pid);
      } else {
        LOG_E("Task has incorrect process name.\n");
        LOG_D("Task pid: %d.\n", task->pid);
        result = PA_TZ_AF_APPNAME_IS_INCORRECT;
        break;
      }
    }

    // Check memory owning as one of authentication rule
    if (memory) {
      result = CheckProcessMemory(task, memory);
      if (result == PA_TZ_SUCCESS) {
        LOG_D("Task %d owns provided memory.\n", task->pid);
      } else if (result == PA_TZ_NON_SUPPORTED) {
        LOG_I("Memory ownership check is not suppported on this platform.\n");
      } else {
        LOG_E("Task does not own provided memory.\n");
        LOG_D("Task pid: %d.\n", task->pid);
        result = PA_TZ_AUTHENTICATION_FAILED;
        break;
      }
    }

    result = PA_TZ_SUCCESS;
  } while (0);

  return result;
}

static PaTzResult CheckProcessNames(const TaskInfo *task, const char *name,
                                    const size_t names_size) {
  size_t size = 0;
  size_t remained_names_size = names_size;
  while (remained_names_size) {
    size_t current_name_size = 0;
    if (!scl_strlen(name + size, kPaAppNameMaxLength, (scl_size_t *)&current_name_size)) {
      return PA_TZ_GENERAL_ERROR;
    }

    if ((current_name_size > 0) &&
        TEE_MemCompare(name + size, task->certificate->paData.paAppName.buf, current_name_size) == 0) {
      return PA_TZ_SUCCESS;
    }

    size += current_name_size + 1;
    remained_names_size -= (current_name_size + 1);
  }

  LOG_E("There is no such name in list.\n");
  return PA_TZ_AUTHENTICATION_FAILED;
}

static PaTzResult CheckProcessMemory(const TaskInfo *task, const PaTzMemoryRange *memory) {
  if (!task || !memory) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  PaTzResult result = PA_TZ_SUCCESS;

  switch (memory->type) {
    case PA_MEMORY_SWD_VA: {
      PhysicalAddress phys_address = 0;
      void *translated_address = NULL;

      result = PlatformSysMapTrustlet((void *)(uintptr_t)memory->addr, memory->size,
                                      kMemoryAccessRead, &translated_address);
      if (PA_TZ_SUCCESS != result) {
        LOG_E("Translation of Trustlet memory return error.\n");
        LOG_D("Result: 0x%x, trustlet memory addr 0x%x.\n", result, memory->addr);
        result = PA_TZ_GENERAL_ERROR;
        break;
      }

      result = PlatformVirtToPhys64(translated_address, &phys_address);

      // Unmap trustlet address in any case
      PlatformSysUnmapTrustlet(translated_address, memory->size);

      if (PA_TZ_NON_SUPPORTED == result) {
        break;
      } else if (PA_TZ_SUCCESS != result) {
        LOG_E("Can not get physical address.\n");
        result = PA_TZ_GENERAL_ERROR;
        break;
      }

      VmaInfo vma = {0};
      result = TaskFindVmaWithPhysicalMemory(task, phys_address, memory->size, &vma);
      break;
    }
    case PA_MEMORY_NWD_VA: {
      result = IsTaskMemoryRangeValid(task, memory->addr, memory->size);
      break;
    }
    case PA_MEMORY_PHYS: {
      VmaInfo vma = {0};
      result = TaskFindVmaWithPhysicalMemory(task, memory->addr, memory->size, &vma);
      break;
    }
    default: {
      LOG_E("Incorrect memory type.\n");
      LOG_D("Memory type 0x%x.\n", memory->type);
      result = PA_TZ_GENERAL_ERROR;
    }
  }

  return result;
}

static PaTzResult IsTaskMemoryRangeValid(const TaskInfo *task, ProcessAddress address, size_t size) {
  PaTzResult result = PA_TZ_SUCCESS;
  size_t remaining_size = size;
  TlbConverterInfo converter_context = {0};

  while (remaining_size) {
    uint32_t offset = PageOffset(address);
    size_t current_size = 0;
    PhysicalAddress phys;

    // Split whole data buffer by page boundaries
    if (offset + remaining_size > PAGE_SIZE) {
      current_size = PAGE_SIZE - offset;
    } else {
      current_size = remaining_size;
    }

    result = TaskAddressToPhysical(&converter_context, task->pgd, address, &phys, NULL);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("Translation of user process memory return error.\n");
      LOG_D("\tResult: 0x%x.\n", result);
      LOGADDR_D("\tUser process memory: ", address);
      break;
    }

    address += current_size;
    remaining_size -= current_size;
  }

  return result;
}
