
/*
 * =====================================================================================
 *
 *       Filename:  process_cmd.c
 *
 *    Description:  PEBBLE process command
 *
 *        Version:  1.0
 *        Created:  06/02/2020
 *       Revision:  none
 *       Compiler:  gcc
 *
 *        Company:  Samsung Electronics
 *        Copyright (c) 2020 by Samsung Electronics, All rights reserved.
 *
 * =====================================================================================
 */

/** Includes */
#include "process_cmd.h"

#define LOOKUP_SIZE 4096

float *netta_truth;
int netnum = 0;
int debug_summary_com = 0;
int debug_summary_pass = 0;
int norm_output = 1;

void summary_array(char *print_name, float *arr, int n)
{

    float sum=0, min, max, idxzero=0;

    for(int i=0; i<n; i++)
    {
        sum = sum + arr[i];
        if (i == 0){
            min = arr[i];
            max = arr[i];
        }
        if (arr[i] < min){
            min = arr[i];
        }
        if (arr[i] > max){
            max = arr[i];
        }
        if (arr[i] == 0){
           idxzero++;
        }
    }

    float mean=0;
    mean = sum / n;

    char mean_char[20];
    char min_char[20];
    char max_char[20];
    char idxzero_char[20];
    ftoa(mean, mean_char, 5);
    ftoa(min, min_char, 5);
    ftoa(max, max_char, 5);
    ftoa(idxzero, idxzero_char, 5);

    PEBBLE_LOG("%s || mean = %s; min=%s; max=%s; number of zeros=%s \n", print_name, mean_char, min_char, max_char, idxzero_char);
}


static pebble_return_code_t make_netowork_TA_params(tci_message_t *cmd,
                                       tci_message_t *resp)
{
#if 0
  uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INPUT,
                                             TEE_PARAM_TYPE_MEMREF_INPUT,
                                             TEE_PARAM_TYPE_NONE,
                                             TEE_PARAM_TYPE_NONE );

  //DMSG("has been called");
  if (param_types != exp_param_types)
  return TEE_ERROR_BAD_PARAMETERS;
#endif
    int *params0 = cmd->payload.make_network_cmd.passint;
    float *params1 = cmd->payload.make_network_cmd.passfloat;

    int n = params0[0];
    int time_steps = params0[1];
    int notruth = params0[2];
    int batch = params0[3];
    int subdivisions = params0[4];
    int random = params0[5];
    int adam = params0[6];
    int h = params0[7];
    int w = params0[8];
    int c = params0[9];
    int inputs = params0[10];
    int max_crop = params0[11];
    int min_crop = params0[12];
    int center = params0[13];
    int burn_in = params0[14];
    int max_batches = params0[15];

    float learning_rate = params1[0];
    float momentum = params1[1];
    float decay = params1[2];
    float B1 = params1[3];
    float B2 = params1[4];
    float eps = params1[5];
    float max_ratio = params1[6];
    float min_ratio = params1[7];
    float clip = params1[8];
    float angle = params1[9];
    float aspect = params1[10];
    float saturation = params1[11];
    float exposure = params1[12];
    float hue = params1[13];
    float power = params1[14];

    make_network_TA(n, learning_rate, momentum, decay, time_steps, notruth, batch, subdivisions, random, adam, B1, B2, eps, h, w, c, inputs, max_crop, min_crop, max_ratio, min_ratio, center, clip, angle, aspect, saturation, exposure, hue, burn_in, power, max_batches);

    return PEBBLE_STATUS_SUCCESS;
}

static pebble_return_code_t update_net_agrv_TA_params(tci_message_t *cmd,
                                       tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_VALUE_INPUT,
                                               TEE_PARAM_TYPE_MEMREF_INOUT,
                                               TEE_PARAM_TYPE_NONE,
                       TEE_PARAM_TYPE_NONE);

    //DMSG("has been called");
    if (param_types != w
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    /*TODO: input/out shared buffer */
    //condition not used?
    netta.workspace = (float *)cmd->payload.workspace_network_cmd.workspace.buf;

    return PEBBLE_STATUS_SUCCESS;
}


