/*
 * Copyright (C) 2020, Samsung Electronics Co., Ltd.
 *
 * TUI LL common functions for all drivers
 */

#include <atomic.h>
#include <bsd_list.h>
#include <errno.h>
#include <fcntl.h>
#include <macros.h>
//#include <panic.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <tz_cred.h>
#include <unistd.h>

#include "tuill_log.h"
#include "tuill_socket_lib.h"

struct ucred {
    struct tz_cred base;
} __attribute__((__packed__));

enum tuill_entity_type {
    TYPE_SERVER,
    TYPE_CLIENT,
    TYPE_DRIVER
};

struct tuill_socket_entity {
    uint32_t eid;
    int32_t fd;
    uint32_t type;
    uint32_t events;
    int32_t socket_family;
    struct socket_callbacks cb;
    void *user_data;
    TAILQ_ENTRY(tuill_socket_entity) node;
};

struct tuill_socket_ctx {
    int32_t efd;   /* epoll fd */
    int32_t cancel_fdpair[2]; /*  to terminate event loop */
    TAILQ_HEAD(entity, tuill_socket_entity) queue;
};

#define START_ENTITY_ID 1

static void delete_item(void *ctx, struct tuill_socket_entity *sock, bool fclose);
static void serv_accept(void *ctx, struct tuill_socket_entity *server);

TA_EXPORT int32_t tuill_socket_init(void **ctx)
{
    TUILL_CALL_TRACE();
    if (ctx == NULL) {
        syslog(LOG_ERR, "bad parameter\n");
        return -TUILLE_BAD_PARAMETERS;
    }

    struct tuill_socket_ctx *_ctx = calloc(1, sizeof(struct tuill_socket_ctx));
    if (_ctx == NULL) {
        syslog(LOG_ERR, "can't allocate context\n");
        return -TUILLE_OUT_OF_MEMORY;
    }

    syslog(LOG_DEBUG, "_ctx=%p, _ctx->efd=%d\n", _ctx, _ctx->efd);
    _ctx->efd = epoll_create(1);
    syslog(LOG_DEBUG, "_ctx->efd=%d\n", _ctx->efd);
    if (_ctx->efd == -1) {
        syslog(LOG_ERR, "can't create epoll errno=%d\n", errno);
        goto error;
    }

    TAILQ_INIT(&_ctx->queue);
    *ctx = _ctx;

    if (socketpair(AF_LOCAL, SOCK_SEQPACKET, PF_LOCAL, _ctx->cancel_fdpair) != 0) {
        syslog(LOG_ERR, "socketpair returned %d\n", errno);
        ALWAYS_ZERO(close(_ctx->efd));
        goto error;
    }

    struct epoll_event event;
    event.events = EPOLLHUP;
    event.data.fd = _ctx->cancel_fdpair[0];

    if (epoll_ctl(_ctx->efd, EPOLL_CTL_ADD, _ctx->cancel_fdpair[0], &event) != 0) {
        syslog(LOG_ERR, "epoll_ctl returned %d\n", errno);
        ALWAYS_ZERO(close(_ctx->efd));
        ALWAYS_ZERO(close(_ctx->cancel_fdpair[0]));
        ALWAYS_ZERO(close(_ctx->cancel_fdpair[1]));
        goto error;
    }

    return 0;
error:
    free(_ctx);
    *ctx = NULL;
    return -TUILLE_GENERIC;
}

TA_EXPORT void tuill_socket_stop_event_loop(void *ctx)
{
    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;
    ALWAYS_ZERO(close(_ctx->cancel_fdpair[1]));
}

TA_EXPORT void tuill_socket_uninit(void *ctx)
{
    struct tuill_socket_entity *tmp;
    struct tuill_socket_entity *var;

    TUILL_CALL_TRACE();
    if (ctx == NULL) {
        syslog(LOG_ERR, "bad parameter\n");
        return;
    }

    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;

    ALWAYS_ZERO(close(_ctx->cancel_fdpair[0]));

    TAILQ_FOREACH_SAFE(var, &_ctx->queue, node, tmp) {
        delete_item(_ctx, var, true);
    }

    ALWAYS_ZERO(close(_ctx->efd));

    free(_ctx);
}

