/*****************************************************
                SamsungSVW_CRNNS.c
******************************************************/
#include "CRN_CircullantD.h"
#include "CRN_CircullantD_impl.h"
#include "SV_basic_op.h"


#define CRN_INPUT_LEN 257

#if defined (_WIN32) && defined(BSH_DEBUG)
#define CRN_DEBUG
#endif

#ifdef CRN_DEBUG
#include <stdio.h>
extern FILE *fCRN_speechLogSpec, *fCRN_noiseLogSpec, *fCRN_estMask;
#endif

typedef long long INT64;
#define SMULL( a,   b )\
    ((INT64)(a)*(INT64)(b))

// scratch memory for speechLogSpec, noiseLogSpec, compensatedLogSpec, estMask, interChannelVariance, bias, Gmin, GainBias
static short scratchMem[CRN_INPUT_LEN * 9] ;
static short *speechLogSpec = scratchMem;
static short *noiseLogSpec = scratchMem + CRN_INPUT_LEN;
static short *compensatedLogSpec = scratchMem + 2 * CRN_INPUT_LEN;
static short *estMask = scratchMem + 3 * CRN_INPUT_LEN;
static short *interChannelVariance = scratchMem + 4 * CRN_INPUT_LEN;
static short *bias = scratchMem + 5 * CRN_INPUT_LEN;
static short *Gmin = scratchMem + 6 * CRN_INPUT_LEN;
static short *Gbias = scratchMem + 7 * CRN_INPUT_LEN;
static short *interChannelVarianceEC = scratchMem + 8 * CRN_INPUT_LEN;
static short TxVAD;
static short RxVAD;

//static int Fx_MultRegGain(short *spec_in, short *spec_out, short *gain_in);
//static int Fx_logSpec(short *spec_in, short *logspec_out, short len, short qForm);
static int Fx_MultRegGain(int *spec_in, int *spec_out, short *gain_in);
static void Fx_logSpec(int *spec_in, short *logspec_out, short len);
static void normalizeInterChannelVariance(short *noiseLogSpec, short *compensatedLogSpec, int RxVAD);
static void compensateEstimatedMask(short *estMask);

void SV_CirCRNNS_Init()
{
	int i;

	CRN_CircullantD_Init();

	for (i = 0; i < CRN_INPUT_LEN; i++) {
		speechLogSpec[i] = 0;
		noiseLogSpec[i] = 0;
		compensatedLogSpec[i] = 0;
		estMask[i] = 0;
		interChannelVariance[i] = 0;
		interChannelVarianceEC[i] = 0;
	}

#if 0
	for (i = 0; i < 16; i++)		 bias[i] = (0 << 8); // ~0.5kHz (Q8)
	for (; i < 32; i++)			 bias[i] = (0 << 8); // 0.5~1kHz
	for (; i < 64; i++)			 bias[i] = (0 << 8); // 1~2kHz
	for (; i < 128; i++)			 bias[i] = (0 << 8); // 2~4kHz
	for (; i < CRN_INPUT_LEN; i++) bias[i] = (0 << 8);  // 4~8kHz
#else
	for (i = 0; i < 16; i++)		 bias[i] = (7 << 8); // ~0.5kHz (Q8)
	for (; i < 32; i++)			 bias[i] = (7 << 8); // 0.5~1kHz
	for (; i < 64; i++)			 bias[i] = (5 << 8); // 1~2kHz
	for (; i < 128; i++)			 bias[i] = (5 << 8); // 2~4kHz
	for (; i < CRN_INPUT_LEN; i++) bias[i] = (3 << 8);  // 4~8kHz
#endif
	for (i = 0; i < 16; i++)		 Gmin[i] = 327; // ~0.5kHz (Q15)
	for (; i < 32; i++)			 Gmin[i] = 327; // 0.5~1kHz
	for (; i < 64; i++)			 Gmin[i] = 327; // 1~2kHz
	for (; i < 128; i++)			 Gmin[i] = 800; // 2~4kHz
	for (; i < CRN_INPUT_LEN; i++) Gmin[i] = 800;  // 4~8kHz
	//for (; i < CRN_INPUT_LEN; i++) Gmin[i] = 3277;  // 4~8kHz

	for (i = 0; i < 16; i++)		 Gbias[i] = 0; // ~0.5kHz (Q15)
	for (; i < 32; i++)			 Gbias[i] = 0; // 0.5~1kHz
	for (; i < 64; i++)			 Gbias[i] = 0; // 1~2kHz
	for (; i < 128; i++)			 Gbias[i] = 0; // 2~4kHz
	for (; i < CRN_INPUT_LEN; i++) Gbias[i] = 1600;  // 4~8kHz
	//for (; i < CRN_INPUT_LEN; i++) Gbias[i] = 8192;  // 4~8kHz
	
}

