#include "pa_provision.h"

#include "nwd_log.h"
#include "hidl_proca.h"
#include "xattr.h"
#include "serialize.h"

#include <fcntl.h>
#include <linux/limits.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <unistd.h>

static const char kPaXattrName[] = "user.pa";

/**
 * @brief Inner structure to contain information about mapped file
 */
typedef struct {
  int fd;
  void *mapped;
  size_t mapped_size;
} PaFileInfo;

enum {
  kPackageNameMaxSize = 1024,
  kPageSize = 4096,
  kProvisioningMaxAttempts = 3,
  kProvisioningTimeoutUSec = 50 * 1000, // 50 ms
};

/**
 * @brief Send buffer with command to PA daemon and receive response
 * @param [in] command,command_size Request buffer
 * @param [out] response,response_size Response buffer
 * @return ::PA_SUCCESS code in case of success and error code otherwise
 */
static PaResult ExchangeCommandViaHidl(const void *command, size_t command_size,
                                         void *response, size_t *response_size);

/**
 * @brief Send command to AP driver via daemon
 * @param [in] command_request Structure with command request
 * @param [out] command_response Structure with command response
 * @return ::PA_SUCCESS code in case of success and error code otherwise
 */
static PaResult SendCommandToDaemon(const PaDriverCommand_t *command_request,
                                    PaDriverCommandResponse_t **command_response);

/**
 * @brief Map file to memory
 * @param [in] fd File descriptor
 * @param [out] file_info Structure with information about mapped file
 * @return ::PA_SUCCESS code in case of success and error code otherwise
 */
static PaResult FileMap(int fd, PaFileInfo *file_info);

/**
 * @brief Unmap previously mapped file
 * @param [in] file_info Structure with information about mapped file
 */
static void FileUnmap(const PaFileInfo *file_info);

/**
 * @brief Write user.pa to xattr of file
 * @param [in] fd File descriptor
 * @param [in] certificate Signed PROCA certificate
 * @param [in] certificate_size Size of certificate buffer
 * @return ::PA_SUCCESS code in case of success and error code otherwise
 */
static PaResult WriteCertificateToXattr(int fd, const void *certificate, size_t certificate_size);


PaResult PaNewCertificate(int fd, const char *package_name, const uint8_t *rsa, size_t rsa_size) {
  if ((fd < 0) || !package_name || !rsa || (rsa_size == 0)) {
    LOG_E("Invalid arguments.\n");
    return PA_INVALID_ARGUMENTS;
  }

  LOG_V("Fd is %d.\n", fd);
  LOG_V("Package name is %s.\n", package_name);

  PaResult result = XattrFdHasAttr(fd, kPaXattrName, NULL);
  if (result == PA_SUCCESS) {
    LOG_D("PA certificate is already present. Provisioning is not needed.\n");
    return PA_ALREADY_PROVISIONED;
  }

  PaHandler handler = {0};
  result = PaHandlerCreate(&handler);
  if (result != PA_SUCCESS) {
    LOG_E("Cannot obtain PA HANDLER.\n");
    return result;
  }

  PaFileInfo apk_file = {0};
  result = FileMap(fd, &apk_file);
  if (result != PA_SUCCESS) {
    LOG_E("Mapping failed.\n");
    LOG_D("fd: %d.\n", fd);
    PaHandlerDestroy(&handler);
    return result;
  }

  PaDriverCommand_t command_request = {PaDriverCommand_PR_NOTHING};

  command_request.present = PaDriverCommand_PR_newCertificate;

  PaDriverCommandNewCertificate_t *certificate_new = &command_request.choice
      .newCertificate;
  certificate_new->handler.buf = (void *)&handler;
  certificate_new->handler.size = sizeof(handler);

  certificate_new->apkAddress.buf = (void *)&apk_file.mapped;
  certificate_new->apkAddress.size = sizeof(&apk_file.mapped);

  certificate_new->packageName.buf = (void *)package_name;
  certificate_new->packageName.size = strnlen(package_name, kPackageNameMaxSize);

  certificate_new->rsaPublicKey.buf = (void *)rsa;
  certificate_new->rsaPublicKey.size = rsa_size;

  for (uint16_t attempt = 1; attempt <= kProvisioningMaxAttempts; attempt++) {
    PaDriverCommandResponse_t *command_response = NULL;
    result = SendCommandToDaemon(&command_request, &command_response);
    if (result != PA_SUCCESS ||
        command_response->present != PaDriverCommandResponse_PR_provisioningResponse ||
        command_response->choice.provisioningResponse.result != PaDriverCommandResult_paSuccess) {
      LOG_D("Process Authenticator driver returns error.\n");
      result = PA_GENERAL_ERROR;
    }

    if (result == PA_SUCCESS) {
      result = WriteCertificateToXattr(fd, 
                                  command_response->choice.provisioningResponse.signedCertificate.buf, 
                                  command_response->choice.provisioningResponse.signedCertificate.size);
      PaFreeDriverCommandResponse(command_response);
      break;
    }
 
    PaFreeDriverCommandResponse(command_response);
    usleep(kProvisioningTimeoutUSec);
  }

  FileUnmap(&apk_file);
  PaHandlerDestroy(&handler);

  if (result != PA_SUCCESS) {
    LOG_E("PaNewCertificate is failed.\n");
  }

  return result;
}

