#include "crypto.h"

#include "driver_log.h"
#include "pa_tz_api.h"

#include <tee_internal_api.h>
#include <tees_kdf.h>


/**
 * @brief Prepare TEE operation for HMAC
 * @param [in] hash_type Hash type
 * @param [out] op TEE operation, should free after use
 * @return ::PA_TZ_SUCCESS if verified, code of error in other case
 */
static PaTzResult PrepareHmacOperation(HashType hash_type,
                                       TEE_OperationHandle *op);


PaTzResult CryptoSha256(const uint8_t *data, size_t data_length, uint8_t *out_md) {
  PaTzResult result = PA_TZ_SUCCESS;
  TEE_OperationHandle sha_handle = NULL;
  TEE_Result tee_result;
  uint32_t size_digest = kSha256Size;

  if ((!data && data_length > 0) || !out_md) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  tee_result = TEE_AllocateOperation(&sha_handle, TEE_ALG_SHA256,
                                     TEE_MODE_DIGEST, 0);
  if (tee_result != TEE_SUCCESS) {
    LOG_E("Can't allocate operation.\n");
    LOG_D("Received result: 0x%x\n", tee_result);
    return PA_TZ_GENERAL_ERROR;
  }

  tee_result = TEE_DigestDoFinal(sha_handle, data, data_length, out_md,
                                 &size_digest);
  if (tee_result != TEE_SUCCESS) {
    LOG_E("Can't compute SHA256.\n");
    LOG_D("Received result: 0x%x\n", tee_result);
    result = PA_TZ_GENERAL_ERROR;
  }

  TEE_FreeOperation(sha_handle);

  return result;
}

PaTzResult RsaSignatureVerification(const uint8_t *data, size_t data_len,
                                    const uint8_t *signature,
                                    size_t singature_len,
                                    const RsaPublicKey *public_key,
                                    HashType hash_type) {
  if (!data || !signature || !public_key) {
    LOG_E("Wrong input parameters \n");
    return PA_TZ_GENERAL_ERROR;
  }

  if(hash_type != kSha256) {
    LOG_E("Hash type isn't SHA256.\n");
    LOG_D("Current hash type: %d\n", hash_type);
    return PA_TZ_GENERAL_ERROR;
  }

  const size_t kNumRsaAttributes = 2;
  const size_t kRsaSize = 2048;

  TEE_Attribute attributes[kNumRsaAttributes];
  TEE_OperationHandle operation_handle = TEE_HANDLE_NULL;
  TEE_ObjectHandle object_handle = TEE_HANDLE_NULL;

  TEE_Result tee_result = TEE_AllocateOperation(
      &operation_handle, TEE_ALG_RSASSA_PKCS1_V1_5_SHA256, TEE_MODE_VERIFY,
      kRsaSize);
  if (tee_result != TEE_SUCCESS) {
    LOG_E("Can't allocate operation.\n");
    LOG_D("Received result: 0x%x\n", tee_result);
    return PA_TZ_SINGATURE_VALIDATION_FAILURE;
  }

  TEE_InitRefAttribute(&attributes[0], TEE_ATTR_RSA_MODULUS,
                       public_key->modulus, public_key->modulus_size);
  TEE_InitRefAttribute(&attributes[1], TEE_ATTR_RSA_PUBLIC_EXPONENT,
                       public_key->public_exponent, public_key->exponent_size);

  PaTzResult result = PA_TZ_SINGATURE_VALIDATION_FAILURE;

  do {
    tee_result = TEE_AllocateTransientObject(TEE_TYPE_RSA_PUBLIC_KEY, kRsaSize,
                                             &object_handle);
    if (tee_result != TEE_SUCCESS) {
      LOG_E("Can't allocate transient object.\n");
      LOG_D("Received result: 0x%x\n", tee_result);
      break;
    }

    tee_result = TEE_PopulateTransientObject(object_handle, attributes,
                                             kNumRsaAttributes);
    if (tee_result != TEE_SUCCESS) {
      LOG_E("Can't pass attributes.\n");
      LOG_D("Received result: 0x%x\n", tee_result);
      break;
    }

    size_t hash_size = CryptoGetHashSize(hash_type);
    if (!hash_size || hash_size > kHashMaxSize) {
      LOG_E("Wrong hash size.\n");
      LOG_D("Current hash size: %u\n", hash_size);
      break;
    }

    uint8_t hash[kHashMaxSize];
    if (CryptoSha256(data, data_len, hash) != PA_TZ_SUCCESS) {
      LOG_E("Can't calculate SHA256.\n");
      break;
    }

    tee_result = TEE_SetOperationKey(operation_handle, object_handle);
    if (tee_result != TEE_SUCCESS) {
      LOG_E("Can't set operation key.\n");
      LOG_D("Received result: 0x%x\n", tee_result);
      break;
    }

    TEE_Attribute attr_salt;
    tee_result = TEE_AsymmetricVerifyDigest(operation_handle, &attr_salt, 0,
                                            hash, hash_size, signature,
                                            singature_len);
    if (tee_result != TEE_SUCCESS) {
      LOG_E("Failed verify of signature.\n");
      LOG_D("Received result: 0x%x\n", tee_result);
      break;
    }

    result = PA_TZ_SUCCESS;
  } while (0);

  TEE_FreeOperation(operation_handle);
  TEE_CloseObject(object_handle);

  return result;
}

