"""
RheoJAX Logging Context Managers.
Context managers for automatically logging operation start/end,
timing, and exception handling.
Includes specialized context managers for:
- General operations (log_operation)
- Model fitting (log_fit)
- Bayesian inference (log_bayesian)
- Data transforms (log_transform)
- File I/O (log_io)
- Pipeline stages (log_pipeline_stage)
- GUI user actions (log_gui_action)
"""
import logging
import time
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any
from rheojax.logging.logger import RheoJAXLogger
[docs]
@contextmanager
def log_operation(
logger: logging.Logger | RheoJAXLogger,
operation: str,
level: int = logging.INFO,
**context,
) -> Generator[dict[str, Any], None, None]:
"""Context manager for logging operation start/end with timing.
Automatically logs when an operation starts and completes,
including elapsed time and any exceptions that occur.
Args:
logger: Logger instance to use.
operation: Name of the operation being performed.
level: Log level for start/end messages (default INFO).
**context: Additional context to include in log messages.
Yields:
Dictionary that can be used to add additional context
to the completion log message.
Example:
>>> with log_operation(logger, "fitting", model="Maxwell"):
... result = model.fit(x, y)
14:32:05 | INFO | rheojax.models | fitting started | model=Maxwell
14:32:07 | INFO | rheojax.models | fitting completed | model=Maxwell | elapsed_seconds=2.15
Example with additional context:
>>> with log_operation(logger, "fitting", model="Maxwell") as ctx:
... result = model.fit(x, y)
... ctx["R2"] = result.r_squared
14:32:05 | INFO | rheojax.models | fitting started | model=Maxwell
14:32:07 | INFO | rheojax.models | fitting completed | model=Maxwell | R2=0.9987 | elapsed_seconds=2.15
"""
start_time = time.perf_counter()
completion_context: dict[str, Any] = {}
# Log start
logger.log(
level,
f"{operation} started",
extra={"operation": operation, "phase": "start", **context},
)
try:
yield completion_context
elapsed = time.perf_counter() - start_time
# Log successful completion
logger.log(
level,
f"{operation} completed",
extra={
"operation": operation,
"phase": "end",
"elapsed_seconds": round(elapsed, 4),
"status": "success",
**context,
**completion_context,
},
)
except Exception as e:
elapsed = time.perf_counter() - start_time
# Log failure (wrapped to avoid masking the original exception)
try:
logger.error(
f"{operation} failed: {e}",
extra={
"operation": operation,
"phase": "end",
"elapsed_seconds": round(elapsed, 4),
"status": "error",
"error_type": type(e).__name__,
"error_message": str(e),
**context,
**completion_context,
},
)
except Exception:
pass # Never mask the original exception
raise
[docs]
@contextmanager
def log_fit(
logger: logging.Logger | RheoJAXLogger,
model: str,
data_shape: tuple[int, ...] | None = None,
test_mode: str = "unknown",
level: int = logging.INFO,
**kwargs,
) -> Generator[dict[str, Any], None, None]:
"""Context manager for model fitting operations.
Specialized wrapper around log_operation for model fitting.
Args:
logger: Logger instance to use.
model: Model name or class name.
data_shape: Shape of input data (optional).
test_mode: Test mode (relaxation, creep, oscillation, flow).
level: Log level (default INFO).
**kwargs: Additional context.
Yields:
Dictionary for adding completion context (e.g., R2, parameters).
Example:
>>> with log_fit(logger, "Maxwell", data_shape=(100,), test_mode="relaxation") as ctx:
... result = model._fit(x, y)
... ctx["R2"] = result.r_squared
... ctx["n_iterations"] = result.iterations
"""
context = {"model": model, "test_mode": test_mode, **kwargs}
if data_shape is not None:
context["data_shape"] = data_shape
with log_operation(logger, "model_fit", level=level, **context) as ctx:
yield ctx
[docs]
@contextmanager
def log_bayesian(
logger: logging.Logger | RheoJAXLogger,
model: str,
num_warmup: int,
num_samples: int,
num_chains: int = 1,
level: int = logging.INFO,
**kwargs,
) -> Generator[dict[str, Any], None, None]:
"""Context manager for Bayesian inference operations.
Specialized wrapper for MCMC sampling operations.
Args:
logger: Logger instance to use.
model: Model name.
num_warmup: Number of warmup samples.
num_samples: Number of posterior samples.
num_chains: Number of MCMC chains.
level: Log level (default INFO).
**kwargs: Additional context.
Yields:
Dictionary for adding completion context (e.g., R-hat, ESS).
Example:
>>> with log_bayesian(logger, "Maxwell", num_warmup=1000, num_samples=2000) as ctx:
... result = model.fit_bayesian(x, y)
... ctx["r_hat_max"] = compute_rhat(result)
... ctx["ess_min"] = compute_ess(result)
... ctx["divergences"] = result.divergences
"""
context = {
"model": model,
"num_warmup": num_warmup,
"num_samples": num_samples,
"num_chains": num_chains,
**kwargs,
}
with log_operation(logger, "bayesian_inference", level=level, **context) as ctx:
yield ctx
[docs]
@contextmanager
def log_io(
logger: logging.Logger | RheoJAXLogger,
operation: str,
filepath: str | None = None,
level: int = logging.INFO,
**kwargs,
) -> Generator[dict[str, Any], None, None]:
"""Context manager for I/O operations.
Specialized wrapper for file read/write operations.
Args:
logger: Logger instance to use.
operation: I/O operation type (read, write, load, save).
filepath: Path to file being accessed.
level: Log level (default INFO).
**kwargs: Additional context.
Yields:
Dictionary for adding completion context (e.g., records, file_size).
Example:
>>> with log_io(logger, "read", filepath="data.csv") as ctx:
... data = read_csv(filepath)
... ctx["records"] = len(data)
... ctx["columns"] = list(data.columns)
"""
context = {"io_operation": operation, **kwargs}
if filepath is not None:
context["filepath"] = str(filepath)
with log_operation(logger, f"io_{operation}", level=level, **context) as ctx:
yield ctx
[docs]
@contextmanager
def log_pipeline_stage(
logger: logging.Logger | RheoJAXLogger,
stage: str,
pipeline_id: str | None = None,
level: int = logging.INFO,
**kwargs,
) -> Generator[dict[str, Any], None, None]:
"""Context manager for pipeline stage execution.
Args:
logger: Logger instance to use.
stage: Pipeline stage name.
pipeline_id: Optional pipeline identifier.
level: Log level (default INFO).
**kwargs: Additional context.
Yields:
Dictionary for adding completion context.
Example:
>>> with log_pipeline_stage(logger, "fit", pipeline_id="pipe_001") as ctx:
... result = pipeline.fit()
... ctx["model"] = result.model_name
"""
context = {"stage": stage, **kwargs}
if pipeline_id is not None:
context["pipeline_id"] = pipeline_id
with log_operation(logger, f"pipeline_{stage}", level=level, **context) as ctx:
yield ctx
@contextmanager
def log_gui_action(
logger: logging.Logger | RheoJAXLogger,
action: str,
widget: str | None = None,
page: str | None = None,
level: int = logging.DEBUG,
**kwargs,
) -> Generator[dict[str, Any], None, None]:
"""Context manager for GUI user interaction logging.
Specialized wrapper for logging user interactions in the GUI.
Defaults to DEBUG level since GUI actions are typically verbose.
Args:
logger: Logger instance to use.
action: Type of action (click, select, navigate, etc.).
widget: Widget identifier or class name.
page: Page where action occurred.
level: Log level (default DEBUG for GUI actions).
**kwargs: Additional context (button_id, value, etc.).
Yields:
Dictionary for adding completion context.
Example:
>>> with log_gui_action(logger, "button_click", widget="FitButton", page="FitPage") as ctx:
... self._perform_fit()
... ctx["result"] = "success"
Example without context manager (for simple actions):
>>> logger.debug("Button clicked", action="click", widget="FitButton", page="FitPage")
"""
context = {"action": action, **kwargs}
if widget is not None:
context["widget"] = widget
if page is not None:
context["page"] = page
with log_operation(logger, f"gui_{action}", level=level, **context) as ctx:
yield ctx