#include "task_access.h"

#include "access_control.h"
#include "authentication.h"
#include "driver_log.h"
#include "tlb.h"

#include <tee_internal_api.h>

PaTzResult TaskAccessGetBytes(const TaskInfo *task,
                              ProcessAddress address, size_t size, void *out) {
  PaTzResult result = PA_TZ_SUCCESS;
  PhysicalAddress phys = 0;
  uintptr_t out_data = (uintptr_t)out;
  size_t remaining_size = size;
  TlbConverterInfo converter_context = {0};

  while (remaining_size) {
    uint32_t offset = PageOffset(address);
    size_t current_size = 0;

    // Split whole data buffer by page boundaries
    if (offset + remaining_size > PAGE_SIZE) {
      current_size = PAGE_SIZE - offset;
    } else {
      current_size = remaining_size;
    }

    LOGADDR_V("Address is: ", address);
    LOG_V("\tLength is 0x%x.\n", current_size);

    AccessPermissionFlags flags = {0};
    result = TaskAddressToPhysical(&converter_context, task->pgd, address, &phys, &flags);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("Can not get physical from task virtual address.\n");
      result = PA_TZ_VIRT_TO_PHYS_ERROR;
      break;
    }

    if (flags.read == 0) {
      LOG_E("Memory should be readable!\n");
      result = PA_TZ_MEMORY_PERM_ERROR;
      break;
    }

    LOGADDR_V("Task phys is: ", phys);

    void *virt = NULL;
    result = PlatformSysMap(phys, current_size, kMemoryAccessRead, &virt);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can not map user memory.\n");
      LOG_D("Received result: 0x%x\n", result);
      break;
    }

    TEE_MemMove((void *)out_data, virt, current_size);

    result = PlatformSysUnmap(virt, current_size);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("PlatformSysUnmap failed.\n");
      LOG_D("Received result: 0x%x\n", result);
    }

    out_data += current_size;
    address += current_size;
    remaining_size -= current_size;
  }

  return result;
}

PaTzResult TaskAccessPutBytes(const TaskInfo *task,
                              const void *in, size_t size, ProcessAddress address) {
  PaTzResult result = PA_TZ_SUCCESS;
  PhysicalAddress phys = 0;
  uintptr_t in_data = (uintptr_t)in;
  size_t remaining_size = size;
  TlbConverterInfo converter_context = {0};

  while (remaining_size) {
    uint32_t offset = PageOffset(address);
    size_t current_size = 0;

    // Split whole data buffer by page boundaries
    if (offset + remaining_size > PAGE_SIZE) {
      current_size = PAGE_SIZE - offset;
    } else {
      current_size = remaining_size;
    }

    LOGADDR_V("Address is: ", address);
    LOG_V("\tLength is 0x%x.\n", current_size);

    AccessPermissionFlags flags = {0};
    result = TaskAddressToPhysical(&converter_context, task->pgd, address, &phys, &flags);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can not get physical from task virtual address.\n");
      result = PA_TZ_VIRT_TO_PHYS_ERROR;
      break;
    }

    if (flags.exec == 1 || flags.write == 0) {
      LOG_E("Memory should be writable and not executable!\n");
      result = PA_TZ_MEMORY_PERM_ERROR;
      break;
    }

    LOGADDR_V("Task phys is: ", phys);

    void *virt = NULL;
    result = PlatformSysMap(phys, current_size, kMemoryAccessWrite, &virt);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can not map user memory.\n");
      LOG_D("Received result: 0x%x\n", result);
      break;
    }

    TEE_MemMove(virt, (void *)in_data, current_size);

    result = PlatformSysUnmap(virt, current_size);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("PlatformSysUnmap failed.\n");
      LOG_D("Received result: 0x%x\n", result);
    }

    in_data += current_size;
    address += current_size;
    remaining_size -= current_size;
  }

  return result;
}

int32_t IsProcessAddressReadable(const TaskInfo *task, ProcessAddress address) {
  if (!task) {
    LOG_E("Invalid arguments.\n");
    return -1;
  }

  VmaInfo vma = {0};

  PaTzResult result = TaskFindVmaWithProcessAddress(task, address, &vma);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can not find VMA with needed address.\n");
    return -1;
  }

  return ((vma.vm_flags & VM_READ) ? 1 : 0);
}

int32_t IsProcessAddressWritable(const TaskInfo *task, ProcessAddress address) {
  if (!task) {
    LOG_E("Invalid arguments.\n");
    return -1;
  }

  VmaInfo vma = {0};

  PaTzResult result = TaskFindVmaWithProcessAddress(task, address, &vma);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can not find VMA with needed address.\n");
    return -1;
  }

  return ((vma.vm_flags & VM_WRITE) ? 1 : 0);
}
