#include <stdlib.h>
#include <string.h>
#include "Log.h"
#include "commonConfig.h"
#include "Tlv.h"
//#include "TLVext.h"

#define LENGTH_FIELD_SIZE_EX (LENGTH_FIELD_SIZE * 2)

namespace vendor {
namespace samsung {
namespace hardware {
namespace security {
namespace drk {

ITEM::ITEM(): Tag(TLV_START), Length(0), Value(NULL) { }

ITEM::ITEM(const ITEM& in): Tag(TLV_START), Length(0), Value(NULL)
{
    set(in.tag(), in.length(), in.value());
}

ITEM::~ITEM()
{
    if (Value) {
        memset(Value, 0x0, Length);
        free(Value);
        Value = NULL;
    }
    Tag    = TLV_START;
    Length = 0;
}

int32_t ITEM::set(TLVTAG t, uint32_t l, uint8_t *v)
{
    int32_t ret = 0;
    if (Value) {
        memset(Value, 0x0, Length);
        free(Value);
        Value = NULL;
    }
    Value = (uint8_t *)calloc(l + 1, sizeof(uint8_t));
    if (Value == NULL) {
        ret = -1;
        goto end;
    }
    memcpy(Value, v, l);
    Length = l;
    Tag = t;
end:
    return ret;
}

uint8_t *ITEM::value() const
{
    return Value;
}
uint32_t ITEM::length() const
{
    return Length;
}
TLVTAG ITEM::tag() const
{
    return Tag;
}

/**=========================================
  * TLV Class
  ===========================================*/
int32_t TLV::encodeLength(uint8_t *in, uint32_t in_value)
{
    int32_t ret = -1;
    uint16_t len = 0;
    if (start_tag == TLV_START) {
        len = (uint16_t)in_value;
        SET_UINT16_FROM_U16(in, 0, len);
        ret = LENGTH_FIELD_SIZE;
    } else {
        SET_UINT32(in, 0, in_value);
        ret = LENGTH_FIELD_SIZE_EX;
    }
    return ret;
}

int32_t TLV::decodeLength(uint8_t *in, uint32_t *out_value)
{
    int32_t ret = -1;
    if (start_tag == TLV_START) {
        *out_value = GET_UINT16(in, 0);
        ret = LENGTH_FIELD_SIZE;
    } else {
        *out_value = GET_UINT32(in, 0);
        ret = LENGTH_FIELD_SIZE_EX;
    }
    return ret;
}

int32_t TLV::setTlv(TLVTAG tag, uint8_t *value, uint32_t length)
{
    int32_t ret = 0;
    ITEM temp;
    ret = temp.set(tag, length, value);
    if (ret == 0) {
        tlv.push_back(temp);
    }
    return ret;
}

TLV::TLV(): start_tag(TLV_START), start_length(0) {}
TLV::TLV(TLVTAG mode): start_tag(mode), start_length(0) {}
TLV::~TLV() { tlv.clear(); }

int32_t TLV::encode(uint8_t **out, uint32_t *out_len)
{
    uint32_t total_length = 0,
             pos = 0;

    total_length = getTotalLength();
    *out = (uint8_t *)calloc(total_length + 1, sizeof(uint8_t));
    if (*out == NULL) {
        return ERR_TLV_ENCODE_FAILED;
    }

    *out[pos++] = (uint8_t)start_tag;
    pos += encodeLength(*out + pos, start_length);
    vector<ITEM>::iterator it;
    for (it = tlv.begin(); it != tlv.end() && total_length >= pos; it++) {
        (*out)[pos++] = (uint8_t)it->tag();
        pos += encodeLength(*out + pos, it->length());
        memcpy(*out + pos, it->value(), it->length());
        pos += it->length();
    }
    *out_len = pos;
    return NOT_ERROR;
}

int32_t TLV::encode(Bytes& out)
{
    int32_t  ret = NOT_ERROR;
    uint8_t  *pout = NULL;
    uint32_t  pout_len = 0;
    ret = encode(&pout, &pout_len);
    if (ret != NOT_ERROR) {
        return ret;
    }
    out.set(pout, pout_len);
    if (pout) {
        memset(pout, 0x0, pout_len);
        free(pout);
        pout = NULL;
    }
    return NOT_ERROR;
}

int32_t TLV::decode(uint8_t *in, uint32_t in_len)
{
    int32_t  ret = NOT_ERROR;
    uint32_t typeLen = 0;
    uint32_t pos = 0;
    uint32_t total_length = 0;
    TLVTAG      t;
    uint32_t l = 0;

    if (in == NULL || in_len < 1) {
        return ERR_INVALID_ARGUMENT;
    }

    if (setMode((TLVTAG)in[pos++]) != NOT_ERROR) {
        return ERR_INVALID_ARGUMENT;
    }

    if (start_tag == TLV_START) {
        typeLen = TAGLENGTH_FIELD_SIZE;
    } else {
        typeLen = TAGLENGTH_FIELD_SIZE_EX;
    }

    if (in_len < typeLen) {
        return ERR_INVALID_ARGUMENT;
    }

    pos += decodeLength(in + pos, &start_length);
    total_length = getTotalLength();

    if (total_length > in_len) {
        start_length = 0;
        ret = ERR_TLV_DECODE_FAILED;
        goto end;
    }

    while ((in_len - typeLen) > pos) {
        t = (TLVTAG)in[pos++];
        pos += decodeLength(in + pos, &l);
        if ((l > in_len) || (pos > (in_len - l))) {
            ret = ERR_TLV_DECODE_FAILED;
            break;
        }
        setTlv(t, in + pos, l);
        pos += l;
        t = (TLVTAG)0; l = 0;
    }
end:
    return ret;
}

int32_t TLV::decode(Bytes& in)
{
    int32_t ret = NOT_ERROR;
    ret = decode((uint8_t *)in, in.length());
    if (ret != NOT_ERROR) {
        return ret;
    }
    return NOT_ERROR;
}

int32_t TLV::get(TLVTAG tag, uint8_t **value, uint32_t *length)
{
    int32_t ret = -1;
    vector<ITEM>::iterator it;
    for (it = tlv.begin(); it != tlv.end(); it++) {
        if (tag == it->tag()) {
            *value = (uint8_t *)calloc(it->length() + 1, sizeof(uint8_t));
            if (*value != NULL) {
                memcpy(*value, it->value(), it->length());
                *length = it->length();
                ret = 0;
            } else {
                ret = -1;
            }
            break;
        }
    }
    return ret;
}

int32_t TLV::get(TLVTAG tag, uint8_t *value, uint32_t *length)
{
    int32_t ret = -1;
    vector<ITEM>::iterator it;
    for (it = tlv.begin(); it != tlv.end(); it++) {
        if (tag == it->tag()) {
            if (value != NULL && length != NULL && *length > it->length()) {
                memcpy(value, it->value(), it->length());
                *length = it->length();
                ret = 0;
            } else {
                ret = -1;
            }
            break;
        }
    }
    return ret;
}

int32_t TLV::get(TLVTAG tag, Bytes& value)
{
    int32_t   ret = -1;
    uint8_t  *lcValue = NULL;
    uint32_t  lcLength = 0;

    ret = get(tag, &lcValue, &lcLength);
    if (ret == 0) {
        value.set(lcValue, lcLength);
    }

    if (lcValue) {
        memset(lcValue, 0x0, lcLength);
        free(lcValue);
        lcValue = NULL;
        lcLength = 0;
    }

    return ret;
}

int32_t TLV::add(TLVTAG tag, uint8_t *value, uint32_t length)
{
    int32_t ret = 0;
    ITEM    temp;

    if (length == 0) {
        return -2;
    }

    if (start_tag == TLV_START) {
        start_length += TAGLENGTH_FIELD_SIZE;
    } else {
        start_length += TAGLENGTH_FIELD_SIZE_EX;
    }

    ret = temp.set(tag, length, value);
    if (ret == 0) {
        tlv.push_back(temp);
        start_length += length;
    }
    return ret;
}

int32_t TLV::add(TLVTAG tag, Bytes& lv)
{
    return add(tag, (uint8_t *)lv, lv.length());
}

int32_t TLV::setMode(TLVTAG tag)
{
    if (tag != TLV_START && tag != TLV_START_EX) {
        return -1;
    }
    start_tag = tag;
    return NOT_ERROR;
}

uint32_t TLV::getTotalLength()
{
    uint32_t ret = 0;

    if (start_tag == TLV_START) {
        ret = start_length + TAGLENGTH_FIELD_SIZE;
    } else {
        ret = start_length + TAGLENGTH_FIELD_SIZE_EX;
    }
    return ret;
}

}  // namespace drk
}  // namespace security
}  // namespace hardware
}  // namespace samsung
}  // namespace vendor
