#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include "pebble_sss.h"
#include <openssl/rand.h>

uint8_t MULTIPLICATIVE_INVERSE_TABLE[256][256];


// Add two polynomials in GF(2^8)  // (Jianwei) a + b
uint8_t p_add(uint8_t a, uint8_t b) {
  return a ^ b;
}


// Multiply a polynomial by x in GF(2^8) // (Jianwei) a x 2
uint8_t time_x(uint8_t a) {
  if ((a >> 7) & 0x1) {
    return (a << 1) ^ IRREDUCTIBLE_POLY;
  } else {
    return (a << 1);
  }
}


// (Jianwei) a x 2 ^ x_power
uint8_t time_x_power(uint8_t a, uint8_t x_power) {
  uint8_t res = a;
  for (; x_power > 0; x_power--) {
    res = time_x(res);
  }
  return res;
}


// Multiply two polynomials in GF(2^8)  // (Jianwei) a x b
uint8_t p_mul(uint8_t a, uint8_t b) {
  uint8_t res = 0;
  for (int degree = 7; degree >= 0; degree--) {
    if ((b >> degree) & 0x1) {
      res = p_add(res, time_x_power(a, degree));
    }
  }
  return res;
}


uint8_t p_inv(uint8_t a) {
  for (int row = 0; row < 256; row++) {
    for (int col = 0; col < 256; col++) {
      MULTIPLICATIVE_INVERSE_TABLE[row][p_mul(row, col)] = col;
    }
  }
  return MULTIPLICATIVE_INVERSE_TABLE[a][1];
}


// Divide two polynomials in GF(2^8)
uint8_t p_div(uint8_t a, uint8_t b) {
  return p_mul(a, p_inv(b));
}


uint8_t rand_byte() {
  uint8_t rand = 0xff;  //avoid returning 0 in case RAND_bytes failed
  RAND_bytes(&rand, 1);
  return rand;
  //return rand() % 0xff; // removed by Jianwei
  //return arc4random() % 0xff;  // added by Jianwei, because it's safer
}


void make_random_poly(int degree, uint8_t secret, uint8_t *poly) {
  for (; degree > 0; degree--) {
    poly[degree] = rand_byte();  // coeffs are randomly generated
  }
  poly[0] = secret;  // secret (1 byte) is the constant
}


uint8_t poly_eval(uint8_t *poly, int degree , uint8_t x) {
  uint8_t res = 0;
  for (; degree >= 0; degree--) {
    uint8_t coeff = poly[degree];
    uint8_t term = 0x01;
    for (int times = degree; times > 0; times--) {
      term = p_mul(term, x);
    }
    res = p_add(res, p_mul(coeff, term));
  }
  return res;
}


// Interpolate a (k-1) degree polynomial and evaluate it at x = 0
uint8_t poly_interpolate(uint8_t *xs, uint8_t *ys, int k) {
  uint8_t res = 0;
  for (int j = 0; j < k; j++) {
    uint8_t prod = 0x01;
    for (int m = 0; m < k; m++) {
      if (m != j) {
        prod = p_mul(prod, p_div(xs[m], p_add(xs[m], xs[j])));
      }
    }
    res = p_add(res, p_mul(ys[j], prod));
  }
  return res;
}


void split(uint8_t *secret, size_t secret_size, int n, int k, uint8_t shares[][SSS_SHARE_MAX_LEN], size_t * share_size) {
  if (n < k || k < 1 || secret_size < 1 || secret_size > SSS_KEY_MAX_LEN) {
    *share_size = 0;  // invalid input
    return;
  }
  *share_size = secret_size + 1;
  for (int i = 0; i < n; i++) {  // pick x points
    while (1) {
      shares[i][0] = rand_byte();
      while (shares[i][0] == 0) {  // x cannot be 0
        shares[i][0] = rand_byte();
      }
      int j = 0;
      for (; j < i; j++) {  // cannot use the same x with previous shares
        if (shares[i][0] == shares[j][0]) break; 
      }
      if (j == i) break; // found a valid x
    }
  }
  for (int secret_idx = 0; secret_idx < secret_size; secret_idx++) {
    uint8_t poly[k];
    make_random_poly(k-1, secret[secret_idx], poly);
    // Evaluate poly on every one of the n x points
    for (int i = 0; i < n; i++) {
      shares[i][secret_idx + 1] = poly_eval(poly, k-1, shares[i][0]);
    }
  }
}


void join(uint8_t shares[][SSS_SHARE_MAX_LEN], size_t share_size, int k, uint8_t secret[SSS_KEY_MAX_LEN], size_t * secret_size) {
  if (share_size < 2 || share_size > SSS_SHARE_MAX_LEN || k < 1 || k > MAX_BUDDIES) {
	  *secret_size =  0; // invalid input
	  return;
  }
  *secret_size = share_size - 1;
  for (int secret_idx = 1; secret_idx <= share_size - 1; secret_idx++) {
    uint8_t xs[k];
    uint8_t ys[k];
    for (int i = 0; i < k; i++) {
      xs[i] = shares[i][0];
      ys[i] = shares[i][secret_idx];
    }
    secret[secret_idx-1] = poly_interpolate(xs, ys, k);
    //printf("join %c", (char) secret[secret_idx-1]);
  }
}

