#include "SV_Frame.h"
#include "SV_common_include.h"
#include "SV_basic_op.h"
#include "SV_UnbiasedMMSE.h"

#if defined (_WIN32) && defined(BSH_DEBUG)
#define DEBUG_DNN_INPUT_SYNTHESIS
#endif

#ifdef DEBUG_DNN_INPUT_SYNTHESIS
#include<stdlib.h>
#include<stdio.h>
extern FILE *f_mixer[4];
extern FILE *f_echo_mix[4];
#endif



typedef long long INT64;
#define SMULL( a,   b )\
    ((INT64)(a)*(INT64)(b))

typedef struct
{
	int *SNR;

	short *OnSetCnt;
	char *InEarMask;

	int InEarFreqSize;
	int InEarLowerBound;
	int InEarUpperBound;
	int reliableFreqBound;
	int bandSize;

	int *BandSNR;

	int *echoMixGain; //Q16
	int *ResiEchoPsd;
	int *EchoRefPsd;
}MixerForDNNInput;
MixerForDNNInput InEarMixer_Vars ;
MixerForDNNInput *InEarMixer = &InEarMixer_Vars;

typedef struct
{
	int TxVAD;
	int RxVAD;
	int noiseOnset;
	int noiseOffset;
	int noiseLevel;
	int *inEarGain;
	short *refMask;

	int NoiseRefNonSpeech[SVTX_MAX_WIN_SIZE];

}TWS_DNN_Input;
static TWS_DNN_Input DNNIn;
TWS_DNN_Input *DNNInVars = &DNNIn;

int SV_DNNInputSynthesis_Init(int InEarLowerBound, int InEarUpperBound, int bandSize, char *scratch_mem_ptr)
{
	int i;
	int used_mem_size = 0, total_mem_size = 0;

	InEarMixer->bandSize = bandSize;
	InEarMixer->InEarLowerBound = InEarLowerBound;
	InEarMixer->InEarUpperBound = InEarUpperBound;
	InEarMixer->InEarFreqSize = InEarUpperBound - InEarLowerBound;
	InEarMixer->reliableFreqBound = 64; // 2KHz

	InEarMixer->SNR = (int*)scratch_mem_ptr;
	used_mem_size = bandSize * sizeof(int);
	total_mem_size += used_mem_size;
	scratch_mem_ptr += used_mem_size;

	InEarMixer->OnSetCnt = (short*)scratch_mem_ptr;
	used_mem_size = bandSize * sizeof(short);
	total_mem_size += used_mem_size;
	scratch_mem_ptr += used_mem_size;

	InEarMixer->InEarMask = (char*)scratch_mem_ptr;
	used_mem_size = bandSize * sizeof(char);
	total_mem_size += used_mem_size;
	scratch_mem_ptr += used_mem_size;

	InEarMixer->echoMixGain = (int*)scratch_mem_ptr;
	used_mem_size = InEarMixer->InEarFreqSize * sizeof(int);
	total_mem_size += used_mem_size;
	scratch_mem_ptr += used_mem_size;

	InEarMixer->ResiEchoPsd = (int*)scratch_mem_ptr;
	used_mem_size = InEarMixer->InEarFreqSize * sizeof(int);
	total_mem_size += used_mem_size;
	scratch_mem_ptr += used_mem_size;

	InEarMixer->EchoRefPsd = (int*)scratch_mem_ptr;
	used_mem_size = InEarMixer->InEarFreqSize * sizeof(int);
	total_mem_size += used_mem_size;
	scratch_mem_ptr += used_mem_size;

	for (i = 0; i < bandSize; i++)
	{
		InEarMixer->SNR[i] = (int)0;
	}

	for (i = 0; i < bandSize; i++)
	{
		InEarMixer->OnSetCnt[i] = (short)0;
	}

	for (i = 0; i < bandSize; i++)
	{
		InEarMixer->InEarMask[i] = (char)0;
	}

	for (i = 0; i < InEarMixer->InEarFreqSize; i++)
	{
		InEarMixer->echoMixGain[i] = ((int)1 << 16);
		InEarMixer->ResiEchoPsd[i] = 0;
		InEarMixer->EchoRefPsd[i] = 0;
	}

	memset(DNNInVars->NoiseRefNonSpeech, 0, sizeof(int) * SVTX_MAX_WIN_SIZE);


	return total_mem_size;
}