static pebble_return_code_t make_convolutional_layer_TA_params(tci_message_t *cmd,
                                       tci_message_t *resp)
{
#if 0
  uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INPUT,
                                             TEE_PARAM_TYPE_VALUE_INPUT,
                                             TEE_PARAM_TYPE_MEMREF_INPUT,
                                             TEE_PARAM_TYPE_NONE);

  //DMSG("has been called");
  if (param_types != exp_param_types)
  return TEE_ERROR_BAD_PARAMETERS;

#endif
    int *params0 = cmd->payload.make_conv_cmd.passint;
    float params1 = cmd->payload.make_conv_cmd.passflo;
    char *params2 = (char *)cmd->payload.make_conv_cmd.acti.buf;

    int batch = params0[0];
    int h = params0[1];
    int w = params0[2];
    int c = params0[3];
    int n = params0[4];
    int groups = params0[5];
    int size = params0[6];
    int stride = params0[7];
    int padding = params0[8];
    int batch_normalize = params0[9];
    int binary = params0[10];
    int xnor = params0[11];
    int adam = params0[12];
    int flipped = params0[13];
    float dot = params1;
    char *acti = params2;

    ACTIVATION_TA activation = get_activation_TA(acti);

    layer_TA lta = make_convolutional_layer_TA_new(batch, h, w, c, n, groups, size, stride, padding, activation, batch_normalize, binary, xnor, adam, flipped, dot);
    netta.layers[netnum] = lta;
    if (lta.workspace_size > netta.workspace_size) netta.workspace_size = lta.workspace_size;
    netnum++;

    return PEBBLE_STATUS_SUCCESS;
}

static pebble_return_code_t make_maxpool_layer_TA_params(tci_message_t *cmd,
                                       tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);

    //DMSG("has been called");
    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    int *params0 = cmd->payload.make_max_cmd.passint;


    int batch = params0[0];
    int h = params0[1];
    int w = params0[2];
    int c = params0[3];
    int size = params0[4];
    int stride = params0[5];
    int padding = params0[6];

    layer_TA lta = make_maxpool_layer_TA(batch, h, w, c, size, stride, padding);
    netta.layers[netnum] = lta;
    netnum++;

    return PEBBLE_STATUS_SUCCESS;
}

static pebble_return_code_t make_dropout_layer_TA_params(tci_message_t *cmd,
                                       tci_message_t *resp)
{
#if 0
  uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INPUT,
                                             TEE_PARAM_TYPE_MEMREF_INPUT,
                                             TEE_PARAM_TYPE_MEMREF_INPUT,
                                             TEE_PARAM_TYPE_MEMREF_INPUT);

  //DMSG("has been called");
  if (param_types != exp_param_types)
  return TEE_ERROR_BAD_PARAMETERS;

#endif
    int *params0 = cmd->payload.make_drop_cmd.passint;
    float *params1 = cmd->payload.make_drop_cmd.passfloat;
    float *params2 = (float *)cmd->payload.make_drop_cmd.net_prev_output.buf;
    float *params3 = (float *)cmd->payload.make_drop_cmd.net_prev_delta.buf;
    int buffersize = cmd->payload.make_drop_cmd.net_prev_output.len / sizeof(float);

    int *passint;
    passint = params0;
    int batch = passint[0];
    int inputs = passint[1];
    int w = passint[2];
    int h = passint[3];
    int c = passint[4];
    float probability = params1[0];

    float *net_prev_output = params2;
    float *net_prev_delta = params3;

    layer_TA lta = make_dropout_layer_TA_new(batch, inputs, probability, w, h, c, netnum);

    if(netnum == 0){
      for(int z=0; z<buffersize; z++){
        lta.output[z] = net_prev_output[z];
        lta.delta[z] = net_prev_delta[z];
      }
    }else{
        lta.output = netta.layers[netnum-1].output;
        lta.delta = netta.layers[netnum-1].delta;
    }

    netta.layers[netnum] = lta;
    netnum++;

    return PEBBLE_STATUS_SUCCESS;
}


