#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <signal.h>
#include <string.h>
#include <unistd.h>

#include <sys/file.h>
#include <sys/stat.h>
#include <sys/un.h>
#include <sys/epoll.h>
#include <sys/signalfd.h>

#include <tee_client_api.h>
#include "hidl_proca.h"
#include "serialize.h"

#include "daemon_log.h"

enum {
  PA_FAILED = -1,
  PA_SUCCESS = 0,
  PA_TERMINATE = 1
};

enum {
  kPaTzDrvPingPeriod = 60,
  kMaxEpollEvents = 10,
  kMaxPingFailuresToRestart = 3
};

#if defined(VERSION_NAME) && defined(VERSION_SUFFIX)
#define VERSION_STRING (VERSION_NAME VERSION_SUFFIX)
#else
#error "Please, define VERSION_NAME and VERSION_SUFFIX"
#endif

static const char kVersionString[] = VERSION_STRING;

static void PrintVersion(void);

/**
 * @brief Prepare signal handling fd and daemon related things
 * @return PA_FAILED on error, fd for signal handling
 */
static int Daemonize(void);

/**
 * @brief Handle response from NWd library.
 * @param [in] arg Some argument that passed to HidlServerRun
 * @param [in,out] buffer,buffer_size Array that will be shared with Trustlet
 * @return PA_SUCCESS if success, PA_FAILED for other cases.
 */
static int HandleClientRequest(void *arg, uint8_t *buffer, size_t *buffer_size);

/**
 * @brief Read incoming signals info from sig_poll_fd and handle it
 * @param [in] session Pointer to session TEE.
 * @param [in] sig_poll_fd signals file descreaptor
 * @return PA_SUCCESS on success
 *         PA_TERMINATE when daemon should terminate
 *         PA_FAILED for other cases.
 */
static int ProcessIncomingSignals(TEEC_Session *session, int sig_poll_fd);

/**
 * @brief Pings proca TEE to check whether it is alive
 * @param [in] session TrustZone session object
 * @return PA_SUCCESS if success, PA_TERMINATE when TEE doesn't respond
 */
static int PingPaTEE(TEEC_Session *session);

/**
 * @brief Utility functions to create epoll and register watchers
 * @param [in] hidl_fd
 * @param [in] sig_poll_fd
 * @return fd of poll object on success, PA_FAILED on error
 */
static int InitEpollEvents(int hidl_fd, int sig_poll_fd);

/**
 * @brief Load TEE driver.
 * @param [in] context Pointer to context TEE.
 * @param [in] session Pointer to session TEE.
 * @return PA_SUCCESS if success, PA_FAILED for other cases.
 */
static int LoadTeeDriver(TEEC_Context *context, TEEC_Session *session);
/**
 * @brief Unload TEE driver.
 * @param [in] context Pointer to context TEE.
 * @param [in] session Pointer to session TEE.
 * @return PA_SUCCESS if success, PA_FAILED for other cases.
 */
static int UnLoadTeeDriver(TEEC_Context *context, TEEC_Session *session);

/**
 * @brief Send command data to TEE driver.
 * @param [in] context Pointer to context TEE.
 * @param [in] session Pointer to session TEE.
 * @param [in] command_data Pointer to command data.
 * @param [in] command_data_len Size of command data.
 * @return PA_SUCCESS if success, PA_FAILED for other cases.
 */
static int SendCommandToTeeDriver(TEEC_Session *session, uint32_t command_id,
                                  void *command_data,
                                  size_t *command_data_len);

/**
 * @brief Start HIDL server main loop in which it polls HIDL fd events
 * @param [in] session Pointer to session TEE
 * @param [in] sig_poll_fd fd to poll for incoming signals
 * @return should not return from this function on success
 */
static int HidlServerMainLoop(TEEC_Session *session, int sig_poll_fd);

