/*
 *	The RSA public-key cryptosystem
 *
 *	Based on XySSL: Copyright (C) 2006-2008	 Christophe Devine
 *
 *	Copyright (C) 2009	Paul Bakker <polarssl_maintainer at polarssl dot org>
 *
 *	All rights reserved.
 *
 *	Redistribution and use in source and binary forms, with or without
 *	modification, are permitted provided that the following conditions
 *	are met:
 *
 *	  * Redistributions of source code must retain the above copyright
 *		notice, this list of conditions and the following disclaimer.
 *	  * Redistributions in binary form must reproduce the above copyright
 *		notice, this list of conditions and the following disclaimer in the
 *		documentation and/or other materials provided with the distribution.
 *	  * Neither the names of PolarSSL or XySSL nor the names of its contributors
 *		may be used to endorse or promote products derived from this software
 *		without specific prior written permission.
 *
 *	THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 *	"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 *	LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 *	FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 *	OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 *	SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
 *	TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
 *	PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 *	LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 *	NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
 *	SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
/*
 *	RSA was designed by Ron Rivest, Adi Shamir and Len Adleman.
 *
 *	http://theory.lcs.mit.edu/~rivest/rsapaper.pdf
 *	http://www.cacr.math.uwaterloo.ca/hac/about/chap8.pdf
 */

#include <string.h>

#include "CommLayerData.h"
#include "CryptoPlatform.h"
#include "bn/bn_wrapper.h"
#include "rsa/rsa.h"
#include "rand/rand.h"
#include "sec_alloc.h"
#include "log.h"

#define CHK(f) if ( f != 1 ) { LOGD(#f " failed"); goto cleanup; }

static unsigned long rsa_exponent = RSA_EXPONENT;

void rsa_set_exponent(unsigned long exp)
{
	rsa_exponent = exp;
}

unsigned long rsa_get_exponent(void)
{
	return rsa_exponent;
}

RSA *RSA_new()
{
	RSA *ret = NULL;

	ret = sec_malloc(sizeof(RSA));
	if (!ret) {
		return NULL;
	}

	memset(ret, 0, sizeof(RSA));

	return ret;
}

void RSA_free(RSA *r)
{
	if (r->n != NULL) BN_clear_free(r->n);
	if (r->e != NULL) BN_clear_free(r->e);
	if (r->d != NULL) BN_clear_free(r->d);
	if (r->p != NULL) BN_clear_free(r->p);
	if (r->q != NULL) BN_clear_free(r->q);
	if (r->dmp1 != NULL) BN_clear_free(r->dmp1);
	if (r->dmq1 != NULL) BN_clear_free(r->dmq1);
	if (r->iqmp != NULL) BN_clear_free(r->iqmp);
	sec_free(r);
}

int rsa_gen_key_mldap(RSA *rsa, int bits, unsigned long e)
{
	if (generate_RSA_key(rsa, bits, e) != 1) {
		return PLATFORM_INTERNAL_ERROR;
	} else {
		return NO_ERROR;
	}
}

/*
 * Do an RSA public key operation
 */
static int rsa_public_mldap(RSA *rsa, unsigned char *input, unsigned char *output)
{
	int ret = PLATFORM_INTERNAL_ERROR;
	int len = 0;
	BN_CTX *ctx = NULL;
	BIGNUM *from = NULL;
	BIGNUM *to = NULL;

	ctx = BN_CTX_new();
	if (!ctx) {
		LOGE("assert (!ctx)");
		goto cleanup;
	}

	from = BN_CTX_get(ctx);
	to = BN_CTX_get(ctx);
	if (!from || !to) {
		LOGE("assert (!from || !to)");
		goto cleanup;
	}

	len = BN_num_bytes(rsa->n);
	if (BN_bin2bn(input, len, from) == NULL) {
		LOGE("BN_bin2bn() failed");
		goto cleanup;
	}

	if (BN_ucmp(from, rsa->n) >= 0) {
		LOGE("BN_ucmp() failed");
		goto cleanup;
	}

	CHK(BN_mod_exp(to, from, rsa->e, rsa->n, ctx));

	if ((ret = bn_dump(to, output, len)) != NO_ERROR) {
		LOGE("bn_dump() failed : %d, len : %d", ret, len);
		goto cleanup;
	}

	ret = NO_ERROR;

cleanup:

	if (ctx) {
		BN_CTX_end(ctx);
		BN_CTX_free(ctx);
	}

	return ret;
}

