#include "base_command.h"

extern "C" {
#include "command.h"
}

#include <cmath>
#include <cstdlib>
#include <time.h>
#include <unistd.h>

#include <iostream>
#include <fstream>
#include <vector>
#include <algorithm>
#include <gtest/gtest.h>

const int kSampleCount = 1000;
const int kForkCount = 20;

/**
 * @brief Threshold for time execution of PaTzAuthenticate()
 *        This value is measured in milliseconds
 */
const int kNativeAppAuthenticationThreshold = 50;
const int kAndroidAppAuthenticationThreshold = 1000;
const int kNativeAppReadTaskThreshold = 20;
const int kNativeAppWriteTaskThreshold = 20;
const int kAppHandlerCreateThreshold = 50;

const int kMaxAndroidTestExecTime = 20;

const std::string kPerformanceReportFileAndroidTest = "/sdcard/PA/performance_result.bin";
const std::string kPerformanceCreateReportFileAndroidTest = "/sdcard/PA/performance_create.bin";

class PerformanceTestAndroid : public ::testing::Test {
public:
  void SetUp() {
    std::cout << "Running Android Pa Client ..." << std::endl;
    unlink(kPerformanceReportFileAndroidTest.c_str());
    unlink(kPerformanceCreateReportFileAndroidTest.c_str());
    system("PATH=/system/bin am start -n com.samsung.androidpaclient/.MainActivity");
  }

  void TearDown() {
    std::cout << "Stopping Android Pa Client ..." << std::endl;
    system("PATH=/system/bin am force-stop com.samsung.androidpaclient");
  }
};

class PerformanceTestNative : public BaseCommandTest {
public:
  PerformanceTestNative() {
  }
  virtual ~PerformanceTestNative() {
  }
};

/*
 * By default, the pid of this process is one of the last in pid tree.
 * So linear algorithm is good for this case.
 * This test checks another variant when pid is not in the end.
 */
class PerformanceTestWithForkBomb: public BaseCommandTest {
public:
  PerformanceTestWithForkBomb() {
    // "Fork bomb"
    for (int i = 0; i < kForkCount; ++i) {
      if (fork() == 0) {
        exit(66);
      }
    }
  }
  virtual ~PerformanceTestWithForkBomb() {
  }
};

typedef struct {
    uint32_t min, max, sample_size;
    float mean, stddev, median;
} StatBundle;

class PerformanceStat {
public:
  PerformanceStat():
    sample_size_(0), min_(0), max_(0),
    mean_(0), std_dev_(0), median_(0) {}
  PerformanceStat(StatBundle& stat):
    sample_size_(stat.sample_size), min_(stat.min), max_(stat.max),
    mean_(stat.mean), std_dev_(stat.stddev),
    median_(stat.median) {}

  void Push(uint32_t val) {time_set_.push_back(val); ++sample_size_;}
  float GetMean() {return mean_;}
  void CalculateStat() {
    calc_mean_stddev();
    calc_minmax();
  }

  void ShowSample();
  void ShowStat();

private:
  std::vector<uint32_t> time_set_;
  uint32_t sample_size_;
  uint32_t min_;
  uint32_t max_;
  float mean_;
  float std_dev_;
  float median_;

  void calc_mean_stddev();
  void calc_minmax() {
    auto mm = std::minmax_element(time_set_.begin(), time_set_.end());
    min_ = *mm.first;
    max_ = *mm.second;
  }

};

void PerformanceStat::ShowSample() {
  std::cout << "Authentication time sample (ms):" << std::endl;
  int counter = 1;
  for (uint32_t v : time_set_) {
    std::cout << std::setw(3) << v << " ";
    if (counter % 5 == 0) {
      std::cout << "  ";
    }
    if (counter % 20 == 0) {
      std::cout << std::endl;
    }
    counter++;
  }
}

void PerformanceStat::ShowStat() {
  int word_width = 15, figure_width = 5;
  std::cout << std::endl;
  std::cout << "Statistics:" << std::endl;
  std::cout << std::left << std::setw(word_width) << " Sample size"
            << std::right << std::setw(figure_width) << sample_size_ << std::endl;
  std::cout << std::left << std::setw(word_width) << " Mean"
            << std::right << std::setw(figure_width) << std::setprecision(4)
            << mean_ << std::endl;
  std::cout << std::left << std::setw(word_width) << " Std dev"
            << std::right << std::setw(figure_width) << std::setprecision(2)
            << std_dev_ << std::endl;
  std::cout << std::left << std::setw(word_width) << " Median"
            << std::right << std::setw(figure_width) << std::setprecision(4)
            << median_ << std::endl;
  std::cout << std::left << std::setw(word_width) << " Min"
            << std::right << std::setw(figure_width) << min_ << std::endl;
  std::cout << std::left << std::setw(word_width) << " Max"
            << std::right << std::setw(figure_width) << max_ << std::endl;
  std::cout << std::endl;
}

