#include "task_parser.h"

#include "config.h"
#include "crypto.h"
#include "driver_log.h"
#include "gaf.h"
#include "kernel_access.h"
#include "task_access.h"
#include "serialize.h"
#include "memory.h"
#include "pa_certificate.h"

#include <tee_internal_api.h>

#include <scl/arithmetic.h>

enum {
  kOutputXattrHeaderSize = sizeof(uint64_t),
  kDentryMaxDepth = 16,
  kDentryNameMaxSize = 256,
  kGetFiveCertificateAttempts = 100
};

/**
 * @brief Type of FIVE signature
 */
typedef enum {
  kRsa = 3, //!< RSA signature
  kHmac = 5 //!< HMAC signature
} FiveSignatureType;

/**
 * @brief Type of HMAC signature hash algorithm
 */
typedef enum {
  kFiveSha1 = 2, //!< SHA1
  kFiveSha256 = 4, //!< SHA256
  kFiveSha512 = 6 //!< SHA512
} FiveHmacAlgo;

/**
 * @brief Format of FIVE signature
 */
typedef struct __attribute__((packed)) {
  uint8_t type; //!< xattr type
  union {
    struct __attribute__((packed)) {
      uint8_t algo;
      uint16_t label_size;
      uint8_t digest[0];
    } hmac;
    struct __attribute__((packed)) {
      uint8_t version; //!< Signature format version
      uint8_t hash_algo; //!< Digest algorithm [enum pkey_hash_algo]
      uint32_t keyid; //!< IMA key identifier - not X509/PGP specific
      uint16_t sig_size; //!< Signature size in BigEndian format
      uint8_t sig[0]; //!<  Signature payload
    } rsa; //!< signature format v2 - for using with asymmetric keys
  };
} FiveSignature;

/**
 * Length/Value structure
 */
typedef struct __attribute__((packed)) {
    uint16_t length;
    uint8_t value[];
} lv;

/**
 * @brief Get kernel certificate structure which is located in kernel space
 * @param [in] context Kernel access context struct
 * @param [in] address Address in kernel address space
 * @param [out] out_cert Parsed certificate
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult ParseKernelProcaCertificate(KernelAccessInfo *context,
                                     KernelAddress address,
                                     KernelProcaCertificate *out_cert);

/**
 * @brief Get PROCA identity structure which is located in kernel space
 * @param [in] address Address in kernel address space
 * @param [out] out_ident PROCA identity structure
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
static PaTzResult ParsePaIdentityStruct(KernelAddress address,
                                        PaIdentity *out_ident);

PaTzResult TaskParseBasicInfo(KernelAddress virt, TaskInfo *out_task) {
  if (out_task == NULL) {
    LOG_E("Can't be NULL pointer to out_task.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  if (!IsKernelVirtualAddress(virt)) {
    LOG_E("It is not kernel address.\n");
    LOGADDR_D("Address: ", virt);
    return PA_TZ_GENERAL_ERROR;
  }

  const GafInfo *gaf = &GetConfig()->gaf;
  KernelAccessInfo context = {0};
  TaskInfo task = {0};

  task.virt = virt;

  PaTzResult result = KernelAccessInit(&context);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can't initialize kernel access context.\n");
    LOG_D("Received result: 0x%08x\n", result);
    return result;
  }

  do {
    result = KernelAccessGetUint32(&context, virt + gaf->task_struct_struct_state,
                                   &task.state);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.state.\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetUint32(&context, virt + gaf->task_struct_struct_pid,
                                   &task.pid);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.pid.\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetPointer(&context, task.virt + gaf->task_struct_struct_mm,
                                    &task.mm);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.mm.\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    LOGADDR_V("\tmm: ", task.mm);
    if (task.mm == 0) {
      LOG_V("This is kernel thread: task.mm == NULL\n");
      result = PA_TZ_GENERAL_ERROR;
      break;
    }

    result = KernelAccessGetPointer(&context, task.mm + gaf->mm_struct_struct_pgd,
                                    &task.pgd);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.mm_struct_struct_pgd\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetPointer(&context, task.mm + gaf->mm_struct_struct_mmap,
                                    &task.mmap);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.mmap\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetBytes(&context, task.mm + gaf->mm_struct_struct_mm_rb,
        sizeof(task.mm_rb),
        &task.mm_rb);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.mm_rb\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    LOGADDR_V("\tpgd: ", task.pgd);
    LOGADDR_V("\tmmap: ", task.mmap);

    result = PA_TZ_SUCCESS;

  } while (0);

  PaTzResult result_deinit = KernelAccessDeinit(&context);
  if (result_deinit != PA_TZ_SUCCESS) {
    LOG_E("Can't deinitialize kernel access context.\n");
    LOG_D("Received result: 0x%08x\n", result_deinit);
  }

  if (result == PA_TZ_SUCCESS) {
    *out_task = task;

    LOG_V("Task basic info: pid: %u state: %u\n", task.pid, task.state);
    LOGADDR_V("\ttask.virt: ", task.virt);
  }

  return result;
}

PaTzResult TaskParseIntegrity(TaskInfo *out_task) {
  if (out_task == NULL) {
    LOG_E("Can't be NULL pointer to out_task.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  const GafInfo *gaf = &GetConfig()->gaf;
  KernelAccessInfo context = {0};
  TaskInfo task = *out_task;

  PaTzResult result = KernelAccessInit(&context);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can't initialize kernel access context.\n");
    LOG_D("Received result: 0x%08x\n", result);
    return result;
  }

  do {
    // task_struct_integrity is pointer to kernel task_integrity struct
    KernelAddress integrity = 0;
    result = KernelAccessGetPointer(&context, task.virt + gaf->task_struct_integrity,
                                    &integrity);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.integrity.\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    // "integrity" variable has address of task_integrity struct
    // struct task_integrity {
    //   atomic_t value;
    //   atomic_t count;
    // .... };
    // The only "value" field is interesting for us. It has offset 0.
    result = KernelAccessGetUint32(&context, integrity, &task.integrity);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.integrity.value\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    LOG_V("\tintegrity: %u\n", (uint32_t)task.integrity);

    result = PA_TZ_SUCCESS;

  } while (0);

  PaTzResult result_deinit = KernelAccessDeinit(&context);
  if (result_deinit != PA_TZ_SUCCESS) {
    LOG_E("Can't deinitialize kernel access context.\n");
    LOG_D("Received result: 0x%08x\n", result_deinit);
  }

  if (result == PA_TZ_SUCCESS) {
    *out_task = task;
  }

  return result;
}

/**
 * @brief Get information about VMA from task structure
 * @param [in] virt Start kernel address of VMA structure
 * @param [out] out_vma Pointer to vma structure
 * @return ::PA_TZ_SUCCESS in case of success, ::PA_TZ_GENERAL_ERROR
 */
