Source code for logpyle.runalyzer

#! /usr/bin/env python
"""
Runalyzer Functions
--------------------------------
.. autofunction:: is_gathered
"""

import code
import sqlite3

try:
    import readline
    import rlcompleter  # noqa: F401
    HAVE_READLINE = True
except ImportError:
    HAVE_READLINE = False


import logging

logger = logging.getLogger(__name__)

from dataclasses import dataclass
from itertools import product
from sqlite3 import Connection, Cursor
from typing import (Any, Callable, Dict, Generator, List, Optional, Sequence, Set,
                    Tuple, Type, Union)

from pytools import Table


@dataclass(frozen=True)
class PlotStyle:
    dashes: Tuple[int, ...]
    color: str


PLOT_STYLES = [
        PlotStyle(dashes=dashes, color=color)
        for dashes, color in product(
            [(), (12, 2), (4, 2),  (2, 2), (2, 8)],
            ["blue", "green", "red", "magenta", "cyan"],
            )]


class RunDB:
    def __init__(self, db: Connection, interactive: bool) -> None:
        self.db = db
        self.interactive = interactive
        self.rank_agg_tables: Set[Tuple[str, Callable[..., Any]]] = set()

    def __del__(self) -> None:
        self.db.close()

    def q(self, qry: str, *extra_args: Any) -> Cursor:
        return self.db.execute(self.mangle_sql(qry), extra_args)

    def mangle_sql(self, qry: str) -> str:
        return qry

    def get_rank_agg_table(self, qty: str,
                           rank_aggregator: Callable[..., Any]) -> str:
        tbl_name = f"rankagg_{rank_aggregator}_{qty}"

        if (qty, rank_aggregator) in self.rank_agg_tables:
            return tbl_name

        logger.info("Building temporary rank aggregation table {tbl_name}.")

        self.db.execute("create temporary table %s as "
                "select run_id, step, %s(value) as value "
                "from %s group by run_id,step" % (
                    tbl_name, rank_aggregator, qty))
        self.db.execute("create index %s_run_step on %s (run_id,step)"
                % (tbl_name, tbl_name))
        self.rank_agg_tables.add((qty, rank_aggregator))
        return tbl_name

    def scatter_cursor(self, cursor: Cursor, labels: Optional[List[str]] = None,
                       *args: Any, **kwargs: Any) -> None:
        import matplotlib.pyplot as plt

        data_args = tuple(zip(*list(cursor)))
        plt.scatter(*(data_args + args), **kwargs)

        if isinstance(labels, list) and len(labels) == 2:
            plt.xlabel(labels[0])
            plt.ylabel(labels[1])
        elif labels is not None:
            raise TypeError("The 'labels' parameter must be a list with two"
                            "elements.")

        if self.interactive:
            plt.show()

    def plot_cursor(self, cursor: Cursor, labels: Optional[List[str]] = None,
                    *args: Any, **kwargs: Any) -> None:
        from matplotlib.pyplot import legend, plot, show

        auto_style = kwargs.pop("auto_style", True)

        if len(cursor.description) == 2:
            if auto_style:
                style = PLOT_STYLES[0]
                kwargs["dashes"] = style.dashes
                kwargs["color"] = style.color

            x, y = list(zip(*list(cursor)))
            p = plot(x, y, *args, **kwargs)

            if isinstance(labels, list) and len(labels) == 2:
                p[0].axes.set_xlabel(labels[0])
                p[0].axes.set_ylabel(labels[1])
            elif labels is not None:
                raise TypeError("The 'labels' parameter must be a list with two"
                                " elements.")

        elif len(cursor.description) > 2:
            small_legend = kwargs.pop("small_legend", True)

            def format_label(kv_pairs: Sequence[Tuple[str, Any]]) -> str:
                return " ".join(f"{column}:{value}"
                            for column, value in kv_pairs)
            format_label = kwargs.pop("format_label", format_label)

            def do_plot(x: List[float], y: List[float],
                        row_rest: Tuple[Any, ...]) -> None:
                my_kwargs = kwargs.copy()
                style = PLOT_STYLES[style_idx[0] % len(PLOT_STYLES)]
                if auto_style:
                    my_kwargs.setdefault("dashes", style.dashes)
                    my_kwargs.setdefault("color", style.color)

                my_kwargs.setdefault("label",
                        format_label(list(zip(
                            (col[0] for col in cursor.description[2:]),
                            row_rest))))

                plot(x, y, *args, hold=True, **my_kwargs)
                style_idx[0] += 1

            style_idx = [0]
            for my_x, my_y, rest in split_cursor(cursor):
                do_plot(my_x, my_y, rest)  # type: ignore[arg-type]

            if small_legend:
                from matplotlib.font_manager import FontProperties
                legend(pad=0.04, prop=FontProperties(size=8), loc="best",
                        labelsep=0)
        else:
            raise ValueError("invalid number of columns")

        if self.interactive:
            show()

    def print_cursor(self, cursor: Cursor) -> None:
        print(table_from_cursor(cursor))


