/**************************************
        SV_AEC.c
***************************************/

#include <string.h>
#include <stdbool.h>
#include "SV_AEC.h"

#define BIT_NORM_BASED_OPT
#ifdef Debug_File_Write_C
#define DEBUG_AEC
#endif

#ifdef DEBUG_AEC
#include <stdio.h>
#include<stdlib.h>
extern FILE *fp_aec, *f_aec[4];
#endif
static int aec_scratch_mem[Nfilter_AEC_Outer * 2 * 2 + Nfilter_AEC_InEar * 2] ;
static int* aec_scratch_mem_ptr = aec_scratch_mem;

#define Q_INV_N 14  // or 16

#ifdef USEOPT
extern INT64 SV_AEC_GetPower(int* a, int n);
#else
static INT64 SV_AEC_GetPower(int* a, int n)
{
    INT64 acc=0;
    do{
        int d=*a++;
        acc+=(INT64)d*d;
    }while(--n>0);
    return acc;
}
#endif

static int SV_AEC_GetSquaredRMS(int* a, int n)
{
    INT64 Pwr = SV_AEC_GetPower(a, n);
    int inv_n = (1 << Q_INV_N) / n;
    return (int)((Pwr*inv_n) >> Q_INV_N);
}

// Processing
static INT64 fx_SV_AEC_ErrPower(float Pwr_Tx_in, int* tx_AEC, int* tx_AECaux, int Nframe, SV_AEC_T* p_struct)
{
    int Start_WauxUpdate_sec = p_struct->Start_WauxUpdate_sec;
    float Thr_auxAECupdate_PwrDiff = p_struct->Thr_auxAECupdate_PwrDiff;

    int *W = p_struct->W;
    int *Waux = p_struct->Waux;
    short Flag_Waux_update = 0;

    short frmVAD_Rx_AEC = p_struct->frmVAD_Rx_AEC;

    INT64 Pwr_TX_out, Pwr_TX_aux, EminVal;

    Pwr_TX_out = SV_AEC_GetSquaredRMS(tx_AEC, Nframe);
    Pwr_TX_aux = SV_AEC_GetSquaredRMS(tx_AECaux, Nframe);

    p_struct->pwrDiff_AECinout_dB = 0.9*p_struct->pwrDiff_AECinout_dB + 0.1*(GetdB((Pwr_Tx_in + 1.0) / (Pwr_TX_out + 1.0)));

    if (p_struct->cnt_NLMSupdate < Start_WauxUpdate_sec)
    {
        memcpy(Waux, W, sizeof(int)*(p_struct->filter_len));
    }
    else
    {
        if (frmVAD_Rx_AEC)
        {
            EminVal = MIN(Pwr_TX_out, Pwr_TX_aux);

            if (Pwr_Tx_in > EminVal)
            {
                if (EminVal == Pwr_TX_out) //  W is the best
                {
                    if ((p_struct->pwrDiff_AECinout_dB > Thr_auxAECupdate_PwrDiff))
                    {
                        memcpy(Waux, W, sizeof(int)*(p_struct->filter_len));
                        Flag_Waux_update = 1;
                    }
                }
                if (EminVal == Pwr_TX_aux) //  Waux is the best
                {
                    memcpy(W, Waux, sizeof(int)*(p_struct->filter_len));
                    memcpy(tx_AEC, tx_AECaux, sizeof(int)*(Nframe));
                    Pwr_TX_out = Pwr_TX_aux;
                }
            }
        }
        else
        {
            p_struct->pwrDiff_AECinout_dB = 0;
        }
    }

    return Pwr_TX_out;
}

static int fn_GetFlagAECUpdate(int RxPwr, int TxPwr, SV_AEC_T* p_struct)
{
    bool flagTxRxDiff = ((TxPwr + 1.0) < ((float)RxPwr+ 1.0)*p_struct->Thr_AECupdate_PwrDiff) ? 1 : 0;
    bool flagRxPower = ((float)RxPwr > p_struct->Thr_AECupdate_PwrAdfin) ? 1 : 0;
    return flagRxPower && flagTxRxDiff;
}

