/* Copyright 2016 Samsung Electronics Co., Ltd.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *  http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* Implementation is not complete but enough for tests */
#include "superlong.h"

#include <stdio.h>
#include <assert.h>
#include <limits.h>
#include <stdint.h>

#define SL_MOD ((SCL_UMAX_MAX * 0 + 1) << (sizeof(scl_umax_t) * CHAR_BIT / 2))

superlong_t const sl_zero = {0, 0, 0 > 1, 0 > 1};

sl_range_t sl_limits[SCL_ITID_COUNT];

static int verbose;

extern void sl_init(int verb) {
	verbose = verb;

	sl_limits[SCL_ITID_CHAR_SIGNED].min = sl_i(SCHAR_MIN);
	sl_limits[SCL_ITID_CHAR_SIGNED].max = sl_i(SCHAR_MAX);
	sl_limits[SCL_ITID_CHAR_UNSIGNED].min = sl_i(0);
	sl_limits[SCL_ITID_CHAR_UNSIGNED].max = sl_u(UCHAR_MAX);

	sl_limits[SCL_ITID_SHORT].min = sl_i(SHRT_MIN);
	sl_limits[SCL_ITID_SHORT].max = sl_i(SHRT_MAX);
	sl_limits[SCL_ITID_SHORT_UNSIGNED].min = sl_i(0);
	sl_limits[SCL_ITID_SHORT_UNSIGNED].max = sl_u(USHRT_MAX);

	sl_limits[SCL_ITID_INT].min = sl_i(INT_MIN);
	sl_limits[SCL_ITID_INT].max = sl_i(INT_MAX);
	sl_limits[SCL_ITID_UNSIGNED].min = sl_i(0);
	sl_limits[SCL_ITID_UNSIGNED].max = sl_u(UINT_MAX);

	sl_limits[SCL_ITID_LONG].min = sl_i(LONG_MIN);
	sl_limits[SCL_ITID_LONG].max = sl_i(LONG_MAX);
	sl_limits[SCL_ITID_LONG_UNSIGNED].min = sl_i(0);
	sl_limits[SCL_ITID_LONG_UNSIGNED].max = sl_u(ULONG_MAX);

	sl_limits[SCL_ITID_LONG_LONG].min = sl_i(SCL_IMAX_MIN);
	sl_limits[SCL_ITID_LONG_LONG].max = sl_i(SCL_IMAX_MAX);
	sl_limits[SCL_ITID_LONG_LONG_UNSIGNED].min = sl_i(0);
	sl_limits[SCL_ITID_LONG_LONG_UNSIGNED].max = sl_u(SCL_UMAX_MAX);
}

extern int sl_print(superlong_t s) {
	int l;
	if (s.overflow) {
		l = printf("{overflow}");
	} else {
		if (s.sign) {
			l = printf("{- ");
		} else {
			l = printf("{+ ");
		}
		l += printf("%lu : %lu}", (unsigned long)s.low, (unsigned long)s.high);
	}
	return l;
}

extern superlong_t sl_i(scl_imax_t i) {
	superlong_t sl;
	sl.sign = i < 0;
	if (!sl.sign) {
		sl.low = (scl_umax_t)i;
	} else {
		sl.low = (scl_umax_t)(-1 - i) + 1;
	}
	sl.high = 0;
	sl.overflow = 0 > 1;

	if (verbose) {
		printf("sl_i(%lld) = ", (long long)i);
		sl_print(sl);
		printf("\n");
	}
	return sl;
}

extern superlong_t sl_u(scl_umax_t i) {
	superlong_t s;
	s.sign = 0 > 1;
	s.low = i;
	s.high = 0;
	s.overflow = 0 > 1;
	return s;
}

extern superlong_t sl_u_lh(scl_umax_t low, scl_umax_t high) {
	superlong_t s;
	s.sign = 0 > 1;
	s.overflow = 0 > 1;
	s.low = low;
	s.high = high;
	return s;
}

static superlong_t sl_u_shift(scl_umax_t u, int l) {
	superlong_t s;
	s.sign = 0 > 1;
	s.overflow = 0 > 1;

	switch (l) {
	case 0:
		s.low = u;
		s.high = 0;
		break;
	case 1:
		s.low = u % SL_MOD * SL_MOD;
		s.high = u / SL_MOD;
		break;
	case 2:
		s.low = 0;
		s.high = u;
		break;
	default:
		assert(0);
		break;
	}
	return s;
}