PaTzResult TaskParseVmaInfo(KernelAddress virt, VmaInfo *out_vma) {
  if (out_vma == NULL) {
    LOG_E("Can't be NULL pointer to out_vma\n");
    return PA_TZ_GENERAL_ERROR;
  }

  KernelAccessInfo context = {0};

  PaTzResult result = KernelAccessInit(&context);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can't initialize kernel access context.\n");
    LOG_D("Received result: 0x%08x\n", result);
    return result;
  }

  const GafInfo *gaf = &GetConfig()->gaf;
  VmaInfo vma = {0};

  do {
    result = KernelAccessGetPointer(
        &context,
        virt + gaf->vm_area_struct_struct_vm_start,
        &vma.vm_start);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.vm_area_struct_struct_vm_start\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetPointer(
        &context,
        virt + gaf->vm_area_struct_struct_vm_end,
        &vma.vm_end);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.vm_area_struct_struct_vm_end\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetPointer(
        &context,
        virt + gaf->vm_area_struct_struct_vm_next,
        &vma.vm_next);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.vm_area_struct_struct_vm_next\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetPointer(
        &context,
        virt + gaf->vm_area_struct_struct_vm_file,
        &vma.vm_file);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.vm_area_struct_struct_vm_file\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetUint64(
        &context,
        virt + gaf->vm_area_struct_struct_vm_flags,
        &vma.vm_flags);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.gaf->vm_area_struct_struct_vm_flags\n");
      LOG_D("Received result: 0x%08x\n", result);
      break;
    }

    result = KernelAccessGetBytes(
        &context,
        virt + gaf->vm_area_struct_struct_vm_rb,
        sizeof(vma.vm_rb),
        &vma.vm_rb);
    if (result != PA_TZ_SUCCESS) {
      LOG_E("Can't get task.gaf->vm_area_struct_struct_vm_rb.\n");
      LOG_D("Received result: 0x%08x.\n", result);
      break;
    }

    LOGADDR_V("VMA info: virt: ", virt);
    LOGADDR_V("\tvirt.vm_next: ", vma.vm_next);
    LOGADDR_V("\tvma.vm_start: ", vma.vm_start);
    LOGADDR_V("\tvma.vm_end: ", vma.vm_end);
    LOG_V("\tvma.size: 0x%08x\n", (uint32_t)(vma.vm_end - vma.vm_start));
    LOG_V("\tvma.vm_flags: 0x%08x%08x\n",
          (uint32_t)(vma.vm_flags >> 32), (uint32_t)vma.vm_flags);
    LOGADDR_V("\tvma.vm_file: ", vma.vm_file);

    result = PA_TZ_SUCCESS;

  } while (0);

  PaTzResult result_deinit = KernelAccessDeinit(&context);
  if (result_deinit != PA_TZ_SUCCESS) {
    LOG_E("Can't deinitialize kernel access context.\n");
    LOG_D("Received result: 0x%08x\n", result_deinit);
  }

  if (result == PA_TZ_SUCCESS) {
    *out_vma = vma;
  }

  return result;
}