static pebble_return_code_t make_connected_layer_TA_params(tci_message_t *cmd,
                                       tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);

    //DMSG("has been called");

    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    int *passarg;
    passarg = cmd->payload.make_connected_cmd.passarg;
    int batch = passarg[0];
    int inputs = passarg[1];
    int outputs = passarg[2];
    int batch_normalize = passarg[3];
    int adam = passarg[4];

    char *acti;
    acti = (char *)cmd->payload.make_connected_cmd.actv.buf;
    ACTIVATION_TA activation = get_activation_TA(acti);

    layer_TA lta = make_connected_layer_TA_new(batch, inputs, outputs, activation, batch_normalize, adam);
    netta.layers[netnum] = lta;
    netnum++;

    return PEBBLE_STATUS_SUCCESS;
}

static pebble_return_code_t make_softmax_layer_TA_params(tci_message_t *cmd,
                                       tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_VALUE_INPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);

    //DMSG("has been called");

    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;
#endif
    int *params0 = cmd->payload.make_softmax_cmd.passint;
    int batch = params0[0];
    int inputs = params0[1];
    int groups = params0[2];
    int w = params0[3];
    int h = params0[4];
    int c = params0[5];
    int spatial = params0[6];
    int noloss = params0[7];
    float temperature = cmd->payload.make_softmax_cmd.passflo;

    layer_TA lta = make_softmax_layer_TA_new(batch, inputs, groups, temperature, w, h, c, spatial, noloss);
    netta.layers[netnum] = lta;
    netnum++;

    return PEBBLE_STATUS_SUCCESS;
}

static pebble_return_code_t make_cost_layer_TA_params(tci_message_t *cmd,
                                       tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_NONE);

    //DMSG("has been called");

    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    int *params0 = cmd->payload.make_cost_cmd.passint;
    int batch = params0[0];
    int inputs = params0[1];

    float *params1 = cmd->payload.make_cost_cmd.passflo;
    float scale = params1[0];
    float ratio = params1[1];
    float noobject_scale = params1[2];
    float thresh = params1[3];

    char *cost_t;
    cost_t = (char *)cmd->payload.make_cost_cmd.passcost.buf;
    COST_TYPE_TA cost_type = get_cost_type_TA(cost_t);


    layer_TA lta = make_cost_layer_TA_new(batch, inputs, cost_type, scale, ratio, noobject_scale, thresh);
    netta.layers[netnum] = lta;
    netnum++;

    // allocate net.truth when the cost layer inside TEE
    netta_truth = OPENSSL_malloc(inputs * batch * sizeof(float));
    //free(netta_truth) needed

    return PEBBLE_STATUS_SUCCESS;
}


static pebble_return_code_t transfer_weights_TA_params(tci_message_t *cmd,
                                             tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_VALUE_INPUT,
                                               TEE_PARAM_TYPE_NONE);

    //DMSG("has been called");

    if (param_types != exp_param_types)
        return TEE_ERROR_BAD_PARAMETERS;

#endif
    float *vec = (float *)cmd->payload.trans_wei_cmd.vec.buf;

    int *params1 = cmd->payload.trans_wei_cmd.passint;
    int length = params1[0];
    int layer_i = params1[1];
    int additional = params1[2];

    char type = cmd->payload.trans_wei_cmd.type;

    load_weights_TA(vec, length, layer_i, type, additional);

    return PEBBLE_STATUS_SUCCESS;
}

static pebble_return_code_t save_weights_TA_params(tci_message_t *cmd,
                                             tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES(TEE_PARAM_TYPE_MEMREF_OUTPUT,
                                               TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_VALUE_INPUT,
                                               TEE_PARAM_TYPE_NONE);

    //DMSG("has been called");

    if (param_types != exp_param_types)
        return TEE_ERROR_BAD_PARAMETERS;

#endif
    float *vec = (float *)resp->payload.save_wei_resp.weights_back.buf;

    int *params1 = cmd->payload.save_wei_cmd.passint;
    int length = params1[0];
    int layer_i = params1[1];

    char type = cmd->payload.save_wei_cmd.type;

    float *weights_encrypted = OPENSSL_malloc(sizeof(float)*length);
    save_weights_TA(weights_encrypted, length, layer_i, type);

    for(int z=0; z<length; z++){
        vec[z] = weights_encrypted[z];
    }
    resp->payload.save_wei_resp.weights_back.len = sizeof(float) * length;

    OPENSSL_free(weights_encrypted);
    return PEBBLE_STATUS_SUCCESS;
}