static superlong_t sl_modulo_add(superlong_t a1, superlong_t a2) {
	scl_umax_t r;
	superlong_t s;
	s.sign = a1.sign;
	r = a1.low % SL_MOD + a2.low % SL_MOD;
	s.low = r % SL_MOD;
	r = r / SL_MOD + a1.low / SL_MOD + a2.low / SL_MOD;
	s.low += r % SL_MOD * SL_MOD;

	r = r / SL_MOD + a1.high % SL_MOD + a2.high % SL_MOD;
	s.high = r % SL_MOD;
	r = r / SL_MOD + a1.high / SL_MOD + a2.high / SL_MOD;
	s.high += r % SL_MOD * SL_MOD;

	s.overflow = r >= SL_MOD;
	return s;
}

static superlong_t sl_modulo_sub(superlong_t a1, superlong_t a2) {
	scl_umax_t r;
	superlong_t s;
	s.sign = a1.sign;
	r = SL_MOD + a1.low % SL_MOD - a2.low % SL_MOD;
	s.low = r % SL_MOD;
	r = (SL_MOD  - 1) + r / SL_MOD + a1.low / SL_MOD - a2.low / SL_MOD;
	s.low += r % SL_MOD * SL_MOD;

	r = (SL_MOD - 1) + r / SL_MOD + a1.high % SL_MOD - a2.high % SL_MOD;
	s.high = r % SL_MOD;
	r = (SL_MOD - 1) + r / SL_MOD + a1.high / SL_MOD - a2.high / SL_MOD;
	s.high += r % SL_MOD * SL_MOD;
	s.overflow = r < SL_MOD;
	return s;
}

static int sl_modulo_cmp(superlong_t a1, superlong_t a2) {
	int ret;
	if (a1.high < a2.high) {
		ret = -1;
	} else if (a1.high > a2.high) {
		ret = +1;
	} else if (a1.low < a2.low) {
		ret = -1;
	} else if (a1.low > a2.low) {
		ret = +1;
	} else {
		ret = 0;
	}
	return ret;
}

extern int sl_cmp(superlong_t a1, superlong_t a2) {
	int ret;
	if (a1.sign == a2.sign) {
		ret = sl_modulo_cmp(a1, a2);
		if (a1.sign) {
			ret = -ret;
		}
	} else if (a1.sign) {
		ret = -1;
	} else {
		ret = +1;
	}

	if (verbose) {
		printf("sl_cmp(");
		sl_print(a1);
		printf(", ");
		sl_print(a2);
		printf(") = %d\n", ret);
	}
	return ret;
}

extern int sl_in_range(superlong_t a, sl_range_t r) {
	int ret;
	ret = sl_cmp(a, r.min);
	if (ret > 0) {
		ret = sl_cmp(a, r.max);
		if (ret < 0) {
			ret = 0;
		}
	}

	if (verbose) {
		printf("sl_in_range %d\n", ret);
	}
	return ret;
}

extern superlong_t sl_neg(superlong_t a) {
	if ((a.low != 0) || (a.high != 0)) {
		a.sign = !a.sign;
	}
	return a;
}

extern superlong_t sl_abs(superlong_t a) {
	a.sign = 0 > 1;
	return a;
}

extern sl_range_t sl_range(superlong_t min, superlong_t max) {
	sl_range_t r;
	assert(sl_cmp(min, max) <= 0);
	r.min = min;
	r.max = max;
	return r;
}

extern superlong_t sl_add(superlong_t a1, superlong_t a2) {
	superlong_t s;
	int cmp;
	s = sl_zero;
	s.overflow = a1.overflow || a2.overflow;
	if (s.overflow) {
		;
	} else if (a1.sign == a2.sign) {
		s = sl_modulo_add(a1, a2);
	} else {
		cmp = sl_modulo_cmp(a1, a2);
		if (cmp < 0) {
			s = sl_modulo_sub(a2, a1);
		} else if (cmp > 0) {
			s = sl_modulo_sub(a1, a2);
		} else {
			s = sl_zero;
		}
	}

	if (verbose) {
		printf("sl_add(");
		sl_print(a1);
		printf(", ");
		sl_print(a2);
		printf(") = ");
		sl_print(s);
		printf("\n");
	}
	return s;
}

