"""
RheoJAX Performance Metrics Logging.
Utilities for timing, memory tracking, and iteration logging.
"""
import functools
import logging
import math
import time
import tracemalloc
from collections.abc import Callable
from contextlib import contextmanager
from typing import Any, TypeVar
from rheojax.logging.logger import RheoJAXLogger, get_logger
F = TypeVar("F", bound=Callable[..., Any])
LoggerType = logging.Logger | RheoJAXLogger | None
[docs]
def timed(
logger: LoggerType = None, level: int = logging.DEBUG, include_args: bool = False
) -> Callable[[F], F]:
"""Decorator to log function execution time.
Args:
logger: Logger to use. If None, uses function's module logger.
level: Log level (default DEBUG).
include_args: Include function arguments in log output.
Returns:
Decorator function.
Example:
>>> @timed()
... def compute_something(x, y):
... return x + y
>>> @timed(level=logging.INFO, include_args=True)
... def fit_model(data):
... return model.fit(data)
"""
def decorator(func: F) -> F:
@functools.wraps(func)
def wrapper(*args, **kwargs):
# Use a local to avoid mutating the shared closure cell (thread-safe)
_logger = logger if logger is not None else get_logger(func.__module__)
start = time.perf_counter()
extra = {
"function": func.__name__,
"module": func.__module__,
}
if include_args:
# Only include serializable args (skip large arrays)
safe_args = []
for arg in args:
if hasattr(arg, "shape"):
safe_args.append(f"array{arg.shape}")
elif isinstance(arg, (str, int, float, bool, type(None))):
safe_args.append(arg)
else:
safe_args.append(type(arg).__name__)
extra["args"] = safe_args
extra["kwargs"] = {
k: (
v
if isinstance(v, (str, int, float, bool, type(None)))
else type(v).__name__
)
for k, v in kwargs.items()
}
try:
result = func(*args, **kwargs)
elapsed = time.perf_counter() - start
_logger.log(
level,
f"{func.__name__} completed",
extra={
**extra,
"elapsed_seconds": round(elapsed, 6),
"elapsed_ms": round(elapsed * 1000, 3),
"status": "success",
},
)
return result
except Exception as e:
elapsed = time.perf_counter() - start
_logger.error(
f"{func.__name__} failed after {elapsed:.3f}s: {e}",
extra={
**extra,
"elapsed_seconds": round(elapsed, 6),
"status": "error",
"error_type": type(e).__name__,
},
)
raise
return wrapper # type: ignore
return decorator
[docs]
@contextmanager
def log_memory(
logger: LoggerType = None,
operation: str = "operation",
level: int = logging.DEBUG,
trace_lines: bool = False,
):
"""Context manager for tracking memory usage.
Uses tracemalloc to measure memory allocation during an operation.
Args:
logger: Logger to use.
operation: Name of operation being measured.
level: Log level (default DEBUG).
trace_lines: Include top memory-allocating lines.
Yields:
None
Example:
>>> with log_memory(logger, "large_computation"):
... result = compute_large_array()
DEBUG | rheojax.core | large_computation memory | current_mb=45.2 | peak_mb=128.5
"""
actual_logger: logging.Logger | RheoJAXLogger = (
logger if logger is not None else get_logger("rheojax.metrics")
)
# Check if tracemalloc is already running (avoid resetting another caller's peak)
already_tracing = tracemalloc.is_tracing()
if not already_tracing:
tracemalloc.start()
try:
yield
finally:
current, peak = tracemalloc.get_traced_memory()
extra = {
"operation": operation,
"current_mb": round(current / 1024 / 1024, 2),
"peak_mb": round(peak / 1024 / 1024, 2),
}
if trace_lines:
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")[:5]
extra["top_allocations"] = [
f"{stat.traceback.format()[0]}: {stat.size / 1024:.1f} KB"
for stat in top_stats
]
# Only stop if we started it
if not already_tracing:
tracemalloc.stop()
actual_logger.log(level, f"{operation} memory usage", extra=extra)
[docs]
class IterationLogger:
"""Logger for optimization iterations with rate limiting.
Logs iteration progress at configurable intervals to avoid
flooding logs during long-running optimizations.
Attributes:
logger: Logger instance.
log_every: Log every N iterations.
level: Log level.
iteration: Current iteration count.
start_time: Time when logging started.
Example:
>>> iter_logger = IterationLogger(logger, log_every=100)
>>> for i in range(1000):
... cost = optimizer.step()
... iter_logger.log(cost=cost, grad_norm=grad_norm)
DEBUG | rheojax.opt | Iteration 100 | cost=0.0234 | grad_norm=0.001
DEBUG | rheojax.opt | Iteration 200 | cost=0.0189 | grad_norm=0.0008
"""
[docs]
def __init__(
self,
logger: LoggerType = None,
log_every: int = 10,
level: int = logging.DEBUG,
operation: str = "optimization",
) -> None:
"""Initialize the iteration logger.
Args:
logger: Logger instance (creates default if None).
log_every: Log every N iterations (default 10).
level: Log level (default DEBUG).
operation: Operation name for log messages.
"""
if log_every <= 0:
raise ValueError(f"log_every must be a positive integer, got {log_every!r}")
self.logger = logger or get_logger("rheojax.optimization")
self.log_every = log_every
self.level = level
self.operation = operation
self.iteration = 0
self.start_time = time.perf_counter()
self._last_cost: float | None = None
[docs]
def log(self, cost: float | None = None, force: bool = False, **metrics) -> None:
"""Log iteration if at logging interval.
Args:
cost: Current cost/loss value.
force: Force logging regardless of interval.
**metrics: Additional metrics to log.
"""
self.iteration += 1
self._last_cost = cost
if force or self.iteration % self.log_every == 0:
elapsed = time.perf_counter() - self.start_time
iter_per_sec = self.iteration / elapsed if elapsed > 0 else 0
extra = {
"iteration": self.iteration,
"elapsed_seconds": round(elapsed, 3),
"iterations_per_second": round(iter_per_sec, 2),
}
if cost is not None:
extra["cost"] = cost
extra.update(metrics)
self.logger.log(self.level, f"Iteration {self.iteration}", extra=extra)
[docs]
def log_final(self, **metrics) -> None:
"""Log final iteration summary.
Args:
**metrics: Final metrics to include.
"""
elapsed = time.perf_counter() - self.start_time
iter_per_sec = self.iteration / elapsed if elapsed > 0 else 0
extra = {
"total_iterations": self.iteration,
"total_elapsed_seconds": round(elapsed, 3),
"average_iterations_per_second": round(iter_per_sec, 2),
}
if self._last_cost is not None:
extra["final_cost"] = self._last_cost
extra.update(metrics)
self.logger.info(f"{self.operation} completed", extra=extra)
[docs]
def reset(self) -> None:
"""Reset iteration counter and timer."""
self.iteration = 0
self.start_time = time.perf_counter()
self._last_cost = None
[docs]
class ConvergenceTracker:
"""Track and log convergence metrics for optimization.
Monitors cost progression and determines when convergence
criteria are met.
Example:
>>> tracker = ConvergenceTracker(logger, tolerance=1e-6)
>>> for i in range(1000):
... cost = optimizer.step()
... if tracker.update(cost):
... print("Converged!")
... break
"""
[docs]
def __init__(
self,
logger: LoggerType = None,
tolerance: float = 1e-6,
patience: int = 5,
min_iterations: int = 10,
) -> None:
"""Initialize the convergence tracker.
Args:
logger: Logger instance.
tolerance: Convergence tolerance for cost improvement.
patience: Number of iterations with small improvement before converged.
min_iterations: Minimum iterations before convergence check.
"""
self.logger = logger or get_logger("rheojax.optimization")
self.tolerance = tolerance
self.patience = patience
self.min_iterations = min_iterations
self.history: list[float] = []
self._small_improvement_count = 0
[docs]
def update(self, cost: float) -> bool:
"""Update with new cost and check for convergence.
Args:
cost: Current cost/loss value.
Returns:
True if convergence criteria met.
"""
self.history.append(cost)
if len(self.history) < self.min_iterations:
return False
# Check improvement (NaN costs never count as small improvement)
if len(self.history) >= 2:
improvement = abs(self.history[-2] - self.history[-1])
if math.isfinite(improvement) and improvement < self.tolerance:
self._small_improvement_count += 1
else:
self._small_improvement_count = 0
if self._small_improvement_count >= self.patience:
self.logger.info(
"Convergence achieved",
extra={
"final_cost": cost,
"iterations": len(self.history),
"last_improvement": improvement,
"tolerance": self.tolerance,
},
)
return True
return False
[docs]
def reset(self) -> None:
"""Reset the tracker."""
self.history.clear()
self._small_improvement_count = 0
@property
def improvement_rate(self) -> float | None:
"""Calculate average improvement rate.
Returns:
Average cost reduction per iteration, or None if insufficient data.
"""
if len(self.history) < 2:
return None
return (self.history[0] - self.history[-1]) / (len(self.history) - 1)