"""
Log Quantity Abstract Interfaces
--------------------------------
.. autoclass:: LogQuantity
.. autoclass:: PostLogQuantity
.. autoclass:: MultiLogQuantity
.. autoclass:: MultiPostLogQuantity
Log Manager
-----------
.. autoclass:: LogManager
.. autofunction:: add_run_info
Built-in Log General-Purpose Quantities
---------------------------------------
.. autoclass:: IntervalTimer
.. autoclass:: LogUpdateDuration
.. autoclass:: EventCounter
.. autoclass:: TimestepCounter
.. autoclass:: StepToStepDuration
.. autoclass:: TimestepDuration
.. autoclass:: InitTime
.. autoclass:: WallTime
.. autoclass:: ETA
.. autoclass:: MemoryHwm
.. autoclass:: GCStats
.. autofunction:: add_general_quantities
Built-in Log Simulation-Related Quantities
------------------------------------------
.. autoclass:: SimulationTime
.. autoclass:: Timestep
.. autofunction:: set_dt
.. autofunction:: add_simulation_quantities
Internal stuff that is only here because the documentation tool wants it
------------------------------------------------------------------------
.. autoclass:: _SubTimer
"""
__copyright__ = "Copyright (C) 2009-2013 Andreas Kloeckner"
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
import logpyle.version
__version__ = logpyle.version.VERSION_TEXT
import logging
import sys
logger = logging.getLogger(__name__)
from dataclasses import dataclass
from sqlite3 import Connection
from time import monotonic as time_monotonic
from typing import (TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List,
Optional, Sequence, TextIO, Tuple, Type, Union, cast)
from pymbolic.compiler import CompiledExpression # type: ignore[import-untyped]
from pymbolic.primitives import Expression # type: ignore[import-untyped]
from pytools.datatable import DataTable
if TYPE_CHECKING and not getattr(sys, "_BUILDING_SPHINX_DOCS", False):
import mpi4py
# {{{ abstract logging interface
[docs]
class LogQuantity:
"""A source of a loggable scalar that is gathered at the start of each time step.
Quantity values are gathered in :meth:`LogManager.tick_before`.
.. automethod:: __init__
.. automethod:: tick
.. autoproperty:: default_aggregator
.. automethod:: __call__
"""
sort_weight = 0
[docs]
def __init__(self, name: str, unit: Optional[str] = None,
description: Optional[str] = None) -> None:
"""Create a new quantity.
Parameters
----------
name
Quantity name.
unit
Quantity unit.
description
Quantity description.
"""
self.name = name
self.unit = unit
self.description = description
@property
def default_aggregator(self) -> None:
"""Default rank aggregation function."""
return None
[docs]
def tick(self) -> None:
"""Perform updates required at every :class:`LogManager` tick."""
pass
[docs]
def __call__(self) -> Any:
"""Return the current value of the diagnostic represented by this
:class:`LogQuantity` or None if no value is available.
This is only called if the invocation interval calls for it.
"""
raise NotImplementedError
[docs]
class PostLogQuantity(LogQuantity):
"""A source of a loggable scalar that is gathered after each time step.
Quantity values are gathered in :meth:`LogManager.tick_after`.
.. automethod:: __init__
.. automethod:: tick
.. autoproperty:: default_aggregator
.. automethod:: __call__
.. automethod:: prepare_for_tick
"""
sort_weight = 0
[docs]
def prepare_for_tick(self) -> None:
"""Perform (optional) update at :meth:`LogManager.tick_before`."""
pass
[docs]
class MultiLogQuantity:
"""A source of a list of loggable scalars gathered at the start of each time
step.
Quantity values are gathered in :meth:`LogManager.tick_before`.
.. automethod:: __init__
.. automethod:: tick
.. autoproperty:: default_aggregators
.. automethod:: __call__
"""
sort_weight = 0
[docs]
def __init__(self, names: List[str],
units: Optional[Sequence[Optional[str]]] = None,
descriptions: Optional[Sequence[Optional[str]]] = None) -> None:
"""Create a new quantity.
Parameters
----------
names
List of quantity names.
units
List of quantity units.
descriptions
List of quantity descriptions.
"""
self.names = names
if units is None:
self.units: Sequence[Optional[str]] = len(names) * [None]
else:
self.units = units
if descriptions is None:
self.descriptions: Sequence[Optional[str]] = len(names) * [None]
else:
self.descriptions = descriptions
@property
def default_aggregators(self) -> List[None]:
"""List of default aggregators."""
return [None] * len(self.names)
[docs]
def tick(self) -> None:
"""Perform updates required at every :class:`LogManager` tick."""
pass
[docs]
def __call__(self) -> Iterable[Optional[float]]:
"""Return an iterable of the current values of the diagnostic represented
by this :class:`MultiLogQuantity`.
This is only called if the invocation interval calls for it.
"""
raise NotImplementedError
[docs]
class MultiPostLogQuantity(MultiLogQuantity, PostLogQuantity):
"""A source of a list of loggable scalars gathered after each time step.
Quantity values are gathered in :meth:`LogManager.tick_after`.
.. automethod:: __init__
.. automethod:: tick
.. autoproperty:: default_aggregators
.. automethod:: __call__
.. automethod:: prepare_for_tick
"""
pass
class DtConsumer:
def __init__(self) -> None:
self.dt: Optional[float] = None
def set_dt(self, dt: Optional[float]) -> None:
self.dt = dt
class TimeTracker(DtConsumer):
def __init__(self, start: float = 0) -> None:
DtConsumer.__init__(self)
self.t = start
def tick(self) -> None:
self.t += cast(float, self.dt)
class SimulationLogQuantity(PostLogQuantity, DtConsumer):
"""A source of loggable scalars that needs to know the simulation timestep."""
def __init__(self, name: str, unit: Optional[str] = None,
description: Optional[str] = None) -> None:
PostLogQuantity.__init__(self, name, unit, description)
DtConsumer.__init__(self)
class PushLogQuantity(LogQuantity):
def __init__(self, name: str, unit: Optional[str] = None,
description: Optional[str] = None) -> None:
LogQuantity.__init__(self, name, unit, description)
self.value: Optional[float] = None
def push_value(self, value: float) -> None:
if self.value is not None:
raise RuntimeError("can't push two values per cycle")
self.value = value
def __call__(self) -> Optional[float]:
v = self.value
self.value = None
return v
class CallableLogQuantityAdapter(LogQuantity):
"""Adapt a 0-ary callable as a :class:`LogQuantity`."""
def __init__(self, callable: Callable[[], float], name: str,
unit: Optional[str] = None, description: Optional[str] = None) \
-> None:
self.callable = callable
LogQuantity.__init__(self, name, unit, description)
def __call__(self) -> float:
return self.callable()
# }}}
# {{{ manager functionality
@dataclass(frozen=True)
class _GatherDescriptor:
quantity: LogQuantity
interval: int
@dataclass(frozen=True)
class _QuantityData:
unit: Optional[str]
description: Optional[str]
default_aggregator: Optional[Callable[..., Any]]
def _join_by_first_of_tuple(list_of_iterables: List[Iterable[Any]]) \
-> Generator[Tuple[int, List[Any]], None, None]:
loi = [i.__iter__() for i in list_of_iterables]
if not loi:
return
# every iterator must have >= 1 object
try:
key_vals = [next(iter) for iter in loi]
except StopIteration:
return
keys = [kv[0] for kv in key_vals]
values = [kv[1] for kv in key_vals]
target_key = max(keys)
force_advance = False
i = 0
while True:
while keys[i] < target_key or force_advance:
try:
new_key, new_value = next(loi[i])
except StopIteration:
return
assert keys[i] < new_key
keys[i] = new_key
values[i] = new_value
if new_key > target_key:
target_key = new_key
force_advance = False
i += 1
if i >= len(loi):
i = 0
if min(keys) == target_key:
yield target_key, values[:]
force_advance = True
def _get_unique_id() -> str:
from uuid import uuid1
return uuid1().hex
def _get_unique_suffix() -> str:
from datetime import datetime
return "-" + datetime.utcnow().strftime("%Y%m%d-%H%M%S")
def _set_up_schema(db_conn: Connection) -> int:
# initialize new database
db_conn.execute("""
create table quantities (
name text,
unit text,
description text,
default_aggregator blob)""")
db_conn.execute("""
create table constants (
name text,
value blob)""")
# schema_version < 2 is missing the 'rank' field.
# schema_version < 3 is missing the 'unixtime' field.
db_conn.execute("""
create table warnings (
rank integer,
step integer,
unixtime integer,
message text,
category text,
filename text,
lineno integer
)""")
# schema_version < 3 does not have the logging table
db_conn.execute("""
create table logging (
rank integer,
step integer,
unixtime integer,
level text,
message text,
filename text,
lineno integer
)""")
schema_version = 3
return schema_version
@dataclass
class _DependencyData:
name: str
qdat: _QuantityData
agg_func: Callable[..., Any]
varname: str
expr: Expression
nonlocal_agg: bool
table: Optional[DataTable] = None
@dataclass
class _WatchInfo:
parsed: Expression
expr: Expression
dep_data: List[_DependencyData]
compiled: CompiledExpression
unit: Optional[str]
format: str
@dataclass(frozen=True)
class _LogWarningInfo:
tick_count: int
time: float
message: str
category: str
filename: str
lineno: int
[docs]
class LogManager:
"""A distributed-memory-capable diagnostic time-series logging facility.
It is meant to log data from a computation, with certain log quantities
available before a cycle, and certain other ones afterwards. A timeline of
invocations looks as follows::
tick_before()
compute...
tick_after()
tick_before()
compute...
tick_after()
...
In a time-dependent simulation, each group of :meth:`tick_before`
:meth:`tick_after` calls captures data for a single time state,
namely that in which the data may have been *before* the "compute"
step. However, some data (such as the length of the timestep taken
in a time-adaptive method) may only be available *after* the completion
of the "compute..." stage, which is why :meth:`tick_after` exists.
A :class:`LogManager` logs any number of named time series of floats to
a file. Non-time-series data, in the form of constants, is also
supported and saved.
If MPI parallelism is used, the "head rank" below always refers to
rank 0.
Command line tools called :command:`runalyzer` are available for looking
at the data in a saved log.
.. automethod:: __init__
.. automethod:: save
.. automethod:: close
.. rubric:: Data retrieval
.. automethod:: get_table
.. automethod:: get_warnings
.. automethod:: get_logging
.. automethod:: get_expr_dataset
.. automethod:: get_joint_dataset
.. rubric:: Configuration
.. automethod:: capture_warnings
.. automethod:: capture_logging
.. automethod:: add_watches
.. automethod:: set_watch_interval
.. automethod:: set_constant
.. automethod:: add_quantity
.. automethod:: enable_save_on_sigterm
.. rubric:: Time Loop
.. automethod:: tick_before
.. automethod:: tick_after
"""
[docs]
def __init__(self, filename: Optional[str] = None, mode: str = "r",
mpi_comm: Optional["mpi4py.MPI.Comm"] = None,
capture_warnings: bool = True,
watch_interval: float = 1.0,
capture_logging: bool = True) -> None:
"""Initialize this log manager instance.
:arg filename: If given, the filename to which this log is bound.
If this database exists, the current state is loaded from it.
:arg mode: One of "w", "r" for write, read. "w" assumes that the
database is initially empty. May also be "wu" to indicate that
a unique filename should be chosen automatically. May also be "wo"
to indicate that the file should be overwritten.
:arg mpi_comm: An optional :class:`mpi4py.MPI.Comm` object.
If given, logs are periodically synchronized to the head node,
which then writes them out to disk.
:arg capture_warnings: Tap the Python :mod:`warnings` facility and save
warnings to the log file. Note that when multiple :class:`LogManager`
instances have warnings capture enabled, the warnings will be saved
to all instances.
:arg watch_interval: print watches every N seconds.
:arg capture_logging: Tap the Python :mod:`logging` facility and save
logging messages to the log file. Note that when multiple
:class:`LogManager` instances have logging capture enabled, the
logging messages will be saved to all instances.
"""
assert isinstance(mode, str), "mode must be a string"
assert mode in ["w", "r", "wu", "wo"], "invalid mode"
self.quantity_data: Dict[str, _QuantityData] = {}
self.last_values: Dict[str, Optional[float]] = {}
self.before_gather_descriptors: List[_GatherDescriptor] = []
self.after_gather_descriptors: List[_GatherDescriptor] = []
self.tick_count = 0
self.constants: Dict[str, object] = {}
self.last_save_time = time_monotonic()
# self-timing
self.start_time = time_monotonic()
self.t_log: float = 0
# parallel support
self.head_rank = 0
self.mpi_comm = mpi_comm
self.is_parallel = mpi_comm is not None
if mpi_comm is None:
self.rank = 0
else:
self.rank = mpi_comm.rank
self.head_rank = 0
# weakref finalization
self.weakref_finalize: Callable[..., Any] = lambda: None
# watch stuff
self.watches: List[_WatchInfo] = []
self.have_nonlocal_watches = False
# Interval between printing watches, in seconds
self.set_watch_interval(watch_interval)
# database binding
import sqlite3 as sqlite
self.sqlite_filename: Optional[str] = None
if filename is None:
file_base = ":memory:"
file_extension = ""
else:
import os
file_base, file_extension = os.path.splitext(filename)
if self.is_parallel:
file_base += "-rank%d" % self.rank
while True:
suffix = ""
if mode == "wu" and not file_base == ":memory:":
if self.is_parallel:
assert self.mpi_comm
suffix = self.mpi_comm.bcast(_get_unique_suffix(),
root=self.head_rank)
else:
suffix = _get_unique_suffix()
filename = file_base + suffix + file_extension
if not file_base == ":memory:":
self.sqlite_filename = filename
if mode == "wo":
import os
try:
os.remove(filename)
except OSError:
pass
self.db_conn = sqlite.connect(filename, timeout=30)
self.mode = mode
try:
self.db_conn.execute("select * from quantities;")
except sqlite.OperationalError:
# we're building a new database
if mode == "r":
raise RuntimeError("Log database '%s' not found" % filename)
self.schema_version = _set_up_schema(self.db_conn)
self.set_constant("schema_version", self.schema_version)
self.set_constant("is_parallel", self.is_parallel)
# set globally unique run_id
if self.is_parallel:
assert self.mpi_comm
self.set_constant("unique_run_id",
self.mpi_comm.bcast(_get_unique_id(),
root=self.head_rank))
else:
self.set_constant("unique_run_id", _get_unique_id())
if self.is_parallel:
assert self.mpi_comm
self.set_constant("rank_count", self.mpi_comm.Get_size())
else:
self.set_constant("rank_count", 1)
else:
# we've opened an existing database
if mode == "w":
raise RuntimeError("Log database '%s' already exists" % filename)
if mode == "wu":
# try again with a new suffix
continue
if mode == "wo":
# try again, someone might have created a file with the same name
continue
self._load()
break
# {{{ warnings/logging capture
self.warning_data: List[_LogWarningInfo] = []
self.old_showwarning: Optional[Callable[..., Any]] = None
if capture_warnings and self.mode[0] == "w":
self.capture_warnings(True)
self.logging_data: List[_LogWarningInfo] = []
self.logging_handler: Optional[logging.Handler] = None
if capture_logging and self.mode[0] == "w":
self.capture_logging(True)
# }}}
# {{{ atexit handling
import weakref
# Make sure the database gets saved at exit.
# Note that this does not handle all possible exit modes:
# - SIGINT (i.e., Ctrl-C): automatically handled
# - SIGKILL (i.e., kill -9), os._exit(), Python fatal internal error:
# impossible to capture
# - SIGTERM (i.e., kill): Users must handle the signal explicitly
# (e.g. via 'logmgr.enable_save_on_sigterm()')
self.weakref_finalize = weakref.finalize(self, self.save)
# FIXME: The weakref keeps the log manager alive until close() is
# called or the application exits.
# }}}
def __del__(self) -> None:
self.weakref_finalize()
[docs]
def enable_save_on_sigterm(self) -> Union[Callable[..., Any], int, None]:
"""Enable saving the log on SIGTERM.
:returns: The previous SIGTERM handler.
"""
# See
# https://mail.python.org/pipermail/python-ideas/2016-February/038471.html
# on why this only captures SIGTERM.
import signal
def sighndl(_signo: int, _stackframe: Any) -> None:
self.weakref_finalize()
import sys
sys.exit(_signo)
return signal.signal(signal.SIGTERM, sighndl)
[docs]
def capture_warnings(self, enable: bool = True) -> None:
"""Enable or disable :mod:`warnings` capture."""
def _showwarning(message: Union[Warning, str], category: Type[Warning],
filename: str, lineno: int, file: Optional[TextIO] = None,
line: Optional[str] = None) -> None:
assert self.old_showwarning
self.old_showwarning(message, category, filename, lineno, file, line)
from time import time
self.warning_data.append(_LogWarningInfo(
tick_count=self.tick_count,
time=time(),
message=str(message),
category=str(category),
filename=filename,
lineno=lineno
))
import warnings
if enable:
if self.schema_version < 3:
raise ValueError("Warnings capture needs at least schema_version 3, "
f" got {self.schema_version}")
if self.old_showwarning is None:
self.old_showwarning = warnings.showwarning
warnings.showwarning = _showwarning
else:
from warnings import warn
warn("Warnings capture already enabled")
else:
if self.old_showwarning is None:
from warnings import warn
warn("Warnings capture already disabled")
else:
warnings.showwarning = self.old_showwarning
self.old_showwarning = None
[docs]
def capture_logging(self, enable: bool = True) -> None:
"""Enable or disable :mod:`logging` capture."""
class LogpyleLogHandler(logging.Handler):
def __init__(self, mgr: LogManager) -> None:
logging.Handler.__init__(self)
self.mgr = mgr
def emit(self, record: logging.LogRecord) -> None:
from time import time
self.mgr.logging_data.append(
_LogWarningInfo(tick_count=self.mgr.tick_count,
time=time(),
message=record.getMessage(),
category=record.levelname,
filename=record.pathname,
lineno=record.lineno))
root_logger = logging.getLogger()
if enable:
if self.schema_version < 3:
raise ValueError("Logging capture needs at least schema_version 3, "
f" got {self.schema_version}")
if self.mode[0] == "w" and self.logging_handler is None:
self.logging_handler = LogpyleLogHandler(self)
root_logger.addHandler(self.logging_handler)
elif self.logging_handler:
from warnings import warn
warn("Logging capture already enabled")
else:
if self.logging_handler:
root_logger.removeHandler(self.logging_handler)
elif self.logging_handler is None:
from warnings import warn
warn("Logging capture already disabled")
self.logging_handler = None
[docs]
def get_logging(self) -> DataTable:
"""Return a :class:`~pytools.datatable.DataTable` of :mod:`logging`
messages logged by this :class:`LogManager` instance."""
# Match the table set up by _set_up_schema
columns = ["rank", "step", "unixtime", "level", "message", "filename",
"lineno"]
result = DataTable(columns)
if self.schema_version < 3:
from warnings import warn
warn("This database lacks a 'logging' table")
return result
for row in self.db_conn.execute(
"select %s from logging" % (", ".join(columns))):
result.insert_row(row)
return result
def _load(self) -> None:
if self.mpi_comm and self.mpi_comm.rank != self.head_rank:
return
from pickle import loads
for name, value in self.db_conn.execute("select name, value from constants"):
self.constants[name] = loads(value)
self.schema_version = cast(int, self.constants.get("schema_version", 0))
self.is_parallel = bool(self.constants["is_parallel"])
for name, unit, description, def_agg in self.db_conn.execute(
"select name, unit, description, default_aggregator "
"from quantities"):
self.quantity_data[name] = _QuantityData(
unit, description, loads(def_agg))
[docs]
def close(self) -> None:
"""Close this :class:`LogManager` instance."""
if self.old_showwarning is not None:
self.capture_warnings(False)
if self.logging_handler:
self.capture_logging(False)
self.weakref_finalize()
self.save()
self.db_conn.close()
[docs]
def get_table(self, q_name: str) -> DataTable:
"""Return a :class:`~pytools.datatable.DataTable` of the data logged
for the quantity *q_name*."""
if q_name not in self.quantity_data:
raise KeyError("invalid quantity name '%s'" % q_name)
result = DataTable(
["step", "rank", "value"])
for row in self.db_conn.execute(
"select step, rank, value from %s" % q_name):
result.insert_row(row)
return result
[docs]
def get_warnings(self) -> DataTable:
"""Return a :class:`~pytools.datatable.DataTable` of warnings logged by
this :class:`LogManager` instance."""
# Match the table set up by _set_up_schema
columns = ["step", "message", "category", "filename", "lineno"]
if self.schema_version >= 2:
columns.insert(0, "rank")
if self.schema_version >= 3:
columns.insert(2, "unixtime")
result = DataTable(columns)
for row in self.db_conn.execute(
"select %s from warnings" % (", ".join(columns))):
result.insert_row(row)
return result
[docs]
def add_watches(self, watches: List[Union[str, Tuple[str, str]]]) -> None:
"""Add quantities that are printed after every time step.
:arg watches:
List of expressions to watch. Each element can either be
a string of the expression to watch, or a tuple of the expression
and a format string. In the format string, you can use the custom
fields ``{display}``, ``{value}``, and ``{unit}`` to indicate where the
watch expression, value, and unit should be printed. The default format
string for each watch is ``{display}={value:g}{unit}``.
"""
default_format = "{display}={value:g}{unit} | "
for watch in watches:
if isinstance(watch, tuple):
expr, fmt = watch
else:
expr = watch
fmt = default_format
parsed = self._parse_expr(expr)
parsed, dep_data = self._get_expr_dep_data(parsed)
if len(dep_data) == 1:
unit = dep_data[0].qdat.unit
else:
unit = None
from pytools import any
self.have_nonlocal_watches = self.have_nonlocal_watches or \
any(dd.nonlocal_agg for dd in dep_data)
from pymbolic import compile # type: ignore[import-untyped]
compiled = compile(parsed, [dd.varname for dd in dep_data])
watch_info = _WatchInfo(parsed=parsed, expr=expr, dep_data=dep_data,
compiled=compiled, unit=unit, format=fmt)
self.watches.append(watch_info)
[docs]
def set_watch_interval(self, interval: float) -> None:
"""Set the interval (in seconds) between the time watches are printed.
:arg interval: watch printing interval in seconds.
"""
self.watch_interval = interval
self.next_watch_tick = self.tick_count + 1
[docs]
def set_constant(self, name: str, value: Any) -> None:
"""Make a named, constant value available in the log.
:arg name: the name of the constant.
:arg value: the value of the constant.
"""
existed = name in self.constants
self.constants[name] = value
from pickle import dumps
value = bytes(dumps(value))
if existed:
self.db_conn.execute("update constants set value = ? where name = ?",
(value, name))
else:
self.db_conn.execute("insert into constants values (?,?)",
(name, value))
def _insert_datapoint(self, name: str, value: Optional[float]) -> None:
if value is None:
return
self.last_values[name] = value
try:
self.db_conn.execute("insert into %s values (?,?,?)" % name,
(self.tick_count, self.rank, float(value)))
except Exception:
print("while adding datapoint for '%s':" % name)
raise
def _update_t_log(self, name: str, value: float) -> None:
if value is None:
return
self.last_values[name] = value
try:
self.db_conn.execute(f"update {name} set value = {float(value)} \
where rank = {self.rank} and step = {self.tick_count}")
except Exception:
print("while adding datapoint for '%s':" % name)
raise
def _gather_for_descriptor(self, gd: _GatherDescriptor) -> None:
if self.tick_count % gd.interval == 0:
q_value = gd.quantity()
if isinstance(gd.quantity, MultiLogQuantity):
for name, value in zip(gd.quantity.names, q_value):
self._insert_datapoint(name, value)
else:
self._insert_datapoint(gd.quantity.name, q_value)
[docs]
def tick_before(self) -> None:
"""Record data points from each added :class:`LogQuantity` that
is not an instance of :class:`PostLogQuantity`. Also, invoke
:meth:`PostLogQuantity.prepare_for_tick` on :class:`PostLogQuantity`
instances.
"""
tick_start_time = time_monotonic()
for gd in self.before_gather_descriptors:
self._gather_for_descriptor(gd)
for gd in self.after_gather_descriptors:
cast(PostLogQuantity, gd.quantity).prepare_for_tick()
# For the first three ticks, force saving the log.
if self.tick_count < 3:
self.save()
self.t_log = time_monotonic() - tick_start_time
[docs]
def tick_after(self) -> None:
"""Record data points from each added :class:`LogQuantity` that
is an instance of :class:`PostLogQuantity`.
May also checkpoint data to disk.
"""
tick_start_time = time_monotonic()
for gd_lst in [self.before_gather_descriptors,
self.after_gather_descriptors]:
for gd in gd_lst:
gd.quantity.tick()
for gd in self.after_gather_descriptors:
self._gather_for_descriptor(gd)
save_interval_seconds = 10
if tick_start_time > self.last_save_time + save_interval_seconds:
self.save()
# print watches
if self.tick_count+1 >= self.next_watch_tick:
self._watch_tick()
self.t_log += time_monotonic() - tick_start_time
# Adjust log update time(s), t_log
for gd in self.after_gather_descriptors:
if isinstance(gd.quantity, LogUpdateDuration):
self._update_t_log(gd.quantity.name, gd.quantity())
self.tick_count += 1
def _save_logging(self) -> None:
for log in self.logging_data:
self.db_conn.execute(
"insert into logging values (?,?,?,?,?,?,?)",
(self.rank, log.tick_count, log.time,
log.category, log.message, log.filename,
log.lineno))
self.logging_data = []
def _save_warnings(self) -> None:
for w in self.warning_data:
self.db_conn.execute(
"insert into warnings values (?,?,?,?,?,?,?)",
(self.rank, w.tick_count, w.time, w.message,
w.category, w.filename, w.lineno))
self.warning_data = []
[docs]
def save(self) -> None:
"""Commit the current state of the log."""
if self.mode[0] != "w":
# No need to save readonly files.
return
self._save_logging()
self._save_warnings()
from sqlite3 import OperationalError
try:
self.db_conn.commit()
except OperationalError as e:
# Even when encountering a commit error, we want to continue
# running the application.
from warnings import warn
warn("encountered sqlite error during commit: %s" % e)
self.last_save_time = time_monotonic()
[docs]
def add_quantity(self, quantity: LogQuantity, interval: int = 1) -> None:
"""Add a :class:`LogQuantity` to this manager.
:arg quantity: add the specified :class:`LogQuantity`.
:arg interval: interval (in time steps) when to gather this quantity.
"""
def add_internal(name: str, unit: Optional[str], description: Optional[str],
def_agg: Optional[Callable[..., Any]]) -> None:
logger.debug("add log quantity '%s'" % name)
if name in self.quantity_data:
raise RuntimeError("cannot add the same quantity '%s' twice" % name)
self.quantity_data[name] = _QuantityData(unit, description, def_agg)
from pickle import dumps
self.db_conn.execute("""insert into quantities values (?,?,?,?)""", (
name, unit, description,
bytes(dumps(def_agg))))
self.db_conn.execute("""create table %s
(step integer, rank integer, value real)""" % name)
gd = _GatherDescriptor(quantity, interval)
if isinstance(quantity, PostLogQuantity):
gd_list = self.after_gather_descriptors
else:
gd_list = self.before_gather_descriptors
gd_list.append(gd)
gd_list.sort(key=lambda gd: gd.quantity.sort_weight)
if isinstance(quantity, MultiLogQuantity):
for name, unit, description, def_agg in zip(
quantity.names,
quantity.units,
quantity.descriptions,
quantity.default_aggregators):
add_internal(name, unit, description, def_agg)
else:
add_internal(quantity.name,
quantity.unit, quantity.description,
quantity.default_aggregator)
self.save()
[docs]
def get_expr_dataset(self, expression: Expression,
description: Optional[str] = None,
unit: Optional[str] = None) \
-> Tuple[Union[str, Any], Union[str, Any, None],
List[Tuple[int, Any]]]:
"""Prepare a time-series dataset for a given expression.
:arg expression: A :mod:`pymbolic` expression that may involve
the time-series variables and the constants in this :class:`LogManager`.
If there is data from multiple ranks for a quantity occurring in
this expression, an aggregator may have to be specified.
:returns: ``(description, unit, table)``, where *table*
is a list of tuples ``(tick_nbr, value)``.
Aggregators are specified as follows:
- ``qty.min``, ``qty.max``, ``qty.avg``, ``qty.sum``, ``qty.norm2``,
``qty.median``
- ``qty[rank_nbr]``
- ``qty.loc``
"""
parsed = self._parse_expr(expression)
parsed, dep_data = self._get_expr_dep_data(parsed)
# aggregate table data
for dd in dep_data:
table = self.get_table(dd.name)
table.sort(["step"])
dd.table = table.aggregated(["step"], # type: ignore
"value", dd.agg_func).data
# evaluate unit and description, if necessary
if unit is None:
from pymbolic import parse, substitute
unit_dict = {dd.varname: dd.qdat.unit for dd in dep_data}
from pytools import all
if all(v is not None for v in unit_dict.values()):
unit_dict = {k: parse(v) for k, v in unit_dict.items()}
unit = substitute(parsed, unit_dict)
else:
unit = None
if description is None:
description = expression
# compile and evaluate
from pymbolic import compile
compiled = compile(parsed, [dd.varname for dd in dep_data])
data = []
for key, values in _join_by_first_of_tuple(
[dd.table for dd in dep_data if dd.table]):
try:
data.append((key, compiled(*values)))
except ZeroDivisionError:
pass
return (description, unit, data)
[docs]
def get_joint_dataset(self, expressions: Sequence[Expression]) -> List[Any]:
"""Return a joint data set for a list of expressions.
:arg expressions: a list of either strings representing
expressions directly, or triples (descr, unit, expr).
In the former case, the description and the unit are
found automatically, if possible. In the latter case,
they are used as specified.
:returns: A triple ``(descriptions, units, table)``, where
*table* is a a list of ``[(tstep, (val_expr1, val_expr2,...)...]``.
"""
# dubs is a list of (desc, unit, table) triples as
# returned by get_expr_dataset
dubs = []
for expr in expressions:
if isinstance(expr, str):
dub = self.get_expr_dataset(expr)
else:
expr_descr, expr_unit, expr_str = expr
dub = self.get_expr_dataset(
expr_str,
description=expr_descr,
unit=expr_unit)
dubs.append(dub)
zipped_dubs = list(zip(*dubs))
zipped_dubs[2] = list(
_join_by_first_of_tuple(zipped_dubs[2]))
return zipped_dubs
def get_plot_data(self, expr_x: Expression, expr_y: Expression,
min_step: Optional[int] = None,
max_step: Optional[int] = None) \
-> Tuple[Tuple[Any, str, str], Tuple[Any, str, str]]:
"""Generate plot-ready data.
:returns: ``(data_x, descr_x, unit_x), (data_y, descr_y, unit_y)``
"""
(descr_x, descr_y), (unit_x, unit_y), data = \
self.get_joint_dataset([expr_x, expr_y])
if min_step is not None:
data = [(step, tup) for step, tup in data if min_step <= step]
if max_step is not None:
data = [(step, tup) for step, tup in data if step <= max_step]
stepless_data = [tup for _step, tup in data]
if stepless_data:
data_x, data_y = list(zip(*stepless_data))
else:
data_x = ()
data_y = ()
return (data_x, descr_x, unit_x), \
(data_y, descr_y, unit_y)
def write_datafile(self, filename: str, expr_x: Expression,
expr_y: Expression) -> None:
(data_x, label_x, _), (data_y, label_y, _) = self.get_plot_data(
expr_x, expr_y)
outf = open(filename, "w")
outf.write(f"# {label_x} vs. {label_y}\n")
for dx, dy in zip(data_x, data_y):
outf.write("{}\t{}\n".format(repr(dx), repr(dy)))
outf.close()
def plot_matplotlib(self, expr_x: Expression, expr_y: Expression) -> None:
from matplotlib.pyplot import plot, xlabel, ylabel
(data_x, descr_x, unit_x), (data_y, descr_y, unit_y) = \
self.get_plot_data(expr_x, expr_y)
xlabel(f"{descr_x} [{unit_x}]")
ylabel(f"{descr_y} [{unit_y}]")
plot(data_x, data_y)
# {{{ private functionality
def _parse_expr(self, expr: Expression) -> Any:
from pymbolic import parse, substitute
parsed = parse(expr)
# substitute in global constants
parsed = substitute(parsed, self.constants)
return parsed
def _get_expr_dep_data(self, parsed: Expression) \
-> Tuple[Expression, List[_DependencyData]]:
class Nth:
def __init__(self, n: int) -> None:
self.n = n
def __call__(self, lst: List[Any]) -> Any:
return lst[self.n]
import pymbolic.mapper.dependency as pmd # type: ignore[import-untyped]
deps = pmd.DependencyMapper(include_calls=False)(parsed)
# gather information on aggregation expressions
dep_data = []
from pymbolic.primitives import Lookup, Subscript, Variable
for dep_idx, dep in enumerate(deps):
nonlocal_agg = True
if isinstance(dep, Variable):
name = dep.name
if name == "math":
continue
agg_func = self.quantity_data[name].default_aggregator
if agg_func is None:
if self.is_parallel:
raise ValueError(
"must specify explicit aggregator for '%s'" % name)
agg_func = lambda lst: lst[0]
elif isinstance(dep, Lookup):
assert isinstance(dep.aggregate, Variable)
name = dep.aggregate.name
agg_name = dep.name
if agg_name == "loc":
agg_func = Nth(self.rank)
nonlocal_agg = False
elif agg_name == "min":
agg_func = min
elif agg_name == "max":
agg_func = max
elif agg_name == "avg":
try:
from statistics import fmean
agg_func = fmean
except ImportError:
# fmean is Python 3.8+ only
from statistics import mean
agg_func = mean
elif agg_name == "median":
from statistics import median
agg_func = median
elif agg_name == "sum":
agg_func = sum
elif agg_name == "norm2":
from math import sqrt
agg_func = lambda iterable: sqrt(
sum(entry**2 for entry in iterable))
else:
raise ValueError("invalid rank aggregator '%s'" % agg_name)
elif isinstance(dep, Subscript):
assert isinstance(dep.aggregate, Variable)
name = dep.aggregate.name
from pymbolic import evaluate
agg_func = Nth(evaluate(dep.index))
qdat = self.quantity_data[name]
assert agg_func
this_dep_data = _DependencyData(name=name, qdat=qdat, agg_func=agg_func,
varname="logvar%d" % dep_idx, expr=dep,
nonlocal_agg=nonlocal_agg)
dep_data.append(this_dep_data)
# substitute in the "logvar" variable names
from pymbolic import substitute, var
parsed = substitute(parsed,
{dd.expr: var(dd.varname) for dd in dep_data})
return parsed, dep_data
def _calculate_next_watch_tick(self) -> None:
ticks_per_interval = (self.tick_count
/ max(1, time_monotonic()-self.start_time)
* self.watch_interval)
self.next_watch_tick = self.tick_count + int(max(1, ticks_per_interval))
def _watch_tick(self) -> None:
"""Print the watches after a tick."""
if not self.have_nonlocal_watches and self.rank != self.head_rank:
return
data_block = {qname: self.last_values.get(qname, 0)
for qname in self.quantity_data.keys()}
if self.mpi_comm is not None and self.have_nonlocal_watches:
gathered_data = self.mpi_comm.gather(data_block, self.head_rank)
else:
gathered_data = [data_block]
if self.rank == self.head_rank:
assert gathered_data
values: Dict[str, List[Optional[float]]] = {}
for data_block in gathered_data:
for name, value in data_block.items():
values.setdefault(name, []).append(value)
def compute_watch_str(watch: _WatchInfo) -> str:
display = watch.expr
unit = watch.unit if watch.unit not in ["1", None] else ""
value = watch.compiled(
*[dd.agg_func(values[dd.name])
for dd in watch.dep_data])
try:
return f"{watch.format}".format(display=display, value=value,
unit=unit)
except ZeroDivisionError:
return f"{display}:div0"
if self.watches:
print("".join(
compute_watch_str(watch) for watch in self.watches),
flush=True)
self._calculate_next_watch_tick()
if self.mpi_comm is not None and self.have_nonlocal_watches:
self.next_watch_tick = self.mpi_comm.bcast(
self.next_watch_tick, self.head_rank)
# }}}
# }}}
# {{{ actual data loggers
[docs]
class _SubTimer:
def __init__(self, itimer: "IntervalTimer") -> None:
self.itimer = itimer
self.elapsed = 0.0
def start(self) -> "_SubTimer":
self.start_time = time_monotonic()
return self
def stop(self) -> "_SubTimer":
self.elapsed += time_monotonic() - self.start_time
del self.start_time
return self
def __enter__(self) -> None:
self.start()
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.stop()
self.submit()
def submit(self) -> None:
self.itimer.add_time(self.elapsed)
self.elapsed = 0
[docs]
class IntervalTimer(PostLogQuantity):
"""Records elapsed times supplied by the user either through
sub-timers, or by explicitly calling :meth:`add_time`.
.. automethod:: __init__
.. automethod:: get_sub_timer
.. automethod:: start_sub_timer
.. automethod:: add_time
"""
[docs]
def __init__(self, name: str, description: Optional[str] = None) -> None:
LogQuantity.__init__(self, name, "s", description)
self.elapsed: float = 0
[docs]
def get_sub_timer(self) -> _SubTimer:
return _SubTimer(self)
[docs]
def start_sub_timer(self) -> _SubTimer:
sub_timer = _SubTimer(self)
sub_timer.start()
return sub_timer
[docs]
def add_time(self, t: float) -> None:
self.start_time = time_monotonic()
self.elapsed += t
def __call__(self) -> float:
result = self.elapsed
self.elapsed = 0
return result
[docs]
class LogUpdateDuration(PostLogQuantity):
"""Records how long the last log update in :class:`LogManager` took.
.. automethod:: __init__
"""
[docs]
def __init__(self, mgr: LogManager, name: str = "t_log") -> None:
LogQuantity.__init__(self, name, "s", "Time spent updating the log")
self.log_manager = mgr
def __call__(self) -> float:
return self.log_manager.t_log
[docs]
class EventCounter(PostLogQuantity):
"""Counts events signaled by :meth:`add`.
.. automethod:: __init__
.. automethod:: add
.. automethod:: transfer
.. automethod:: pop
"""
[docs]
def __init__(self, name: str = "interval",
description: Optional[str] = None) -> None:
PostLogQuantity.__init__(self, name, "1", description)
self.events = 0
[docs]
def add(self, n: int = 1) -> None:
self.events += n
[docs]
def transfer(self, counter: "EventCounter") -> None:
self.events += counter.pop()
[docs]
def pop(self) -> int:
events = self.events
self.events = 0
return events
def prepare_for_tick(self) -> None:
self.events = 0
def __call__(self) -> int:
result = self.events
return result
def time_and_count_function(f: Callable[..., Any], timer: IntervalTimer,
counter: Optional[EventCounter] = None,
increment: int = 1) -> Callable[..., Any]:
def inner_f(*args: Any, **kwargs: Any) -> Any:
if counter is not None:
counter.add(increment)
sub_timer = timer.start_sub_timer()
try:
return f(*args, **kwargs)
finally:
sub_timer.stop().submit()
return inner_f
[docs]
class TimestepCounter(LogQuantity):
"""Counts the number of times :class:`LogManager` ticks."""
def __init__(self, name: str = "step") -> None:
LogQuantity.__init__(self, name, "1", "Timesteps")
self.steps = 0
def __call__(self) -> int:
result = self.steps
self.steps += 1
return result
[docs]
class StepToStepDuration(PostLogQuantity):
"""Records the wall time between the starts of consecutive time steps, i.e.,
the wall time between :meth:`LogManager.tick_before` of step x and
:meth:`LogManager.tick_before` of step x+1. The value stored is the value for
step x+1.
.. note::
In most cases, this quantity should approximately match ``t_step`` +
``t_log``. If it does not, it might indicate that the application
performs operations outside :meth:`LogManager.tick_before` and
:meth:`LogManager.tick_after`, or that some other time is not being
accounted for.
.. automethod:: __init__
"""
[docs]
def __init__(self, name: str = "t_2step") -> None:
PostLogQuantity.__init__(self, name, "s", "Step-to-step duration")
self.last_start_time: Optional[float] = None
self.last2_start_time: Optional[float] = None
def prepare_for_tick(self) -> None:
self.last2_start_time = self.last_start_time
self.last_start_time = time_monotonic()
def __call__(self) -> Optional[float]:
if self.last2_start_time is None or self.last_start_time is None:
return None
else:
return self.last_start_time - self.last2_start_time
[docs]
class TimestepDuration(PostLogQuantity):
"""Records the wall time between invocations of :meth:`LogManager.tick_before`
and :meth:`LogManager.tick_after`, i.e., the duration of the time step.
.. automethod:: __init__
"""
# We would like to run last, so that if log gathering takes any
# significant time, we catch that, too. (CUDA sync-on-time-taking,
# I'm looking at you.)
sort_weight = 1000
[docs]
def __init__(self, name: str = "t_step") -> None:
PostLogQuantity.__init__(self, name, "s", "Time step duration")
def prepare_for_tick(self) -> None:
self.last_start = time_monotonic()
def __call__(self) -> float:
now = time_monotonic()
assert hasattr(self, "last_start"), "tick_after called without tick_before"
result = now - self.last_start
del self.last_start
return result
[docs]
class InitTime(LogQuantity):
"""Stores the time it took for the application to initialize.
Measures the time from process start to the start of the first time step.
.. automethod:: __init__
"""
[docs]
def __init__(self, name: str = "t_init") -> None:
LogQuantity.__init__(self, name, "s", "Init time")
try:
import psutil
except ModuleNotFoundError:
from warnings import warn
warn("Measuring the init time requires the 'psutil' module.")
self.done = True
else:
self.create_time = psutil.Process().create_time()
self.done = False
def __call__(self) -> Optional[float]:
if self.done:
return None
self.done = True
from time import time
# Can't use time_monotonic() here since that does *not* return
# the time since the UNIX epoch (like time() and
# psutil.Process.create_time() do), but from another (undefined)
# reference point.
return time() - self.create_time
[docs]
class WallTime(LogQuantity):
"""Records (monotonically increasing) wall time since the quantity was
initialized.
.. automethod:: __init__
"""
[docs]
def __init__(self, name: str = "t_wall") -> None:
LogQuantity.__init__(self, name, "s", "Wall time")
self.start = time_monotonic()
def __call__(self) -> float:
return time_monotonic()-self.start
[docs]
class ETA(LogQuantity):
"""Records an estimate of how long the computation will still take.
.. automethod:: __init__
"""
[docs]
def __init__(self, total_steps: int, name: str = "t_eta") -> None:
LogQuantity.__init__(self, name, "s", "Estimated remaining duration")
self.steps = 0
self.total_steps = total_steps
self.start = time_monotonic()
def __call__(self) -> float:
fraction_done = self.steps/self.total_steps
self.steps += 1
time_spent = time_monotonic()-self.start
if fraction_done > 1e-9:
return time_spent/fraction_done-time_spent
else:
return 0
[docs]
def add_general_quantities(mgr: LogManager) -> None:
"""Add generally applicable :class:`LogQuantity` objects to *mgr*."""
mgr.add_quantity(TimestepDuration())
mgr.add_quantity(StepToStepDuration())
mgr.add_quantity(WallTime())
mgr.add_quantity(LogUpdateDuration(mgr))
mgr.add_quantity(TimestepCounter())
mgr.add_quantity(InitTime())
mgr.add_quantity(MemoryHwm())
[docs]
class SimulationTime(TimeTracker, LogQuantity):
"""Record (monotonically increasing) simulation time."""
def __init__(self, name: str = "t_sim", start: float = 0) -> None:
LogQuantity.__init__(self, name, "s", "Simulation Time")
TimeTracker.__init__(self, start)
def __call__(self) -> float:
return self.t
[docs]
class Timestep(SimulationLogQuantity):
"""Record the magnitude of the simulated time step."""
def __init__(self, name: str = "dt", unit: str = "s") -> None:
SimulationLogQuantity.__init__(self, name, unit, "Simulation Timestep")
def __call__(self) -> Optional[float]:
return self.dt
[docs]
def set_dt(mgr: LogManager, dt: float) -> None:
"""Set the simulation timestep on :class:`LogManager` ``mgr`` to ``dt``.
:arg mgr: the :class:`LogManager` instance.
:arg dt: the simulation timestep.
"""
for gd_lst in [mgr.before_gather_descriptors,
mgr.after_gather_descriptors]:
for gd in gd_lst:
if isinstance(gd.quantity, DtConsumer):
gd.quantity.set_dt(dt)
[docs]
def add_simulation_quantities(mgr: LogManager) -> None:
"""Add :class:`LogQuantity` objects relating to simulation time.
:arg mgr: the :class:`LogManager` instance.
"""
mgr.add_quantity(SimulationTime())
mgr.add_quantity(Timestep())
[docs]
def add_run_info(mgr: LogManager) -> None:
"""Add generic run metadata, such as command line, host, and time."""
try:
import psutil
except ModuleNotFoundError:
import sys
mgr.set_constant("cmdline", " ".join(sys.argv))
else:
mgr.set_constant("cmdline", " ".join(psutil.Process().cmdline()))
from socket import gethostname
mgr.set_constant("machine", gethostname())
from time import localtime, strftime, time
mgr.set_constant("date", strftime("%a, %d %b %Y %H:%M:%S %Z", localtime()))
mgr.set_constant("unixtime", time())
[docs]
class MemoryHwm(PostLogQuantity):
"""Record (monotonically increasing) memory high water mark (HWM) in MBytes."""
def __init__(self, name: str = "memory_usage_hwm") -> None:
PostLogQuantity.__init__(self, name, "MByte", "Memory High Water Mark")
import os
if os.uname().sysname == "Linux":
self.fac = 1024
elif os.uname().sysname == "Darwin":
self.fac = 1024*1024
else:
raise ValueError("MemoryHwm is only supported on Linux/Mac.")
def __call__(self) -> float:
from resource import RUSAGE_SELF, getrusage
res = getrusage(RUSAGE_SELF)
return res.ru_maxrss / self.fac
[docs]
class GCStats(MultiPostLogQuantity):
"""Record Garbage Collection statistics.
Information regarding the meaning of these values can be found at:
- https://docs.python.org/3/library/gc.html
- https://alex.dzyoba.com/blog/arc-vs-gc
.. # noqa: E501
- https://stackoverflow.com/questions/64561488/pythons-gc-get-objects-from-get-count
- https://github.com/python/cpython/blob/main/Modules/gcmodule.c
"""
def __init__(self) -> None:
names = [ # gc.isenabled():
"gc_isenabled",
# gc.get_count():
"gc_count_gen0", "gc_count_gen1", "gc_count_gen2",
# gc.get_stats():
"gc_collections_gen0", "gc_collected_gen0",
"gc_uncollectable_gen0",
"gc_collections_gen1", "gc_collected_gen1",
"gc_uncollectable_gen1",
"gc_collections_gen2", "gc_collected_gen2",
"gc_uncollectable_gen2",
]
units = ["bool",
"1", "1", "1",
"1", "1", "1", "1", "1", "1", "1", "1", "1"]
descriptions = ["Is automatic GC enabled?",
"GC count gen0", "GC count gen1", "GC count gen2",
"GC collections gen0", "GC objects collected gen0",
"GC objects uncollectable gen0",
"GC collections gen1", "GC objects collected gen1",
"GC objects uncollectable gen1",
"GC collections gen2", "GC objects collected gen2",
"GC objects uncollectable gen2",
]
assert len(names) == len(units) == len(descriptions) == 13
super().__init__(names, cast(List[Optional[str]], units),
cast(List[Optional[str]], descriptions))
def __call__(self) -> Iterable[Optional[float]]:
import gc
enabled = gc.isenabled()
counts = gc.get_count()
stats = gc.get_stats()
return [enabled,
counts[0], counts[1], counts[2],
stats[0]["collections"], stats[0]["collected"],
stats[0]["uncollectable"],
stats[1]["collections"], stats[1]["collected"],
stats[1]["uncollectable"],
stats[2]["collections"], stats[2]["collected"],
stats[2]["uncollectable"]
]
# }}}
# vim: foldmethod=marker