extern superlong_t sl_sub(superlong_t a1, superlong_t a2) {
	superlong_t s;
	int cmp;
	s = sl_zero;
	s.overflow = a1.overflow || a2.overflow;
	if (s.overflow) {
		;
	} else if (a1.sign != a2.sign) {
		s = sl_modulo_add(a1, a2);
	} else {
		cmp = sl_modulo_cmp(a1, a2);
		if (cmp < 0) {
			s = sl_modulo_sub(a2, a1);
			s.sign = !s.sign;
		} else if (cmp > 0) {
			s = sl_modulo_sub(a1, a2);
		} else {
			s = sl_zero;
		}
	}

	if (verbose) {
		printf("sl_sub(");
		sl_print(a1);
		printf(", ");
		sl_print(a2);
		printf(") = ");
		sl_print(s);
		printf("\n");
	}
	return s;
}

extern superlong_t sl_mult(superlong_t m1, superlong_t m2) {
	superlong_t r, t1, t2, t3;
	scl_umax_t l1l, l1h, l2l, l2h, h1l, h1h, h2l, h2h;
	(void)h1l; (void)h2l;
	r = sl_zero;
	r.overflow = m1.overflow || m2.overflow || ((m1.high != 0) && (m2.high != 0));
	if (!r.overflow) {

		l1l = m1.low % SL_MOD;
		l1h = m1.low / SL_MOD;
		l2l = m2.low % SL_MOD;
		l2h = m2.low / SL_MOD;

		h1l = m1.high % SL_MOD;
		h1h = m1.high / SL_MOD;
		h2l = m2.high % SL_MOD;
		h2h = m2.high / SL_MOD;

		r.overflow = ((h2h > 0) && (l1h > 0))
				  || ((h1h > 0) && (l2h > 0))
				  /* TODO implement full multiplication */
				  || (m1.high != 0) || (m2.high != 0);
		if (!r.overflow) {
			t1 = sl_u_lh(l1l * l2l, l1h * l2h);
			t2 = sl_u_shift(l1l * l2h, 1);
			t3 = sl_u_shift(l1h * l2l, 1);

			r = sl_modulo_add(t1, sl_modulo_add(t2, t3));

			r.sign = ((r.low != 0) || (r.high != 0)) && (m1.sign != m2.sign);
		}
	}

	if (verbose) {
		printf("sl_mult(");
		sl_print(m1);
		printf(", ");
		sl_print(m2);
		printf(") = ");
		sl_print(r);
		printf("\n");
	}
	return r;
}

extern superlong_t sl_div(superlong_t a, superlong_t d) {
	superlong_t r;
	r = sl_zero;
	r.overflow = a.overflow || d.overflow
			  || ((d.low == 0) && (d.high == 0));
	if (!r.overflow) {
		if ((d.low == 1) && (d.high == 0)) {
			r = a;
		} else if ((a.high != 0) || (d.high != 0)) {
			/* TODO full division */
			r.overflow = 0 < 1;
		} else {
			r.low = a.low / d.low;
			r.high = 0;
		}
		r.sign = ((r.low != 0) || (r.high != 0))
			  && (a.sign != d.sign);
	}

	if (verbose) {
		printf("sl_div(");
		sl_print(a);
		printf(", ");
		sl_print(d);
		printf(") = ");
		sl_print(r);
		printf("\n");
	}
	return r;
}

extern superlong_t sl_mod(superlong_t a, superlong_t d) {
	superlong_t r;
	r = sl_zero;
	r.overflow = a.overflow || d.overflow
			  || ((d.low == 0) && (d.high == 0));
	if (!r.overflow) {
		if ((d.low == 1) && (d.high == 0)) {
			r = sl_zero;
		} else if ((a.high != 0) || (d.high != 0)) {
			/* TODO full division by module */
			r.overflow = 0 < 1;
		} else {
			r.low = a.low % d.low;
			r.high = 0;
		}
		r.sign = ((r.low != 0) || (r.high != 0))
			  && ((a.sign ? -1 : 1) % (a.sign ? -2 : 2) < 0);
	}

	if (verbose) {
		printf("sl_mod(");
		sl_print(a);
		printf(", ");
		sl_print(d);
		printf(") = ");
		sl_print(r);
		printf("\n");
	}
	return r;
}