#ifdef USEOPT
extern void SV_AEC_vector_mult_dual(int *out1, int *out2, int *A, int *B1, int *B2, int n);
#else
static void SV_AEC_vector_mult_dual(int *out1, int *out2, int *A, int *B1, int *B2, int n)
{
    INT64 acc1=0, acc2=0;
    do    {
        int indata = *A++;
        acc1 += (INT64)indata*(*B1++);
        acc2 += (INT64)indata*(*B2++);
    }while(--n>0);
    *out1 = (int)(acc1>>31);
    *out2 = (int)(acc2>>31);
}
#endif


#if defined(_WIN32)     // // predefined macro for VS
static int CountLeadZeros(int v)
{
	if (v == 0) return sizeof(int) * 8;
	int c = 0; // c will be the number of zero bits on the right
	if ((v & 0xFFFF0000) == 0) { c += 16; v <<= 16; }
	if ((v & 0xFF000000) == 0) { c += 8; v <<= 8; }
	if ((v & 0xF0000000) == 0) { c += 4; v <<= 4; }
	if ((v & 0xc0000000) == 0) { c += 2; v <<= 2; }
	if ((v & 0x80000000) == 0) { c += 1; v <<= 1; }
	return c;
}
#elif (__GNUC__>=5)    // predefined macro for GCC armclang (__GNUC__==5)
static int inline CountLeadZeros(int v)
{
	int res;
	__asm ("CLZ %[result], %[input_v]"
	: [result] "=r" (res)
		: [input_v] "r" (v)
		);
	return res;
	//    return _arm_clz(v);    // there is no arm untrinsic for GCC :-(
}
#else    // predefined macro for ARMCC Compiler 5 (__GNUC__==4)
static int inline CountLeadZeros(int v)
{
	int res;
	__asm {
		clz res, v
	};
	return res;
	//    return __clz(v);     // or you may use this intrinsic instead of inline asm
}
#endif

#ifdef USEOPT
extern void SV_AEC_fn_filterUpdate_BasedOnBitNorm(int *W, int in, int *ref, int refPwr, int stepSize, int n);

#else
static void SV_AEC_fn_filterUpdate_BasedOnBitNorm(int *W, int in, int *ref, int refPwr, int stepSize, int n)
{
	int k = n;
	INT64 adf_var;

	//adf_var = (stepSize / refPwr) * in * ref;
	//W = W + adf_var;

//	adf_var = ((INT64)stepSize * in) / refPwr;//Q.Q_W
	int lz = CountLeadZeros(refPwr - 1);                // replace CountLeadZeros() with appropriate one-cycle intrinsic or asm instruction
	adf_var = ((INT64)stepSize * in) >> (32 - lz);

	do {
		*W++ += adf_var * *(ref++); //Q.Q_W
	} while (--k > 0);

	return;
}
#endif

//#ifdef USEOPT
#if 0
extern void SV_AEC_fn_filterUpdate(int *W, int in, int *ref, int refPwr, int stepSize, int n);
#else
static void SV_AEC_fn_filterUpdate(int *W, int in, int *ref, int refPwr, int stepSize, int n)
{
	int k = n;
	INT64 adf_var;

	//adf_var = (stepSize / refPwr) * in * ref;
	//W = W + adf_var;
	adf_var = ((INT64)stepSize * in) / refPwr;//Q.Q_W

	do {
		*W++ += adf_var * *(ref++); //Q.Q_W
	} while (--k > 0);

	return;
}
#endif

