#!/usr/bin/env python3
# -*- encoding: utf-8 -*-

# ===-- klee-stats --------------------------------------------------------===##
#
#                      The KLEE Symbolic Virtual Machine
#
#  This file is distributed under the University of Illinois Open Source
#  License. See LICENSE.TXT for details.
#
# ===----------------------------------------------------------------------===##

"""Output statistics logged by Klee."""

import os
import sys
import argparse
import sqlite3
import collections

# Mapping of: (column head, explanation, internal klee name)
# column head must start with a capital letter
Legend = [
    ('Instrs', 'number of executed instructions', "Instructions"),
    ('Time(s)', 'total wall time (s)', "WallTime"),
    ('TUser(s)', 'total user time', "UserTime"),
    ('ICov(%)', 'instruction coverage in the LLVM bitcode (%)', "ICov"),
    ('BCov(%)', 'branch coverage in the LLVM bitcode (%)', "BCov"),
    ('ICount', 'total static instructions in the LLVM bitcode', "ICount"),
    ('TSolver(s)', 'time spent in the constraint solver', "SolverTime"),
    ('TSolver(%)', 'time spent in the constraint solver', "RelSolverTime"),
    ('States', 'number of currently active states', "NumStates"),
    ('MaxStates', 'maximum number of active states', "maxStates"),
    ('AvgStates', 'average number of active states', "avgStates"),
    ('Mem(MB)', 'megabytes of memory currently used', "MallocUsage"),
    ('MaxMem(MB)', 'megabytes of memory currently used', "MaxMem"),
    ('AvgMem(MB)', 'megabytes of memory currently used', "AvgMem"),
    ('Queries', 'number of queries issued to STP', "NumQueries"),
    ('AvgQC', 'average number of query constructs per query', "AvgQC"),
    ('Tcex(s)', 'time spent in the counterexample caching code', "CexCacheTime"),
    ('Tcex(%)', 'relative time spent in the counterexample caching code wrt wall time', "CexCacheTime"),
    ('Tfork(s)', 'time spent forking', "ForkTime"),
    ('Tfork(%)', 'relative time spent forking wrt wall time', "ForkTime"),
    ('TResolve(s)', 'time spent in object resolution', "ResolveTime"),
    ('TResolve(%)', 'time spent in object resolution wrt wall time', "ResolveTime"),
    ('QCexCMisses', 'Counterexample cache misses', "QueryCexCacheMisses"),
    ('QCexCHits', 'Counterexample cache hits', "QueryCexCacheHits"),
]

def getInfoFile(path):
    """Return the path to info"""
    return os.path.join(path, 'info')

def getLogFile(path):
    """Return the path to run.stats."""
    return os.path.join(path, 'run.stats')

class LazyEvalList:
    """Store all the lines in run.stats and eval() when needed."""
    def __init__(self, fileName):
        # The first line in the records contains headers.
      self.filename = fileName

    def conn(self):
        return sqlite3.connect(self.filename)

    def aggregateRecords(self):
        try:
            memC = self.conn().execute("SELECT max(MallocUsage) / 1024 / 1024, avg(MallocUsage) / 1024 / 1024 from stats")
            maxMem, avgMem = memC.fetchone()
        except sqlite3.OperationalError as e:
            maxMem, avgMem = None, None

        try:
            stateC = self.conn().execute("SELECT max(NumStates), avg(NumStates) from stats")
            maxStates, avgStates = stateC.fetchone()
        except sqlite3.OperationalError as e:
            maxStates, avgStates = None, None

        return {"maxMem":maxMem, "avgMem": avgMem, "maxState": maxStates, "avgStates": avgStates}

    def getLastRecord(self):
        try:
            cursor = self.conn().execute("SELECT * FROM stats ORDER BY rowid DESC LIMIT 1")
            column_names = [description[0] for description in cursor.description]
            return dict(zip(column_names, cursor.fetchone()))
        except sqlite3.OperationalError as e:
            return None


def stripCommonPathPrefix(paths):
    paths = map(os.path.normpath, paths)
    paths = [p.split('/') for p in paths]
    zipped = zip(*paths)
    i = 0
    for i, elts in enumerate(zipped):
        if len(set(elts)) > 1:
            break
    return ['/'.join(p[i:]) for p in paths]