void SV_DNNInputSynthesis_Deinit()
{

}
void SV_DNNInputSynthesis_SetPar(int TxVAD, int RxVAD, int noiseOnset, int noiseOffset, int noiseLevel, int *inEarGain, int *BandSNR, short *refMask)
{
	DNNInVars->TxVAD = TxVAD;
	DNNInVars->RxVAD = RxVAD;
	DNNInVars->noiseOnset = noiseOnset;
	DNNInVars->noiseOffset = noiseOffset;
	DNNInVars->noiseLevel = noiseLevel;
	DNNInVars->refMask = refMask;
	DNNInVars->inEarGain = inEarGain;

	InEarMixer->BandSNR = BandSNR;
}




//return in Q(BF_Q), input in same Q value
static int SV_div_l2(int a, int b)
{
	short n, m, sa, sb, sr;
	int r;
	if (a == 0) return 0;
	if (b == 0) return 0x7fffffff;
	n = SV_norm_l(a);
	sa = SV_extract_h(SV_L_shl(a, n - 1)); //Q(n-1)
	m = SV_norm_l(b);
	sb = SV_extract_h(SV_L_shl(b, m)); //Q(m)
	sr = SV_div_s(sa, sb); //Q((n-1)-m+15)
	r = SV_L_shl((int)sr, 16 - ((n - 1) - m + 15)); //Q(BF_Q)
	return r;
}

//Input in the range [0.25,1) in Q16, output in Q14
static short SV_rsqrt_limit(int x)
{
	short n, r, r2, y;
	n = x - 32768;
	r = SV_add(23557, SV_mult(n, SV_add(-13490, SV_mult(n, 6713))));
	r2 = SV_mult(r, r);
	y = SV_shl(SV_sub(SV_add(SV_mult(r2, n), r2), 16384), 1);
	return SV_add(r, SV_mult(r, SV_mult(y, SV_sub(SV_mult(y, 12288), 16384))));
}

//Input in Q(BF_Q), output in Q(BF_Q)
static int SV_rsqrt2(int x)
{
	short tmp;
	int n;
	n = SV_norm_l(x) - 15;
	if (n & 0x1) n -= 1;
	tmp = SV_rsqrt_limit(SV_L_shl(x, n)); // Q(14 - (BF_Q - 16 + n) >> 1)
	return SV_L_shl((int)tmp, 16 + (16 >> 1) - 22 + (n >> 1));
}


static int GetInstGain(int A, int B) // A/B
{
	int Eratio, gain;
	Eratio = SV_div_l2(B, A);
	gain = SV_rsqrt2(Eratio); //Q(BF_Q)
	return gain;
}