void SV_CirCRNNSProc(int *DNNout, int *speechBFout, int *noiseBFout, int RxVAD) {


	// feature extraction
	Fx_logSpec(speechBFout, speechLogSpec, 512); // Log magnitude spectrum (Q8 format)
	Fx_logSpec(noiseBFout, noiseLogSpec, 512);
	normalizeInterChannelVariance(noiseLogSpec, compensatedLogSpec, RxVAD);

	// DNN inference 
	CRN_CircullantD_Exe(speechLogSpec, compensatedLogSpec, estMask);

	compensateEstimatedMask(estMask);

	Fx_MultRegGain(speechBFout, DNNout, estMask);

#ifdef CRN_DEBUG
	float ftmp;
	for (int i = 0; i < CRN_INPUT_LEN; i++)
	{
		ftmp = speechLogSpec[i] / (float)(1 << 8);
		fwrite(&ftmp, sizeof(float), 1, fCRN_speechLogSpec);
	}
	for (int i = 0; i < CRN_INPUT_LEN; i++)
	{
		ftmp = compensatedLogSpec[i] / (float)(1 << 8);
		fwrite(&ftmp, sizeof(float), 1, fCRN_noiseLogSpec);
	}

	for (int i = 0; i < CRN_INPUT_LEN; i++)
	{
		ftmp = estMask[i] / (float)(1 << 15);
		fwrite(&ftmp, sizeof(float), 1, fCRN_estMask);
	}
#endif
}

void SV_CirCRNNSUpdate(short *refMask, short TxVAD, short RxVAD) {
	int i;

	short powerDiff;
	short alpha, alpha_rev;  //Q15

	alpha = 31785;
	alpha_rev = 0x7fff - alpha;

	if (TxVAD == 0 && RxVAD == 0) // noise only
	{
		for (i = 0; i < CRN_INPUT_LEN; i++) {

			powerDiff = noiseLogSpec[i] - speechLogSpec[i];
			if (powerDiff > 0) {
				interChannelVariance[i] = (short)(((int)alpha * (int)interChannelVariance[i] + (int)alpha_rev * (int)powerDiff + 16384) >> 15);
			}
		}
	}
	else if(TxVAD == 0 && RxVAD == 1)
	{
		for (i = 0; i < CRN_INPUT_LEN; i++) {

			powerDiff = (noiseLogSpec[i] - interChannelVariance[i]) - speechLogSpec[i];
			if (powerDiff > 0) {
				interChannelVarianceEC[i] = (short)(((int)alpha * (int)interChannelVarianceEC[i] + (int)alpha_rev * (int)powerDiff + 16384) >> 15);
			}
		}
		// TODO: refMask based
	}
}