#ifdef DEBUG_AEC
int ch_cnt = 0;
#endif
float SV_AEC_Exe(int* tx, int* ref, int Nframe, SV_AEC_T* p_struct)
{
    // AEC parameter & buffer
    int k;
    int Step_aec = p_struct->Step_aec;

    short flag_FrmDT = p_struct->flag_FrmDT;
    short flag_ringbakctone = p_struct->flag_ringbakctone;

    int *W = p_struct->W;
    int *Waux = p_struct->Waux;
    int *tx_AEC = p_struct->tx_AEC;
    int *tx_AECaux = p_struct->tx_AECaux;
    int *ADF_IN;
    int filteredIn;
    INT64 pwr_adfin;
    int FlagSampleADFupate;

    INT64 Pwr_TX_out;

    float Pwr_Tx_in = p_struct->Pwr_Tx_in;

    // precompute first Power value
    int inv_n=(1<<Q_INV_N)/p_struct->filter_len;
    INT64 pwr_adfin_64=SV_AEC_GetPower(ref,p_struct->filter_len-1);
    int prev_sample=0;

    bool UpdateFlag=(Pwr_Tx_in > p_struct->Thr_AECupdate_PwrTx) && (flag_ringbakctone == 0) && (flag_FrmDT == 0);

    // ADF filtering & update
    for (k = 0; k < Nframe; k++)
    {
        ADF_IN = ref + k;

        // adf filter convolution(W) and convolution(Waux)
        int fltout1, fltout2;
        SV_AEC_vector_mult_dual(&fltout1, &fltout2,
                        ADF_IN, W, Waux, p_struct->filter_len);
        tx_AEC[k] = tx[k] - fltout1;
        tx_AECaux[k] = tx[k] - fltout2;

#ifdef DEBUG_AEC
        short Stmp = 0;
        if (ch_cnt == 2)
        {
            Stmp = (short)fltout1;
            fwrite(&Stmp, sizeof(short), 1, f_aec[0]);
            Stmp = (short)tx_AEC[k];
            fwrite(&Stmp, sizeof(short), 1, f_aec[1]);
        }

        Stmp = 0;
        if (ch_cnt == 2)
        {
            Stmp = (short)fltout2;
            fwrite(&Stmp, sizeof(short), 1, f_aec[2]);
            Stmp = (short)tx_AECaux[k];
            fwrite(&Stmp, sizeof(short), 1, f_aec[3]);
        }
#endif

        // filter update
//        pwr_adfin = (int)GetPower(ADF_IN, p_struct->filter_len) + 33;
        INT64 prev2=(INT64)prev_sample*prev_sample;
        int curr_sample=ADF_IN[p_struct->filter_len-1];
        INT64 curr2=(INT64)curr_sample*curr_sample;
        pwr_adfin_64=pwr_adfin_64-prev2+curr2;
        prev_sample=*ADF_IN;
        pwr_adfin=(int)((pwr_adfin_64*inv_n)>>Q_INV_N)+33;

        if (UpdateFlag)
        {
            FlagSampleADFupate = fn_GetFlagAECUpdate(pwr_adfin, Pwr_Tx_in, p_struct);

            if (FlagSampleADFupate)
            {
                if (p_struct->cnt_NLMSupdate < 1000000)
                    p_struct->cnt_NLMSupdate++;;
#ifdef BIT_NORM_BASED_OPT
				SV_AEC_fn_filterUpdate_BasedOnBitNorm(W, tx_AEC[k], ADF_IN, pwr_adfin, Step_aec, p_struct->filter_len);
#else
				SV_AEC_fn_filterUpdate(W, tx_AEC[k], ADF_IN, pwr_adfin, Step_aec, p_struct->filter_len);
#endif
            }
        }

#ifdef Debug_File_Write_C
        //for (int i=0;i<80;i++)
        //    fprintf(fp_aec, "%d\n ", ADF_IN[i]);
        //fprintf(fp_aec, "%lld, %f\n", Step_aec / pwr_adfin, (float)Step_aec / pwr_adfin);
        //fprintf(fp_aec, "%ld\n", ADF_IN[k]);
#endif
    }

    // ErrPower based filter correction
    Pwr_TX_out = fx_SV_AEC_ErrPower(Pwr_Tx_in, tx_AEC, tx_AECaux, Nframe, p_struct);

    memcpy(tx, tx_AEC, sizeof(int)*(Nframe));

#ifdef DEBUG_AEC
    ch_cnt++;
    if (ch_cnt == 3) ch_cnt = 0;
#endif
    return GetdB(Pwr_TX_out + 1.0);

}

