#include "kernel_access.h"

#include "driver_log.h"

#include <tee_internal_api.h>

/**
 * @brief Print content of context mapping to log
 * @param [in] context Kernel access context
 */
static void KernelAccessDumpMaps(const KernelAccessInfo *context);

/**
 * @brief Check that provided address already mapped in context
 * @param [in] context Kernel access context
 * @param [in] address Kernel virtual address
 * @return 1 if page is mapped, 0 if page is not mapped
 */
static uint32_t KernelAccessIsPageMaped(const KernelAccessInfo *context, KernelAddress address);

/**
 * @brief Translate Kernel Virtual Address to Secure Virtual Address
 * @param [in] context Kernel access context
 * @param [in] address Kernel virtual address
 * @param [out] out Pointer to secure virtual address
 * @return ::PA_TZ_SUCCESS, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult KernelAccessGetVirtual(const KernelAccessInfo *context, KernelAddress address, void **out);

/**
 * @brief Find empty entry for mapping
 * @param [in] context Kernel access context
 * @return Index of empty entry or -1 if error is occured.
 */
static int32_t KernelAccessFindMapPlace(const KernelAccessInfo *context);

/**
 * @brief Map and add some memory to context manually
 * @param [in,out] context Pointer to structure of kernel space access
 * @param [in] address Kernel virtual address of data
 * @param [in] size Size of data
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult KernelAccessMapRegion(KernelAccessInfo *context, KernelAddress address, size_t size);


PaTzResult KernelAccessInit(KernelAccessInfo *context) {
  if (!context) {
    LOG_E("Invalid argument.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  TEE_MemFill(context, 0, sizeof(*context));

  return PA_TZ_SUCCESS;
}

PaTzResult KernelAccessDeinit(KernelAccessInfo *context) {
  if (!context) {
    LOG_E("Invalid argument.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  for (size_t i = 0; i < kMaxMapsCount; ++i) {
    if (!context->maps[i].is_valid) {
      continue;
    }

    PaTzResult result = PlatformUnmapRegion(context->maps[i].virt, context->maps[i].size);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Failed destroying kernel access.\n");
    }
  }

  TEE_MemFill(context, 0, sizeof(*context));

  return PA_TZ_SUCCESS;
}

PaTzResult KernelAccessGetBytes(KernelAccessInfo *context,
                                KernelAddress address, size_t size, void *out) {
  KernelAddress start = address;
  size_t remaining_size = size;

  KernelAccessInfo inner_context = {0};

  if (!out) {
    LOG_E("Invalid argument.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  if (!IsKernelVirtualAddress(address)) {
    LOG_E("It is not kernel address!\n");
    LOGADDR_D("Address: ", address);
    return PA_TZ_GENERAL_ERROR;
  }

  if (!context) {
    if (KernelAccessInit(&inner_context) != PA_TZ_SUCCESS) {
      LOG_E("Cannot init Kernel Access context.\n");
      return PA_TZ_GENERAL_ERROR;
    }
    context = &inner_context;
  }

  PaTzResult result = PA_TZ_SUCCESS;

  // Copying data from Kernel page by page
  while (remaining_size) {
    size_t offset = PageOffset(start);
    size_t current_size = 0;

    if (offset + remaining_size > PAGE_SIZE) {
      // Copy only to current page bounder
      current_size = PAGE_SIZE - offset;
    } else {
      current_size = remaining_size;
    }

    // Cache of page mapping
    if (!KernelAccessIsPageMaped(context, start)) {
      if (KernelAccessMapRegion(context, start, current_size) != PA_TZ_SUCCESS) {
        LOG_D("Deinit context and try map again.\n");
        KernelAccessDumpMaps(context);
        KernelAccessDeinit(context);
        if (KernelAccessMapRegion(context, start, current_size) != PA_TZ_SUCCESS) {
            LOG_E("Cannot map Kernel page.\n");
            result = PA_TZ_GENERAL_ERROR;
            break;
        }
      }
    }

    void *virt = NULL;
    if (KernelAccessGetVirtual(context, start, &virt) != PA_TZ_SUCCESS || !virt) {
      LOG_E("Cannot obtain correct virt for kernel.\n");
      LOGADDR_D("Address: ", address);
      KernelAccessDumpMaps(context);
      result = PA_TZ_GENERAL_ERROR;
      break;
    }

    TEE_MemMove(out, virt, current_size);

    out = (uint8_t *)out + current_size;
    start += current_size;
    remaining_size -= current_size;
  }

  KernelAccessDeinit(&inner_context);

  return result;
}

PaTzResult KernelAccessGetUint8(KernelAccessInfo *context,
                                KernelAddress address, uint8_t *out) {
  return KernelAccessGetBytes(context, address, sizeof(*out), out);
}

PaTzResult KernelAccessGetUint32(KernelAccessInfo *context,
                                 KernelAddress address, uint32_t *out) {
  return KernelAccessGetBytes(context, address, sizeof(*out), out);
}

PaTzResult KernelAccessGetUint64(KernelAccessInfo *context,
                                 KernelAddress address, uint64_t *out) {
  return KernelAccessGetBytes(context, address, sizeof(*out), out);
}

PaTzResult KernelAccessGetPointer(KernelAccessInfo *context,
                                  KernelAddress address, KernelAddress *out) {
  return KernelAccessGetBytes(context, address, sizeof(*out), out);
}

static PaTzResult KernelAccessMapRegion(KernelAccessInfo *context, KernelAddress address, size_t size) {
  PhysicalAddress phys_address = 0;
  if (size == 0) {
    LOG_E("Size of data can't be 0\n");
    return PA_TZ_GENERAL_ERROR;
  }

  int32_t index = KernelAccessFindMapPlace(context);
  if (index == -1) {
    LOG_E("The kernel access context is full.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  phys_address = KernelVirtToPhys(address);

  if (!IsPhysicalAddress(phys_address)) {
    LOG_E("Physical address is invalid.\n");
    LOGADDR_D("Address: ", phys_address);
    return PA_TZ_GENERAL_ERROR;
  }

  void *virt = NULL;

  PaTzResult result = PlatformMapRegion(phys_address, size,
                                        kMemoryAccessRead, &virt);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Failed to map.\n");
    return result;
  }

  context->maps[index].kernel = AlignToPageDown(address);
  context->maps[index].physical = AlignToPageDown(phys_address);
  context->maps[index].virt = (void *)(uintptr_t)AlignToPageDown((uintptr_t)virt);
  context->maps[index].size = size;
  context->maps[index].is_valid = 1;

  return PA_TZ_SUCCESS;
}

static void KernelAccessDumpMaps(const KernelAccessInfo *context) {
  for (size_t i = 0; i < kMaxMapsCount; ++i) {
    LOG_V("Maps[%d] is valid: %d.\n", i, context->maps[i].is_valid);
    if (!context->maps[i].is_valid) {
      continue;
    }

    LOG_V("\tsize: %d.\n", context->maps[i].size);
    LOGADDR_V("\tkern: ", context->maps[i].kernel);
    LOGADDR_V("\tphys: ", context->maps[i].physical);
    LOG_V("\tvirt: 0x%x.\n", context->maps[i].virt);
  }
}

static uint32_t KernelAccessIsPageMaped(const KernelAccessInfo *context, KernelAddress address) {
  if (!context) {
    return 0;
  }

  for (size_t i = 0; i < kMaxMapsCount; ++i) {
    if (AlignToPageDown(address) >= context->maps[i].kernel &&
        AlignToPageDown(address) < context->maps[i].kernel + context->maps[i].size) {
      return 1;
    }
  }

  return 0;
}

static PaTzResult KernelAccessGetVirtual(const KernelAccessInfo *context, KernelAddress address, void **out) {
  if (!context || !out) {
    return PA_TZ_GENERAL_ERROR;
  }

  for (size_t i = 0; i < kMaxMapsCount; ++i) {
    if (!context->maps[i].is_valid) {
      continue;
    }

    if (AlignToPageDown(address) >= context->maps[i].kernel &&
        AlignToPageDown(address) < context->maps[i].kernel + context->maps[i].size) {
      size_t offset = address - context->maps[i].kernel;
      *out = (uint8_t *)context->maps[i].virt + offset;
      return PA_TZ_SUCCESS;
    }
  }

  return PA_TZ_GENERAL_ERROR;
}

static int32_t KernelAccessFindMapPlace(const KernelAccessInfo *context) {
  if (!context) {
    return -1;
  }

  for (size_t i = 0; i < kMaxMapsCount; ++i) {
    if (!context->maps[i].is_valid) {
      return i;
    }
  }

  return -1;
}