void PerformanceStat::calc_mean_stddev() {
  uint32_t sum = 0;
  uint32_t sum_of_squares = 0;

  std::sort(time_set_.begin(), time_set_.end());

  for (auto t : time_set_) {
    sum += t;
    sum_of_squares += t * t;
  }

  float second_moment = static_cast<float>(sum_of_squares) / time_set_.size();
  mean_ = static_cast<float>(sum) / time_set_.size();
  std_dev_ =  sqrt(second_moment - mean_ * mean_);

  size_t central = time_set_.size() / 2;
  if (time_set_.size() % 2) {
    // Odd size
    median_ = time_set_[central];
  } else {
    // Even size
    median_ = (time_set_[central] + time_set_[central + 1]) / 2;
  }
}

static uint32_t DiffTimeMs(const timespec *start, const timespec *end) {
  timespec temp;
  if (end->tv_nsec - start->tv_nsec < 0) {
    temp.tv_sec = end->tv_sec - start->tv_sec - 1;
    temp.tv_nsec = 1000000000 + end->tv_nsec - start->tv_nsec;
  } else {
    temp.tv_sec = end->tv_sec - start->tv_sec;
    temp.tv_nsec = end->tv_nsec - start->tv_nsec;
  }

  return temp.tv_sec * 1000 + temp.tv_nsec / 1000000;
}

/**
 * @brief Check that returned from SWd buffer is correct re-writed
 * @param [in] buffer, buffer_size Buffer
 * @return true if buffer is correct
 */
static bool CheckSwdBuffer(const uint8_t *buffer, size_t buffer_size) {
  for (int i = 0; i < buffer_size; ++i)
    if (buffer[i] != kUpdateInitValue) {
      return false;
    }

  return true;
}

/**
 * @brief To check file existence
 * @param [in] buffer File name
 * @return true if file exists
 */
static bool Exist(const std::string& name) {
  struct stat buffer;
  return (stat(name.c_str(), &buffer) == 0);
}

TEST_F(PerformanceTestNative, NativeClient) {
  TciCommand tci_command = {0};

  tci_command.cmdId = kPerformanceNativeClient;
  tci_command.handler = handler;

  PerformanceStat stat;

  for (int i = 0; i < kSampleCount; ++i) {
    TEEC_Result status = SendCommand(tci_command);

    ASSERT_EQ(TEEC_SUCCESS, status) << "Authentication SHOULD pass";
    ASSERT_EQ(PA_TZ_SUCCESS, tci_command.tz_result) << "Authentication SHOULD pass";

    stat.Push(tci_command.authentication_exec_time);
  }

  stat.CalculateStat();
  stat.ShowStat();

  EXPECT_LE(stat.GetMean(), kNativeAppAuthenticationThreshold);
}

TEST_F(PerformanceTestNative, PaHandlerCreate) {
  PerformanceStat stat;

  for (size_t i = 0; i < kSampleCount; ++i) {
    PaHandler handler;
    struct timespec start = {0}, end = {0};

    clock_gettime(CLOCK_THREAD_CPUTIME_ID, &start);
    PaResult result = PaHandlerCreate(&handler);
    clock_gettime(CLOCK_THREAD_CPUTIME_ID, &end);

    PaHandlerDestroy(&handler);

    ASSERT_EQ(PA_SUCCESS, result) << "Can not obtain handler";

    stat.Push(DiffTimeMs(&start, &end));
  }

  stat.CalculateStat();
  stat.ShowStat();

  EXPECT_LE(stat.GetMean(), kAppHandlerCreateThreshold);
}

TEST_F(PerformanceTestNative, NativeClientRead1K) {
  TciCommand tci_command = {0};

  std::vector<uint8_t> buffer(0x1000, kStartInitValue);

  tci_command.cmdId = kPerformanceNativeClientRead;
  tci_command.handler = handler;
  tci_command.address = (uint64_t) &buffer[0];
  tci_command.size = buffer.size();

  PerformanceStat stat;

  for (int i = 0; i < kSampleCount; ++i) {
    TEEC_Result status = SendCommand(tci_command);

    ASSERT_EQ(TEEC_SUCCESS, status) << "Read SHOULD pass";
    ASSERT_EQ(PA_TZ_SUCCESS, tci_command.tz_result) << "Read SHOULD pass";

    stat.Push(tci_command.authentication_exec_time);
  }

  stat.CalculateStat();
  stat.ShowStat();

  EXPECT_LE(stat.GetMean(), kNativeAppReadTaskThreshold);
}

TEST_F(PerformanceTestNative, NativeClientWrite1K) {
  TciCommand tci_command = {0};

  uint8_t buffer[0x1000];

  tci_command.cmdId = kPerformanceNativeClientWrite;
  tci_command.handler = handler;
  tci_command.address = (uint64_t) buffer;
  tci_command.size = sizeof(buffer);

  PerformanceStat stat;

  for (int i = 0; i < kSampleCount; ++i) {
    TEEC_Result status = SendCommand(tci_command);

    ASSERT_EQ(TEEC_SUCCESS, status) << "Write SHOULD pass";
    ASSERT_EQ(PA_TZ_SUCCESS, tci_command.tz_result) << "Write SHOULD pass";
    ASSERT_TRUE(CheckSwdBuffer(buffer, sizeof(buffer))) << "SWd SHOULD correct re-write buffer";

    stat.Push(tci_command.authentication_exec_time);
  }

  stat.CalculateStat();
  stat.ShowStat();

  EXPECT_LE(stat.GetMean(), kNativeAppWriteTaskThreshold);
}

