#include "TuiMnemonic.h"
#include "bip39_english.h"

#include "Vendor_Interface.h"
#include "TZ_Vendor_debug_tl.h"
#include "tl_tui_bc_error_msg.h"
#include "sha2.h"
#include "memzero.h"
#include <string.h>
#include <TuiLayout.h>

TreeNode gNodeList[MNEMONIC_LIST_NODE_SIZE];
TreeNode* gRootNode;
uint32_t listIndex = 0;

uint32_t createMnemonicTree() {
	uint8_t key[MNEMONIC_WORD_TREE_SIZE] = { '\0', };
	listIndex = 0;
	gRootNode = getNewNode(0);

	for (uint32_t i = 0; i < MNEMONIC_LIST_TOTAL_SIZE; i++) {
		if (memcmp(key, wordlist[i], MNEMONIC_WORD_TREE_SIZE) != 0) {
			memcpy(key, wordlist[i], MNEMONIC_WORD_TREE_SIZE);
			insertNode(key, i);
		}
	}
	return TIMA_SUCCESS;
}

TreeNode* getNewNode(uint8_t data) {
	gNodeList[listIndex].data = data;
	gNodeList[listIndex].mnemonicIndex = -1;
	for (uint32_t i = 0; i < ALPHABET_SIZE; i++) {
		gNodeList[listIndex].childNode[i] = NULL;
	}

	return &gNodeList[listIndex++];
}

uint32_t insertNode(uint8_t* key, uint32_t mnemonicIndex) {
	TreeNode *curNode = gRootNode;
	uint32_t index;

	for (uint32_t i = 0; i < MNEMONIC_WORD_TREE_SIZE; i++) {
		index = key[i] - 'a';
		if (curNode->childNode[index] == NULL) {
			curNode->childNode[index] = getNewNode(key[i]);
			curNode->childNode[index]->mnemonicIndex = mnemonicIndex;
		}
		curNode = curNode->childNode[index];
	}
	return TIMA_SUCCESS;
}

int32_t getMnemonicIndex(uint8_t* key) {
	TreeNode *curNode = gRootNode;
	uint32_t index;
	uint32_t keySize = (uint32_t)strlen((char *)key);

	if (keySize > MNEMONIC_WORD_MAX_SIZE || keySize == 0) {
		return -1;
	}
	
	for (uint32_t i = 0; i < ((keySize > MNEMONIC_WORD_TREE_SIZE) ? MNEMONIC_WORD_TREE_SIZE : keySize); i++) {
		index = key[i] - 'a';
		if (curNode->childNode[index] == NULL) {
			return -1;
		}
		curNode = curNode->childNode[index];
	}

	if (keySize <= MNEMONIC_WORD_TREE_SIZE) {
		return curNode->mnemonicIndex;
	} else {
		uint32_t listIndex = curNode->mnemonicIndex;
		do {
			if (memcmp(key, wordlist[listIndex], keySize) == 0) {
				return listIndex;
			}
			listIndex++;
		} while (listIndex < MNEMONIC_LIST_TOTAL_SIZE && memcmp(key, wordlist[listIndex], MNEMONIC_WORD_TREE_SIZE) == 0);
		return -1;
	}
}

uint32_t getRecommendWordList(uint8_t* key, uint8_t* recommendList[]) {
	int32_t keyIndex = getMnemonicIndex(key);

	for (uint32_t i = 0; i < RESTORE_MNEMONIC_RECOMMEND_LIST_COUNT +1; i++) {
		if (keyIndex != -1 && keyIndex +  i < MNEMONIC_LIST_TOTAL_SIZE && strlen((char*)key) != 0
				&& memcmp(key, wordlist[keyIndex + i], strlen((char*)key)) == 0) {
			recommendList[i] = (uint8_t*)wordlist[keyIndex + i];
		} else {
			recommendList[i] = NULL;
		}
	}
	return TIMA_SUCCESS;
}

int mnemonicToEntropy(const uint8_t *mnemonic, uint8_t *entropy)
{
	if (!mnemonic) {
		return 0;
	}

	uint32_t i = 0, n = 0;

	while (mnemonic[i]) {
		if (mnemonic[i] == ' ') {
			n++;
		}
		i++;
	}
	n++;

	// check number of words
	if (n != 12 && n != 15 && n != 18 && n != 21 && n != 24) {
		return 0;
	}

	char current_word[10];
	uint32_t j, k, ki, bi = 0;
	uint8_t bits[32 + 1];

	memzero(bits, sizeof(bits));
	i = 0;
	while (mnemonic[i]) {
		j = 0;
		while (mnemonic[i] != ' ' && mnemonic[i] != 0) {
			if (j >= sizeof(current_word) - 1) {
				return 0;
			}
			current_word[j] = mnemonic[i];
			i++; j++;
		}
		current_word[j] = 0;
		if (mnemonic[i] != 0) {
			i++;
		}
		k = 0;
		for (;;) {
			if (!wordlist[k]) { // word not found
				return 0;
			}
			if (strcmp(current_word, wordlist[k]) == 0) { // word found on index k
				for (ki = 0; ki < 11; ki++) {
					if (k & (1 << (10 - ki))) {
						bits[bi / 8] |= 1 << (7 - (bi % 8));
					}
					bi++;
				}
				break;
			}
			k++;
		}
	}
	if (bi != n * 11) {
		return 0;
	}
	memcpy(entropy, bits, sizeof(bits));
	return n * 11;
}

int mnemonicCheck(const uint8_t *mnemonic)
{
	uint8_t bits[32 + 1];
	int seed_len = mnemonicToEntropy(mnemonic, bits);
	if (seed_len != (12 * 11) && seed_len != (15 * 11) && seed_len != (18 * 11) && seed_len != (21 * 11) && seed_len != (24 * 11)) {
		return 0;
	}
	int words = seed_len / 11;

	uint8_t checksum = bits[words * 4 / 3];
	sha256_Raw(bits, words * 4 / 3, bits);
	if (words == 12) {
		return (bits[0] & 0xF0) == (checksum & 0xF0); // compare first 4 bits
	} else if (words == 15) {
		return (bits[0] & 0xF8) == (checksum & 0xF8); // compare first 5 bits
	} else if (words == 18) {
		return (bits[0] & 0xFC) == (checksum & 0xFC); // compare first 6 bits
	} else if (words == 21) {
		return (bits[0] & 0xFE) == (checksum & 0xFE); // compare first 7 bits
	} else if (words == 24) {
		return bits[0] == checksum; // compare 8 bits
	}
	return 0;
}
