#include "task_nwd_access.h"
#include "command.h"

#include <tees_log.h>
#include <tee_internal_api.h>

#ifndef PAGE_SIZE
#define PAGE_SIZE (4096)
#endif

// buffer for 4 pages
static uint8_t g_inner_buffer[4 * PAGE_SIZE];

static uint64_t SimpleChecksum(const uint8_t *buffer, size_t size) {
  uint64_t sum = 0;

  for (size_t i = 0; i < size; ++i) {
    sum += buffer[i];
  }

  return sum;
}

PaTzResult ReadFromNwdTask(PaHandler handler,
                           PaTzUserSpaceVirtualAddress address, size_t size,
                           uint8_t *buffer, uint64_t *out_checksum) {
  if (!out_checksum) {
    return PA_TZ_GENERAL_ERROR;
  }

  PaTzResult result = PaTzReadFromNwdTask(handler, address, size, buffer);
  if (result != PA_TZ_SUCCESS) {
    MB_LOGE("Can't read from NWD task.\n");
    return result;
  }

  *out_checksum = SimpleChecksum(buffer, size);

  return result;
}

PaTzResult ReadBigBufferToNwdTask(PaHandler handler,
                                  PaTzUserSpaceVirtualAddress address,
                                  size_t size, uint64_t *out_checksum) {
  if (!out_checksum) {
    return PA_TZ_GENERAL_ERROR;
  }

  PaTzResult result = PA_TZ_SUCCESS;

  uint64_t checksum_read = 0;
  size_t remaining_size = size;

  while (remaining_size) {
    size_t current_size = (remaining_size > sizeof(g_inner_buffer) ?
        sizeof(g_inner_buffer) : remaining_size);

    result = PaTzReadFromNwdTask(handler, address, current_size, g_inner_buffer);
    if (result != PA_TZ_SUCCESS) {
      MB_LOGE("Can't read from NWD task.\n");
      break;
    }

    checksum_read += SimpleChecksum(g_inner_buffer, current_size);

    remaining_size -= current_size;
    address += current_size;
  }

  if (result == PA_TZ_SUCCESS) {
    *out_checksum = checksum_read;
  }

  return result;
}

PaTzResult WriteBigBufferToNwdTask(PaHandler handler,
                                   PaTzUserSpaceVirtualAddress address,
                                   size_t size, uint64_t *out_checksum) {
  if (!out_checksum) {
    return PA_TZ_GENERAL_ERROR;
  }

  PaTzResult result = PA_TZ_SUCCESS;

  uint64_t checksum_read = 0, checksum_write = 0;
  size_t remaining_size = size;

  while (remaining_size) {
    size_t current_size = (remaining_size > sizeof(g_inner_buffer) ?
        sizeof(g_inner_buffer) : remaining_size);

    result = PaTzReadFromNwdTask(handler, address, current_size, g_inner_buffer);
    if (result != PA_TZ_SUCCESS) {
      MB_LOGE("Can't read from NWD task.\n");
      break;
    }

    checksum_read += SimpleChecksum(g_inner_buffer, current_size);

    TEE_MemFill(g_inner_buffer, kUpdateInitValue, current_size);

    checksum_write += SimpleChecksum(g_inner_buffer, current_size);

    result = PaTzWriteToNwdTask(handler, g_inner_buffer, current_size, address);
    if (result != PA_TZ_SUCCESS) {
      MB_LOGE("Cleaning error.\n");
      break;
    }

    remaining_size -= current_size;
    address += current_size;
  }

  if (result == PA_TZ_SUCCESS) {
    *out_checksum = checksum_write;
  }

  return result;
}