def getKleeOutDirs(dirs):
    kleeOutDirs = []
    for dir in dirs:
        if os.path.exists(os.path.join(dir, 'info')):
            kleeOutDirs.append(dir)
        else:
            for root, subdirs, _ in os.walk(dir):
                for d in subdirs:
                    path = os.path.join(root, d)
                    if os.path.exists(os.path.join(path, 'info')):
                        kleeOutDirs.append(path)
    return kleeOutDirs


def select_columns(record, pr):
    if pr == 'all':
        return record

    if pr == 'reltime':
        s_column = ['Path', 'WallTime', 'RelUserTime', 'RelSolverTime',
                  'RelCexCacheTime', 'RelForkTime', 'RelResolveTime']
    elif pr == 'abstime':
        s_column = ['Path', 'WallTime', 'UserTime', 'SolverTime',
                  'CexCacheTime', 'ForkTime', 'ResolveTime']
    elif pr == 'more':
        s_column = ['Path', 'Instructions', 'WallTime', 'ICov', 'BCov', 'ICount',
                  'RelSolverTime', 'States', 'maxStates', 'MallocUsage', 'maxMem']
    else:
        s_column = ['Path', 'Instructions', 'WallTime', 'ICov',
                  'BCov', 'ICount', 'RelSolverTime']

    filtered_record = dict()
    for column in s_column:
        if column in record.keys():
            filtered_record[column] = record[column]

    return filtered_record


def add_artificial_columns(record):
    # special case for straight-line code: report 100% branch coverage
    if "NumBranches" in record and record["NumBranches"] == 0:
        record["FullBranches"] = 1
        record["NumBranches"] = 1

    # Convert recorded times from microseconds to seconds
    for key in ["UserTime", "WallTime", "QueryTime", "SolverTime", "CexCacheTime", "ForkTime", "ResolveTime"]:
        if not key in record:
            continue
        record[key] /= 1000000

    # Convert memory from byte to MiB
    if "MallocUsage" in record:
        record["MallocUsage"] /= (1024*1024)

    # Calculate avg. query construct
    if "NumQueryConstructs" in record and "NumQueries" in record:
        record["AvgQC"] = int(record["NumQueryConstructs"] / max(1, record["NumQueries"]))

    # Calculate total number of instructions
    if "CoveredInstructions" in record and "UncoveredInstructions" in record:
        record["ICount"] = (record["CoveredInstructions"] + record["UncoveredInstructions"])

    # Calculate relative instruction coverage
    if "CoveredInstructions" in record and "ICount" in record:
        record["ICov"] = 100 * record["CoveredInstructions"] / record["ICount"]

    # Calculate branch coverage
    if "FullBranches" in record and "PartialBranches" in record and "NumBranches" in record:
        record["BCov"] = 100 * ( 2 * record["FullBranches"] + record["PartialBranches"]) / ( 2 * record["NumBranches"])

    # Add relative times
    for key in ["SolverTime", "CexCacheTime", "ForkTime", "ResolveTime", "UserTime"]:
        if "WallTime" in record and key in record:
            record["Rel"+key] = 100 * record[key] / record["WallTime"]

    return record