static pebble_return_code_t forward_network_TA_params(tci_message_t *cmd,
                                          tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_VALUE_INPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);
    //TEE_PARAM_TYPE_VALUE_INPUT

    //DMSG("has been called");

    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    float *net_input = (float *)cmd->payload.forward_cmd.net_input.buf;
    int net_train = cmd->payload.forward_cmd.net_train;

    netta.input = net_input;
    netta.train = net_train;

    if(debug_summary_com == 1){
        summary_array("forward_network / net.input", netta.input, cmd->payload.forward_cmd.net_input.len / sizeof(float));
    }
    forward_network_TA();

    return PEBBLE_STATUS_SUCCESS;
}

//
// static pebble_return_code_t forward_network_TA_params(tci_message_t *cmd,
//                                           tci_message_t *resp)
// {
//     uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_INPUT,
//                                                TEE_PARAM_TYPE_VALUE_INPUT,
//                                                TEE_PARAM_TYPE_NONE,
//                                                TEE_PARAM_TYPE_NONE);
//     //TEE_PARAM_TYPE_VALUE_INPUT
//
//     //DMSG("has been called");
//
//     if (param_types != exp_param_types)
//     return TEE_ERROR_BAD_PARAMETERS;
//
//     float *net_input = params[0].memref.buffer;
//     int net_train = params[1].value.a;
//
//     netta.input = net_input;
//     netta.train = net_train;
//
//     forward_network_TA();
//
//     return PEBBLE_STATUS_SUCCESS;
// }


static pebble_return_code_t forward_network_back_TA_params(tci_message_t *cmd,
                                           tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_OUTPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);
    if (param_types != exp_param_types)
        return TEE_ERROR_BAD_PARAMETERS;

#endif
    float *params0 = (float *)resp->payload.forward_back_resp.net_input_back.buf;
    int buffersize = cmd->payload.forward_back_cmd.net_input_back.len / sizeof(float);
    for(int z=0; z<buffersize; z++){
        params0[z] = netta.layers[netta.n-1].output[z];
    }
    resp->payload.forward_back_resp.net_input_back.len = sizeof(float) * buffersize;

    // ?????
    //OPENSSL_free(ta_net_input);
    if(debug_summary_com == 1){
        summary_array("forward_network_back / l_pp2.output", netta.layers[netta.n-1].output, buffersize);
    }
    return PEBBLE_STATUS_SUCCESS;
}


//
// static pebble_return_code_t backward_network_TA_params(tci_message_t *cmd,
//                                            tci_message_t *resp)
// {
//     uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_OUTPUT,
//                                                TEE_PARAM_TYPE_MEMREF_OUTPUT,
//                                                TEE_PARAM_TYPE_NONE,
//                                                TEE_PARAM_TYPE_NONE);
//     if (param_types != exp_param_types)
//         return TEE_ERROR_BAD_PARAMETERS;
//     //float *ltaoutput_diff = diff_private(lta.output, lta.outputs*lta.batch, 4.0f, 4.0f);
//     //float *ltadelta_diff = diff_private(lta.delta, lta.outputs*lta.batch, 4.0f, 4.0f);
//     //IMSG("diff");
//
//
//     float *params0 = params[0].memref.buffer;
//     float *params1 = params[1].memref.buffer;
//     float *buffersize = params[0].memref.size / sizeof(float);
//     for(int z=0; z<buffersize; z++){
//         params0[z] = ta_net_input[z];
//         params1[z] = ta_net_delta[z];
//     }
//
//     //free(ltaoutput_diff);
//     //free(ltadelta_diff);
//     return PEBBLE_STATUS_SUCCESS;
// }



