#include "base64_decode.h"

#define WHITESPACE 64
#define EQUALS     65
#define INVALID    66

static const unsigned char d[] = {
  66,66,66,66,66,66,66,66,66,66,64,66,66,66,66,66,66,66,66,66,66,66,66,66,66,
  66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,62,66,66,66,63,52,53,
  54,55,56,57,58,59,60,61,66,66,66,65,66,66,66, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
  10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,66,66,66,66,66,66,26,27,28,
  29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,66,66,
  66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,
  66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,
  66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,
  66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,
  66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,66,
  66,66,66,66,66,66
};

int base64decode(const char *in, size_t inLen, unsigned char *out,
                 size_t *outLen) {
  const char *end = in + inLen;
  char iter = 0;
  size_t buf = 0, len = 0;

  while (in < end) {
    unsigned char c = d[(unsigned) (*in++)];

    switch (c) {
      case WHITESPACE:
        continue; /* skip whitespace */
      case INVALID:
        return 1; /* invalid input, return error */
      case EQUALS: /* pad character, end of data */
        in = end;
        continue;
      default:
        buf = buf << 6 | c;
        iter++;  // increment the number of iteration
        /* If the buffer is full, split it into bytes */
        if (iter == 4) {
          if ((len += 3) > *outLen)
            return 1; /* buffer overflow */
          *(out++) = (buf >> 16) & 255;
          *(out++) = (buf >> 8) & 255;
          *(out++) = buf & 255;
          buf = 0;
          iter = 0;

        }
    }
  }

  if (iter == 3) {
    if ((len += 2) > *outLen)
      return 1; /* buffer overflow */
    *(out++) = (buf >> 10) & 255;
    *(out++) = (buf >> 2) & 255;
  } else if (iter == 2) {
    if (++len > *outLen)
      return 1; /* buffer overflow */
    *(out++) = (buf >> 4) & 255;
  }

  *outLen = len; /* modify to reflect the actual output size */
  return 0;
}
