/***************************************************************************
 *                                  _   _ ____  _
 *  Project                     ___| | | |  _ \| |
 *                             / __| | | | |_) | |
 *                            | (__| |_| |  _ <| |___
 *                             \___|\___/|_| \_\_____|
 *
 * Copyright (C) 1998 - 2014, Daniel Stenberg, <daniel@haxx.se>, et al.
 *
 * This software is licensed as described in the file COPYING, which
 * you should have received as part of this distribution. The terms
 * are also available at http://curl.haxx.se/docs/copyright.html.
 *
 * You may opt to use, copy, modify, merge, publish, distribute and/or sell
 * copies of the Software, and permit persons to whom the Software is
 * furnished to do so, under the terms of the COPYING file.
 *
 * This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
 * KIND, either express or implied.
 *
 ***************************************************************************/

/**
 * @brief Base64 encode/decode based on CURL implementation
 */

/* Base64 encoding/decoding */

#include "base64.h"

/* ---- Base64 Encoding/Decoding Table --- */
static const uint8_t base64[]=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

/* The Base 64 encoding with an URL and filename safe alphabet, RFC 4648
   section 5 */
static const uint8_t base64url[]=
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";

static uint32_t decodeQuantum(uint8_t *dest, uint8_t *src)
{
	uint32_t padding = 0;
	const uint8_t *s, *p;
	uint32_t i, x = 0;

	for(i = 0, s = src; i < 4; i++, s++) {
		uint32_t v = 0;

		if(*s == '=') {
			x = (x << 6);
			padding++;
		}
		else {
			p = base64;

			while(*p && (*p != *s)) {
				v++;
				p++;
			}

			if(*p == *s)
				x = (x << 6) + v;
			else
				return 0;
		}
	}

	if(padding < 1)
		dest[2] = (uint8_t)(x & 0x000000FF);

	x >>= 8;
	if(padding < 2)
		dest[1] = (uint8_t)(x & 0x000000FF);

	x >>= 8;
	dest[0] = (uint8_t)(x & 0x000000FF);

	return 3 - padding;
}

/**
 * @brief base64_decode
 *
 * @param[in]     src
 * @param[in,out] outptr
 * @param[in,out] outlen
 *
 * @return
 */
uint32_t base64_decode(uint8_t *src, uint32_t srclen,
		uint8_t *outptr, uint32_t *outlen)
{
	uint32_t length = 0;
	uint32_t padding = 0;
	uint32_t i;
	uint32_t numQuantums;
	uint32_t rawlen = 0;
	uint8_t *pos;

	/* Check the length of the input string is valid */
	if(!srclen || srclen % 4 || srclen > BASE64_DECODE_MAX_LEN)
		return BASE64_BAD_CONTENT_ENCODING;

	/* Find the position of any = padding characters */
	while((src[length] != '=') && src[length] && (length < srclen))
		length++;

	/* A maximum of two = padding characters is allowed */
	if(src[length] == '=') {
		padding++;
		if(src[length + 1] == '=')
			padding++;
	}

	/* Check the = padding characters weren't part way through the input */
	if(length + padding != srclen)
		return BASE64_BAD_CONTENT_ENCODING;

	/* Calculate the number of quantums */
	numQuantums = srclen / 4;

	/* Calculate the size of the decoded string */
	rawlen = (numQuantums * 3) - padding;

	/* check size of outptr */
	if (*outlen < rawlen + 1)
		return BASE64_OUT_OF_MEMORY;

	pos = outptr;

	/* Decode the quantums */
	for(i = 0; i < numQuantums; i++) {
		uint32_t result = decodeQuantum(pos, src);
		if(!result) {
			return BASE64_BAD_CONTENT_ENCODING;
		}

		pos += result;
		src += 4;
	}

	/* Zero terminate */
	*pos = '\0';

	/* Return the decoded data */
	*outlen = rawlen;

	return BASE64_OK;
}

static uint32_t decodeQuantumUrl(uint8_t *dest, uint8_t *src)
{
	uint32_t padding = 0;
	const uint8_t *s, *p;
	uint32_t i, x = 0;

	for(i = 0, s = src; i < 4; i++, s++) {
		uint32_t v = 0;

		if(*s == '=') {
			x = (x << 6);
			padding++;
		}
		else {
			p = base64url;

			while(*p && (*p != *s)) {
				v++;
				p++;
			}

			if(*p == *s)
				x = (x << 6) + v;
			else
				return 0;
		}
	}

	if(padding < 1)
		dest[2] = (uint8_t)(x & 0x000000FF);

	x >>= 8;
	if(padding < 2)
		dest[1] = (uint8_t)(x & 0x000000FF);

	x >>= 8;
	dest[0] = (uint8_t)(x & 0x000000FF);

	return 3 - padding;
}

/**
 * @brief base64url_decode
 *
 * @param[in]     src
 * @param[in,out] outptr
 * @param[in,out] outlen
 *
 * @return
 */