static pebble_return_code_t backward_network_TA_params(tci_message_t *cmd,
                                           tci_message_t *resp)
{

#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_VALUE_INPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);
    //TEE_PARAM_TYPE_VALUE_INPUT

    //DMSG("has been called");

    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    float *params0 = (float *)cmd->payload.backward_cmd.net_input.buf;
    //float *params1 = params[1].memref.buffer;
    int net_train = cmd->payload.backward_cmd.net_train;

    netta.train = net_train;

    if(debug_summary_com == 1){
        summary_array("backward_network / l_pp1.output", params0, cmd->payload.backward_cmd.net_input.len / sizeof(float));
        //summary_array("backward_network / l_pp1.delta", params1, params[1].memref.size / sizeof(float));
    }
    //backward_network_TA(params0, params1); //zeros, removing
    backward_network_TA(params0);

    return PEBBLE_STATUS_SUCCESS;
}



static pebble_return_code_t backward_network_TA_addidion_params(tci_message_t *cmd,
                                           tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_OUTPUT,
                                               TEE_PARAM_TYPE_MEMREF_OUTPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);
    if (param_types != exp_param_types)
        return TEE_ERROR_BAD_PARAMETERS;
#endif
    //float *ltaoutput_diff = diff_private(lta.output, lta.outputs*lta.batch, 4.0f, 4.0f);
    //float *ltadelta_diff = diff_private(lta.delta, lta.outputs*lta.batch, 4.0f, 4.0f);
    //IMSG("diff");


    float *params0 = (float *)resp->payload.backward_add_resp.net_input_back.buf;
    float *params1 = (float *)resp->payload.backward_add_resp.net_delta_back.buf;
    int buffersize = cmd->payload.backward_add_cmd.net_input_back.len / sizeof(float);  // from host
    for(int z=0; z<buffersize; z++){
        params0[z] = ta_net_input[z];
        params1[z] = ta_net_delta[z];
    }
    resp->payload.backward_add_resp.net_input_back.len = buffersize * sizeof(float);
    resp->payload.backward_add_resp.net_delta_back.len = buffersize * sizeof(float);;
    //free(ta_net_input);
    //free(ta_net_delta);
    //free(ltaoutput_diff);
    //free(ltadelta_diff);

    if(debug_summary_com == 1){
        summary_array("backward_network_addidion / l_pp1.output", ta_net_input, buffersize);
        summary_array("backward_network_addidion / l_pp1.delta", ta_net_delta, buffersize);
    }
    return PEBBLE_STATUS_SUCCESS;
}


static pebble_return_code_t backward_network_back_TA_params(tci_message_t *cmd,
                                           tci_message_t *resp)
{

#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);
    //TEE_PARAM_TYPE_VALUE_INPUT

    //DMSG("has been called");

    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    float *params0 = (float *)cmd->payload.backward_back_cmd.net_input.buf;
    float *params1 = (float *)cmd->payload.backward_back_cmd.net_delta.buf;
    int buffersize = cmd->payload.backward_back_cmd.net_input.len / sizeof(float);

    for(int z=0; z<buffersize; z++){
        netta.layers[netta.n - 1].output[z] = params0[z];
        netta.layers[netta.n - 1].delta[z] = params1[z];
    }

    if(debug_summary_com == 1){
        summary_array("backward_network_back / l_pp2.output", netta.layers[netta.n - 1].output, buffersize);
        summary_array("backward_network_back / l_pp2.delta", netta.layers[netta.n - 1].delta, buffersize);
    }

    return PEBBLE_STATUS_SUCCESS;
}



static pebble_return_code_t backward_network_back_TA_addidion_params(tci_message_t *cmd,
                                           tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_OUTPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);
    if (param_types != exp_param_types)
        return TEE_ERROR_BAD_PARAMETERS;

