
#include <stdint.h>
#include <string.h>

#include "pd_rand/pd_rand.h"
#include "drbg/ctr_drbg.h"

#include "cc_rand.h"

static uint32_t __get_le32(uint8_t *buf)
{
	return (uint32_t)buf[0] + ((uint32_t)buf[1] << 8) +
		((uint32_t)buf[2] << 16) + ((uint32_t)buf[3] << 24);
}

static int cc_drbg_reseed_helper(CTR_DRBG_CTX *ctx)
{
	int res;
	uint8_t entropy[CTR_DRBG_SEEDLEN_BYTES] = {0};

	if (!pd_rand(entropy, sizeof(entropy)))
	{
		return 0;
	}

	res = CTR_DRBG_Reseed(ctx, entropy, sizeof(entropy), NULL, 0);
	if (res != DRBG_NO_ERROR)
	{
		return 0;
	}

	return 1;
}

int cc_drbg_init(uint8_t *buf, uint32_t len)
{
	uint8_t tmp[CTR_DRBG_CTX_RAW_SIZE] = {0};
	uint8_t entropy[CTR_DRBG_KEYLEN_BYTES] = {0};
	uint8_t nonce[CTR_DRBG_KEYLEN_BYTES / 2] = {0};
	uint32_t ctx_len;
	CTR_DRBG_CTX ctx;
	int res;

	memset(&ctx, 0, sizeof(ctx));
	if (!pd_rand(entropy, sizeof(entropy)) || !pd_rand(nonce, sizeof(nonce)))
	{
		return 0;
	}
	res = CTR_DRBG_Instantiate(&ctx, entropy, sizeof(entropy), nonce,
				sizeof(nonce), NULL, 0, CTR_DRBG_KEYLEN);
	if (res != DRBG_NO_ERROR)
	{
		return 0;
	}

	ctx_len = sizeof(tmp);
	res = CTR_DRBG_Serialization(&ctx, tmp, &ctx_len);
	if (res != DRBG_NO_ERROR)
	{
		return 0;
	}

	if (ctx_len > len)
	{
		return 0;
	}

	memcpy(buf, tmp, ctx_len);

	/* TODO: zeroize temp buffers at the end */

	return 1;
}

int cc_drbg_seed(uint8_t *buf, uint32_t len)
{
	uint8_t tmp[CTR_DRBG_CTX_RAW_SIZE] = {0};
	uint32_t ctx_len;
	CTR_DRBG_CTX ctx;
	int res;

	memset(&ctx, 0, sizeof(ctx));
	res = CTR_DRBG_Deserialization(&ctx, buf, len);
	if (res != DRBG_NO_ERROR)
	{
		return 0;
	}

	res = cc_drbg_reseed_helper(&ctx);
	if (!res)
	{
		return 0;
	}

	ctx_len = sizeof(tmp);
	res = CTR_DRBG_Serialization(&ctx, tmp, &ctx_len);
	if (res != DRBG_NO_ERROR)
	{
		return 0;
	}

	if (ctx_len > len)
	{
		return 0;
	}

	memcpy(buf, tmp, ctx_len);

	/* TODO: zeroize temp buffers at the end */

	return 1;
}

int cc_drbg_bytes(uint8_t *buf, uint32_t len)
{
	int res;
	uint32_t bytes;
	CTR_DRBG_CTX ctx;

	(void)len;
	bytes = __get_le32(buf);
	memset(&ctx, 0, sizeof(ctx));
	res = CTR_DRBG_Deserialization(&ctx, buf + 4, CTR_DRBG_CTX_RAW_SIZE);
	if (res != DRBG_NO_ERROR)
	{
		return 0;
	}

	res = CTR_DRBG_Generate(&ctx, buf + 4 + CTR_DRBG_CTX_RAW_SIZE,
				bytes * 8, NULL, 0);
	if (res == DRBG_E_RESEED_NEEDED)
	{
		res = cc_drbg_reseed_helper(&ctx);
		if (!res)
		{
			return 0;
		}
		res = CTR_DRBG_Generate(&ctx, buf + 4 + CTR_DRBG_CTX_RAW_SIZE,
					bytes * 8, NULL, 0);
	}
	if (res != DRBG_NO_ERROR)
	{
		return 0;
	}

	bytes = CTR_DRBG_CTX_RAW_SIZE;
	res = CTR_DRBG_Serialization(&ctx, buf + 4, &bytes);
	if (res != DRBG_NO_ERROR || bytes != CTR_DRBG_CTX_RAW_SIZE)
	{
		return 0;
	}

	return 1;
}