/*
 * Do an RSA private key operation
 */
static int rsa_private_mldap(RSA *rsa, unsigned char *input, unsigned char *output)
{
	int ret = PLATFORM_INTERNAL_ERROR;
	int len = 0;
	BN_CTX *ctx = NULL;
	BIGNUM *T = NULL;
	BIGNUM *T1 = NULL;
	BIGNUM *T2 = NULL;
	BIGNUM *out = NULL;

	ctx = BN_CTX_new();
	if (!ctx) {
		goto cleanup;
	}

	T = BN_CTX_get(ctx);
	T1 = BN_CTX_get(ctx);
	T2 = BN_CTX_get(ctx);
	out = BN_CTX_get(ctx);
	if (!T || !T1 || !T2 || !out) {
		goto cleanup;
	}

	len = BN_num_bytes(rsa->n);
	if (BN_bin2bn(input, len, T) == NULL) {
		goto cleanup;
	}

	if (BN_ucmp(T, rsa->n) >= 0) {
		goto cleanup;
	}

	/*
	 * faster decryption using the CRT
	 * T1 = input ^ dP mod P
	 * T2 = input ^ dQ mod Q
	 */
	CHK(BN_mod_exp(T1, T, rsa->dmp1, rsa->p, ctx));
	CHK(BN_mod_exp(T2, T, rsa->dmq1, rsa->q, ctx));

	/*
	 * T = (T1 - T2) * (Q^-1 mod P) mod P
	 */
	CHK(BN_sub(T, T1, T2));

	CHK(BN_mul(T1, T, rsa->iqmp, ctx));
	CHK(BN_mod(T, T1, rsa->p, ctx));

	if (T->neg) {
		CHK(BN_add(out, T, rsa->p));
		if (BN_copy(T, out) == NULL) {
			LOGD("BN_copy() failed");
			goto cleanup;
		}
	}

	/*
	 * output = T2 + T * Q
	 */
	CHK(BN_mul(T1, T, rsa->q, ctx));
	CHK(BN_add(out, T2, T1));

	if (bn_dump(out, output, len) != NO_ERROR) {
		LOGD("bn_dump() failed");
		goto cleanup;
	}

	ret = NO_ERROR;

cleanup:

	if (ctx) {
		BN_CTX_end(ctx);
		BN_CTX_free(ctx);
	}

	return ret;
}

/*
 * Do an RSA operation to sign the message digest
 */
int rsa_pkcs1_sign_mldap(RSA *rsa, int hash_id, int hashlen, unsigned char *hash, unsigned char *sig)
{
	int nb_pad = 0, olen = 0;
	unsigned char *p = sig;

	olen = BN_num_bytes(rsa->n);

	/* the only PKCS#1 v1.5 supported */

	switch (hash_id)
	{
		case RSA_SHA1:
			nb_pad = olen - 3 - 35;
			break;
		case RSA_SHA256:
			nb_pad = olen - 3 - 51;
			break;
		default:
			return WRONG_DATA;
	}

	if (nb_pad < 8)
		return WRONG_DATA;

	*p++ = 0;
	*p++ = RSA_SIGN;
	memset(p, 0xFF, nb_pad);
	p += nb_pad;
	*p++ = 0;

	switch (hash_id)
	{
		case RSA_SHA1:
			memcpy(p, ASN1_HASH_SHA1, 15);
			memcpy(p + 15, hash, 20);
			break;
		case RSA_SHA256:
			memcpy(p, ASN1_HASH_SHA256, 19);
			memcpy(p + 19, hash, 32);
			break;
		default:
			return WRONG_DATA;
	}

	return rsa_private_mldap(rsa, sig, sig);
}

