#include "tlb.h"

#include "driver_log.h"

#include <tee_internal_api.h>

#define AP_MASK 0x3
#define AP_RO 0x3
#define AP_RW 0x1

#define AP_SHIFT_POS 6

#define NOT_EXECUTE_AP_MASK (1ULL << 54)

/**
 * Type of page table in four-level page tables architecture
 */
typedef enum {
  kPud = 0,      //!< Page upper directory
  kPmd = 1,      //!< Page middle directory
  kPte = 2,      //!< Page table entry
  kPgd = 3,      //!< Page global directory
  kLastType      //!< Last element
} PageTableEntryType;

/**
 * Table with constant for determine offsets in TLBs
 */
static const uint64_t kPageTableShifts[][2] = {
    {PTRS_PER_PUD, PUD_LEVEL_SHIFT},
    {PTRS_PER_PMD, PMD_LEVEL_SHIFT},
    {PTRS_PER_PTE, PTE_LEVEL_SHIFT}
};

/**
 * Calculates physical address by its virtual mapping.
 * @param   [in]  pte the current PTE.
 * @param   [in]  vaddr virtual address.
 * @return  The physical mapping address of the virtual one.
 */
static inline PhysicalAddress PteToPhysical(PhysicalAddress pte, ProcessAddress vaddr);

/**
 * Applies mask to the table address.
 * @param   [in]  table memory table physical address.
 * @param   [in]  type Type of page table.
 * @return  Fixed table address.
 */
static inline PhysicalAddress GetPageTableAddress(PhysicalAddress table,
                                                  PageTableEntryType type) {
  // Number of significant bits (starting from the lowest bit).
  uint64_t ptrs_per_level = kPageTableShifts[type][0];

  return PageAddress(table & PTRS_LEVEL_MASK(ptrs_per_level));
}

/**
 * Calculates index of table pages.
 * @param   [in]  vaddr User space virtual REE address.
 * @param   [in]  type Type of page table.
 * @return  An index in the table.
 */
static inline uint32_t GetPageTableIndex(ProcessAddress vaddr,
                                         PageTableEntryType type) {
  // Bit shift value in the address
  uint64_t level_shift =  kPageTableShifts[type][1];

  /*
   * As for now the size of the offset is 9 bits (512 tables, 8 byte each entry = 4096 bytes, i.e page),
   * so the word size is optimal as a return value.
   */
  return ((vaddr >> level_shift) & ADDR_LEVEL_MASK);
}

/**
 * Convert addresses from PTE table to physical.
 * @param   [in]  pte Physical address of PTE.
 * @param   [in]  vaddr User space virtual REE address.
 * @return  Physical address.
 */
static inline PhysicalAddress PteToPhysical(PhysicalAddress pte,
                                            ProcessAddress vaddr) {
  return (PageAddress(pte & PTRS_LEVEL_MASK(PTRS_PER_PTE))
      + PageOffset(vaddr));
}

/**
 * Fill access permission flags.
 * @param  [in]  pte PTE info.
 * @param  [out]  access permission flags.
 */
static inline void FillAccessPermissionFlags(PhysicalAddress pte, AccessPermissionFlags *flags) {
  if (!flags) return;

  uint8_t ap_flag = (pte >> AP_SHIFT_POS) & AP_MASK;
  if (ap_flag == AP_RW) {
    flags->read = flags->write = 1;
  } else if (ap_flag == AP_RO) {
    flags->read = 1;
    flags->write = 0;
  } else {
    flags->read = flags->write = 0;
  }

  flags->exec = ((pte & NOT_EXECUTE_AP_MASK)  == NOT_EXECUTE_AP_MASK)? 0: 1;
}

