#include "memory.h"
#include "driver_log.h"
#include "kaslr.h"

#include <DrApi/DrApi.h>
#include <DrApi/DrApiMm.h>
#include <DrApi/DrApiMmExt.h>

#if TBASE_API_LEVEL < 5
#error "TBase driver should be built with TBASE_API_LEVEL >= 5 to use \
        extended memory layout"
#endif

extern taskid_t GetCurrentClientTaskId(void);

PaTzResult PlatformVirtToPhys64(const void *virt, PhysicalAddress *phys) {
  if (drApiVirt2Phys64(0, (addr_t)virt, phys) != DRAPI_OK) {
    return PA_TZ_GENERAL_ERROR;
  }
  return PA_TZ_SUCCESS;
}

uint32_t GetMappingFlags(MemoryAccessType type) {
  if (type == kMemoryAccessRead) {
    return (MAP_READABLE | MAP_NOT_SECURE);
  } else if (type == kMemoryAccessWrite) {
    return (MAP_WRITABLE | MAP_READABLE | MAP_NOT_SECURE);
  }

  return 0;
}

PaTzResult PlatformSysMap(PhysicalAddress phys, size_t size,
                          MemoryAccessType type, void **virt) {
  PaTzResult result = PlatformCheckNonSecureRegion(phys, size);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Failed CheckNonSecureRegion.\n");
    LOG_D("Received result: 0x%x.\n", result);
    return result;
  }

  // Prepared mapping flags
  uint32_t mask_operation = GetMappingFlags(type);

  if (drApiMapPhysicalBuffer(phys, size, mask_operation, virt) != DRAPI_OK) {
    LOG_E("Failed mapping the size.\n");
    LOG_D("The size: %d.\n", size);
    return PA_TZ_GENERAL_ERROR;
  }

  return PA_TZ_SUCCESS;
}

PaTzResult PlatformSysUnmap(void *virt, size_t size) {
  if (drApiUnmapBuffer(virt) != DRAPI_OK) {
    LOG_E("Failed unmapping the size.\n");
    LOG_D("The size: %d.\n", size);
    return PA_TZ_GENERAL_ERROR;
  }

  if (drApiCacheDataCleanInvalidateAll() != DRAPI_OK) {
    LOG_E("drApiCacheDataCleanInvalidateAll failed.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  return PA_TZ_SUCCESS;
}

PaTzResult PlatformSysMapTrustlet(const void *virt_trustlet, size_t size,
                                  MemoryAccessType type, void **virt_driver) {
  if (!virt_trustlet || !virt_driver) {
    LOG_E("Invalid arguments.\n");
    LOG_D("virt_trustlet: 0x%x, virt_driver: 0x%x.\n", virt_trustlet, virt_driver);
    return PA_TZ_GENERAL_ERROR;
  }

  drApiResult_t result = drApiMapTaskBuffer(
      GetCurrentClientTaskId(), (addr_t)virt_trustlet, size,
      GetMappingFlags(type) | MAP_ALLOW_NONSECURE, virt_driver);
  if (DRAPI_OK != result) {
    LOG_E("Fail to translate task buffer.\n");
    LOG_D("Received result: 0x%x.\n", result);
    return PA_TZ_GENERAL_ERROR;
  }

  return PA_TZ_SUCCESS;
}

PaTzResult PlatformSysUnmapTrustlet(const void *virt_driver, size_t size) {
  // Will be unmap automatically
  return PA_TZ_SUCCESS;
}

PaTzResult PlatformReadOemBuffer(uint32_t offset, size_t size, void *out) {
  drApiResult_t result = drApiReadOemData(offset, out);

  return (result == DRAPI_OK ? PA_TZ_SUCCESS : PA_TZ_GENERAL_ERROR);
}

PaTzResult KernelGetKaslrOffset(uint32_t *kaslr_offset) {
  if (!kaslr_offset) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  Kaslr kaslr = {0};

  PaTzResult result = PhysicalGetBytes(KernelGetKaslrStructAddress(), sizeof(kaslr), &kaslr);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can not read KASLR.\n");
    return result;
  }

  if (kaslr.magic != kKaslrMagicCode) {
    LOG_I("KASLR Magic code is NOT found.\n");
    *kaslr_offset = 0;
  } else {
    LOG_I("KASLR Magic code is found.\n");
    LOG_D("KASLR Offset is 0x%08x.\n", kaslr.offset);
    *kaslr_offset = kaslr.offset;
  }

  return PA_TZ_SUCCESS;
}

PaTzResult PlatformCheckNonSecureRegion(PhysicalAddress phys_addr, uint32_t region_size) {
  uint32_t type = 0;
  if (DRAPI_OK == drApiGetPhysMemType64(&type, phys_addr, region_size)) {
    if (DRAPI_PHYS_MEM_TYPE_SECURE == type) {
      LOG_E("Memory region is secure.\n");
      return PA_TZ_GENERAL_ERROR;
    }
  } else {
    LOG_E("drApiGetPhysMemType64() failed.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  return PA_TZ_SUCCESS;
}