#endif
    float *params0 = (float *)resp->payload.backward_back_add_resp.net_input_back.buf;
    //float *params1 = params[1].memref.buffer;
    int buffersize = cmd->payload.backward_back_add_cmd.net_input_back.len / sizeof(float);

    for(int z=0; z<buffersize; z++){
        params0[z] = netta.layers[netta.n - 1].output[z];
        //params1[z] = netta.layers[netta.n - 1].delta[z]; zeros, removing
    }
    cmd->payload.backward_back_add_resp.net_input_back.len = sizeof(float) * buffersize;

    if(debug_summary_com == 1){
        summary_array("backward_network_back_addidion / l_pp2.output", netta.layers[netta.n - 1].output, buffersize);
        //summary_array("backward_network_back_addidion / l_pp2.delta", netta.layers[netta.n - 1].delta, buffersize);
    }
    return PEBBLE_STATUS_SUCCESS;
}
//
// static pebble_return_code_t backward_network_back_TA_params(tci_message_t *cmd,
//                                            tci_message_t *resp)
// {
//
//
//     uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_INPUT,
//                                                TEE_PARAM_TYPE_MEMREF_INPUT,
//                                                TEE_PARAM_TYPE_VALUE_INPUT,
//                                                TEE_PARAM_TYPE_NONE);
//     //TEE_PARAM_TYPE_VALUE_INPUT
//
//     //DMSG("has been called");
//
//     if (param_types != exp_param_types)
//     return TEE_ERROR_BAD_PARAMETERS;
//
//     float *ca_net_input = params[0].memref.buffer;
//     float *ca_net_delta = params[1].memref.buffer;
//     int net_train = params[2].value.a;
//
//     netta.train = net_train;
//
//     backward_network_TA(ca_net_input, ca_net_delta);
//
//     return PEBBLE_STATUS_SUCCESS;
// }

static pebble_return_code_t update_network_TA_params(tci_message_t *cmd,
                                         tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);
    //TEE_PARAM_TYPE_VALUE_INPUT

    //DMSG("has been called");

    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    int *params0 = cmd->payload.update_cmd.passint;
    float *params1 = cmd->payload.update_cmd.passflo;

    update_args_TA a;
    a.batch = params0[0];
    a.adam = params0[1];
    a.t = params0[2];
    a.learning_rate = params1[0];
    a.momentum = params1[1];
    a.decay = params1[2];
    a.B1 = params1[3];
    a.B2 = params1[4];
    a.eps = params1[5];

    update_network_TA(a);
    //mdbg_check(1); //optee api to debug allocated buffers

    return PEBBLE_STATUS_SUCCESS;
}

static pebble_return_code_t net_truth_TA_params(tci_message_t *cmd,
                                         tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_INPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);
    //TEE_PARAM_TYPE_VALUE_INPUT

    //DMSG("has been called");

    if (param_types != exp_param_types)
    return TEE_ERROR_BAD_PARAMETERS;

#endif
    int size_truth = cmd->payload.net_truth_cmd.net_truth.len;
    float *params0 = (float *)cmd->payload.net_truth_cmd.net_truth.buf;

    for(int z=0; z<size_truth/sizeof(float); z++){
        netta_truth[z] = params0[z];
    }
    netta.truth = netta_truth;

    return PEBBLE_STATUS_SUCCESS;
}

static pebble_return_code_t calc_network_loss_TA_params(tci_message_t *cmd,
                                         tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_INPUT,
                                             TEE_PARAM_TYPE_NONE,
                                             TEE_PARAM_TYPE_NONE,
                                             TEE_PARAM_TYPE_NONE);

#endif
    int *params0 = cmd->payload.calc_loss_cmd.passint;
    int n = params0[0];
    int batch = params0[1];

    calc_network_loss_TA(n, batch);

    return PEBBLE_STATUS_SUCCESS;
}