// return Rx settings for AEC
short AECStatus_Get_DelayRx(SV_AEC_T* p_struct){    return p_struct->DelayRx;}
short AECStatus_Get_Nfilter_AEC(SV_AEC_T* p_struct){    return p_struct->filter_len; }
short AECStatus_Get_ValClipp(SV_AEC_T* p_struct){    return p_struct->ValClipp;}

// Set parameters
extern void SV_AEC_SetPar(float Pwr_Mic1, short frmVAD_Rx_AEC, SV_AEC_T* p_struct )
{
    p_struct->Pwr_Tx_in = Pwr_Mic1;
    p_struct->frmVAD_Rx_AEC = frmVAD_Rx_AEC;

    return;
}

// Init
extern void SV_AEC_Init(int Fs, SV_AEC_T* p_struct)
{
    p_struct->filter_len = Nfilter_AEC_Outer;
    p_struct->W = aec_scratch_mem_ptr;
    aec_scratch_mem_ptr += p_struct->filter_len;
    p_struct->Waux = aec_scratch_mem_ptr;
    aec_scratch_mem_ptr += p_struct->filter_len;
    // init structure
    p_struct->DelayRx = 50;
    p_struct->ValClipp = 30000;
    p_struct->Step_aec = 4294967; //Q.Q_W
    p_struct->Start_WauxUpdate_sec = 32000;

    p_struct->Thr_AECupdate_PwrTx = 302.0803855f;// 10^((8525/(2^8))/10)-1
    p_struct->Thr_AECupdate_PwrAdfin = 4468.09632f;// 10^((9984/(2^8)+3.05)/10)-1
    p_struct->Thr_AECupdate_PwrDiff = 0.011482f;// 10^((614/(2^8)-3.02)/10)
    p_struct->Thr_auxAECupdate_PwrDiff = 3.162278f; //(1280/(2^8))

	memset(p_struct->W, 0, sizeof(int)* p_struct->filter_len);
	memset(p_struct->Waux, 0, sizeof(int)* p_struct->filter_len);

    return;
}
extern void SV_AEC_Init_for_InEar(int Fs, SV_AEC_T* p_struct)
{
    p_struct->filter_len = Nfilter_AEC_InEar;
    p_struct->W = aec_scratch_mem_ptr;
    aec_scratch_mem_ptr += p_struct->filter_len;
    p_struct->Waux = aec_scratch_mem_ptr;
    aec_scratch_mem_ptr += p_struct->filter_len;
    // init structure
    p_struct->DelayRx = 50;
    p_struct->ValClipp = 30000;
    p_struct->Step_aec = 4294967; //Q.Q_W
    p_struct->Start_WauxUpdate_sec = 32000;

    p_struct->Thr_AECupdate_PwrTx = 2137.34672094338;// 10^((8525/(2^8))/10)-1
    p_struct->Thr_AECupdate_PwrAdfin = 4468.09632;// 10^((9984/(2^8)+3.05)/10)-1
    p_struct->Thr_AECupdate_PwrDiff = 10;// 10^((614/(2^8)-3.02)/10)
    p_struct->Thr_auxAECupdate_PwrDiff = 40; //(1280/(2^8))

	memset(p_struct->W, 0, sizeof(int)* p_struct->filter_len);
	memset(p_struct->Waux, 0, sizeof(int)* p_struct->filter_len);

    return;
}

extern void  SV_AEC_Deinit()
{
}