/*
 * Do an RSA operation and check the message digest
 */
int rsa_pkcs1_verify_mldap(RSA *rsa, int hash_id, int hashlen, unsigned char *hash, unsigned char *sig)
{
	int ret = 0, len = 0, siglen = 0;
	unsigned char *p = NULL;
	unsigned char buf[512] = {0};

	siglen = BN_num_bytes(rsa->n);

	if (siglen < 16 || siglen > (int)sizeof(buf))
		return WRONG_DATA;

	ret = rsa_public_mldap(rsa, sig, buf);
	if (ret != 0) {
		LOGE("ERROR rsa_public : %d", ret);
		return (ret);
	}

	p = buf;

	if (*p++ != 0 || *p++ != RSA_SIGN) {
		LOGE("rsa_pkcs1_verify_mldap() RSA_SIGN tag missing");
		return WRONG_DATA;
	}

	while (*p != 0) {
		if (p >= buf + siglen - 1 || *p != 0xFF) {
			LOGE("rsa_pkcs1_verify_mldap() signature length is wrong");
			return WRONG_DATA;
		}
		p++;
	}
	p++;

	len = siglen - (int)(p - buf);

	if (len == 35 && hash_id == RSA_SHA1) {
		if (memcmp(p, ASN1_HASH_SHA1, 15) == 0 && memcmp(p + 15, hash, 20) == 0) {
			return NO_ERROR;
		} else {
			LOGE("SHA1 hash is wrong");
			return WRONG_DATA;
		}
	}

	if (len == 51 && hash_id == RSA_SHA256) {
		if (memcmp(p, ASN1_HASH_SHA256, 19) == 0 && memcmp(p + 19, hash, 32) == 0) {
			return NO_ERROR;
		} else {
			LOGE("SHA256 hash is wrong");
			return WRONG_DATA;
		}
	}

	if (len == 83 && hash_id == RSA_SHA512) {
		if (memcmp(p, ASN1_HASH_SHA512, 19) == 0 && memcmp(p + 19, hash, 64) == 0) {
			return NO_ERROR;
		} else {
			LOGE("SHA512 hash is wrong");
			return WRONG_DATA;
		}
	}

	return WRONG_DATA;
}

int rsa_check_keypair(RSA *rsa)
{
	/* SHA256 digest to check */
	unsigned char digest[32] = {0};
	unsigned char sign[512] = {0};
	int ret = 0;

	if (!rsa->n || !rsa->e || !rsa->p || !rsa->q || !rsa->dmp1 || !rsa->dmq1 || !rsa->iqmp) {
		LOGD("rsa_check_keypair: some number is NULL");
		return WRONG_PRIV_KEY;
	}

	if (RAND_bytes(digest, sizeof(digest)) != 1) {
		/* use hardcoded string if random fails */
		strncpy((char *)digest, "1234567890qwertyuiop[]asdfghjkl", sizeof(digest));
	}

	ret = rsa_pkcs1_sign_mldap(rsa, RSA_SHA256, SHA256_DIGEST_LEN, digest, sign);
	if (ret != NO_ERROR) {
		LOGE("rsa_pkcs1_sign_mldap() failed: ret = %d", ret);
		return PLATFORM_INTERNAL_ERROR;
	}

	ret = rsa_pkcs1_verify_mldap(rsa, RSA_SHA256, SHA256_DIGEST_LEN, digest, sign);
	if (ret != NO_ERROR) {
		LOGE("rsa_pkcs1_verify_mldap() failed: ret = %d", ret);
		return WRONG_PRIV_KEY;
	}

	return NO_ERROR;
}