def split_cursor(cursor: Cursor) -> Generator[
        Tuple[List[Any], List[Any], Optional[Tuple[Any, ...]]], None, None]:

    x: List[Any] = []
    y: List[Any] = []
    last_rest = None
    for row in cursor:
        row_tuple = tuple(row)
        row_rest = row_tuple[2:]

        if last_rest is None:
            last_rest = row_rest

        if row_rest != last_rest:
            yield x, y, last_rest
            del x[:]
            del y[:]

            last_rest = row_rest

        x.append(row_tuple[0])
        y.append(row_tuple[1])
    if x:
        yield x, y, last_rest


def table_from_cursor(cursor: Cursor) -> Table:
    tbl = Table()
    tbl.add_row(tuple([column[0] for column in cursor.description]))
    for row in cursor:
        tbl.add_row(row)
    return tbl


class MagicRunDB(RunDB):
    def mangle_sql(self, qry: str) -> str:
        up_qry = qry.upper()
        if "FROM" in up_qry and "$$" not in up_qry:
            return qry

        magic_columns = set()
        import re

        # should be: re.Match[Any]
        def replace_magic_column(match: Any) -> str:
            qty_name = match.group(1)
            rank_aggregator = match.group(2)

            if rank_aggregator is not None:
                rank_aggregator = rank_aggregator[1:]
                magic_columns.add((qty_name, rank_aggregator))
                return f"{rank_aggregator}_{qty_name}.value AS {qty_name}"
            else:
                magic_columns.add((qty_name, None))
                return "%s.value AS %s" % (qty_name, qty_name)

        magic_column_re = re.compile(r"\$([a-zA-Z][A-Za-z0-9_]*)(\.[a-z]*)?")
        qry, _ = magic_column_re.subn(replace_magic_column, qry)

        other_clauses = [  # noqa: F841
                "UNION",  "INTERSECT", "EXCEPT", "WHERE", "GROUP",
                "HAVING", "ORDER", "LIMIT", ";"]

        from_clause = "from runs "
        last_tbl = None
        for tbl, rank_aggregator in magic_columns:
            if rank_aggregator is not None:
                full_tbl = f"{rank_aggregator}_{tbl}"
                full_tbl_src = "{} as {}".format(
                        self.get_rank_agg_table(tbl, rank_aggregator),
                        full_tbl)

                if last_tbl is not None:
                    addendum = f" and {last_tbl}.step = {full_tbl}.step"
                else:
                    addendum = ""
            else:
                full_tbl = tbl
                full_tbl_src = tbl

                if last_tbl is not None:
                    addendum = " and {}.step = {}.step and {}.rank={}.rank".format(
                            last_tbl, full_tbl, last_tbl, full_tbl)
                else:
                    addendum = ""

            from_clause += " inner join {} on ({}.run_id = runs.id{}) ".format(
                    full_tbl_src, full_tbl, addendum)
            last_tbl = full_tbl

        def get_clause_indices(qry: str) -> Dict[str, int]:
            other_clauses = ["UNION",  "INTERSECT", "EXCEPT", "WHERE", "GROUP",
                    "HAVING", "ORDER", "LIMIT", ";"]

            result = {}
            up_qry = qry.upper()
            for clause in other_clauses:
                clause_match = re.search(r"\b%s\b" % clause, up_qry)
                if clause_match is not None:
                    result[clause] = clause_match.start()

            return result

        # add 'from'
        if "$$" in qry:
            qry = qry.replace("$$", " %s " % from_clause)
        else:
            clause_indices = get_clause_indices(qry)

            if not clause_indices:
                qry = qry+" "+from_clause
            else:
                first_clause_idx = min(clause_indices.values())
                qry = (
                        qry[:first_clause_idx]
                        + from_clause
                        + qry[first_clause_idx:])

        return qry