TA_EXPORT void tuill_socket_run_event_loop(void *ctx)
{
    struct tuill_socket_entity *tmp;
    struct tuill_socket_entity *var;
    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;
    const size_t max_waiting_events = 3;
    struct epoll_event events[max_waiting_events];
    openlog(log_tag, LOG_PID, 0);

    do {
        int events_count = epoll_wait(_ctx->efd, events, max_waiting_events, -1);
        if (events_count == -1) {
            syslog(LOG_ERR, "epoll_wait returned -1, errno=%d\n", errno);
            continue;
        }
        syslog(LOG_DEBUG, "events_count=%d\n", events_count);
        for (int i = 0; i < events_count; i++) {
            if (events[i].data.fd == _ctx->cancel_fdpair[0]) {
                if (events[i].events != EPOLLHUP) {
                    syslog(LOG_ERR, "unexpected event on cancellation socket\n");
                    TEE_Panic(EINVAL);
                }
                syslog(LOG_DEBUG, "exiting event loop\n");
                return;
            }
            TAILQ_FOREACH_SAFE(var, &_ctx->queue, node, tmp) {
                if (var->fd != events[i].data.fd) {
                    continue;
                }
                if (events[i].events & EPOLLIN) {
                    syslog(LOG_DEBUG, "EPOLLIN\n");
                    if (var->type == TYPE_SERVER) {
                        syslog(LOG_DEBUG, "server\n");
                        serv_accept(_ctx, var);
                    } else if (var->cb.income) {
                        syslog(LOG_DEBUG, "var->type=%d var->fd=%d\n", var->type, var->fd);
                        struct tuill_buffer buff;
                        buff.data_len =
                            recv(events[i].data.fd, buff.data,
                                 sizeof(struct tuill_internal_command), 0);
                        var->cb.income(var->user_data, &buff);
                    }
                } else if (events[i].events & EPOLLOUT) {
                    syslog(LOG_DEBUG,
                           "EPOLLOUT var->cb.outcome=%p var->user_data=%p\n",
                           var->cb.outcome,
                           var->user_data);
                    if (var->cb.outcome) {
                        var->cb.outcome(var->user_data);
                    }
                } else if (events[i].events & EPOLLERR) {
                    syslog(LOG_DEBUG, "EPOLLERR\n");
                    if (var->cb.error) {
                        var->cb.error(var->user_data);
                        delete_item(_ctx, var, true);
                    }
                } else if (events[i].events & EPOLLHUP) {
                    syslog(LOG_DEBUG, "EPOLLHUP\n");
                    if (var->cb.hangup) {
                        var->cb.hangup(var->user_data);
                        delete_item(_ctx, var, false);
                    }
                } else {
                    syslog(LOG_ERR, "Unknown event 0x%X\n", events[i].events);
                }

                break;
            }
        }
    } while (true);
}

static void delete_item(void *ctx, struct tuill_socket_entity *var, bool fclose)
{
    TUILL_CALL_TRACE();
    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;
    if (var->cb.cleanup) {
        var->cb.cleanup(var->user_data);
    }

    if (_ctx->efd > 0 && var->fd > 0) {
        syslog(LOG_DEBUG, "deleting from epoll var->fd=%d\n", var->fd);
        ALWAYS_ZERO(epoll_ctl(_ctx->efd, EPOLL_CTL_DEL, var->fd, NULL));
    }

    if (fclose) {
        syslog(LOG_DEBUG, "closing var->fd=%d\n", var->fd);
        ALWAYS_ZERO(close(var->fd));
    }

    TAILQ_REMOVE(&_ctx->queue, var, node);
    free(var);
}

TA_EXPORT int32_t tuill_socket_write_driver(void *ctx, uint32_t id, void *data, ssize_t len)
{
    struct tuill_socket_entity *tmp;
    struct tuill_socket_entity *var;
    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;
    TUILL_CALL_TRACE();
    syslog(LOG_DEBUG, "ctx=%p, id=%d, data=%p, len=%zu\n", ctx, id, data, len);
    TAILQ_FOREACH_SAFE(var, &_ctx->queue, node, tmp) {
        if (var->eid == id) {
            ssize_t count = write(var->fd, data, len);
            if (count != len) {
                syslog(LOG_ERR, "return count=%zu errno=%d\n", count, errno);
                return -TUILLE_COMMUNICATION;
            }

            return 0;
        }
    }
    syslog(LOG_DEBUG, "ID %d doesn't exist\n", id);
    return -TUILLE_BAD_PARAMETERS;
}

