#include <gtest/gtest.h>

#include <stdint.h>
#include <scl/string.h>

extern "C" {
  #include "mock.h"

  #include "authentication.h"
  #include "config.h"
  #include "crypto.h"
  #include "pa_certificate.h"
  #include "task.h"
  #include "tlb.h"
  #include "serialize.h"

  #include "PaFlagBits.h"
}

static const uint32_t kInvalidIntegrationFlag = 4444;
static const uint32_t kValidPid = 125;
static const uint32_t kUnauthenticatedPid = 196;
static const uint32_t kPidVmaNotFound = 302;
static const uint32_t kIncorrectProcessNameLength = 1050;
static const uint32_t kMaximalSizeNameRule = 1024;

static const char kValidProcessName[] = "secure_storage";
static const char kInvalidProcessName[] = "/system/bin/cool_application";
static const char kInvalidProcessNamev2[] = "secure_stor_v2";

static const uint8_t kValidPaId[kPaIdLength] = {0};
static const uint8_t kBadPaId[kPaIdLength] = {0xFF};

static const PhysicalAddress kValidPhysicalAddress = 0x4000;
static const uint64_t kValidTrustletAddress = 0x8000;
static const uint64_t kValidUserAddress = 0x9000;
static const uint64_t kInvalidUserAddress = 0xB000;

static const PaTzMemoryRange kMemoryWithIncorrectType = {0, 0, (PaTzMemoryType)45};
static const PaTzMemoryRange kValidPhysMemory = {kValidPhysicalAddress, 100, PA_MEMORY_PHYS};
static const PaTzMemoryRange kValidTrustletMemory = {kValidTrustletAddress, 100, PA_MEMORY_SWD_VA};
static const PaTzMemoryRange kValidUserMemory = {kValidUserAddress, 5000, PA_MEMORY_NWD_VA};

static const PaTzMemoryRange kInvalidUserMemoryFull = {kInvalidUserAddress, 5000, PA_MEMORY_NWD_VA};
// First two pages are valid but last is invalid
static const PaTzMemoryRange kInvalidUserMemoryPart = {kValidUserAddress, kInvalidUserAddress - kValidUserAddress + 1000, PA_MEMORY_NWD_VA};

static PaCertificate_t kValidRsaCertificate = {0};
static PaCertificate_t kBadRsaCertificate = {0};

static PaCertificate_t kValidHmacCertificate = {0};
static PaCertificate_t kBadHmacCertificate = {0};

static const TaskInfo kValidTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPreload, &kValidRsaCertificate};
static const TaskInfo kUnauthenticatedTask = {0, 0, kUnauthenticatedPid, 0, 0, 0, kIntegrityNone, &kValidRsaCertificate};
static const TaskInfo kInvalidIntegrityValueTask = {0, 0, kUnauthenticatedPid, 0, 0, kIntegrityNone, kInvalidIntegrationFlag, &kValidRsaCertificate};
static const TaskInfo kWithoutCertTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPreload, NULL};
static const TaskInfo kBadCertTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPreload, &kBadRsaCertificate};

static const TaskInfo kValidHmacTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPreload, &kValidHmacCertificate};
static const TaskInfo kBadHmacTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPreload, &kBadHmacCertificate};

static TaskInfo kPendingGoodTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPending, &kValidRsaCertificate};
static TaskInfo kPendingBadTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPending, &kValidRsaCertificate};
static TaskInfo kPendingAlwaysTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPending, &kValidRsaCertificate};
static TaskInfo kPendingDeadTask = {0, 0, kValidPid, 0, 0, 0, kIntegrityPending, &kValidRsaCertificate};

extern "C" PaTzResult PaCertificateValidate(const PaCertificate_t *certificate) {
  if (certificate == &kValidRsaCertificate || certificate == &kValidHmacCertificate) {
    return PA_TZ_SUCCESS;
  } else if (certificate == &kBadRsaCertificate || certificate == &kBadHmacCertificate) {
    return PA_TZ_AF_CERTIFICATE_IS_INCORRECT;
  } else {
    return PA_TZ_GENERAL_ERROR;
  }
}

extern "C" PaTzResult TaskParseIntegrity(TaskInfo *out_task) {
  if (out_task == &kPendingGoodTask) {
    out_task->integrity = kIntegrityPreload;
  } else if (out_task == &kPendingBadTask) {
    out_task->integrity = kIntegrityNone;
  } else if (out_task == &kPendingDeadTask) {
    return PA_TZ_GENERAL_ERROR;
  }

  return PA_TZ_SUCCESS;
}