int main(int argc, const char** argv) {
  PrintVersion();
  LOG_I("PA daemon is starting.\n");

  int sig_poll_fd = Daemonize();
  if (sig_poll_fd == PA_FAILED) {
    LOG_E("Daemonize failed.\n");
    exit(EXIT_FAILURE);
  }

  TEEC_Context context = {0};
  TEEC_Session session = {0};

  do {
    // Open connection to TA (TZ) or NWd driver (nonTZ)
    int checker = LoadTeeDriver(&context, &session);

    if (checker != PA_SUCCESS) {
      LOG_E("PA daemon has not sent the config to the driver.\n");
      LOG_D("Received error: %d.\n", checker);
      break;
    }

    LOG_I("PA daemon has sent the config to the driver.\n");

    checker = HidlServerMainLoop(&session, sig_poll_fd);

    LOG_W("PA daemon has exited from Server Mode.\n");
  } while (0);

  if (UnLoadTeeDriver(&context, &session) != PA_SUCCESS) {
    LOG_E("Cannot close connection to TA/Driver.\n");
  }

  close(sig_poll_fd);

  LOG_I("PA Daemon is stopped.\n");

  exit(EXIT_SUCCESS);
}

static int ProcessIncomingSignals(TEEC_Session *session, int sig_poll_fd) {

  struct signalfd_siginfo fdsi;
  ssize_t b_read = read(sig_poll_fd, &fdsi, sizeof(fdsi));
  if (b_read != sizeof(fdsi)) {
    LOG_W("Read on signal fd returned less bytes then expected\n");
    LOG_I("Expected %u, got %d bytes\n", sizeof(fdsi), b_read);
    return 0;
  }

  switch (fdsi.ssi_signo) {
  case SIGHUP:
    LOG_I("Received %s signal.\n", strsignal(fdsi.ssi_signo));
    break;
  case SIGINT:
  case SIGTERM:
    LOG_I("PA Daemon has got signal to stop.\n");
    return PA_TERMINATE;
  case SIGALRM: {
    int ret = PingPaTEE(session);
    return ret;
  }
  default:
    LOG_W("Unhandled signal %s.\n", strsignal(fdsi.ssi_signo));
    break;
  }
  return PA_SUCCESS;
}

static int PingPaTEE(TEEC_Session *session) {
  static unsigned int ping_failures_num = 0;

  int result = PA_SUCCESS;

  PaDriverCommand_t command_request = {PaDriverCommand_PR_NOTHING};
  command_request.present = PaDriverCommand_PR_ping;

  uint8_t buff[kSerializedDataMaxSize];
  size_t buff_size = sizeof(buff);

  memset(buff, 0, buff_size);

  result = PaEncodeDriverCommand(&command_request, buff,
                                 (uint32_t*)&buff_size);
  if (result == -1) {
    LOG_E("Cannot encode command.\n");
    return PA_TERMINATE;
  }

  int checker = SendCommandToTeeDriver(session, 0, buff, &buff_size);
  if ((unsigned int)checker == TEEC_ERROR_BAD_PARAMETERS) {
    // new alarms will not be triggered any more
    LOG_I("Got bad params error, probably ping command is not supported\n");
    return PA_SUCCESS;
  } else if (checker != PA_SUCCESS) {
    ++ping_failures_num;

    LOG_D("Failed to ping tz driver for %d times in a row.\n",
                                ping_failures_num);
    LOG_D("Received error: %d.\n", checker);

    if (ping_failures_num == kMaxPingFailuresToRestart) {
      LOG_E("Failed to ping tz driver for 3 times in a row, terminating.\n");
      return PA_TERMINATE;
    }
  } else if (checker == PA_SUCCESS) {
    ping_failures_num = 0;
  }

  alarm(kPaTzDrvPingPeriod);

  return PA_SUCCESS;
}

static int Daemonize(void) {
  sigset_t blockedSignals;
  memset(&blockedSignals, 0, sizeof(sigset_t));

  sigemptyset(&blockedSignals);
  sigaddset(&blockedSignals, SIGCHLD);
  sigaddset(&blockedSignals, SIGTSTP);
  sigaddset(&blockedSignals, SIGTTOU);
  sigaddset(&blockedSignals, SIGTTIN);

  sigaddset(&blockedSignals, SIGHUP);
  sigaddset(&blockedSignals, SIGTERM);
  sigaddset(&blockedSignals, SIGINT);
  sigaddset(&blockedSignals, SIGALRM);

  sigprocmask(SIG_BLOCK, &blockedSignals, NULL);

  sigset_t pollSignals;

  sigemptyset(&pollSignals);

  sigaddset(&pollSignals, SIGHUP);
  sigaddset(&pollSignals, SIGTERM);
  sigaddset(&pollSignals, SIGINT);
  sigaddset(&pollSignals, SIGALRM);

  int sfd = signalfd(-1, &pollSignals, 0);
  if (sfd == -1) {
    LOG_E("signalfd failed.\n");
    LOG_I("error - %s\n", strerror(errno));
    return PA_FAILED;
  }

  // set newly created file permissions
  umask(000);

  return sfd;
}