TA_EXPORT int32_t tuill_socket_delete_driver(void *ctx, uint32_t id)
{
    struct tuill_socket_entity *tmp;
    struct tuill_socket_entity *var;
    TUILL_CALL_TRACE();
    syslog(LOG_DEBUG, "ctx=%p, id=%d\n", ctx, id);
    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;
    if (id == 0) {
        return 0;
    }

    TAILQ_FOREACH_SAFE(var, &_ctx->queue, node, tmp) {
        if (var->eid == id) {
            delete_item(_ctx, var, true);
            return 0;
        }
    }
    syslog(LOG_DEBUG, "ID %d doesn't exist\n", id);
    return -TUILLE_BAD_PARAMETERS;
}

TA_EXPORT int32_t tuill_socket_create_driver(void *ctx, char *name, const struct socket_callbacks *cb,
                                void *user_data, uint32_t *id)
{
    TUILL_CALL_TRACE();
    syslog(LOG_DEBUG, "ctx=%p, name=%s, cb=%p, user_data=%p\n", ctx, name, cb, user_data);
    static atomic_t entity_id = ATOMIC_INIT(START_ENTITY_ID);
    *id = 0;
    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;

    struct tuill_socket_entity *driver = calloc(1, sizeof(struct tuill_socket_entity));
    if (driver == NULL) {
        syslog(LOG_ERR, "can't allocate memory\n");
        return -TUILLE_OUT_OF_MEMORY;
    }

    atomic_inc(entity_id);
    driver->eid = (uint32_t)atomic_read(entity_id);
    driver->fd = open(name, O_RDWR | O_NONBLOCK, 0);
    if (driver->fd < 0) {
        syslog(LOG_ERR, "can't open %s errno=%d\n", name, errno);
        free(driver);
        return -TUILLE_GENERIC;
    }

    struct epoll_event event;
    event.events = EPOLLIN;
    event.data.fd = driver->fd;

    syslog(LOG_DEBUG, "_ctx->efd=%d, driver->fd=%d\n", _ctx->efd, driver->fd);
    if (epoll_ctl(_ctx->efd, EPOLL_CTL_ADD, driver->fd, &event) < 0) {
        syslog(LOG_ERR, "epoll_ctl returned %d\n", errno);
        goto error;
    }

    driver->user_data = user_data;
    driver->cb = *cb;
    driver->type = TYPE_DRIVER;
    *id = driver->eid;

    syslog(LOG_DEBUG,
           "including in list driver->fd=%d driver->type=%d\n",
           driver->fd,
           driver->type);
    TAILQ_INSERT_TAIL(&_ctx->queue, driver, node);

    return 0;

error:
    ALWAYS_ZERO(close(driver->fd));
    free(driver);
    syslog(LOG_ERR, "Can't create driver errno=%d\n", errno);
    return -TUILLE_GENERIC;
}