//typedef long long INT64;
//#define MAX(a,b) (((a)>(b)) ? (a):(b))
//#define ABS(a) (((a)>=0) ? (a):(-a))
#define QOUTDATA_INTERMEDIATE (26)
#define QOUTDATA (8)
#define QMAG2 (31)
#define QDBPOWCOEFF 13
#define LOG10INV 4932        // round(1/log2(10.0)*2^14)
#define QLOG10INV 14
#define LOG10FIX  /* fnLog10(1) */ -626068080
static const short log2coeffs[3] = { -11050, +32709, -21712 };
#if defined(_WIN32)     // // predefined macro for VS
static int CountLeadZeros(int v)
{
	if (v == 0) return sizeof(int) * 8;
	int c = 0; // c will be the number of zero bits on the right
	if ((v & 0xFFFF0000) == 0) { c += 16; v <<= 16; }
	if ((v & 0xFF000000) == 0) { c += 8; v <<= 8; }
	if ((v & 0xF0000000) == 0) { c += 4; v <<= 4; }
	if ((v & 0xc0000000) == 0) { c += 2; v <<= 2; }
	if ((v & 0x80000000) == 0) { c += 1; v <<= 1; }
	return c;
}
#elif (__GNUC__>=5)    // predefined macro for GCC armclang (__GNUC__==5)
static int inline CountLeadZeros(int v)
{
	int res;
	__asm ("CLZ %[result], %[input_v]"
	: [result] "=r" (res)
		: [input_v] "r" (v)
		);
	return res;
	//    return _arm_clz(v);    // there is no arm untrinsic for GCC :-(
}
#else    // predefined macro for ARMCC Compiler 5 (__GNUC__==4)
static int inline CountLeadZeros(int v)
{
	int res;
	__asm {
		clz res, v
	};
	return res;
	//    return __clz(v);     // or you may use this intrinsic instead of inline asm
}
#endif

static int SV_DNN_Get_CLZofMax(int* a, int n)
{
    int m = 0;
    do {
        int d = *a++;
        d = ABS(d);
        m = MAX(m, d);
    } while (--n > 0);
    return CountLeadZeros(m) - 1;
}
static inline int m10Log10(int mag2, int a, int b, int c)
{
	mag2 += 1; // avoid log(0) problem
	// shift into [0.5..1]
	int sh = CountLeadZeros(mag2) - (32 - QMAG2);
	mag2 = (sh >= 0) ? mag2 << sh : mag2 >> (-sh);
	// log (polynomial apx)
	int x2 = (int)(((INT64)mag2*mag2) >> QMAG2);
	int d = (int)((a*(INT64)x2 + b * (INT64)mag2 + (((INT64)c) << QMAG2)) >> (QDBPOWCOEFF + (QMAG2 - QOUTDATA_INTERMEDIATE)));   // QOUTDATA_INTERMEDIATE=14 format
	// shift backward correction
	d -= sh << QOUTDATA_INTERMEDIATE;
	// convert log2 -> log10
	d = (((INT64)d*LOG10INV) >> (QLOG10INV));
	// d -= 1 << QOUTDATA_INTERMEDIATE; // correction ???
	d -= LOG10FIX;
	d = (((INT64)d * 10) >> (QOUTDATA_INTERMEDIATE - QOUTDATA));
	return d;
}


static void SV_DNN_DBpow(short* out, int* in, short* coeffs, int offset, int blknrm, int n)
{
    int a=coeffs[0];
    int b=coeffs[1];
    int c=coeffs[2];
    do{
        int re=*in++;
        int im=*in++;
        // normalize (to be within 16 bit)
        re=(blknrm >=0) ? re<< blknrm : re>> (-blknrm);   // Q15
        im=(blknrm >=0) ? im<< blknrm : im>> (-blknrm);   // Q15
        // magnitude square
        int mag2=re*re+im*im;                       // Q30
#if 0
        // shift into [0.5..1]
        int sh=CountLeadZeros(mag2)-(32-QMAG2);
        mag2=(sh>=0) ? mag2<<sh : mag2>>(-sh);
        // log (polynomial apx)
        int x2=(int)(((INT64)mag2*mag2)>>QMAG2);
        int d=(int)((a*(INT64)x2+b*(INT64)mag2+(((INT64)c)<<QMAG2))>>(QDBPOWCOEFF+ (QMAG2 - QOUTDATA_INTERMEDIATE)));   // QOUTDATA_INTERMEDIATE=14 format
		// shift backward correction
		d -= sh << QOUTDATA_INTERMEDIATE;
		// convert log2 -> log10
		d = (((INT64)d*LOG10INV) >> (QLOG10INV));
		// d -= 1 << QOUTDATA_INTERMEDIATE; // correction ???
		d -= LOG10FIX;
		d = (((INT64)d*10) >> (QOUTDATA_INTERMEDIATE-QOUTDATA));
         add db offset and save
#else
		int d = m10Log10(mag2,a,b,c);
#endif
		*out++ = d - offset;
    }while(--n>0);
}