extern "C" PaTzResult RsaSignatureVerification(const uint8_t *data,
                                               size_t data_len,
                                               const uint8_t *signature,
                                               size_t singature_len,
                                               const RsaPublicKey *public_key,
                                               HashType hash_type) {
  return (memcmp(data, kValidPaId, kPaIdLength) == 0 ? PA_TZ_SUCCESS : PA_TZ_GENERAL_ERROR);
}

extern "C" PaTzResult CryptoHmacSignatureVerification(const uint8_t *data,
                                                      size_t data_length,
                                                      HashType hash_type,
                                                      const uint8_t *signature,
                                                      size_t signature_length) {
  return (memcmp(data, kValidPaId, kPaIdLength) == 0 ? PA_TZ_SUCCESS : PA_TZ_GENERAL_ERROR);
}

extern "C" int PaTzEncoderPaData(const PaData_t *command, void *memory,
                                 uint32_t *size_buffer) {
  memcpy(memory, command->paId.buf, command->paId.size);
  *size_buffer = command->paId.size;
  return 0;
}

extern "C" PaTzResult TaskFindVmaWithPhysicalMemory(const TaskInfo *task,
    PhysicalAddress address, size_t size, VmaInfo *out_vma) {
  return PA_TZ_SUCCESS;
}

extern "C" const PaConfig *GetConfig() {
  static PaConfig config = {0};

  return &config;
}

extern "C" PaTzResult TaskAddressToPhysical(TlbConverterInfo* context, KernelAddress pgd,
                                            ProcessAddress virt,
                                            PhysicalAddress *phys,
                                            AccessPermissionFlags* flags) {
  if (virt >= kValidUserMemory.addr && virt <= kValidUserMemory.addr + kValidUserMemory.size) {
    *phys = virt;

    return PA_TZ_SUCCESS;
  }

  return PA_TZ_GENERAL_ERROR;
}

extern "C" PaTzResult PlatformVirtToPhys64(const void *virt,
                                           PhysicalAddress *phys) {
  return PA_TZ_SUCCESS;
}

extern "C" PaTzResult PlatformSysMapTrustlet(const void *virt_trustlet,
                                             size_t size, MemoryAccessType type,
                                             void **virt_driver) {
  *virt_driver = (void *)virt_trustlet;

  return PA_TZ_SUCCESS;
}

extern "C" PaTzResult PlatformSysUnmapTrustlet(const void *virt_driver, size_t size) {
  return PA_TZ_SUCCESS;
}

class ProcessAuthenticationTest : public ::testing::Test {
protected:
  virtual void SetUp() {
    InitMocks();
    kValidRsaCertificate.paData.paId.buf = (uint8_t *)kValidPaId;
    kValidRsaCertificate.paData.paId.size = kPaIdLength;
    kValidRsaCertificate.paData.paFlags = (1 << PaFlagBits_bitAndroid);
    kValidRsaCertificate.paData.paAppName.size = sizeof(kValidProcessName);
    kValidRsaCertificate.paData.paAppName.buf = (uint8_t *)kValidProcessName;

    kValidHmacCertificate.paData.paId.buf = (uint8_t *)kValidPaId;
    kValidHmacCertificate.paData.paId.size = kPaIdLength;
    kValidHmacCertificate.paData.paFlags = (1 << PaFlagBits_bitHmac) | (1 << PaFlagBits_bitAndroid);
    kValidHmacCertificate.paData.paAppName.size = sizeof(kValidProcessName);
    kValidHmacCertificate.paData.paAppName.buf = (uint8_t *)kValidProcessName;

    kBadRsaCertificate.paData.paId.buf = (uint8_t *)kBadPaId;
    kBadRsaCertificate.paData.paId.size = kPaIdLength;

    kBadHmacCertificate.paData.paId.buf = (uint8_t *)kBadPaId;
    kBadHmacCertificate.paData.paId.size = kPaIdLength;
    kBadHmacCertificate.paData.paFlags = (1 << PaFlagBits_bitHmac) | (1 << PaFlagBits_bitAndroid);
  }

  virtual void TearDown() {
    DeinitMocks();
  }
};