def make_runalyzer_symbols(db: RunDB) \
        -> Dict[str, Union[RunDB, str, None, Callable[..., Any]]]:
    return {
            "__name__": "__console__",
            "__doc__": None,
            "db": db,
            "mangle_sql": db.mangle_sql,
            "q": db.q,
            "dbplot": db.plot_cursor,
            "dbscatter": db.scatter_cursor,
            "dbprint": db.print_cursor,
            "split_cursor": split_cursor,
            "table_from_cursor": table_from_cursor,
            }


class RunalyzerConsole(code.InteractiveConsole):
    def __init__(self, db: RunDB) -> None:
        self.db = db
        code.InteractiveConsole.__init__(self,
                make_runalyzer_symbols(db))

        try:
            import numpy  # noqa: F401
            self.runsource("from numpy import *")
        except ImportError:
            pass

        try:
            import matplotlib.pyplot  # noqa
            self.runsource("from matplotlib.pyplot import *")
        except ImportError:
            pass
        except RuntimeError:
            pass

        if HAVE_READLINE:
            import atexit
            import os

            histfile = os.path.join(os.environ["HOME"], ".runalyzerhist")
            if os.access(histfile, os.R_OK):
                readline.read_history_file(histfile)
            atexit.register(readline.write_history_file, histfile)
            readline.parse_and_bind("tab: complete")

        self.last_push_result = False

    def push(self, cmdline: str) -> bool:
        if cmdline.startswith("."):
            try:
                self.execute_magic(cmdline)
            except Exception:
                import traceback
                traceback.print_exc()
        else:
            self.last_push_result = code.InteractiveConsole.push(self, cmdline)

        return self.last_push_result

    def execute_magic(self, cmdline: str) -> None:
        cmd_end = cmdline.find(" ")
        if cmd_end == -1:
            cmd = cmdline[1:]
            args = ""
        else:
            cmd = cmdline[1:cmd_end]
            args = cmdline[cmd_end+1:]

        if cmd == "help":
            print("""
Commands:
 .help        show this help message
 .q SQL       execute a (potentially mangled) query
 .constants   show a list of (constant) run properties
 .quantities  show a list of time-dependent quantities
 .warnings    show a list of warnings
 .logging     show a list of logging messages

Plotting:
 .plot SQL    plot results of (potentially mangled) query.
              result sets can be (x,y) or (x,y,descr1,descr2,...),
              in which case a new plot will be started for each
              tuple (descr1, descr2, ...)
 .scatter SQL make scatterplot results of (potentially mangled) query.
              result sets can have between two and four columns
              for (x,y,size,color).

SQL mangling, if requested ("MagicSQL"):
    select $quantity where pred(feature)

Custom SQLite aggregates:
    stddev, var, norm1, norm2

Available Python symbols:
    db: the SQLite database
    mangle_sql(query_str): mangle the SQL query string query_str
    q(query_str): get db cursor for mangled query_str
    dbplot(cursor): plot result of cursor
    dbscatter(cursor): make scatterplot result of cursor
    dbprint(cursor): print result of cursor
    split_cursor(cursor): x,y,data gather that .plot uses internally
    table_from_cursor(cursor): Create a printable table from a cursor
""")
        elif cmd == "q":
            self.db.print_cursor(self.db.q(args))

        elif cmd == "runprops" or cmd == "constants":
            cursor = self.db.db.execute("select * from runs")
            columns = [column[0] for column in cursor.description]
            columns.sort()
            for col in columns:
                print(col)
        elif cmd == "quantities":
            self.db.print_cursor(self.db.q("select * from quantities order by name"))
        elif cmd == "warnings":
            self.db.print_cursor(self.db.q("select * from warnings"))
        elif cmd == "logging":
            self.db.print_cursor(self.db.q("select * from logging"))
        elif cmd == "title":
            from pylab import title
            title(args)
        elif cmd == "plot":
            cursor = self.db.db.execute(self.db.mangle_sql(args))
            columnnames = [column[0] for column in cursor.description]
            self.db.plot_cursor(cursor, labels=columnnames)
        elif cmd == "scatter":
            cursor = self.db.db.execute(self.db.mangle_sql(args))
            columnnames = [column[0] for column in cursor.description]
            self.db.scatter_cursor(cursor, labels=columnnames)
        else:
            print("invalid magic command")