size_t CryptoGetRsaSignatureSize(RsaKeyType type) {
  size_t size = 0;

  switch (type) {
    case kRSA1024: {
      size = 128;
      break;
    }
    case kRSA2048: {
      size = 256;
      break;
    }
    case kRSA3072: {
      size = 384;
      break;
    }
    default: {
      size = 0;
    }
  }

  return size;
}

size_t CryptoGetHashSize(HashType type) {
  size_t size = 0;

  switch (type) {
    case kSha1: {
      size = kSha1Size;
      break;
    }
    case kSha224: {
      size = kSha224Size;
      break;
    }
    case kSha256: {
      size = kSha256Size;
      break;
    }
    case kSha384: {
      size = kSha384Size;
      break;
    }
    case kSha512: {
      size = kSha512Size;
      break;
    }
    default: {
      size = 0;
    }
  }

  return size;
}

PaTzResult CryptoHmacSignatureGenerate(const uint8_t *data, size_t data_length,
                                       HashType hash_type,
                                       uint8_t *signature, size_t *signature_length) {
  PaTzResult result = PA_TZ_GENERAL_ERROR;
  TEE_OperationHandle hmac_operation = TEE_HANDLE_NULL;
  const size_t key_length = CryptoGetHashSize(hash_type);

  if (!data || !signature || !signature_length || key_length == 0) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  if (*signature_length < key_length) {
    LOG_E("Signature buffer is too small.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  result = PrepareHmacOperation(hash_type, &hmac_operation);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can not allocate operation for HMAC.\n");
    return result;
  }

  TEE_Result ret = TEE_MACComputeFinal(hmac_operation,
                                       data, data_length,
                                       signature, (uint32_t *)signature_length);
  if (ret != TEE_SUCCESS) {
    LOG_E("Can not compute HMAC.\n");
    result = PA_TZ_GENERAL_ERROR;
  } else {
    result = PA_TZ_SUCCESS;
  }

  TEE_FreeOperation(hmac_operation);

  return result;
}

PaTzResult CryptoHmacSignatureVerification(const uint8_t *data, size_t data_length,
                                           HashType hash_type,
                                           const uint8_t *signature, size_t signature_length) {
  PaTzResult result = PA_TZ_GENERAL_ERROR;
  TEE_OperationHandle hmac_operation = TEE_HANDLE_NULL;
  const size_t key_length = CryptoGetHashSize(hash_type);

  if (!data || !signature || key_length == 0) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  if (signature_length != key_length) {
    LOG_E("Signature length is incorrect.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  result = PrepareHmacOperation(hash_type, &hmac_operation);
  if (result != PA_TZ_SUCCESS) {
    LOG_E("Can not allocate operation for HMAC.\n");
    return result;
  }

  TEE_Result ret = TEE_MACCompareFinal(hmac_operation,
                                       data, data_length,
                                       signature, key_length);
  if (ret != TEE_SUCCESS) {
    result = PA_TZ_GENERAL_ERROR;
    LOG_E("Can not compare HMAC.\n");
  } else {
    result = PA_TZ_SUCCESS;
  }

  TEE_FreeOperation(hmac_operation);

  return result;
}

static PaTzResult PrepareHmacOperation(HashType hash_type,
                                       TEE_OperationHandle *op) {
  PaTzResult result = PA_TZ_GENERAL_ERROR;
  TEE_ObjectHandle hash_object = TEE_HANDLE_NULL;
  TEE_OperationHandle hmac_operation = TEE_HANDLE_NULL;

  const size_t key_length = CryptoGetHashSize(hash_type);
  TEE_OBJECT_TYPES object_type;
  TEE_CRYPTO_ALGORITHMS hmac_algorithm;

  if (!op) {
    LOG_E("Invalid arguments.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  switch (hash_type) {
    case kSha256: {
      hmac_algorithm = TEE_ALG_HMAC_SHA256;
      object_type = TEE_TYPE_HMAC_SHA256;
      break;
    }
    default: {
      LOG_E("Incorrect type of hash.\n");
      return PA_TZ_GENERAL_ERROR;
    }
  }

  TEE_Result ret = TEE_AllocateOperation(&hmac_operation, hmac_algorithm, TEE_MODE_MAC, key_length * 8);
  if (ret != TEE_SUCCESS) {
    LOG_E("Can not allocate operation for HMAC.\n");
    return PA_TZ_GENERAL_ERROR;
  }

  do {
    ret = TEE_AllocateTransientObject(object_type, key_length * 8, &hash_object);
    if (ret != TEE_SUCCESS) {
      LOG_E("Can not allocate object for key.\n");
      break;
    }

    ret = TEES_DeriveKeyKDF(NULL, 0, NULL, 0, key_length, hash_object);
    if (ret != TEE_SUCCESS) {
      LOG_E("Can not derive key.\n");
      break;
    }

    ret = TEE_SetOperationKey(hmac_operation, hash_object);
    if (ret != TEE_SUCCESS) {
      LOG_E("Can not set operation key for HMAC.\n");
      break;
    }

    TEE_MACInit(hmac_operation, NULL, 0);

    result = PA_TZ_SUCCESS;
  } while (0);

  TEE_CloseObject(hash_object);

  if (result == PA_TZ_SUCCESS) {
    *op = hmac_operation;
  } else {
    TEE_FreeOperation(hmac_operation);
  }

  return result;
}