TEST_F(ProcessAuthenticationTest, AuthenticateValidPid) {
  PaTzResult result = ProcessAuthentication(&kValidTask, NULL, 0, NULL);

  EXPECT_TRUE(PA_TZ_SUCCESS == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateValidPidNative) {
  kValidRsaCertificate.paData.paFlags &= ~(1 << PaFlagBits_bitAndroid);

  PaTzResult result = ProcessAuthentication(&kValidTask, NULL, 0, NULL);

  EXPECT_TRUE(PA_TZ_SUCCESS == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateValidProcessName) {
  PaTzResult result = ProcessAuthentication(&kValidTask, kValidProcessName,
                                            sizeof(kValidProcessName), NULL);

  EXPECT_TRUE(PA_TZ_SUCCESS == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateInvalidProcessName) {
  PaTzResult result = ProcessAuthentication(&kValidTask, kInvalidProcessName,
                                            sizeof(kInvalidProcessName), NULL);

  EXPECT_EQ(PA_TZ_AF_APPNAME_IS_INCORRECT, result) <<
      "AppNameIsIncorrect SHOULD be returned if name is mismatched";
}

TEST_F(ProcessAuthenticationTest, AuthenticateInvalidProcessNameWithCorrectLength_Failed) {
  PaTzResult result = ProcessAuthentication(&kValidTask, kInvalidProcessNamev2,
                                            sizeof(kInvalidProcessNamev2), NULL);

  EXPECT_EQ(PA_TZ_AF_APPNAME_IS_INCORRECT, result) <<
      "AppNameIsIncorrect SHOULD be returned if name length is mismatched";
}

TEST_F(ProcessAuthenticationTest, AuthenticateHugeProcessName_Failed) {
  static char huge_process_name[kIncorrectProcessNameLength] = "";
  for (int i = 0; i < kIncorrectProcessNameLength; i++) {
    huge_process_name[i] = 'x';
  }
  PaTzResult result = ProcessAuthentication(&kValidTask, huge_process_name,
                                            sizeof(huge_process_name), NULL);

  EXPECT_EQ(PA_TZ_AF_APPNAME_IS_INCORRECT, result) <<
      "AppNameIsIncorrect SHOULD be returned if scl_strlen() failed";
}

TEST_F(ProcessAuthenticationTest, AuthenticateValidProcessNames_Success) {
  uint32_t size = 0;
  char process_names[kMaximalSizeNameRule] = "";
  size_t current_names_size = 0;

  // process_names will consist of the following names:
  // ssssecure_storage\0sssecure_storage\0ssecure_storage\0secure_storage\0...
  for (uint32_t number = 0; number < kMaximalNameNumbers; number++) {
    scl_bool result = scl_strcpy(process_names + size, kMaximalSizeNameRule,
                          "ssssecure_storage" + number);
    ASSERT_TRUE(result);

    result = scl_strlen(process_names + size, kMaximalSizeNameRule,
                        (scl_size_t *)&current_names_size);
    ASSERT_TRUE(result);

    size += current_names_size + 1;
  }

  PaTzResult result = ProcessAuthentication(&kValidTask, process_names, size, NULL);
  ASSERT_EQ(PA_TZ_SUCCESS, result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateInvalidMemoryType) {
  PaTzResult result = ProcessAuthentication(&kValidTask, NULL, 0, &kMemoryWithIncorrectType);

  EXPECT_TRUE(PA_TZ_AUTHENTICATION_FAILED == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateValidPhysicalMemory) {
  PaTzResult result = ProcessAuthentication(&kValidTask, NULL, 0, &kValidPhysMemory);

  EXPECT_TRUE(PA_TZ_SUCCESS == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateValidTrustletMemory) {
  PaTzResult result = ProcessAuthentication(&kValidTask, NULL, 0, &kValidTrustletMemory);

  EXPECT_TRUE(PA_TZ_SUCCESS == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateValidUserMemory) {
  PaTzResult result = ProcessAuthentication(&kValidTask, NULL, 0, &kValidUserMemory);

  EXPECT_TRUE(PA_TZ_SUCCESS == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateInvalidUserMemoryFull) {
  PaTzResult result = ProcessAuthentication(&kValidTask, NULL, 0, &kInvalidUserMemoryFull);

  EXPECT_TRUE(PA_TZ_AUTHENTICATION_FAILED == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateInvalidUserMemoryPart) {
  PaTzResult result = ProcessAuthentication(&kValidTask, NULL, 0, &kInvalidUserMemoryPart);

  EXPECT_TRUE(PA_TZ_AUTHENTICATION_FAILED == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateUnauthenticatedPid) {
  PaTzResult result = ProcessAuthentication(&kUnauthenticatedTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_AF_INTEGRITY_IS_NONE, result) <<
      "IntegrityIsNone SHOULD be returned if five is clear integrity flag";
}

TEST_F(ProcessAuthenticationTest, AuthenticateInvalidIntegrityValuePid) {
  PaTzResult result = ProcessAuthentication(&kInvalidIntegrityValueTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_AF_INTEGRITY_IS_NONE, result) <<
      "IntegrityIsNone SHOULD be returned if five integrity flag is unknown";
}

TEST_F(ProcessAuthenticationTest, AuthenticateNullTask) {
  PaTzResult result = ProcessAuthentication(NULL, NULL, 0, NULL);

  EXPECT_TRUE(PA_TZ_GENERAL_ERROR == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateWithoutCertificate) {
  PaTzResult result = ProcessAuthentication(&kWithoutCertTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_AF_CERTIFICATE_IS_INCORRECT, result) <<
      "CertificateIsIncorrect SHOULD be returned if handler does not have pointer to certificate";
}

TEST_F(ProcessAuthenticationTest, AuthenticateBadCertificate) {
  PaTzResult result = ProcessAuthentication(&kBadCertTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_AF_CERTIFICATE_IS_INCORRECT, result) <<
      "CertificateIsIncorrect SHOULD be returned if certificate has incorrect signature";
}

TEST_F(ProcessAuthenticationTest, AuthenticateValidHmacTask) {
  PaTzResult result = ProcessAuthentication(&kValidHmacTask, NULL, 0, NULL);

  EXPECT_TRUE(PA_TZ_SUCCESS == result);
}

TEST_F(ProcessAuthenticationTest, AuthenticateBadHmacCertificate) {
  PaTzResult result = ProcessAuthentication(&kBadHmacTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_AF_CERTIFICATE_IS_INCORRECT, result) <<
      "CertificateIsIncorrect SHOULD be returned if certificate has incorrect signature";
}

TEST_F(ProcessAuthenticationTest, AuthenticatePendingGoodPid) {
  PaTzResult result = ProcessAuthentication(&kPendingGoodTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_SUCCESS, result) <<
      "PA_TZ_SUCCESS should be returned (authenticated app) if FIVE returns "
      "correct integrity status";
}

TEST_F(ProcessAuthenticationTest, AuthenticatePendingBadPid) {
  PaTzResult result = ProcessAuthentication(&kPendingBadTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_AF_INTEGRITY_IS_NONE, result) <<
      "PA_TZ_AF_INTEGRITY_IS_NONE should be returned (NOT authenticated app) "
      "if FIVE returns incorrect integrity status";
}

TEST_F(ProcessAuthenticationTest, AuthenticatePendingAlwaysPid) {
  PaTzResult result = ProcessAuthentication(&kPendingAlwaysTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_AF_INTEGRITY_IS_NOT_READY, result) <<
      "PA_TZ_AF_INTEGRITY_IS_NOT_READY should be returned (app in not processed by FIVE) "
      "if FIVE does not return certain status";
}

TEST_F(ProcessAuthenticationTest, AuthenticatePendingDeadPid) {
  PaTzResult result = ProcessAuthentication(&kPendingDeadTask, NULL, 0, NULL);

  EXPECT_EQ(PA_TZ_AF_TASK_IS_NOT_FOUND, result) <<
      "PA_TZ_AF_TASK_IS_NOT_FOUND should be returned if not passible to parse task info";
}

TEST_F(ProcessAuthenticationTest, CheckIntegritySigningRights_Success) {
  TaskInfo task;
  task.integrity = kIntegrityPreloadWithSign;
  PaTzResult result = CheckIntegritySigningRights(&task);

  EXPECT_EQ(PA_TZ_SUCCESS, result);
}

TEST_F(ProcessAuthenticationTest, CheckIntegritySigningRights_Fail) {
  TaskInfo task;
  task.integrity = kIntegrityPreload;
  PaTzResult result = CheckIntegritySigningRights(&task);

  EXPECT_EQ(PA_TZ_GENERAL_ERROR, result);
}

TEST_F(ProcessAuthenticationTest, CheckIntegrityWeak_Success) {
  TaskInfo task;
  task.integrity = kIntegrityMixed;
  PaTzResult result = CheckIntegrityWeak(&task);

  EXPECT_EQ(PA_TZ_SUCCESS, result);
}

TEST_F(ProcessAuthenticationTest, CheckIntegrityWeak_Fail) {
  TaskInfo task;
  task.integrity = kIntegrityDmVerity;
  PaTzResult result = CheckIntegrityWeak(&task);

  EXPECT_EQ(PA_TZ_GENERAL_ERROR, result);
}