TEST_F(PerformanceTestNative, NativeClientRead10K) {
  TciCommand tci_command = {0};

  std::vector<uint8_t> buffer(MAX_BUFFER_SIZE, kStartInitValue);

  tci_command.cmdId = kPerformanceNativeClientRead;
  tci_command.handler = handler;
  tci_command.address = (uint64_t) &buffer[0];
  tci_command.size = buffer.size();

  PerformanceStat stat;

  for (int i = 0; i < kSampleCount; ++i) {
    TEEC_Result status = SendCommand(tci_command);

    ASSERT_EQ(TEEC_SUCCESS, status) << "Read SHOULD pass";
    ASSERT_EQ(PA_TZ_SUCCESS, tci_command.tz_result) << "Read SHOULD pass";

    stat.Push(tci_command.authentication_exec_time);
  }

  stat.CalculateStat();
  stat.ShowStat();

  EXPECT_LE(stat.GetMean(), kNativeAppReadTaskThreshold);
}

TEST_F(PerformanceTestNative, NativeClientWrite10K) {
  TciCommand tci_command = {0};

  uint8_t buffer[MAX_BUFFER_SIZE];

  tci_command.cmdId = kPerformanceNativeClientWrite;
  tci_command.handler = handler;
  tci_command.address = (uint64_t) buffer;
  tci_command.size = sizeof(buffer);

  PerformanceStat stat;

  for (int i = 0; i < kSampleCount; ++i) {
    TEEC_Result status = SendCommand(tci_command);

    ASSERT_EQ(TEEC_SUCCESS, status) << "Write SHOULD pass";
    ASSERT_EQ(PA_TZ_SUCCESS, tci_command.tz_result) << "Write SHOULD pass";
    ASSERT_TRUE(CheckSwdBuffer(buffer, sizeof(buffer))) << "SWd SHOULD correct re-write buffer";

    stat.Push(tci_command.authentication_exec_time);
  }

  stat.CalculateStat();
  stat.ShowStat();

  EXPECT_LE(stat.GetMean(), kNativeAppWriteTaskThreshold);
}

TEST_F(PerformanceTestWithForkBomb, NativeClientPidIsNotInTheEnd) {
  TciCommand tci_command = {0};

  tci_command.cmdId = kPerformanceNativeClient;
  tci_command.handler = handler;

  PerformanceStat stat;

  for (int i = 0; i < kSampleCount; ++i) {
    TEEC_Result status = SendCommand(tci_command);

    ASSERT_EQ(TEEC_SUCCESS, status) << "Authentication SHOULD pass";
    ASSERT_EQ(PA_TZ_SUCCESS, tci_command.tz_result) << "Authentication SHOULD pass";

    stat.Push(tci_command.authentication_exec_time);
  }

  stat.CalculateStat();
  stat.ShowStat();

  EXPECT_LE(stat.GetMean(), kNativeAppAuthenticationThreshold);
}

TEST_F(PerformanceTestAndroid, Authenticate) {
  system("PATH=/system/bin am broadcast -a com.samsung.androidpaclient.MEASURE_PERFORMANCE "
         "-n com.samsung.androidpaclient/.IntentReceiver --el testid 0");

  // To check file with results is created
  int count = 0;
  while (!Exist(kPerformanceReportFileAndroidTest) && count < kMaxAndroidTestExecTime) {
    count++;
    sleep(1);
  }

  std::ifstream pa_result_file(kPerformanceReportFileAndroidTest,
                               std::ifstream::in | std::ifstream::binary);

  ASSERT_TRUE(pa_result_file.good());

  StatBundle stat_report;
  pa_result_file.read(reinterpret_cast<char*>(&stat_report), sizeof(stat_report));
  pa_result_file.close();

  PerformanceStat stat(stat_report);
  stat.ShowStat();

  EXPECT_LE(stat_report.mean, kAndroidAppAuthenticationThreshold);
}

TEST_F(PerformanceTestAndroid, PaHandlerCreate) {
  system("PATH=/system/bin am broadcast -a com.samsung.androidpaclient.MEASURE_PERFORMANCE "
         "-n com.samsung.androidpaclient/.IntentReceiver --el testid 1");

  // To check file with results is created
  int count = 0;
  while (!Exist(kPerformanceCreateReportFileAndroidTest) && count < kMaxAndroidTestExecTime) {
    count++;
    sleep(1);
  }

  std::ifstream pa_result_file(kPerformanceCreateReportFileAndroidTest,
                               std::ifstream::in | std::ifstream::binary);

  ASSERT_TRUE(pa_result_file.good());

  StatBundle stat_report;
  pa_result_file.read(reinterpret_cast<char*>(&stat_report), sizeof(stat_report));
  pa_result_file.close();

  PerformanceStat stat(stat_report);
  stat.ShowStat();

  EXPECT_LE(stat_report.mean, kAppHandlerCreateThreshold);
}