uint32_t base64url_decode(uint8_t *src, uint32_t srclen,
		uint8_t *outptr, uint32_t *outlen)
{
	uint32_t length = 0;
	uint32_t padding = 0;
	uint32_t i;
	uint32_t numQuantums;
	uint32_t rawlen = 0;
	uint8_t *pos;

	/* Check the length of the input string is valid */
	if(!srclen || srclen % 4 || srclen > BASE64_DECODE_MAX_LEN)
		return BASE64_BAD_CONTENT_ENCODING;

	/* Find the position of any = padding characters */
	while((src[length] != '=') && src[length] && (length < srclen))
		length++;

	/* A maximum of two = padding characters is allowed */
	if(src[length] == '=') {
		padding++;
		if(src[length + 1] == '=')
			padding++;
	}

	/* Check the = padding characters weren't part way through the input */
	if(length + padding != srclen)
		return BASE64_BAD_CONTENT_ENCODING;

	/* Calculate the number of quantums */
	numQuantums = srclen / 4;

	/* Calculate the size of the decoded string */
	rawlen = (numQuantums * 3) - padding;

	/* check size of outptr */
	if (*outlen < rawlen + 1)
		return BASE64_OUT_OF_MEMORY;

	pos = outptr;

	/* Decode the quantums */
	for(i = 0; i < numQuantums; i++) {
		uint32_t result = decodeQuantumUrl(pos, src);
		if(!result) {
			return BASE64_BAD_CONTENT_ENCODING;
		}

		pos += result;
		src += 4;
	}

	/* Zero terminate */
	*pos = '\0';

	/* Return the decoded data */
	*outlen = rawlen;

	return BASE64_OK;
}

static uint32_t base64_encode_with_table(const uint8_t *table64,
		uint8_t *inputbuff, uint32_t insize,
		uint8_t *outptr, uint32_t *outlen)
{
	uint8_t ibuf[3];
	uint8_t obuf[4];
	int i;
	int inputparts;
	uint8_t *output;
	uint8_t *base64data;

	const uint8_t *indata = inputbuff;

	if(insize == 0 || insize > BASE64_ENCODE_MAX_LEN)
		return BASE64_OUT_OF_MEMORY;

	if (*outlen < insize*4/3+4)
		return BASE64_OUT_OF_MEMORY;

	base64data = output = outptr;

	/*
	 * TODO:
	 * The base64 data needs to be created using the network encoding
	 * not the host encoding.  And we can't change the actual input
	 * so we copy it to a buffer, translate it, and use that instead.
	 */

	while(insize > 0) {
		for(i = inputparts = 0; i < 3; i++) {
			if(insize > 0) {
				inputparts++;
				ibuf[i] = (uint8_t) *indata;
				indata++;
				insize--;
			}
			else
				ibuf[i] = 0;
		}

		obuf[0] = (uint8_t)  ((ibuf[0] & 0xFC) >> 2);
		obuf[1] = (uint8_t) (((ibuf[0] & 0x03) << 4) | \
				((ibuf[1] & 0xF0) >> 4));
		obuf[2] = (uint8_t) (((ibuf[1] & 0x0F) << 2) | \
				((ibuf[2] & 0xC0) >> 6));
		obuf[3] = (uint8_t)   (ibuf[2] & 0x3F);

		switch(inputparts) {
			case 1: /* only one byte read */
				snprintf((char *)output, 5, "%c%c==",
						table64[obuf[0]],
						table64[obuf[1]]);
				break;
			case 2: /* two bytes read */
				snprintf((char *)output, 5, "%c%c%c=",
						table64[obuf[0]],
						table64[obuf[1]],
						table64[obuf[2]]);
				break;
			default:
				snprintf((char *)output, 5, "%c%c%c%c",
						table64[obuf[0]],
						table64[obuf[1]],
						table64[obuf[2]],
						table64[obuf[3]] );
				break;
		}
		output += 4;
	}
	*output = '\0';
	*outlen = strlen((char *)base64data);

	return BASE64_OK;
}

/**
 * @brief base64 encode
 *
 * @param[in]     inputbuff Input buffer to encode
 * @param[in]     insize    Length of the input buffer
 * @param[in,out] outptr    Output buffer to save the encoding result
 * @param[in,out] outlen    In: Length of output buffer; Out: result length
 *
 * @return
 */

uint32_t base64_encode(uint8_t *inputbuff, uint32_t insize,
		uint8_t *outptr, uint32_t *outlen)
{
	return base64_encode_with_table(base64, inputbuff, insize,
			outptr, outlen);
}

/**
 * @brief base64url encode
 *
 * @param[in]     inputbuff Input buffer to encode
 * @param[in]     insize    Length of the input buffer
 * @param[in,out] outptr    Output buffer to save the encoding result
 * @param[in,out] outlen    In: Length of output buffer; Out: result length
 *
 * @return
 */
uint32_t base64url_encode(uint8_t *inputbuff, uint32_t insize,
		uint8_t *outptr, uint32_t *outlen)
{
	return base64_encode_with_table(base64url, inputbuff, insize,
		       outptr, outlen);
}

/* test only */