static pebble_return_code_t net_output_return_TA_params(tci_message_t *cmd,
                                              tci_message_t *resp)
{
#if 0
    uint32_t exp_param_types = TEE_PARAM_TYPES( TEE_PARAM_TYPE_MEMREF_OUTPUT,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE,
                                               TEE_PARAM_TYPE_NONE);

    if (param_types != exp_param_types)
        return TEE_ERROR_BAD_PARAMETERS;

#endif
    float *params0 = (float *)resp->payload.output_return_resp.net_output_back.buf;
    int buffersize = cmd->payload.output_return_cmd.net_output_back.len / sizeof(float);

    if(norm_output){
        // remove confidence scores
        float maxconf; maxconf = 0.00001f;
        int maxidx; maxidx = 0;

        for(int z=0; z<buffersize; z++){
            if(ta_net_output[z] > maxconf){
                maxconf = ta_net_output[z];
                maxidx = z;
            }
            ta_net_output[z] = 0.0f;
        }
        ta_net_output[maxidx] = 1.00f;
    }

    for(int z=0; z<buffersize; z++){
        params0[z] = ta_net_output[z];
    }
    resp->payload.output_return_resp.net_output_back.len = sizeof(float) * buffersize;

    OPENSSL_free(ta_net_output);

    return PEBBLE_STATUS_SUCCESS;

}

/**
 * @brief
 * process_cmd
 * Process command
 *
 * @param[in] commandId - command id
 * @param[in] tci_msg   - tci message
 *
 * @return PEBBLE status code
 */
pebble_return_code_t process_cmd(uint32_t commandId, tci_message_t *tci_req, tci_message_t *tci_resp) {
        PEBBLE_LOG("process_cmd()");

        pebble_return_code_t ret = PEBBLE_STATUS_SUCCESS;
#if 0
        // Device integrity check
        device_status = pebble_ICCC_check();
        if (device_status != PEBBLE_DEVICE_OK) {
                PEBBLE_LOG("pebble_ICCC_check FAIL");
        }
#endif
        switch (commandId) {
                case MAKE_NETWORK_CMD:
                        return make_netowork_TA_params(tci_req, tci_resp);

                case WORKSPACE_NETWORK_CMD:
                        return update_net_agrv_TA_params(tci_req, tci_resp);

                case MAKE_CONV_CMD:
                        return make_convolutional_layer_TA_params(tci_req, tci_resp);

                case MAKE_MAX_CMD:
                        return make_maxpool_layer_TA_params(tci_req, tci_resp);

                case MAKE_DROP_CMD:
                        return make_dropout_layer_TA_params(tci_req, tci_resp);

                case MAKE_CONNECTED_CMD:
                        return make_connected_layer_TA_params(tci_req, tci_resp);

                case MAKE_SOFTMAX_CMD:
                        return make_softmax_layer_TA_params(tci_req, tci_resp);

                case MAKE_COST_CMD:
                        return make_cost_layer_TA_params(tci_req, tci_resp);

                case TRANS_WEI_CMD:
                        return transfer_weights_TA_params(tci_req, tci_resp);

                case SAVE_WEI_CMD:
                        return save_weights_TA_params(tci_req, tci_resp);

                case FORWARD_CMD:
                        return forward_network_TA_params(tci_req, tci_resp);

                case BACKWARD_CMD:
                        return backward_network_TA_params(tci_req, tci_resp);

                case BACKWARD_ADD_CMD:
                        return backward_network_TA_addidion_params(tci_req, tci_resp);

                case UPDATE_CMD:
                        return update_network_TA_params(tci_req, tci_resp);

                case NET_TRUTH_CMD:
                        return net_truth_TA_params(tci_req, tci_resp);

                case CALC_LOSS_CMD:
                        return calc_network_loss_TA_params(tci_req, tci_resp);

                case OUTPUT_RETURN_CMD:
                        return net_output_return_TA_params(tci_req, tci_resp);

                case FORWARD_BACK_CMD:
                        return forward_network_back_TA_params(tci_req, tci_resp);

                case BACKWARD_BACK_CMD:
                        return backward_network_back_TA_params(tci_req, tci_resp);

                case BACKWARD_BACK_ADD_CMD:
                        return backward_network_back_TA_addidion_params(tci_req, tci_resp);

                default:
                        PEBBLE_LOG("received unknown command!");
                        PEBBLE_LOG_DEBUG("received unknown command: %d", commandId);
                        /* Unknown command ID */
                        ret = PEBBLE_STATUS_FAIL;
                        break;
        }

exit:
        return ret;
}
