#!/usr/bin/env python3
# -*- mode: python -*-
# =============================================================================
#
#  Copyright (c) 2018-2020 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import argparse
import json
import logging
import os
import webbrowser
import sys
import hashlib
import tempfile

if sys.version_info[0] == 3:
    from importlib import reload
# set encoding to utf8
reload(sys)
if sys.version_info[0] == 2:
    sys.setdefaultencoding('utf8')

try:
    from qti.aisw.dlc_utils import snpe_dlc_utils
except ImportError as ie:
    print("Failed to find necessary package:")
    print(str(ie))
    print("Please ensure that $SNPE_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)

get_si_notation = snpe_dlc_utils.get_si_notation


def model_info_to_graph_info(rows_all, total_macs, total_params):
    """
    Assigns layer names as nodes (or vertices) in the model.
    Assigns connections between layer names and their input/output layers
    as links (or edges).
    """

    # List that will store dictionaries of layer names and their parameters
    nodes = []

    # List that will store directionaries of links between layers and their input/output layers
    links = []

    for row in rows_all:

        out_names_list = []
        for out_name in row.output_names:
            out_names_list.append(out_name)

        in_names_list = []
        for in_name in row.input_names:
            in_names_list.append(in_name)

        # Get parameter values for the layer
        m = max(len(row.get_parm_list()), len(row.get_input_list()))
        m = max(m,len(row.get_output_list()))
        parms = []

        # Get ID and output dimensions of layer as a string
        dims_len = len(row.output_dims_list[0])
        dims = ''
        if dims_len != 0:
            for index in range(dims_len - 1):
                dims = dims + str(row.output_dims_list[0][index]) + 'x'
            dims = dims + str(row.output_dims_list[0][dims_len - 1])
        parms.append('ID: ' + str(row.id))
        parms.append('<br/>Output Dims: ' + dims)
        parms.append('<br/>Layer Type:' + row.type)
        for i in range(0,m):
            if row.get_parm(i):
                parms.append('<br/>'+row.get_parm(i).replace('\n',''))

        # Get param_count per inference if it exists
        param_count = row.get_num_params()
        if param_count>0:
            parms.append('<br/>param count: '+ get_si_notation(param_count, total_params))

        # Get MACs per inference if it exists
        macs = row.get_macs()
        if macs>0:
            parms.append('<br/>MACs per inference: '+ get_si_notation(macs, total_macs))

        # Insert Affinity here
        if row.layer_affinity != 'UNSET':
            parms.append('<br/>Layer Affinity: '+ row.layer_affinity)

        # Create a node for the layer
        colour = int(hashlib.md5(row.type.encode()).hexdigest(),16) % 5
        colour = colour + 10
        nodes.append({'name':row.name, 'parameters':parms, 'type': row.type, 'color':colour})
        for out_name in out_names_list:
            # To avoid self loop
            if out_name != row.name:
                links.append({"source":row.name, "target":out_name, 'dummy':row.name})
        for in_name in in_names_list:
            # To avoid self loop
            if in_name != row.name:
                links.append({"source":in_name, "target":row.name})



    return nodes, links

def main():
    parser = argparse.ArgumentParser()
    required = parser.add_argument_group('required arguments')
    required.add_argument('-i', '--input_dlc', required=True, type=str, help="Path to a dl container archive")
    parser.add_argument("-s", "--save", type=str, help="Save HTML file. Specify a file name and/or target save path")
    args = parser.parse_args()

    snpe_dlc_utils.setUpLogger(True)

    logger = logging.getLogger()
    if not os.path.exists(args.input_dlc):
        logger.error("Cannot find archive DLC file " + args.input_dlc)
        sys.exit(-1)

    # Load input dlc file and extract model information layer by layer
    m = snpe_dlc_utils.ModelInfo()
    rows_all = m.extract_model_info(args.input_dlc)

    total_params = m.get_total_params()
    total_macs = m.get_total_macs()

    # Add method to ModelInfo class
    nodes, links = model_info_to_graph_info(rows_all, total_macs, total_params)
    # get meta-data from ModelInfo class
    (model_version, total_params, total_macs, converter_command, quantizer_command,
    converter_version, model_copyright) = m.get_meta_data(total_params, total_macs, args.input_dlc)
    input_file_path = os.path.abspath(args.input_dlc)

    # Set path to location of the script, rather than current working directory
    filepath = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
    # Set path of the shared directory

    if(os.path.exists(os.path.abspath(os.path.join(filepath, '..', '..','share', 'dlcviewer')))):
      sharedpath = os.path.abspath(os.path.join(filepath, '..', '..','share', 'dlcviewer'))
    elif(os.path.exists(os.path.abspath(os.path.join(filepath, '..','share', 'dlcviewer')))):
      sharedpath = os.path.abspath(os.path.join(filepath, '..','share', 'dlcviewer'))
    else:
        logger.error("Failed to access the dependency packages in share location.")
        sys.exit(-1)
    # Retrieve name of DLC file
    modelname = os.path.splitext(os.path.basename(args.input_dlc))[0]

    # Read in viewer template
    with open(os.path.abspath(os.path.join(sharedpath, 'snpe_dlc_viewer_template.html')), 'r', encoding='utf8') as file :
      filedata = file.read()
    # Replace target strings with nodes and links info
    filedata = filedata.replace('??nodes??', '%s' % json.dumps(nodes))
    filedata = filedata.replace('??links??', '%s' % json.dumps(links))
    # Replace target strings with meta-data. Note: Only need values here since we have the
    # the keys for each of these already defined in the html template
    filedata = filedata.replace('??total_params??', '%s' % total_params.split(":")[1])
    filedata = filedata.replace('??total_macs??', '%s' % total_macs.split(":")[1])
    filedata = filedata.replace('??model??', '%s' % input_file_path)
    filedata = filedata.replace('??model_version??', '%s' % model_version.split(":")[1])
    filedata = filedata.replace('??converter_command??', '%s' % converter_command.split(":")[1])
    filedata = filedata.replace('??quantizer command??', '%s' % quantizer_command.split(":")[1])
    filedata = filedata.replace('??converter_version??', '%s' % converter_version.split(":")[1])

    # get model copyright info and create list from copyright string by
    # splitting it at every newline
    model_copyright = m.get_model_copyright().split('\n')
    model_copyright_table = ""

    for line in model_copyright:
        # here, class no_border so that we dont add lines in the copyright statement
        model_copyright_table += "<tr><td class='no_border'>" + line + "</td></tr>"

    # add to the html template
    filedata = filedata.replace('??copyright_table??', '%s' % model_copyright_table)

    if m.is_aix_enabled():
        # get aix info if it exists
        if m.is_aix_record_present():
            try:
                aix_records = m.get_aix_records()
                warning_msgs = ""
                headers = ["AIP Record Name", "nnc_version", "record_version", "hta_blob_id", "record_size", "Subnets"]
                aix_table = "<tr>"
                for header in headers:
                    aix_table += "<th class='info_headers'>" + header + "</th>"
                aix_table += "</tr>"
                for aix_record_name, aix_meta_info in aix_records.items():
                    aix_table += "<tr><td>" + aix_record_name + "</td>" # add the record name column

                    # add everything after name but before Subnets(since Subnets have further info)
                    for i in range(1, len(headers) - 1):
                        aix_table += "<td>" + str(aix_meta_info[headers[i]]) + "</td>"

                    subnet_col = "num_of_subnets: " + str(aix_meta_info['num_of_subnets'])
                    aix_table += "<td>" + subnet_col

                    # Add subnets meta info for record
                    if aix_meta_info['compatibility']:
                        for j in range(0, aix_meta_info['num_of_subnets']):
                            subnet_name = "subnet_" + str(j)
                            aix_table += "<p style='margin:0'>" + subnet_name + ':</p>'
                            # note: separated if cases for start/end ids so that they get printed one after the other for
                            #        better visual. Python was ordering them randomly even if OrderedDict was used.
                            if "start_layer_Id" in aix_meta_info[subnet_name].keys():
                                aix_table += "<p style='padding-left:15px; margin:0'> start_layer_Id: " \
                                             + str(aix_meta_info[subnet_name]["start_layer_Id"]) \
                                             + "</p>"
                                aix_meta_info[subnet_name].pop("start_layer_Id")
                            if "end_layer_Id" in aix_meta_info[subnet_name].keys():
                                aix_table += "<p style='padding-left:15px; margin:0'> end_layer_Id: " \
                                             + str(aix_meta_info[subnet_name]["end_layer_Id"]) \
                                             + "</p>"
                                aix_meta_info[subnet_name].pop("end_layer_Id")

                            for subnet_key, subnet_value in aix_meta_info[subnet_name].items():
                                if isinstance(subnet_value, list):
                                    aix_table += "<p style='padding-left:15px; margin:0'>" + str(subnet_key) + ":</p>"
                                    for value in subnet_value:
                                        aix_table += "<p style='padding-left:25px; margin:0'>" + str(value) + "</p>"
                                else:
                                    aix_table += "<p style='padding-left:15px; margin:0'>" + str(subnet_key) + ": " + str(subnet_value) + "</p>"
                        aix_table += "</td>"
                    else:
                        # add warning message if record is not compatible with current version of snpe
                        warning_msgs += "- Record " + aix_record_name + " is incompatible with the latest version of SNPE\n"

                if len(warning_msgs):
                    raise Exception(warning_msgs)

            except Exception as e:
                aix_table += "</table><p class='error'>Error querying AIP data:\n" + e.message + "</p>"
        else:
            aix_table = "<tr><td> No AIP Records Found in model. </td></tr>"

        # add to the html template
        filedata = filedata.replace('??aix_table??', '%s' % aix_table)
        filedata = filedata.replace('??is_use_aix_set??', '%s' % "block") # enable AIX info tab

    else:
        filedata = filedata.replace('??is_use_aix_set??', '%s' % "none")  # Disable displaying AIX info tab in HTML output

    html_out = ""
    if args.save:
        save_html = args.save
        if(save_html.endswith('.html')):
            if '/' in save_html:
                # Fully qualified file-path
                html_out = os.path.join(os.path.abspath(save_html.rpartition('/')[0]), save_html.rpartition('/')[-1])
            else:
                # Only filename specified; append CWD path
                html_out = os.path.join(os.getcwd(), save_html)
        else:
            # Only path specified; append file name taken from model name
            html_out = os.path.join(os.path.abspath(save_html), '%s.html' %modelname)
    else:
        # No specifications provided; store file at a temporary location for rendering
        render_path = os.path.abspath(tempfile.gettempdir())
        html_out = os.path.join(render_path, '%s.html' %modelname)

    print("Network Model HTML file saved at %s" % html_out)
    # Write out to an HTML file, specific to DLC name
    try:
      with open(html_out, 'w', encoding='utf8') as f:
          f.write(filedata)
    except IOError:
        logger.error("IOError: Cannot write HTML file " + html_out)
        sys.exit(-1)

    # Open HTML in browser in try order
    webbrowser.open('file://'+os.path.abspath(html_out))

if __name__ == "__main__":
    main()