static uint32_t base64_encode_with_table_nopad(const uint8_t *table64,
		uint8_t *inputbuff, uint32_t insize,
		uint8_t *outptr, uint32_t *outlen)
{
	uint8_t ibuf[3];
	uint8_t obuf[4];
	int i;
	int inputparts;
	uint8_t *output;
	uint8_t *base64data;

	const uint8_t *indata = inputbuff;

	if(insize == 0 || insize > BASE64_ENCODE_MAX_LEN)
		return BASE64_OUT_OF_MEMORY;

	if (*outlen < insize*4/3+4)
		return BASE64_OUT_OF_MEMORY;

	base64data = output = outptr;

	/*
	 * TODO:
	 * The base64 data needs to be created using the network encoding
	 * not the host encoding.  And we can't change the actual input
	 * so we copy it to a buffer, translate it, and use that instead.
	 */

	while(insize > 0) {
		for(i = inputparts = 0; i < 3; i++) {
			if(insize > 0) {
				inputparts++;
				ibuf[i] = (uint8_t) *indata;
				indata++;
				insize--;
			}
			else
				ibuf[i] = 0;
		}

		obuf[0] = (uint8_t)  ((ibuf[0] & 0xFC) >> 2);
		obuf[1] = (uint8_t) (((ibuf[0] & 0x03) << 4) | \
				((ibuf[1] & 0xF0) >> 4));
		obuf[2] = (uint8_t) (((ibuf[1] & 0x0F) << 2) | \
				((ibuf[2] & 0xC0) >> 6));
		obuf[3] = (uint8_t)   (ibuf[2] & 0x3F);

		switch(inputparts) {
			case 1: /* only one byte read */
				snprintf((char *)output, 5, "%c%c",
						table64[obuf[0]],
						table64[obuf[1]]);
				output += 2;
				break;
			case 2: /* two bytes read */
				snprintf((char *)output, 5, "%c%c%c",
						table64[obuf[0]],
						table64[obuf[1]],
						table64[obuf[2]]);
				output += 3;
				break;
			default:
				snprintf((char *)output, 5, "%c%c%c%c",
						table64[obuf[0]],
						table64[obuf[1]],
						table64[obuf[2]],
						table64[obuf[3]] );
				output += 4;
				break;
		}
		//output += 4;
	}
	*output = '\0';
	*outlen = strlen((char *)base64data);

	return BASE64_OK;
}

uint32_t base64url_encode_nopad(uint8_t *inputbuff, uint32_t insize,
		uint8_t *outptr, uint32_t *outlen)
{
	return base64_encode_with_table_nopad(base64url, inputbuff, insize,
		       outptr, outlen);
}

static uint32_t decodeQuantumUrlNopad(uint8_t *dest, uint8_t *src, uint8_t width)
{
	uint32_t padding = 0;
	const uint8_t *s, *p;
	uint32_t i, x = 0;

	for(i = 0, s = src; i < 4; i++, s++) {
		uint32_t v = 0;

		if(i >= width) {
			x = (x << 6);
			padding++;
		}
		else {
			p = base64url;

			while(*p && (*p != *s)) {
				v++;
				p++;
			}

			if(*p == *s)
				x = (x << 6) + v;
			else
				return 0;
		}
	}

	if(padding < 1)
		dest[2] = (uint8_t)(x & 0x000000FF);

	x >>= 8;
	if(padding < 2)
		dest[1] = (uint8_t)(x & 0x000000FF);

	x >>= 8;
	dest[0] = (uint8_t)(x & 0x000000FF);

	return 3 - padding;
}

uint32_t base64url_decode_nopad(uint8_t *src, uint32_t srclen,
		uint8_t *outptr, uint32_t *outlen)
{
	uint32_t i;
	uint32_t numQuantums;
	uint32_t rawlen = 0;
	uint8_t *pos;
	uint8_t last_width = 0;
	uint32_t result = 0;

	/* Check the length of the input string is valid */
	if(!srclen || srclen > BASE64_DECODE_MAX_LEN)
		return BASE64_BAD_CONTENT_ENCODING;

	/* Calculate the number of quantums */
	numQuantums = srclen / 4;
	last_width = srclen % 4;
	if (last_width == 1)
		return BASE64_BAD_CONTENT_ENCODING;

	/* Calculate the size of the decoded string */
	if (last_width == 0) {
		rawlen = (numQuantums * 3);
	} else {
		rawlen = (numQuantums * 3) + last_width - 1;
	}

	/* check size of outptr */
	if (*outlen < rawlen + 1)
		return BASE64_OUT_OF_MEMORY;

	pos = outptr;

	/* Decode the quantums */
	for(i = 0; i < numQuantums; i++) {
		result = decodeQuantumUrlNopad(pos, src, 4);
		if(!result) {
			return BASE64_BAD_CONTENT_ENCODING;
		}

		pos += result;
		src += 4;
	}
	if (last_width != 0) {
		result = decodeQuantumUrlNopad(pos, src, last_width);
		pos += result;
	}
	/* Zero terminate */
	*pos = '\0';

	/* Return the decoded data */
	*outlen = rawlen;

	return BASE64_OK;
}

/* ---- End of Base64 Encoding ---- */