def grafana(dirs, host_address, port):
    dr = getLogFile(dirs[0])
    from flask import Flask, jsonify, request
    import datetime
    app = Flask(__name__)

    import re
    from dateutil import parser
    def getKleeStartTime():
        with open(getInfoFile(dirs[0]), "r") as file:
            for line in file:
                m = re.match("Started: (.*)", line)
                if m:
                    dateString = m.group(1)
                    return parser.parse(dateString).timestamp()

        print("Error: Couldn't find klee's start time", file=sys.stderr)
        sys.exit()

    def toEpoch(date_text):
        dt = datetime.datetime.strptime(date_text, "%Y-%m-%dT%H:%M:%S.%fZ")
        epoch = (dt - datetime.datetime(1970, 1, 1)).total_seconds()
        return epoch

    @app.route('/')
    def status():
        return 'OK'

    @app.route('/search', methods=['GET', 'POST'])
    def search():
        conn = sqlite3.connect(dr)
        cursor = conn.execute('SELECT * FROM stats LIMIT 1')
        names = [description[0] for description in cursor.description]
        return jsonify(names)

    @app.route('/query', methods=['POST'])
    def query():
        jsn = request.get_json()
        interval = jsn["intervalMs"]
        limit = jsn["maxDataPoints"]
        frm = toEpoch(jsn["range"]["from"])
        to = toEpoch(jsn["range"]["to"])
        targets = [str(t["target"]) for t in jsn["targets"]]
        startTime = getKleeStartTime()
        fromTime = frm - startTime if frm - startTime > 0 else 0
        toTime = to - startTime if to - startTime > fromTime else fromTime + 100
        #convert to microseconds
        startTime, fromTime, toTime = startTime*1000000, fromTime*1000000, toTime*1000000
        sqlTarget = ",".join(["AVG( {0} )".format(t) for t in targets if t.isalnum()])

        conn = sqlite3.connect(dr)
        s = "SELECT WallTime + ? , {fields} " \
            + " FROM stats" \
            + " WHERE WallTime >= ? AND WallTime <= ?" \
            + " GROUP BY WallTime/? LIMIT ?"
        s = s.format(fields=sqlTarget) #can't use prepared staments for this one

        #All times need to be in microseconds, interval is in milliseconds
        cursor = conn.execute(s, ( startTime, fromTime, toTime, interval*1000, limit))
        result = [ {"target": t, "datapoints": []} for t in targets ]
        for line in cursor:
            unixtimestamp = int(line[0]) / 1000 #Convert from microsecond to miliseconds
            for field, datastream in zip(line[1:], result):
                  if "Time" in datastream["target"] and "Wall" not in datastream["target"]\
                     and "User" not in datastream["target"]:
                    val = (field/(line[0]-startTime))*100
                    datastream["datapoints"].append([val, unixtimestamp])
                  else:
                    datastream["datapoints"].append([field, unixtimestamp])

        ret = jsonify(result)
        return ret

    app.run(host=host_address, port=port)
    return 0


def write_csv(data):
    import csv
    data = data[0]
    c = data.conn().cursor()
    sql3_cursor = c.execute("SELECT * FROM stats")
    csv_out = csv.writer(sys.stdout)
    # write header
    csv_out.writerow([d[0] for d in sql3_cursor.description])
    # write data
    for result in sql3_cursor:
        csv_out.writerow(result)


def rename_columns(row, name_mapping):
    """
    Renames the columns in a row based on the mapping.
    If a column name is not found in the mapping, keep the old name
    :param row:
    :param name_mapping:
    :return: updated row
    """
    keys = list(row.keys())
    for k in keys:
        new_key = name_mapping.get(k, k)
        if new_key == k:
            continue
        row[new_key] = row.pop(k)
    return row


def write_table(args, data, dirs, pr):
    from tabulate import TableFormat, Line, DataRow, tabulate

    KleeTable = TableFormat(lineabove=Line("-", "-", "-", "-"),
                            linebelowheader=Line("-", "-", "-", "-"),
                            linebetweenrows=None,
                            linebelow=Line("-", "-", "-", "-"),
                            headerrow=DataRow("|", "|", "|"),
                            datarow=DataRow("|", "|", "|"),
                            padding=0,
                            with_header_hide=None)

    if len(data) > 1:
        dirs = stripCommonPathPrefix(dirs)
    # attach the stripped path
    data = list(zip(dirs, data))

    # build the main body of the table
    table = dict()
    for i, (path, records) in enumerate(data):
        stats = records.aggregateRecords()
        # Get raw row
        single_row = records.getLastRecord()
        if single_row is None:
            # print("Error reading SQLite file (probably corrupt)")
            # continue
            single_row = {}
        single_row['Path'] = path
        single_row.update(stats)

        # Extend row with additional entries
        single_row = add_artificial_columns(single_row)
        single_row = select_columns(single_row, pr)

        for key in single_row:
            # Get an existing column or add an empty one
            column = table.setdefault(key, [])

            # Resize the column if needed
            missing_entries = i - len(column) - 1
            if missing_entries > 0:
                # Append empty entries for this column
                column.extend([None]* missing_entries)

            # add the value
            column.append(single_row[key])

    # Add missing entries if needed
    max_len = len(data)
    for c_name in table:
        c_length = len(table[c_name])
        if  c_length < max_len:
            table[c_name].extend([None] * (max_len - c_length))

    # Rename columns
    name_mapping = dict()
    for entry in Legend:
        name_mapping[entry[2]] = entry[0]
    table = rename_columns(table, name_mapping)

    # Add a summary row
    if max_len > 1:
        # calculate the total
        for k in table:
            if k == "Path": # Skip path
                continue
            # TODO: this is a bit bad but ... . In a nutshell, if the name of a column starts or ends with certain
            #  pattern change the summary function.
            if k.startswith("Avg") or k.endswith("(%)"):
                total = sum([e for e in table[k] if e is not None])/max_len
            elif k.startswith("Max"):
                total = max([e for e in table[k] if e is not None])
            else:
                total = sum([e for e in table[k] if e is not None])
            table[k].append(total)

        table['Path'].append('Total ({0})'.format(max_len))

    # Prepare the order of the header: start to order entries according to the order in legend and add the unknown entries at the end
    headers = ["Path"]
    available_headers = list(table.keys())
    for entry in Legend:
        l_name = entry[0]
        if l_name in available_headers:
            headers.append(l_name)
            available_headers.remove(l_name)
    headers += available_headers

    # Make sure we keep the correct order of the column entries
    final_table = collections.OrderedDict()
    for column in headers:
        final_table[column] = table[column]
    table = final_table

    # Output table
    if args.tableFormat != 'klee':
        print(tabulate(
            table, headers='keys',
            tablefmt=args.tableFormat,
            floatfmt='.{p}f'.format(p=2),
            numalign='right', stralign='center'))
    else:
        stream = tabulate(
            table, headers='keys',
            tablefmt=KleeTable,
            floatfmt='.{p}f'.format(p=2),
            numalign='right', stralign='center')
        # add a line separator before the total line
        if len(data) > 1:
            stream = stream.splitlines()
            stream.insert(-2, stream[-1])
            stream = '\n'.join(stream)
        print(stream)