PaTzResult TaskParsePaCertificate(const TaskInfo *task,
                                  PaCertificate_t **certificate) {
  if (!task || !certificate) {
    LOG_E("Invalid argument.\n");
    return PA_TZ_GENERAL_ERROR;
  }
  uint64_t xattr_size = task->pa_identity.certificate_size;
  uint8_t memory[kMaxPaSignatureLength];
  PaTzResult result = KernelAccessGetBytes(NULL, task->pa_identity.certificate,
      xattr_size, memory);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can not read extra data from task.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  int checker = PaDecodeCertificate(memory, xattr_size, certificate);
  if (checker == -1 || task->certificate == NULL) {
    LOG_E("Failed decoding certificate.\n");
    LOG_D("Received result: %d\n", checker);
    return PA_TZ_GENERAL_ERROR;
  }

  return PA_TZ_SUCCESS;
}

/**
 * @brief Parse kernel to get FIVE signature
 * @note Kernel set FIVE signature in file_struct using special kworker in
 * async manner. So we mayn't have proper f_signature in this moment.
 * To reduce probability of the problem we will "poll" f_signature pointer
 * some time and return error only after timeout.
 * Max detected timeout on Star2 Android P was 20ms, so threshold set to 100ms.
 * There is enough for authentication, provisioning should be retry in Normal
 * World code in case of error.
 */
PaTzResult TaskParseFiveSignature(KernelAddress file_struct, void *signature, size_t *signature_len) {
  if (!signature || !signature_len) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  KernelAddress f_signature = 0;
  PaTzResult result;

  for (uint16_t attempt = 1; attempt <= kGetFiveCertificateAttempts; attempt++) {
    result = KernelAccessGetPointer(NULL,
        file_struct + GetConfig()->gaf.file_struct_f_signature,
        &f_signature);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("Cannot read pointer on FIVE signature.\n");
      return PA_TZ_GENERAL_ERROR;
    }

    if (f_signature == 0) {
      LOG_D("f_signature is NULL. Try again.\n");
      TEE_Wait(1);
    } else {
      break;
    }
  }

  LOGADDR_V("file_struct.f_signature: ", f_signature);
  if (f_signature == 0) {
    LOG_D("f_signature is NULL.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  uint8_t buffer[kSerializedDataMaxSize];
  result = KernelAccessGetBytes(NULL, f_signature, sizeof(buffer), buffer);
  if (PA_TZ_SUCCESS != result) {
    LOG_E("Cannot get FIVE signature.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  size_t buf_size = 0;
  const lv *header = NULL;
  for (size_t i = 0; i < 4; i++) {
    header = (lv *)(buffer + buf_size);
    buf_size += header->length + sizeof(*header);
  }

  TEE_MemMove(signature, buffer, buf_size);
  *signature_len = buf_size;

  LOG_V("Five signature size is: %d\n", *signature_len);
  LOG_V("Five signature is:\n");
  LOGM_V(signature, *signature_len);

  return PA_TZ_SUCCESS;
}

static PaTzResult ParseKernelProcaCertificate(KernelAccessInfo *context,
                                     KernelAddress address,
                                     KernelProcaCertificate *out_cert) {
  const GafInfo *gaf = &GetConfig()->gaf;

  KernelProcaCertificate cert = {0};

  PaTzResult result = KernelAccessGetPointer(context,
                          address + gaf->proca_certificate_struct_app_name,
                          &cert.app_name);
  if (PA_TZ_SUCCESS != result) {
    LOG_E("Cannot read pointer on cert.app_name\n");
    return result;
  }

  result = KernelAccessGetUint64(context,
                address + gaf->proca_certificate_struct_app_name_size,
                (uint64_t*)&cert.app_name_size);
  if (PA_TZ_SUCCESS != result) {
    LOG_E("Cannot read uint64 on cert.app_name_size\n");
    return result;
  }

  *out_cert = cert;

  return result;
}

static PaTzResult ParsePaIdentityStruct(KernelAddress address,
                                        PaIdentity *out_ident) {
  const GafInfo *gaf = &GetConfig()->gaf;

  PaIdentity pa_identity = {0};

  KernelAccessInfo context = {0};
  PaTzResult result = KernelAccessInit(&context);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can't initialize kernel access context.\n");
    LOG_D("Received result: 0x%08x\n", result);
    return result;
  }

  do {
    result = KernelAccessGetUint64(&context,
                  address + gaf->proca_identity_struct_certificate_size,
                  (uint64_t*)&pa_identity.certificate_size);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("Cannot read pa_identity.certificate_size uint64\n");
      break;
    }

    result = KernelAccessGetPointer(&context,
                  address + gaf->proca_identity_struct_certificate,
                  &pa_identity.certificate);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("Cannot read pa_identity.certificate pointer\n");
      break;
    }

    result = KernelAccessGetPointer(&context,
                  address + gaf->proca_identity_struct_file,
                  &pa_identity.file);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("Cannot read pa_identity.file pointer\n");
      break;
    }

    result = ParseKernelProcaCertificate(&context,
                  address + gaf->proca_identity_struct_parsed_cert,
                  &pa_identity.parsed_cert);
    if (PA_TZ_SUCCESS != result) {
      LOG_E("Cannot read pa_identity.app_name struct\n");
      break;
    }
  } while (0);

  PaTzResult result_deinit = KernelAccessDeinit(&context);
  if (result_deinit != PA_TZ_SUCCESS) {
    LOG_E("Can't deinitialize kernel access context.\n");
    LOG_D("Received result: 0x%08x\n", result_deinit);
  }

  if (pa_identity.file == 0 ||
      pa_identity.certificate == 0 ||
      pa_identity.certificate_size == 0) {
    LOG_D("Can't get proca_dentity.file or certificate. "
          "Received result: 0x%08x\n", result);
  } else {
    *out_ident = pa_identity;
  }

  return result;
}

