#include "gaf.h"
#include "driver_log.h"
#include "memory.h"
#include "kernel_access.h"

static const uint16_t kGafVersion = 0x0600;

/**
 * @brief Parse GAF info structure is located at a specific address which
 * is passed in parameter
 * @param [in] gaf_virt Virtual address of GAF structure for kernel space
 * @param [out] gaf Output GAF structure
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult ParseGafInfo(KernelAddress gaf_virt, GafInfo *gaf);

/**
 * @brief Check checksum for GAF structure
 * @param [in] gaf GAF structure
 * @return ::1 in case of success, ::0 checksum is bad
 */
static uint32_t IsChecksumGafOk(const GafInfo *gaf);

/**
 * @brief Calculate checksum for GAF structure
 * @param [in] data Pointer to buffer
 * @param [in] size Size of data
 * @return Calculated checksum
 */
static uint16_t CalcChecksum(const uint8_t *data, size_t size);

/**
 * @brief Dump important fields of GAF to log
 * @param [in] gaf GAF structure
 */
static void GafDump(const GafInfo *gaf);


PaTzResult LoadGafInfo(KernelAddress gaf_virt, GafInfo *gaf) {
  if (!gaf) {
    LOG_E("Invalid arguments\n");
    return PA_TZ_GENERAL_ERROR;
  }

  if (ParseGafInfo(gaf_virt, gaf) != PA_TZ_SUCCESS) {
    LOG_E("GAF structure can not be read.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  if (gaf->size != sizeof(GafInfo)) {
    LOG_E("GAF structure has invalid size.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  if (gaf->ver < kGafVersion) {
    LOG_E("GAF structure has unsupported version.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  if (gaf->ver > kGafVersion) {
    LOG_I("GAF versions does not match.\n");
  }

  if (!IsChecksumGafOk(gaf)) {
    LOG_E("GAF structure has incorrect checksum.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  return PA_TZ_SUCCESS;
}

static PaTzResult ParseGafInfo(KernelAddress gaf_virt, GafInfo *gaf) {
  LOG_V("Start load GAF INFO\n");

  PaTzResult result = KernelAccessGetBytes(NULL, gaf_virt, sizeof(GafInfo), gaf);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Failed get info GAF\n");
    return PA_TZ_GENERAL_ERROR;
  }

  GafDump(gaf);

  return PA_TZ_SUCCESS;
}

static void GafDump(const GafInfo *gaf) {
  LOG_V("GAF debug info:\n");
  LOG_V("sizeof(GafInfo): %d\n", sizeof(GafInfo));
  LOG_V("GAF structure:\n");
  LOG_V("ver:                         0x%04X\n", gaf->ver);
  LOG_V("size:                        0x%08X\n", gaf->size);
  LOG_V("state:                       0x%04X\n", gaf->task_struct_struct_state);
  LOG_V("comm:                        0x%04X\n", gaf->task_struct_struct_comm);
  LOG_V("tasks:                       0x%04X\n", gaf->task_struct_struct_tasks);
  LOG_V("pid:                         0x%04X\n", gaf->task_struct_struct_pid);
  LOG_V("mm_struct:                   0x%04X\n", gaf->task_struct_struct_mm);
  LOG_V("pgd:                         0x%04X\n", gaf->mm_struct_struct_pgd);
  LOG_V("mmap:                        0x%04X\n", gaf->mm_struct_struct_mmap);
  LOG_V("mm_rb:                       0x%04X\n", gaf->mm_struct_struct_mm_rb);
  LOG_V("vm_start:                    0x%04X\n", gaf->vm_area_struct_struct_vm_start);
  LOG_V("vm_end:                      0x%04X\n", gaf->vm_area_struct_struct_vm_end);
  LOG_V("vm_next:                     0x%04X\n", gaf->vm_area_struct_struct_vm_next);
  LOG_V("vm_file:                     0x%04X\n", gaf->vm_area_struct_struct_vm_file);
  LOG_V("vm_flags:                    0x%04X\n", gaf->vm_area_struct_struct_vm_flags);
  LOG_V("vm_rb:                       0x%04X\n", gaf->vm_area_struct_struct_vm_rb);
  LOG_V("file_struct_f_path:          0x%04X\n", gaf->file_struct_f_path);
  LOG_V("path_struct_mnt:             0x%04X\n", gaf->path_struct_mnt);
  LOG_V("path_struct_dentry:          0x%04X\n", gaf->path_struct_dentry);
  LOG_V("dentry_struct_d_parent:      0x%04X\n", gaf->dentry_struct_d_parent);
  LOG_V("dentry_struct_d_name:        0x%04X\n", gaf->dentry_struct_d_name);
  LOG_V("qstr_struct_name:            0x%04X\n", gaf->qstr_struct_name);
  LOG_V("qstr_struct_len:             0x%04X\n", gaf->qstr_struct_len);
  LOG_V("struct_mount_mnt_mountpoint: 0x%04X\n", gaf->struct_mount_mnt_mountpoint);
  LOG_V("struct_mount_mnt:            0x%04X\n", gaf->struct_mount_mnt);
  LOG_V("struct_mount_mnt_parent:     0x%04X\n", gaf->struct_mount_mnt_parent);
  LOG_V("list_head_struct_next:       0x%04X\n", gaf->list_head_struct_next);
  LOG_V("list_head_struct_prev:       0x%04X\n", gaf->list_head_struct_prev);
  LOG_V("is_kdp_ns_on:                %d\n", gaf->is_kdp_ns_on);
  LOG_V("task_struct_integrity:       0x%04X\n", gaf->task_struct_integrity);
  LOG_V("proca_task_descr_task:              0x%04X\n", gaf->proca_task_descr_task);
  LOG_V("proca_task_descr_proca_identity:    0x%04X\n", gaf->proca_task_descr_proca_identity);
  LOG_V("proca_task_descr_pid_map_node:      0x%04X\n", gaf->proca_task_descr_pid_map_node);
  LOG_V("proca_task_descr_app_name_map_node: 0x%04X\n", gaf->proca_task_descr_app_name_map_node);
  LOG_V("proca_identity_struct_certificate:      0x%04X\n", gaf->proca_identity_struct_certificate);
  LOG_V("proca_identity_struct_certificate_size: 0x%04X\n", gaf->proca_identity_struct_certificate_size);
  LOG_V("proca_identity_struct_parsed_cert:      0x%04X\n", gaf->proca_identity_struct_parsed_cert);
  LOG_V("proca_identity_struct_file:             0x%04X\n", gaf->proca_identity_struct_file);
  LOG_V("file_struct_f_signature:       0x%04X\n", gaf->file_struct_f_signature);
  LOG_V("proca_table_hash_tables_shift: 0x%04X\n", gaf->proca_table_hash_tables_shift);
  LOG_V("proca_table_pid_map:           0x%04X\n", gaf->proca_table_pid_map);
  LOG_V("proca_table_app_name_map:      0x%04X\n", gaf->proca_table_app_name_map);
  LOG_V("proca_certificate_struct_app_name:      0x%04X\n", gaf->proca_certificate_struct_app_name);
  LOG_V("proca_certificate_struct_app_name_size: 0x%04X\n", gaf->proca_certificate_struct_app_name_size);
  LOG_V("hlist_node_struct_next:      0x%04X\n", gaf->hlist_node_struct_next);
  LOG_V("struct_vfsmount_bp_mount:    0x%04X\n", gaf->struct_vfsmount_bp_mount);
}

static uint16_t CalcChecksum(const uint8_t *data, size_t size) {
  size_t i;
  uint16_t sum;

  for (sum = 0, i = 0; i < size; i++) {
    if (sum & 0x8000) {
      sum = ((sum << 1) | 1) ^ data[i];
    } else {
      sum = (sum << 1) ^ data[i];
    }
  }

  return sum;
}

static uint32_t IsChecksumGafOk(const GafInfo *gaf) {
  uint16_t checksum = CalcChecksum((uint8_t *)gaf,
                                   offsetof(GafInfo, GAFINFOCheckSum));

  LOG_V("Effective GAF structure checksum: 0x%04x\n", gaf->GAFINFOCheckSum);
  LOG_V("Calculated GAF structure checksum: 0x%04x\n", checksum);

  if (checksum != gaf->GAFINFOCheckSum) {
    LOG_E("GAF checksum isn't valid!\n");
    return 0;
  }

  if (checksum == 0) {
    LOG_E("GAF structure checksum equals 0x00, please check if it is correctly" \
          "calculated in this kernel\n");
  }

  LOG_V("GAF structure checksum is valid!\n");

  return 1;
}