#define QDB_1BIT_OFFSET 14
#define DB_1BIT_OFFSET (98642)   // round(db(2)*2^14)
void Fx_logSpec(int *spec_in, short *logspec_out, short len)
{
	int len2 = len >> 1;

	// CLZ of max in array
	int lz = SV_DNN_Get_CLZofMax(spec_in, len);            // blknorm
	int deltaq = 15 - 16 + lz;                                // diff from Q30 (or Q14)
	int offset = (DB_1BIT_OFFSET * deltaq)>>(QDB_1BIT_OFFSET-QOUTDATA);

	int a = log2coeffs[0];
	int b = log2coeffs[1];
	int c = log2coeffs[2];
	logspec_out[0]		= (m10Log10(((ABS(spec_in[0])) << lz) >> 16, a, b, c) << 1) - offset;
	logspec_out[len2]	= (m10Log10(((ABS(spec_in[1])) << lz) >> 16, a, b, c) << 1) - offset;

	//	SV_DNN_DBpow(logspec_out, spec_in, offset, lz-16, len >> 1);
	SV_DNN_DBpow(logspec_out+1, spec_in+2, log2coeffs, offset, lz-16, (len >> 1)-1);
}


/*
Multiply Regression Gain
@param gain_in, Q15
*/
static int Fx_MultRegGain(int *spec_in, int *spec_out, short *gain_in)
{
	short i;
	short g;
	//short lenR = len / 2 + 1;

	*spec_out++ = (int)(SMULL(*spec_in++, gain_in[0]) >> 15);
	*spec_out++ = (int)(SMULL(*spec_in++, gain_in[CRN_INPUT_LEN - 1]) >> 15);

	for (i = 1; i < CRN_INPUT_LEN - 1; i++)
	{
		g = *gain_in++;
		
		*spec_out++ = (int)(SMULL(*spec_in++, g) >> 15);
		*spec_out++ = (int)(SMULL(*spec_in++, g) >> 15);
	}

	return 0;
}

static void normalizeInterChannelVariance(short *noiseLogSpec, short *compensatedLogSpec, int RxVAD) {
	int i;
	short powerDiff;

	for (i = 0; i < CRN_INPUT_LEN; i++) {
		powerDiff = interChannelVariance[i] - bias[i];
		compensatedLogSpec[i] = noiseLogSpec[i] - powerDiff;
	}
	if (RxVAD == 1)
	{
		for (i = 0; i < CRN_INPUT_LEN; i++) {
			compensatedLogSpec[i] = noiseLogSpec[i] - interChannelVarianceEC[i];
		}
	}
}


static void compensateEstimatedMask(short *estMask)
{
	int i;

#if 0
	short g;
	int log_g;
	for (i = 0; i < CRN_INPUT_LEN; i++) {
		g = estMask[i];

		// conversion to log domain (output: Q26)
		log_g = Fx_log10(g);

		// power compensation in the log domain

		log_g = log_g + (((int)Gbias[i]) << 15);

		// conversion to linear domain
		if (log_g >= 0) estMask[i] = 0x7fff;
		else estMask[i] = Fx_exp10(log_g);

		estMask[i] = (estMask[i] < Gmin[i]) ? Gmin[i] : estMask[i];
	}
#else
	int g;
	for (i = 0; i < CRN_INPUT_LEN; i++) {
		g = ((int)estMask[i] + (int)Gbias[i]);
		if (g > 32767) estMask[i] = 0x7fff;
		else if (g < Gmin[i]) estMask[i] = Gmin[i];
		else estMask[i] = (short)g;
	}
#endif
}

extern void SV_CirCRNNS_SetPar( short _TxVAD, short _RxVAD)
{
	TxVAD = _TxVAD;
	RxVAD = _RxVAD;
}

void SV_CirCRNNS_Exe(int *DNNout, int *speechBFout, int *noiseBFout)
{
	SV_CirCRNNSProc(DNNout, speechBFout, noiseBFout, RxVAD);
	SV_CirCRNNSUpdate((void*)0, TxVAD, RxVAD);
}
void SV_CirCRNNS_Deinit()
{

}