static void MixEchoAtNoiseRef(int *outSpec, int *inSpec, int *RxSpec, int *refSpec,int RxVAD, int SingleTalk, char *inEarMixingMask)
{

	int i;
	int alpha = 26214;
	int alpha_rev = 0x7fff - alpha;

	int beta_fast = 16384;
	int beta_slow = 26214;
	int beta, beta_rev;
	INT64 RxPSD_thd = 1000;
	outSpec += InEarMixer->InEarLowerBound * 2;
	refSpec += InEarMixer->InEarLowerBound * 2;
	inSpec += InEarMixer->InEarLowerBound * 2;
	RxSpec += InEarMixer->InEarLowerBound * 2;




#ifdef DEBUG_ECHOMIX
	short mask[257];
	for (i = 0; i < InEarMixer->InEarFreqSize; i++) mask[i] = 0;
#endif

	if (SingleTalk == 1)
	{
		INT64 Etmp;
		int InstGain;
		for (i = 0; i < InEarMixer->InEarFreqSize; i++)
		{
			int real = 2 * i, imag = real + 1;

			Etmp = SMULL(RxSpec[real], RxSpec[real]) + SMULL(RxSpec[imag], RxSpec[imag]);
			InEarMixer->EchoRefPsd[i] = (int)((SMULL(InEarMixer->EchoRefPsd[i], alpha) + SMULL((int)(Etmp >> 15), alpha_rev)) >> 15);

			if (InEarMixer->EchoRefPsd[i] < RxPSD_thd)
			{
				continue;
			}
#ifdef DEBUG_ECHOMIX
			mask[i] = 1;
#endif

			Etmp = SMULL(refSpec[real], refSpec[real]) + SMULL(refSpec[imag], refSpec[imag]);
			InEarMixer->ResiEchoPsd[i] = (int)((SMULL(InEarMixer->ResiEchoPsd[i], alpha) + SMULL((int)(Etmp >> 15), alpha_rev)) >> 15);

			InstGain = GetInstGain(InEarMixer->ResiEchoPsd[i], InEarMixer->EchoRefPsd[i]);
			if (InstGain > InEarMixer->echoMixGain[i])
			{
				beta = beta_fast;
				beta_rev = 0x7fff - beta;
			}
			else
			{
				beta = beta_slow;
				beta_rev = 0x7fff - beta;
			}
			InEarMixer->echoMixGain[i] = (int)((SMULL(InEarMixer->echoMixGain[i], beta) + SMULL(InstGain, beta_rev)) >> 15);

		}
	}


	if (RxVAD == 1)
	{
		for (i = 0; i < InEarMixer->InEarFreqSize; i++)
		{
			int real = 2 * i, imag = real + 1;
			int BandIdex = SV_UnbiasedMMSE_Get_BandIndex(i);
			if (inEarMixingMask[BandIdex] == 1) {
				outSpec[real] = inSpec[real] + (int)(SMULL(InEarMixer->echoMixGain[i], RxSpec[real]) >> 16);
				outSpec[imag] = inSpec[imag] + (int)(SMULL(InEarMixer->echoMixGain[i], RxSpec[imag]) >> 16);
			}
		}
	}

#ifdef DEBUG_ECHOMIX
	for (i = 0; i < InEarMixer->InEarFreqSize; i++)
	{
		float ftmp;
		ftmp = InEarMixer->ResiEchoPsd[i] / (float)((int)1 << 15);
		fwrite(&ftmp, sizeof(float), 1, f_echo_mix[0]);
		ftmp = InEarMixer->EchoRefPsd[i] / (float)((int)1 << 15);
		fwrite(&ftmp, sizeof(float), 1, f_echo_mix[1]);
		ftmp = (float)mask[i];
		fwrite(&ftmp, sizeof(float), 1, f_echo_mix[2]);
		ftmp = InEarMixer->echoMixGain[i] / (float)((int)1 << 16);
		fwrite(&ftmp, sizeof(float), 1, f_echo_mix[3]);
	}
#endif
}

// older version: SV_TWS_GEVBF_CompensateBM at  SV_TWS_GEVBF.c
static void RemoveSpeechLeakage(int *BMSpec, short *refMask, int TxVAD, int RxVAD, int noiseLevel)
{
	int i;
	INT64 E_NonSpeech;
	INT64 E_BM;
	int real, imag;


	if (noiseLevel == 0) {
		for (i = 1; i < 255; i++)
		{
			real = 2 * i;
			imag = real + 1;

			E_NonSpeech = SMULL(DNNInVars->NoiseRefNonSpeech[real], DNNInVars->NoiseRefNonSpeech[real]) + SMULL(DNNInVars->NoiseRefNonSpeech[imag], DNNInVars->NoiseRefNonSpeech[imag]);
			E_BM = SMULL(BMSpec[real], BMSpec[real]) + SMULL(BMSpec[imag], BMSpec[imag]);
			if (refMask[i] > 9830 && E_BM > E_NonSpeech)
			{
				BMSpec[real] = DNNInVars->NoiseRefNonSpeech[real];
				BMSpec[imag] = DNNInVars->NoiseRefNonSpeech[imag];
			}
		}
	}
	else {
		// noise level > 0
		if (TxVAD != 0 && RxVAD == 0)
		{

			for (i = 128; i < 255; i++)
			{
				real = 2 * i;
				imag = real + 1;

				E_NonSpeech = SMULL(DNNInVars->NoiseRefNonSpeech[real], DNNInVars->NoiseRefNonSpeech[real]) + SMULL(DNNInVars->NoiseRefNonSpeech[imag], DNNInVars->NoiseRefNonSpeech[imag]);
				E_BM = SMULL(BMSpec[real], BMSpec[real]) + SMULL(BMSpec[imag], BMSpec[imag]);
				if (refMask[i] > 20000 && E_BM > E_NonSpeech)
				{
					BMSpec[real] = DNNInVars->NoiseRefNonSpeech[real];
					BMSpec[imag] = DNNInVars->NoiseRefNonSpeech[imag];
				}
			}
		}
	}


	if (RxVAD == 0)
	{
		if (TxVAD == 0)
		{
			for (i = 1; i < 255; i++)
			{
				real = 2 * i;
				imag = real + 1;

				DNNInVars->NoiseRefNonSpeech[real] = BMSpec[real];
				DNNInVars->NoiseRefNonSpeech[imag] = BMSpec[imag];
			}
		}
		else
		{
			for (i = 1; i < 255; i++)
			{
				real = 2 * i;
				imag = real + 1;

				if (refMask[i] < (32767 >> 1))
				{
					DNNInVars->NoiseRefNonSpeech[real] = BMSpec[real];
					DNNInVars->NoiseRefNonSpeech[imag] = BMSpec[imag];
				}
			}
		}
	}

}