/**
 * Fill physical addresses in page table.
 * @param   [in]  context TLB converter context.
 * @param   [in]  type Type of page table.
 * @param   [in]  base_phys Base physical address from sequence PGD->PUD->PMD->PTE in page table
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult FillMapTableAddress(TlbConverterInfo *context,
                                      PageTableEntryType type) {
  if (!context) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  size_t idx = (size_t)type;
  if (idx > kPte) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  // Base physical address from sequence PGD->PUD->PMD->PTE in page table
  PhysicalAddress base_phys;
  if (type == kPud) {
    base_phys = context->phys_pgd;
  } else {
    base_phys = context->buffer[type - 1][context->index_table[type - 1]];
  }

  PhysicalAddress phys = GetPageTableAddress(base_phys, type);

  LOG_V("TLB info: idx: %u\n", idx);
  LOGADDR_V("\tbase_phys: ", base_phys);
  LOGADDR_V("\tphys: ", phys);

  PaTzResult result = PhysicalGetBytes(phys, sizeof(context->buffer[idx]),
                                       &context->buffer[idx][0]);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Cant't get bytes.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  return PA_TZ_SUCCESS;
}

/**
 * Update TLB context from page table.
 * @param   [in]  context TLB converter context.
 * @param   [in]  user_addr Virtual user address.
 * @param   [in]  type Type of page table.
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult UpdateTlbContext(TlbConverterInfo *context,
                                   ProcessAddress user_addr,
                                   PageTableEntryType type) {
  if (!context) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  size_t idx = (size_t)type;
  if (idx > kPte) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  uint32_t index_table = GetPageTableIndex(user_addr, type);

  if (index_table >= kSizeTable) {
    LOG_E("Out of range index_table.\n");
    LOG_D("Current index_table: %u\n", index_table);
    return PA_TZ_GENERAL_ERROR;
  }

  if (context->buffer[idx][index_table] == 0) {
    return PA_TZ_GENERAL_ERROR;
  }

  context->index_table[idx] = index_table;

  return PA_TZ_SUCCESS;
}
/**
 * Get physical address from user virtual address using page tables.
 * @param   [in]  context TLB converter context.
 * @param   [in]  user_virt User space virtual REE address.
 * @param   [out] addr The physical mapping address of the virtual one.
 * @param   [out] flags Access permission flags.
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult GetPhysAddressFromPte(TlbConverterInfo *context,
                                        ProcessAddress user_virt,
                                        PhysicalAddress *addr,
                                        AccessPermissionFlags* flags) {
  if (!context || !addr) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  uint32_t index_table = context->index_table[kPte];

  if (index_table >= kSizeTable ) {
    LOG_E("Invalid arguments index_table.\n");
    LOG_D("Current index_table: %u\n", index_table);
    return PA_TZ_GENERAL_ERROR;
  }

  if (context->buffer[kPte][index_table] == 0) {
    return PA_TZ_GENERAL_ERROR;
  }

  PhysicalAddress pte_phys = context->buffer[kPte][index_table];

  FillAccessPermissionFlags(pte_phys, flags);

  PhysicalAddress real_phys = PteToPhysical(pte_phys, user_virt);
  if (real_phys == 0) {
    return PA_TZ_GENERAL_ERROR;
  }

  *addr = real_phys;

  return PA_TZ_SUCCESS;
}

PaTzResult TaskAddressToPhysical(TlbConverterInfo* context, KernelAddress pgd_address,
                                 ProcessAddress user_virt, PhysicalAddress *phys,
                                 AccessPermissionFlags* flags) {
  if (!context || pgd_address == 0 || user_virt == 0 || !phys) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  PaTzResult result = PA_TZ_SUCCESS;
  uint32_t conv_match_counter = 0;
  uint32_t pud_index, pmd_index, pte_index;

  if (context->pgd == pgd_address) {
    ++conv_match_counter;

    pud_index = GetPageTableIndex(user_virt, kPud);

    if (pud_index == context->index_table[kPud]) {
      // PUD offsets match.
      ++conv_match_counter;
      pmd_index = GetPageTableIndex(user_virt, kPmd);

      if (pmd_index == context->index_table[kPmd]) {
        // PMD offsets match.
        ++conv_match_counter;
        pte_index = GetPageTableIndex(user_virt, kPte);

        if (pte_index == context->index_table[kPte]) {
          // PTE offsets match.
          ++conv_match_counter;
        }
      }
    }
  }

  LOG_V("TLB converter step counter is %d\n", conv_match_counter);

  /*
   * The converter fix section.
   * Map tables starting from those which offset doesn't match:
   * 0: from PUD;
   * 1: from PMD;
   * 2: only PTE.
   * 3: PTE to phys
   * Note: break is only used for error catching by design.
   */
  switch (conv_match_counter) {
    case 0: {
      TEE_MemFill(context, 0, sizeof(*context));

      PhysicalAddress pgd = KernelVirtToPhys(pgd_address);
      if (!IsPhysicalAddress(pgd)) {
        LOG_E("Physical address PGD isn't valid.\n");
        result = PA_TZ_GENERAL_ERROR;
        break;
      }

      context->pgd = pgd_address;
      context->phys_pgd = pgd;

      result = FillMapTableAddress(context, kPud);
      if (result != PA_TZ_SUCCESS) {
        LOG_E("The PUD is not valid pgd\n");
        break;
      }
    }
    /* no break */
    case 1: {
      result = UpdateTlbContext(context, user_virt, kPud);
      if (result != PA_TZ_SUCCESS) {
        LOG_V("The PUD is not valid\n");
        break;
      }

      result = FillMapTableAddress(context, kPmd);
      if (result != PA_TZ_SUCCESS) {
        LOG_V("The PMD is not valid\n");
        break;
      }
    }
    /* no break */
    case 2: {
      result = UpdateTlbContext(context, user_virt, kPmd);
      if (result != PA_TZ_SUCCESS) {
        LOG_V("The PMD is not valid\n");
        break;
      }
      // Also it must be updated PTE table
      result = FillMapTableAddress(context, kPte);
      if (result != PA_TZ_SUCCESS) {
        LOG_V("The PTE is not valid\n");
        break;
      }
    }
    /* no break */
    case 3: {
      result = UpdateTlbContext(context, user_virt, kPte);
      if (result != PA_TZ_SUCCESS) {
        LOG_V("The PTE is not valid\n");
        break;
      }
    }
    /* no break */
    case 4: {
      result = GetPhysAddressFromPte(context, user_virt, phys, flags);
      if (result != PA_TZ_SUCCESS) {
        LOG_V("The PTE is not valid\n");
        break;
      }

      break;
    }
    default: {
      LOG_E("Unknown case.\n");
      LOG_D("case: %u\n", conv_match_counter);
      result = PA_TZ_GENERAL_ERROR;
    }
  }

  return result;
}