static int HandleClientRequest(void *arg, uint8_t *buffer, size_t *buffer_size) {
  TEEC_Session *session = (TEEC_Session *)arg;
  LOG_D("Read bytes from client: %d.\n", *buffer_size);

  ssize_t checker = SendCommandToTeeDriver(session, 0, buffer,
      buffer_size);
  if (checker != PA_SUCCESS) {
    LOG_E("Failed to send data.\n");
    LOG_D("Received error: %d.\n", checker);
  } else {
    LOG_D("Status of sending data is %d.\n", checker);
  }

  return checker;
}

static int InitEpollEvents(int hidl_fd, int sig_poll_fd) {
  int epoll_fd = epoll_create1(0);
  if (epoll_fd == -1) {
    LOG_E("epoll_create1 failed.\n");
    LOG_I("Error is %s\n", strerror(errno));
    return PA_FAILED;
  }

  struct epoll_event ev;
  ev.events = EPOLLIN;
  ev.data.fd = hidl_fd;

  int ret = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, hidl_fd, &ev);
  if (ret == -1) {
    LOG_E("epoll_ctl for HIDL fd failed.\n");
    LOG_I("Error is %s\n", strerror(errno));
    close(epoll_fd);
    return PA_FAILED;
  }

  ev.events = EPOLLIN;
  ev.data.fd = sig_poll_fd;

  ret = epoll_ctl(epoll_fd, EPOLL_CTL_ADD, sig_poll_fd, &ev);
  if (ret == -1) {
    LOG_E("epoll_ctl for signal fd failed.\n");
    LOG_I("Error is %s\n", strerror(errno));
    close(epoll_fd);
    return PA_FAILED;
  }

  return epoll_fd;
}

static int HidlServerMainLoop(TEEC_Session *session, int sig_poll_fd) {
  int hidl_fd = HidlServerRun(HandleClientRequest, session);
  if (hidl_fd < 0) {
    LOG_E("Failed to get polling fd of HIDL server.\n");
    LOG_I("Error is %s\n", strerror(errno));
    return hidl_fd;
  }

  int epoll_fd = InitEpollEvents(hidl_fd, sig_poll_fd);
  if (epoll_fd == PA_FAILED) {
    goto hidl_server_cleanup;
  }

  alarm(kPaTzDrvPingPeriod);

  struct epoll_event epoll_events[kMaxEpollEvents];

  for (;;) {
    int nfds = epoll_wait(epoll_fd, epoll_events, kMaxEpollEvents, -1);
    if (nfds == -1) {
      if (errno == EINTR) {
        continue;
      }
      LOG_E("epoll_wait failed.\n");
      LOG_I("Error failed %s\n", strerror(errno));
      goto epoll_cleanup;
    }

    int ret;
    for (int i = 0; i < nfds; ++i) {
      if (epoll_events[i].data.fd == hidl_fd) {
        ret = HidlServerProcessRequest(hidl_fd);
        if (ret) {
          LOG_W("Unable to process HIDL client request\n");
        }
      } else if (epoll_events[i].data.fd == sig_poll_fd) {
        ret = ProcessIncomingSignals(session, sig_poll_fd);
        if (ret) {
          LOG_I("Exiting from PROCA daemon main loop after signal.\n");
          goto epoll_cleanup;
        }
      } else {
        LOG_W("Got epoll events from unknown fd.\n");
        LOG_I("fd is %d, event is %d\n",
                epoll_events[i].data.fd, epoll_events[i].events);
      }
    }
  }

epoll_cleanup:
  close(epoll_fd);

hidl_server_cleanup:
  HidlServerStop(hidl_fd);

  return PA_FAILED;
}

