#include "task.h"

#include "config.h"
#include "crypto.h"
#include "driver_log.h"
#include "gaf.h"
#include "kernel_access.h"
#include "task_parser.h"
#include "task_descriptor.h"
#include "tlb.h"

#include <tee_internal_api.h>
#include <scl/string.h>

/**
 * @brief Checker for TaskFindVma that find VMA by Process virtual address
 * @param [in] task Pointer to task structure
 * @param [in] vma Pointer to VMA structure
 * @param [in] param Pointer to process address
 * @return ::FindResult
 */
static FindResult ProcessAddressComparator(const TaskInfo *task,
                                           const VmaInfo *vma, const void* param);

/**
 * @brief Create task info structure
 * @param [in] task_addr Address of task_struct in kernel space
 * @param [in] pa_identity PROCA identity
 * @param [out] out_task Pointer to task structure which was parsed
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult CreateTaskInfo(const KernelAddress task_addr,
                                 const PaIdentity *pa_identity,
                                 TaskInfo *out_task);

void FreeTaskInfo(TaskInfo *task) {
  if (!task) {
    LOG_E("Pointer of memory is NULL\n");
    return;
  }

  PaCertificateDestroy(task->certificate);
}

static PaTzResult CreateTaskInfo(const KernelAddress task_addr,
                                 const PaIdentity *pa_identity,
                                 TaskInfo *out_task) {
  TaskInfo task = {0};

  KernelAddress task_struct = task_addr;
  LOGADDR_V("task_addr : ", task_struct);

  PaTzResult result = TaskParseBasicInfo(task_struct, &task);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Cannot get task info.\n");
    FreeTaskInfo(&task);
    return PA_TZ_GENERAL_ERROR;
  }

  result = TaskParseIntegrity(&task);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Cannot get task extra.\n");
    FreeTaskInfo(&task);
    return PA_TZ_GENERAL_ERROR;
  }

  task.pa_identity = *pa_identity;

  LOGADDR_V("\ttask.pa_identity.file: ",      task.pa_identity.file);
  LOGADDR_V("\ttask.pa_identity.certificate: ", task.pa_identity.certificate);
  LOG_V("\ttask.pa_identity.certificate_size: %u.\n",
              task.pa_identity.certificate_size);
  LOGADDR_V("\ttask.pa_identity.parsed_cert.app_name: ",
              task.pa_identity.parsed_cert.app_name);
  LOG_V("\ttask.pa_identity.parsed_cert.app_name_size: %u\n",
              task.pa_identity.parsed_cert.app_name_size);
  if (task.pa_identity.certificate) {
    result = TaskParsePaCertificate(&task, &task.certificate);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("Cannot parse task certificate.\n");
      FreeTaskInfo(&task);
      return PA_TZ_AF_CERTIFICATE_IS_ABSENT;
    }
  }

  *out_task = task;

  return PA_TZ_SUCCESS;
}

PaTzResult TaskFindByAppName(const char *app_name_str,
                             const size_t app_name_len,
                             TaskInfo *out_task) {
  LOG_V("Attempting to find task by app name.\n");

  TaskDescriptor descr;
  PaTzResult result = TaskDescriptorFindByAppName(app_name_str,
                                                  app_name_len, &descr);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Failed to find task by application name.\n");
    LOG_D("Application name is: \n");
    LOGM_D(app_name_str, app_name_len);
    return result;
  }

  return CreateTaskInfo(descr.task_address, &descr.pa_identity, out_task);
}

PaTzResult TaskFindByPid(uint32_t pid, TaskInfo *out_task) {
  LOG_V("Attempting to find task by PID %d.\n", pid);

  TaskDescriptor descr;
  PaTzResult result = TaskDescriptorFindByPid(pid, &descr);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Failed to find task by PID.\n");
    LOG_D("PID %d.\n", pid);
    return result;
  }

  return CreateTaskInfo(descr.task_address, &descr.pa_identity, out_task);
}

FindResult CheckPhysicalCommandBuffer(const TaskInfo *task, const VmaInfo *vma,
                                      const void* param) {
  if (!task || !vma || !param) {
    LOG_E("Invalid argument\n");
    return kNonEqual;
  }

  // Physical address have to access on read/write/execute
  if (!(vma->vm_flags & VM_READ) || !(vma->vm_flags & VM_WRITE)
      || (vma->vm_flags & VM_EXEC)) {
    LOG_V("Flags aren't read/write/exec: 0x%08x%08x\n",
          (uint32_t )(vma->vm_flags >> 32), (uint32_t )vma->vm_flags);
    return kNonEqual;
  }

  PaTzResult result;
  PhysicalAddress translated_phys;
  PhysicalAddress phys = *((PhysicalAddress *)param);

  phys = AlignToPageDown(phys);
  KernelAddress current_vma = vma->vm_start;
  TlbConverterInfo converter_context = {0};

  while (current_vma < vma->vm_end) {
    result = TaskAddressToPhysical(&converter_context, task->pgd, current_vma, &translated_phys, NULL);

    if (result == PA_TZ_SUCCESS && translated_phys == phys) {
      LOGADDR_V("Find needed phys addr: ", translated_phys);
      return kEqual;
    }

    current_vma += PAGE_SIZE;
  }

  return kNonEqual;
}

PaTzResult TaskFindVmaLinear(const TaskInfo *task, const CustomFind custom_find,
                             const void *param, VmaInfo *out_vma) {
  if (!task || !custom_find || !out_vma) {
    LOG_E("Invalid argument\n");
    return PA_TZ_GENERAL_ERROR;
  }

  const uint32_t kMaxNumVma = 4096;

  PaTzResult result = PA_TZ_GENERAL_ERROR;
  uint8_t flag_find = 0;
  uint32_t counter_vma = 0;
  VmaInfo vma = {0};
  KernelAddress current_virt = task->mmap;

  do {
    if (TaskParseVmaInfo(current_virt, &vma) != PA_TZ_SUCCESS) {
      result = PA_TZ_GENERAL_ERROR;
      LOG_E("Can't get vma info.\n");
      break;
    }

    flag_find = custom_find(task, &vma, param);
    if (flag_find == kEqual) {
      *out_vma = vma;
      result = PA_TZ_SUCCESS;
      LOG_V("Find needed VMA:\n");
      LOGADDR_V("\tvma.vm_start: ", vma.vm_start);
      LOGADDR_V("\tvma.vm_end: ", vma.vm_end);
      break;
    }

    current_virt = vma.vm_next;

    if (++counter_vma > kMaxNumVma) {
      LOG_E("Max count of VMA is reached.\n");
      break;
    }

    if (task->mmap == current_virt) {
      LOG_E("All VMAs are checked. Loop is detected.\n");
      break;
    }
  } while (current_virt && IsKernelVirtualAddress(current_virt));

  return result;
}

PaTzResult TaskFindVmaRbTree(const TaskInfo *task, const CustomFind custom_find,
                             const void *param, VmaInfo *out_vma) {
  if (!task || !custom_find || !out_vma) {
    LOG_E("Invalid argument\n");
    return PA_TZ_GENERAL_ERROR;
  }

  const uint32_t kMaxNumVma = 4096;

  PaTzResult result = PA_TZ_GENERAL_ERROR;
  int8_t flag_find = 0;
  uint32_t counter_vma = 0;
  VmaInfo vma = {0};
  KernelAddress current_virt = task->mm_rb.rb_node;
  const size_t offset = GetConfig()->gaf.vm_area_struct_struct_vm_next +
      sizeof(KernelAddress) * 2;

  do {
    current_virt -= offset;

    if (TaskParseVmaInfo(current_virt, &vma) != PA_TZ_SUCCESS) {
      result = PA_TZ_GENERAL_ERROR;
      LOG_E("Can't get vma info.\n");
      break;
    }

    flag_find = custom_find(task, &vma, param);
    if (flag_find == kEqual) {
      *out_vma = vma;
      result = PA_TZ_SUCCESS;
      LOG_V("Find needed VMA:\n");
      LOGADDR_V("\tvma.vm_start: ", vma.vm_start);
      LOGADDR_V("\tvma.vm_end: ", vma.vm_end);
      break;
    } else if (flag_find == kBigger) {
      current_virt = vma.vm_rb.rb_left;
    } else {
      current_virt = vma.vm_rb.rb_right;
    }

    if (++counter_vma > kMaxNumVma) {
      LOG_E("Max count of VMA is reached.\n");
      break;
    }
  } while (current_virt && IsKernelVirtualAddress(current_virt));

  return result;
}

static FindResult ProcessAddressComparator(const TaskInfo *task,
                                           const VmaInfo *vma, const void* param) {
  ProcessAddress address = *(ProcessAddress *)param;

  if (vma->vm_end > address) {
    if (vma->vm_start <= address) {
      return kEqual;
    } else {
      /* Current VMA is bigger than needed */
      return kBigger;
    }
  }

  /* Current VMA is smaller than needed */
  return kSmaller;
}

PaTzResult TaskFindVmaWithPhysicalMemory(const TaskInfo *task,
    PhysicalAddress address, size_t size, VmaInfo *out_vma) {
  return TaskFindVmaLinear(task, CheckPhysicalCommandBuffer, &address, out_vma);
}

PaTzResult TaskFindVmaWithProcessAddress(const TaskInfo *task,
    ProcessAddress address, VmaInfo *out_vma) {
  return TaskFindVmaRbTree(task, ProcessAddressComparator, &address, out_vma);
}