static PaResult SendCommandToDaemon(
    const PaDriverCommand_t *command_request,
    PaDriverCommandResponse_t **command_response) {
  if (!command_request || !command_response) {
    LOG_E("Invalid arguments");
    return PA_INVALID_ARGUMENTS;
  }

  PaResult result = PA_SUCCESS;

  do {
    uint8_t request_buffer[kSerializedDataMaxSize];
    uint8_t response_buffer[kSerializedDataMaxSize];
    uint32_t request_buffer_size = sizeof(request_buffer);
    size_t response_buffer_size = sizeof(response_buffer);

    memset(request_buffer, 0, sizeof(request_buffer));
    memset(response_buffer, 0, sizeof(response_buffer));

    int result_encoded = PaEncodeDriverCommand(command_request, request_buffer,
                                               &request_buffer_size);
    if (result_encoded == -1) {
      LOG_E("Cannot encode command.\n");
      result = PA_GENERAL_ERROR;
      break;
    }

    result = ExchangeCommandViaHidl(request_buffer, request_buffer_size,
                                    response_buffer, &response_buffer_size);
    if (result != PA_SUCCESS) {
      LOG_E("Command to driver was not sent.\n");
      break;
    } else {
      LOG_V("Command to driver was sent SUCCESSFULLY.\n");
    }

    int result_decoded = PaDecodeDriverCommandResponse(&response_buffer,
                                                       response_buffer_size,
                                                       command_response);
    if (result_decoded == -1) {
      LOG_E("Cannot decode the response of command.\n");
      result = PA_GENERAL_ERROR;
      break;
    }

  } while (0);

  return result;
}

static PaResult ExchangeCommandViaHidl(const void *command, size_t command_size,
    void *response, size_t *response_size) {
  uint8_t buffer[kSerializedDataMaxSize];
  size_t size = command_size;

  memcpy(buffer, command, command_size);

  int res = HidlNewCertificate(buffer, &size);
  if (res != 0) {
    LOG_D("HidlNewCertificate = %d.\n", res);
    LOG_E("HIDL communication error.\n");
    return PA_GENERAL_ERROR;
  }

  memcpy(response, buffer, size);
  *response_size = size;

  return PA_SUCCESS;
}

static PaResult FileMap(int fd, PaFileInfo *file_info) {
  if (!file_info) {
    LOG_E("Invalid arguments.\n");
    return PA_INVALID_ARGUMENTS;
  }

  struct stat file_stat = {0};
  int ret = fstat(fd, &file_stat);
  if (ret != 0) {
    LOG_E("Can not stat file.\n");
    LOG_D("File with fd: %d.\n", fd);
    return PA_GENERAL_ERROR;
  }

  // We do not need to map whole file, just map 1 page or less
  size_t size_to_map = file_stat.st_size < kPageSize ? file_stat.st_size : kPageSize;

  void *mmaped = mmap(0, size_to_map, PROT_READ, MAP_PRIVATE, fd, 0);
  if (mmaped == MAP_FAILED) {
    LOG_E("Mapping failed.\n");
    LOG_D("Can not map fd %d size %d.\n", fd, size_to_map);
    return PA_GENERAL_ERROR;
  }

  file_info->fd = fd;
  file_info->mapped = mmaped;
  file_info->mapped_size = size_to_map;

  return PA_SUCCESS;
}

static void FileUnmap(const PaFileInfo *file_info) {
  if (!file_info) {
    LOG_E("Invalid arguments.\n");
    return;
  }

  int ret = munmap(file_info->mapped, file_info->mapped_size);
  if (ret != 0) {
    LOG_D("munmap return error for pointer %p, size %d.\n",
        file_info->mapped, file_info->mapped_size);
  }

  return;
}

static PaResult WriteCertificateToXattr(int fd, const void *certificate, size_t certificate_size) {
  if (certificate_size >= XATTR_SIZE_MAX) {
    LOG_E("Size of user.pa is bigger than max xattr size.\n");
    return PA_GENERAL_ERROR;
  }

  PaResult result = XattrUserPaFcntl(fd, certificate, certificate_size);
  if (result != PA_SUCCESS) {
    LOG_D("Try set xattr via standard API.\n");
    result = XattrFdWrite(fd, kPaXattrName, certificate, certificate_size);
  }

  if (result != PA_SUCCESS) {
    LOG_E("Can not write xattr.\n");
    LOG_D("xattr name: %s, fd: %d.\n", kPaXattrName, fd);
  } else {
    result = XattrFdHasAttr(fd, kPaXattrName, NULL);
    if (result != PA_SUCCESS) {
      LOG_E("Writing xattr check is failed.\n");
    }
  }

  return result;
}