//Q12
const static int SNRThd[9] = {
	409600,  //20dB
	409600,  //20dB
	409600,  //20dB
	129527,  //15dB
	40960,  //10dB
	12953,  //5dB
	8172,   //3dB
	4096,   //0dB
	4096,   //0dB
};
static void MixBFOut(
	int *DNNin1, int *DNNin2,
	int *BFout, int *BMout,
	int *InEarout, int *InEarBMout,
	int *InEarGain,
	int TxVAD,
	int RxVAD,
	int noiseOnset,
	int noiseOffset,
	int noiseLevel)
{
	int i;
	int alpha, alpha_rev;
	int highestMixingBand = 0;
	int InEarfreqIdexBias = 1;
	//int GainThd = 1165413;  // 25dB (Q16)
	int GainThd = 2072430;  // 30dB (Q16)
	//int GainThd = 3685360;  // 35dB (Q16)

	int *BandSNR = InEarMixer->BandSNR;

	if (TxVAD != 0 && RxVAD == 0)
	{
		alpha = 27853; // 0.7
		alpha_rev = 32767 - alpha;

		for (i = 0; i < InEarMixer->bandSize; i++) {
			InEarMixer->SNR[i] = (int)((SMULL(alpha, InEarMixer->SNR[i]) + SMULL(alpha_rev, BandSNR[i])) >> 15);
			//InEarMixer->SNR[i] = BandSNR[i];
		}
	}

	// state machine for mixer
	for (i = 0; i < InEarMixer->bandSize; i++) {
		if (InEarMixer->InEarMask[i] == 0)
		{
			if ((InEarMixer->SNR[i] < SNRThd[i] && noiseLevel > 0 && TxVAD != 0)
				|| noiseLevel > 1)
			{
				InEarMixer->OnSetCnt[i]++;
			}
			else
			{
				InEarMixer->OnSetCnt[i] = 0;
			}

			if (InEarMixer->OnSetCnt[i] > 10)
			{
				InEarMixer->InEarMask[i] = 1;  // 0 -> 1
				InEarMixer->OnSetCnt[i] = 0;
			}
		}
		else if (InEarMixer->InEarMask[i] == 1)
		{
			//if (InEarMixer->SNR[i] > SNRThd[i] && noiseLevel < 1 && TxVAD != 0)
			if (noiseLevel == 0)
			{
				InEarMixer->OnSetCnt[i] += 2;
			}
			else if (InEarMixer->SNR[i] > SNRThd[i] && noiseLevel <= 1 && TxVAD != 0)
			{
				InEarMixer->OnSetCnt[i]++;
			}
			else
			{
				InEarMixer->OnSetCnt[i] = 0;
			}

			if (InEarMixer->OnSetCnt[i] > 20)
			{
				InEarMixer->InEarMask[i] = 0; // 1 -> 0
				InEarMixer->OnSetCnt[i] = 0;
			}
		}
	}


	// Á¦ÀÏ ³ôÀº band ¾Æ·¡·Î padding
	for (i = InEarMixer->bandSize - 1; i >= 0; i--) {
		if (InEarMixer->InEarMask[i] == 1)
		{
			highestMixingBand = i;
			break;
		}
	}
	for (i = 0; i < highestMixingBand; i++) {
		InEarMixer->InEarMask[i] = 1;
	}


	// Calculate reliable frequency range of in-ear mic based on gain value
	for (i = 32; i < InEarMixer->InEarUpperBound; i++) {
		if (InEarGain[i] > GainThd) break;
	}
	InEarMixer->reliableFreqBound = i + InEarfreqIdexBias;

	// Mixing
	for (i = 0; i < InEarMixer->InEarLowerBound; i++) {
		int real = 2 * i;
		int imag = real + 1;
		DNNin1[real] = BFout[real];
		DNNin1[imag] = BFout[imag];
		DNNin2[real] = BMout[real];
		DNNin2[imag] = BMout[imag];
	}

	//if (RxVAD == 0)
	if (1)
	{
		for (; i < InEarMixer->reliableFreqBound; i++) {
			int real = 2 * i;
			int imag = real + 1;
			int BandIdex = SV_UnbiasedMMSE_Get_BandIndex(i);

			if (InEarMixer->InEarMask[BandIdex] == 1)
			{
				DNNin1[real] = InEarout[real];
				DNNin1[imag] = InEarout[imag];
				DNNin2[real] = InEarBMout[real];
				DNNin2[imag] = InEarBMout[imag];
			}
			else
			{
				DNNin1[real] = BFout[real];
				DNNin1[imag] = BFout[imag];
				DNNin2[real] = BMout[real];
				DNNin2[imag] = BMout[imag];
			}
		}
	}
	else
	{
		for (; i < InEarMixer->reliableFreqBound; i++) {
			int real = 2 * i;
			int imag = real + 1;
			DNNin1[real] = BFout[real];
			DNNin1[imag] = BFout[imag];
			DNNin2[real] = BMout[real];
			DNNin2[imag] = BMout[imag];
		}
	}

	for (; i < 256; i++) {
		int real = 2 * i;
		int imag = real + 1;
		DNNin1[real] = BFout[real];
		DNNin1[imag] = BFout[imag];
		DNNin2[real] = BMout[real];
		DNNin2[imag] = BMout[imag];
	}

#ifdef DEBUG_DNN_INPUT_SYNTHESIS
	float ftmp;

	for (i = 0; i < InEarMixer->InEarFreqSize; i++)
	{
		float ftmp;
		int real = BFout[2 * i];
		int imag = BFout[2 * i + 1];
		int BandIdex = SV_UnbiasedMMSE_Get_BandIndex(i);

		ftmp = ((float)real*real + (float)imag*imag) / (float)((INT64)1 << 30);
		fwrite(&ftmp, sizeof(float), 1, f_mixer[0]);


		ftmp = InEarMixer->SNR[BandIdex] / (float)((INT64)1 << 12); //(SNR_Q)
		fwrite(&ftmp, sizeof(float), 1, f_mixer[1]);

		ftmp = InEarMixer->InEarMask[BandIdex];
		fwrite(&ftmp, sizeof(float), 1, f_mixer[2]);
	}

	ftmp = (float)InEarMixer->reliableFreqBound;
	fwrite(&ftmp, sizeof(float), 1, f_mixer[3]);
#endif
}