TA_EXPORT int32_t tuill_socket_create_client(void *ctx, char *server_name, struct socket_callbacks *cb,
                                void *user_data)
{
    /* called from tuill library */
    TUILL_CALL_TRACE();
    struct sockaddr_un address;
    struct epoll_event event;
    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;
    syslog(LOG_DEBUG, "ctx=%p\n", ctx);

    struct tuill_socket_entity *client = calloc(1, sizeof(struct tuill_socket_entity));
    if (client == NULL) {
        syslog(LOG_ERR, "can't allocate memory\n");
        return -TUILLE_OUT_OF_MEMORY;
    }

    /* we don't need entity_id in this case */
    client->fd = socket(AF_LOCAL, SOCK_SEQPACKET, 0);
    if (client->fd < 0) {
        syslog(LOG_ERR, "Failed to create socket. errno: %d\n", errno);
        free(client);
        return -TUILLE_GENERIC;
    }

    address.sun_family = AF_LOCAL;
    snprintf(address.sun_path, UNIX_PATH_MAX, TUILL_SERVER_TEMPLATE, server_name);

    if (connect(client->fd, (struct sockaddr *)&address, sizeof(struct sockaddr_un)) < 0) {
        syslog(LOG_ERR, "Failed to connect server. errno: %d\n", errno);
        goto error;
    }

    event.events = EPOLLOUT | EPOLLIN | EPOLLET;
    event.data.fd = client->fd;

    syslog(LOG_DEBUG, "_ctx->efd=%d, client->fd=%d\n", _ctx->efd, client->fd);
    if (epoll_ctl(_ctx->efd, EPOLL_CTL_ADD, client->fd, &event) < 0) {
        syslog(LOG_ERR, "epoll_ctl returned %d\n", errno);
        goto error;
    }

    client->user_data = user_data;
    client->cb = *cb;
    client->type = TYPE_CLIENT;

    syslog(LOG_DEBUG,
           "including in list client->fd=%d client->type=%d\n",
           client->fd,
           client->type);
    TAILQ_INSERT_TAIL(&_ctx->queue, client, node);

    syslog(LOG_DEBUG, "client=%p _ctx=%p\n", client, _ctx);
    if (client->cb.handshake) {
        client->cb.handshake(client->fd, user_data);
    }

    return 0;

error:
    ALWAYS_ZERO(close(client->fd));
    free(client);
    return -TUILLE_GENERIC;
}

TA_EXPORT int32_t tuill_socket_create_server(void *ctx, char *name, int32_t socket_family,
                                struct socket_callbacks *cb, void *user_data)
{
    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;
    struct sockaddr_un listener_addr;

    struct tuill_socket_entity *server = calloc(1, sizeof(struct tuill_socket_entity));
    if (server == NULL) {
        syslog(LOG_ERR, "can't allocate memory\n");
        return -TUILLE_OUT_OF_MEMORY;
    }

    /* we don't need entity_id for server */
    server->fd = socket(socket_family, SOCK_SEQPACKET, 0);
    if (server->fd < 0) {
        syslog(LOG_ERR, "Failed to create socket. errno: %d\n", errno);
        goto error;
    }

    int32_t res = snprintf(listener_addr.sun_path, UNIX_PATH_MAX, TUILL_SERVER_TEMPLATE, name);
    if (res < 0 || res >= UNIX_PATH_MAX) {
        syslog(LOG_ERR, "snprintf failed. res=%d\n", res);
        goto error;
    }
    listener_addr.sun_family = socket_family;

    res = bind(server->fd, (struct sockaddr *)&listener_addr, sizeof(listener_addr));
    if (res < 0) {
        syslog(LOG_ERR, "bind(%s) failed, errno: %d\n", listener_addr.sun_path, errno);
        goto error;
    } else {
        syslog(LOG_DEBUG, "bind(%s) successfull\n", listener_addr.sun_path);
    }

    res = listen(server->fd, 0);
    if (res < 0) {
        syslog(LOG_ERR, "listen() failed, errno: %d\n", errno);
        goto error;
    }

    struct epoll_event ev;
    ev.events = EPOLLOUT | EPOLLIN | EPOLLET;
    ev.data.fd = server->fd;

    syslog(LOG_DEBUG, "_ctx->efd=%d, server->fd=%d\n", _ctx->efd, server->fd);
    res = epoll_ctl(_ctx->efd, EPOLL_CTL_ADD, server->fd, &ev);
    if (res < 0) {
        syslog(LOG_ERR, "epoll_ctl failed, errno: %d\n", errno);
        goto error;
    }

    server->user_data = user_data;
    server->cb = *cb;
    server->type = TYPE_SERVER;
    server->socket_family = socket_family;

    syslog(LOG_DEBUG, "including in list server->fd=%d server->type=%d\n", server->fd,
           server->type);
    TAILQ_INSERT_TAIL(&_ctx->queue, server, node);

    return 0;

error:
    ALWAYS_ZERO(close(server->fd));
    free(server);
    return -TUILLE_GENERIC;
}