def main():
    tabulate_available = False
    epilog = ""

    try:
        from tabulate import tabulate, _table_formats
        epilog = 'LEGEND\n' + tabulate([(f[:2]) for f in Legend])

        tabulate_available = True
    except:
        pass

    parser = argparse.ArgumentParser(
        description='output statistics logged by klee',
        epilog=epilog,
        formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('dir', nargs='+', help='klee output directory')

    if tabulate_available:
        parser.add_argument('--table-format',
                              choices=['klee'] + list(_table_formats.keys()),
                              dest='tableFormat', default='klee',
                              help='Table format for the summary.')

    parser.add_argument('--to-csv',
                          action='store_true', dest='toCsv',
                          help='Output stats as comma-separated values (CSV)')
    parser.add_argument('--grafana',
                          action='store_true', dest='grafana',
                          help='Start a grafana web server')
    parser.add_argument('--grafana-host', dest='grafana_host',
                          help='IP address grafana web server should listen to',
                          default="127.0.0.1")
    parser.add_argument('--grafana-port', dest='grafana_port', type=int,
                          help='Port grafana web server should listen to',
                          default=5000)

    # argument group for controlling output verboseness
    pControl = parser.add_mutually_exclusive_group(required=False)
    pControl.add_argument('--print-all',
                          action='store_true', dest='pAll',
                          help='Print all available information.')
    pControl.add_argument('--print-rel-times',
                          action='store_true', dest='pRelTimes',
                          help='Print only values of measured times. '
                          'Values are relative to the measured system '
                          'execution time.')
    pControl.add_argument('--print-abs-times',
                          action='store_true', dest='pAbsTimes',
                          help='Print only values of measured times. '
                          'Absolute values (in seconds) are printed.')
    pControl.add_argument('--print-more',
                          action='store_true', dest='pMore',
                          help='Print extra information (needed when '
                          'monitoring an ongoing run).')

    args = parser.parse_args()


    # get print controls
    pr = 'NONE'
    if args.pAll:
        pr = 'all'
    elif args.pRelTimes:
        pr = 'reltime'
    elif args.pAbsTimes:
        pr = 'abstime'
    elif args.pMore:
        pr = 'more'

    dirs = getKleeOutDirs(args.dir)
    if len(dirs) == 0:
        print('no klee output dir found', file=sys.stderr)
        exit(1)

    if args.grafana:
      return grafana(dirs, args.grafana_host, args.grafana_port)

    # Filter non-existing files, useful for star operations
    valid_log_files = [getLogFile(f) for f in dirs if os.path.isfile(getLogFile(f))]

    # read contents from every run.stats file into LazyEvalList
    data = [LazyEvalList(d) for d in valid_log_files]

    if args.toCsv:
        write_csv(data)
        return

    if tabulate_available:
        write_table(args, data, dirs, pr)
        return

    print('Error: Package "tabulate" required for table formatting. '
          'Please install it using "pip" or your package manager.'
          'You can still use --grafana and --to-csv without tabulate.',
          file=sys.stderr)
    sys.exit(1)



if __name__ == '__main__':
    main()
