#!/usr/bin/env python3
# -*- mode: python -*-
# =============================================================================
#
#  Copyright (c) 2017-2020 Qualcomm Technologies, Inc.
#  All Rights Reserved.
#  Confidential and Proprietary - Qualcomm Technologies, Inc.
#
# =============================================================================

import traceback
import os
import argparse
import logging

import sys
try:
    import qti.aisw
except ImportError as ie1:
    print("Failed to find necessary python package")
    print(str(ie1))
    print("Please ensure that $SNPE_ROOT/lib/python is in your PYTHONPATH")
    sys.exit(1)

from qti.aisw.converters.caffe2 import snpe_caffe2_to_dlc
from qti.aisw.converters.common.utils import converter_utils


def getArgs():
    logger = logging.getLogger()
    logger.debug("Parsing the arguments")

    parser = argparse.ArgumentParser(
        description=
        'Script to convert caffe2 networks into a DLC file.')

    required = parser.add_argument_group('required arguments')
    required.add_argument('-p', '--predict_net', type=str, required=True,
                          help='Input caffe2 binary network definition protobuf')
    required.add_argument('-e', '--exec_net', type=str, required=True,
                          help='Input caffe2 binary file containing the weight data')
    required.add_argument('-i', '--input_dim', nargs=2, action='append', required=True,
                        help='The names and dimensions of the network input layers specified in the format "input_name" B,C,H,W. Ex "data" 1,3,224,224. Note that the quotes should always be included in order to handle special characters, spaces, etc. For multiple inputs specify multiple --input_dim on the command line like: --input_dim "data1" 10,3,224,224 --input_dim "data2" 10,3,50,100 We currently assume that all inputs have 3 dimensions.')

    optional = parser.add_argument_group('optional arguments')
    optional.add_argument('-d', '--dlc', type=str,
                        help='Output DLC file containing the model. If not specified, the data will be written to a file with same name and location as the predict_net file with a .dlc extension')
    optional.add_argument('--copyright_file', type=str,
                          help='Path to copyright file. If provided, the content of the file will be added to the dlc.')
    # The "enable_preprocessing" option only works when ImageInputOp is specified. Otherwise preprocessing must occur prior to passing the input to SNPE
    optional.add_argument('--enable_preprocessing', action="store_const", const=True, default=False,
                        help="If specified, the converter will enable image mean subtraction and cropping specified by ImageInputOp. Do NOT enable if there is not a ImageInputOp present in the Caffe2 network.")
    optional.add_argument('--encoding', type=str, choices=['argb32', 'rgba', 'nv21', 'bgr'], default='bgr',
                        help='Image encoding of the source images. Default is bgr if not specified')
    optional.add_argument('--opaque_input', type=str, help="A space separated list of input blob names which should be treated as opaque (non-image) data. These inputs will be consumed as-is by SNPE. Any input blob not listed will be assumed to be image data.", nargs='*', default=[])
    optional.add_argument('--model_version', type=str,
                        help='User-defined ASCII string to identify the model, only first 64 bytes will be stored')
    optional.add_argument('--reorder_list', nargs='+',
                        help='A list of external inputs or outputs that SNPE should automatically reorder to match the specified Caffe2 channel ordering. Note that this feature is only enabled for the GPU runtime.', default = [])
    optional.add_argument("--verbose", dest="verbose", action="store_true",
                        help="Verbose printing", default = False)

    args = parser.parse_args()
    if args.dlc is None:
        filename, fileext = os.path.splitext(os.path.realpath(args.predict_net))
        args.dlc = filename + ".dlc"

    return args


if __name__ == '__main__':
    args = getArgs()
    converter_utils.setUpLogger(args.verbose)

    converter = snpe_caffe2_to_dlc.Caffe2SnapDnnConverter()
    try:
        converter_command = converter_utils.sanitize_args(args, args_to_ignore=['p', 'predict_net', 'e', 'exec_net', 'd', 'dlc'])
        converter.convert( args.predict_net,
                           args.exec_net,
                           args.dlc,
                           args.copyright_file,
                           args.encoding,
                           args.input_dim,
                           args.enable_preprocessing,
                           args.model_version,
                           converter_command,
                           args.reorder_list,
                           args.opaque_input)
    except Exception as e:
        print('Encountered Error:', str(e))
        print()
        print('Stack Trace:')
        traceback.print_exc()
        sys.exit(1)
    sys.exit(0)