# {{{ custom aggregates

from pytools import VarianceAggregator  # noqa: E402


class Variance(VarianceAggregator):
    def __init__(self) -> None:
        VarianceAggregator.__init__(self,  # type: ignore[no-untyped-call]
                                    entire_pop=True)


class StdDeviation(Variance):
    def finalize(self) -> Optional[float]:
        result = Variance.finalize(self)  # type: ignore[no-untyped-call]

        if result is None:
            return None
        else:
            from math import sqrt
            return sqrt(result)


class Norm1:
    def __init__(self) -> None:
        self.abs_sum = 0.0

    def step(self, value: float) -> None:
        self.abs_sum += abs(value)

    def finalize(self) -> float:
        return self.abs_sum


class Norm2:
    def __init__(self) -> None:
        self.square_sum = 0.0

    def step(self, value: float) -> None:
        self.square_sum += value**2

    def finalize(self) -> float:
        from math import sqrt
        return sqrt(self.square_sum)


def my_sprintf(format: str, arg: str) -> str:
    return format % arg

# }}}


[docs] def is_gathered(conn: sqlite3.Connection) -> bool: """ Returns whether a connection to an existing database has been gathered. Parameters ---------- conn SQLite3 connection object """ # get a list of tables with the name of 'runs' res = list(conn.execute(""" SELECT name FROM sqlite_master WHERE type='table' AND name='runs' """)) assert len(res) <= 1 if len(res) == 1: return True return False
def auto_gather(filenames: List[str]) -> sqlite3.Connection: # allow for creating ungathered files. # Check if database has been gathered, if not, create one in memory # until no files have been checked, assume none have been gathered gathered = False # check if any of the provided files have been gathered for f in filenames: db = sqlite3.connect(f) if is_gathered(db): gathered = True if gathered: # gathered files should only have one file if len(filenames) > 1: raise Exception("Runalyzing multiple gathered files is not supported!!!") return sqlite3.connect(filenames[0]) # create in memory database of files to be gathered from logpyle.runalyzer_gather import (FeatureGatherer, gather_multi_file, make_name_map, scan) print("Creating an in memory database from provided files") from os.path import exists infiles = [f for f in filenames if exists(f)] # list of run features as {name: sql_type} fg = FeatureGatherer(False, None) features, dbname_to_run_id = scan(fg, infiles) fmap = make_name_map("") qmap = make_name_map("") connection = gather_multi_file(":memory:", infiles, fmap, qmap, fg, features, dbname_to_run_id) return connection # {{{ main program def make_wrapped_db( filenames: List[str], interactive: bool, mangle: bool, gather: bool = True ) -> RunDB: if gather: db = auto_gather(filenames) else: assert len(filenames) == 1, \ "Enable autogather to support multiple input files" db = sqlite3.connect(filenames[0]) db.create_aggregate("stddev", 1, StdDeviation) # type: ignore[arg-type] db.create_aggregate("var", 1, Variance) db.create_aggregate("norm1", 1, Norm1) # type: ignore[arg-type] db.create_aggregate("norm2", 1, Norm2) # type: ignore[arg-type] db.create_function("sprintf", 2, my_sprintf) from math import pow, sqrt db.create_function("sqrt", 1, sqrt) db.create_function("pow", 2, pow) if mangle: db_wrap_class: Type[RunDB] = MagicRunDB else: db_wrap_class = RunDB return db_wrap_class(db, interactive=interactive) # }}} # vim: foldmethod=marker