PaTzResult TaskParsePaTaskDescriptorStruct(KernelAddress address,
                                           TaskDescriptor *out_descr) {
  if (!out_descr) {
    LOG_W("out_descr is null.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  TaskDescriptor descr = {0};

  descr.task_descr_address = address;

  const GafInfo *gaf = &GetConfig()->gaf;

  PaTzResult result = KernelAccessGetPointer(NULL,
      address + gaf->proca_task_descr_pid_map_node,
      &descr.map_next_ptrs[kTaskDescriptorPidMapNext]);
  if (result != PA_TZ_SUCCESS) {
    LOGADDR_V("Cannot get descr.map_next_ptrs[kTaskDescriptorPidMapNext]: ",
              address + gaf->proca_task_descr_pid_map_node);
    return PA_TZ_GENERAL_ERROR;
  }

  result = KernelAccessGetPointer(NULL,
      address + gaf->proca_task_descr_app_name_map_node,
      &descr.map_next_ptrs[kTaskDescriptorAppNameMapNext]);
  if (result != PA_TZ_SUCCESS) {
    LOGADDR_V("Cannot get descr.map_next_ptrs[kTaskDescriptorAppNameMapNext]: ",
          address + gaf->proca_task_descr_app_name_map_node);
    return PA_TZ_GENERAL_ERROR;
  }

  result = KernelAccessGetPointer(NULL, address + gaf->proca_task_descr_task,
                                  &descr.task_address);
  if (result != PA_TZ_SUCCESS) {
    LOGADDR_V("Cannot get descr.task_address: ",
              address + gaf->proca_task_descr_task);
    return PA_TZ_GENERAL_ERROR;
  }

  if (!descr.task_address) {
    LOG_V("Task address is descr.task_address is NULL.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  result = KernelAccessGetUint32(NULL,
                        descr.task_address + gaf->task_struct_struct_pid,
                        (uint32_t *)&descr.pid);
  if (result != PA_TZ_SUCCESS) {
    LOGADDR_V("Cannot get descr.pid: ",
              descr.task_address + gaf->task_struct_struct_pid);
    return PA_TZ_GENERAL_ERROR;
  }

  result = KernelAccessGetUint32(NULL,
                        descr.task_address + gaf->task_struct_struct_state,
                        (uint32_t *)&descr.state);
  if (result != PA_TZ_SUCCESS) {
    LOGADDR_V("Cannot get descr.state: ",
              descr.task_address + gaf->task_struct_struct_state);
    return PA_TZ_GENERAL_ERROR;
  }

  result = ParsePaIdentityStruct(
              address + gaf->proca_task_descr_proca_identity,
              &descr.pa_identity);
  if (result != PA_TZ_SUCCESS) {
    LOGADDR_V("Cannot get descr.pa_identity: ",
              address + gaf->proca_task_descr_proca_identity);
    return PA_TZ_GENERAL_ERROR;
  }

  *out_descr = descr;

  return result;
}