static int LoadTeeDriver(TEEC_Context *context, TEEC_Session *session) {
  if (!context || !session) {
    LOG_E("Invalid arguments.\n");
    return PA_FAILED;
  }

#if defined(QSEE)
  const TEEC_UUID kClientTaUuid = {.timeLow = 0x70726f63, .timeMid = 0x6100,
      .timeHiAndVersion = 0x0000, .clockSeqAndNode = {0x00, 0x00, 0x00, 0x00,
          0x00, 0x00, 0x00, 0x00}};
#elif defined(TEEGRIS)
  const TEEC_UUID kClientTaUuid = {.timeLow = 0x00000000, .timeMid = 0x0000,
      .timeHiAndVersion = 0x0000, .clockSeqAndNode = {0x00, 0x00, 0x00, 0x50,
          0x52, 0x4f, 0x43, 0x41}};
#elif defined(TBASE)
  const TEEC_UUID kClientTaUuid = {.timeLow = 0xffffffff, .timeMid = 0xD000,
      .timeHiAndVersion = 0x0000, .clockSeqAndNode = {0x00, 0x00, 0x00, 0x00,
          0x00, 0x00, 0x00, 0x62}};
#else
#  error "Unknown TEE!"
#endif

  TEEC_Result tee_result = TEEC_InitializeContext(NULL, context);

  if (tee_result != TEEC_SUCCESS) {
    LOG_E("Can't initialize context.\n");
    LOG_D("Received error: 0x%x.\n", tee_result);
    return PA_FAILED;
  }

  uint32_t return_origin;
  TEEC_Operation operation;

  operation.paramTypes = TEEC_PARAM_TYPES(TEEC_MEMREF_TEMP_INOUT,
                                          TEEC_VALUE_INOUT, TEEC_NONE,
                                          TEEC_NONE);

  operation.params[0].tmpref.buffer = NULL;
  operation.params[0].tmpref.size = 0;

  tee_result = TEEC_OpenSession(context, session, &kClientTaUuid, 0, NULL,
                                &operation, &return_origin);
  if (tee_result != TEEC_SUCCESS) {
    LOG_E("Can't open session.\n");
    LOG_D("Received error: 0x%x.\n", tee_result);
    TEEC_FinalizeContext(context);
    return PA_FAILED;
  }

  return PA_SUCCESS;
}

static int UnLoadTeeDriver(TEEC_Context *context, TEEC_Session *session) {
  if (!context || !session) {
    LOG_E("Invalid arguments.\n");
    return PA_FAILED;
  }

  TEEC_CloseSession(session);
  TEEC_FinalizeContext(context);

  return PA_SUCCESS;
}

static int SendCommandToTeeDriver(TEEC_Session *session, uint32_t command_id,
                                  void *command_data,
                                  size_t *command_data_len) {
  if (!session || !command_data || !command_data_len) {
    LOG_E("Invalid arguments.\n");
    return PA_FAILED;
  }

  TEEC_Operation operation;
  operation.paramTypes = TEEC_PARAM_TYPES(TEEC_MEMREF_TEMP_INOUT,
                                          TEEC_VALUE_OUTPUT, TEEC_NONE,
                                          TEEC_NONE);
  operation.params[0].tmpref.buffer = NULL;
  operation.params[0].tmpref.size = 0;

  if (command_data) {
    operation.params[0].tmpref.buffer = command_data;
    operation.params[0].tmpref.size = *command_data_len;
  }

  uint32_t return_origin;

  TEEC_Result tee_result = TEEC_InvokeCommand(session, command_id, &operation,
                                              &return_origin);
  if (tee_result != TEEC_SUCCESS) {
    LOG_E("Can't send command to TEE driver.\n");
    LOG_D("Received error: 0x%x.\n", tee_result);
    return PA_FAILED;
  }

  uint32_t result = operation.params[1].value.a;

  LOG_V("Result command from driver: %u.\n", result);

  if (result == 0 && command_data_len) {
    *command_data_len = operation.params[0].tmpref.size;
  }

  return PA_SUCCESS;
}

static void PrintVersion(void) {
  TEES_Log(TEES_LOG_LEVEL_CRITICAL, __FUNCTION__, __LINE__, PLATFORM_LOG_TAG,
           "Version: %s", kVersionString);
}