// DNN input spectrum »ý¼º
// ÀÔ·Â BF Ãâ·Â, inEar ÆÄ¿ö º¸»óµÈ °Í, Rx
void SV_DNNInputSynthesis_Exe(int *SpeechRefSpec, int *NoiseRefSpec, int *BFSpec, int *BMSpec, int *inEarBFSpec, int *inEarBMSpec, int *RxSpec)
{

	MixBFOut(SpeechRefSpec, NoiseRefSpec,
		BFSpec, BMSpec,
		inEarBFSpec, inEarBMSpec,
		DNNInVars->inEarGain,
		DNNInVars->TxVAD,
		DNNInVars->RxVAD,
		DNNInVars->noiseOnset,
		DNNInVars->noiseOffset,
		DNNInVars->noiseLevel);

	RemoveSpeechLeakage(NoiseRefSpec, 
		DNNInVars->refMask,
		DNNInVars->TxVAD, 
		DNNInVars->RxVAD,
		DNNInVars->noiseLevel);

	MixEchoAtNoiseRef(NoiseRefSpec,
		NoiseRefSpec,
		RxSpec,
		inEarBFSpec,
		DNNInVars->RxVAD,
		DNNInVars->RxVAD && (DNNInVars->TxVAD == 0),
		InEarMixer->InEarMask);
}