static void serv_accept(void *ctx, struct tuill_socket_entity *server)
{
    int32_t client_fd = -1;

    struct ucred cred;
    socklen_t length = sizeof(struct ucred);

    const int enable = !0;
    struct tuill_socket_entity *client = NULL;

    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;

    TUILL_CALL_TRACE();
    syslog(LOG_DEBUG, "_ctx=%p server=%p\n", _ctx, server);
    if (server->cb.accept == NULL) {
        syslog(LOG_ERR, "accept callback is not set\n");
        return;
    }

    syslog(LOG_DEBUG, "server->fd=%d\n", server->fd);
    if ((client_fd = accept(server->fd, NULL, NULL)) < 0) {
        syslog(LOG_ERR, "accept failed %d\n", errno);
        return;
    }

    syslog(LOG_DEBUG, "client_fd=%d\n", client_fd);
    if (setsockopt(client_fd, SOL_SOCKET, SO_PASSCRED, &enable, sizeof(enable)) < 0) {
        syslog(LOG_ERR, "cannot enable client credentials errno=%d\n", errno);
        goto client_error;
    }

    if (getsockopt(client_fd, SOL_UNSPEC, SO_PEERCRED, &cred, &length) < 0) {
        /* iwd client doesn't have cred data */
        syslog(LOG_ERR, "cannot obtain client credentials: %d\n", errno);
        length = 0;
    }

    client = calloc(1, sizeof(struct tuill_socket_entity));
    if (client == NULL) {
        syslog(LOG_ERR, "can't allocate memory errno=%d\n", errno);
        goto client_error;
    }

    /* we don't need entity_id for accepted client */
    client->fd = client_fd;
    client->type = TYPE_CLIENT;
    client->cb = server->cb;
    syslog(LOG_DEBUG, "server->cb.accept=%p\n", server->cb.accept);
    if (server->cb.accept(server->socket_family, client->fd, &client->user_data, &client->cb,
                          &cred, length) < 0) {
        syslog(LOG_ERR, "accept callback returned error\n");
        goto client_error;
    }

    struct epoll_event ev;
    ev.events = EPOLLOUT | EPOLLIN | EPOLLET;
    ev.data.fd = client->fd;

    syslog(LOG_DEBUG, "_ctx->efd=%d, client->fd=%d\n", _ctx->efd, client->fd);
    if (epoll_ctl(_ctx->efd, EPOLL_CTL_ADD, client->fd, &ev) < 0) {
        syslog(LOG_ERR, "epoll_ctl failed, errno: %d\n", errno);
        goto client_error;
    }

    syslog(LOG_DEBUG,
           "including in list client->fd=%d client->type=%d\n",
           client->fd,
           client->type);
    TAILQ_INSERT_TAIL(&_ctx->queue, client, node);

    return;

client_error:
    ALWAYS_ZERO(close(client_fd));
    free(client);
    return;
}

TA_EXPORT int32_t tuill_socket_client_send_to_server(void *ctx, struct tuill_internal_command *cmd)
{
    /* called from tuill library */
    struct tuill_socket_entity *tmp;
    struct tuill_socket_entity *var;
    int32_t nbytes = 0;
    TUILL_CALL_TRACE();
    syslog(LOG_DEBUG, "ctx=%p, cmd=%p\n", ctx, cmd);

    struct tuill_socket_ctx *_ctx = (struct tuill_socket_ctx *)ctx;
    /* we have one server in the list */
    TAILQ_FOREACH_SAFE(var, &_ctx->queue, node, tmp) {
        syslog(LOG_DEBUG, "var=%p\n", var);
        nbytes = (int32_t)send(var->fd, cmd, sizeof(struct tuill_internal_command), 0);
        if (nbytes < 0) {
            syslog(LOG_ERR, "nbytes=%d\n", nbytes);
            return -TUILLE_COMMUNICATION;
        }
    }
    syslog(LOG_DEBUG, "nbytes=%d\n", nbytes);
    return nbytes;
}

TA_EXPORT int32_t tuill_socket_server_send_to_client(int32_t fd, struct tuill_internal_command *cmd)
{
    /* called on server side */
    TUILL_CALL_TRACE();

    int32_t nbytes = (int32_t)send(fd, cmd, sizeof(struct tuill_internal_command), 0);
    if (nbytes < 0) {
        syslog(LOG_ERR, "nbytes=%d\n", nbytes);
        return -TUILLE_COMMUNICATION;
    }
    syslog(LOG_DEBUG, "nbytes=%d\n", nbytes);
    return nbytes;
}
