"""
RheoJAX JAX-Safe Logging Utilities.
Utilities for logging JAX arrays and operations without interfering
with JIT compilation or causing expensive device transfers.
NOTE: This module uses ``import jax`` directly (inside try/except ImportError
guards) for API introspection — trace detection, debug callbacks, and config
reading. These are NOT numeric imports and do not require ``safe_import_jax()``,
which sets float64 mode for array computation.
"""
import logging
from typing import Any
[docs]
def log_array_info(
arr: Any, name: str = "array", include_device: bool = True
) -> dict[str, Any]:
"""Extract loggable info from JAX/NumPy array without device transfer.
This function extracts metadata from arrays (shape, dtype, device)
without transferring array data from GPU to CPU, making it safe
to use in performance-critical code.
Args:
arr: JAX or NumPy array.
name: Name for the array in log output.
include_device: Include device information for JAX arrays.
Returns:
Dictionary with array metadata.
Example:
>>> import jax.numpy as jnp
>>> x = jnp.ones((100, 50))
>>> info = log_array_info(x, "input_data")
>>> logger.debug("Processing data", **info)
DEBUG | rheojax | Processing data | name=input_data | shape=(100, 50) | dtype=float32
"""
info = {
f"{name}_shape": getattr(arr, "shape", "unknown"),
f"{name}_dtype": str(getattr(arr, "dtype", "unknown")),
}
# Add size information
if hasattr(arr, "size"):
info[f"{name}_size"] = arr.size
# Add device info for JAX arrays (without transferring data)
if include_device and hasattr(arr, "devices"):
try:
devices = arr.devices()
if devices:
info[f"{name}_device"] = str(list(devices)[0])
except Exception as exc:
logging.getLogger(__name__).debug(
"Could not read device info for %s: %s", name, exc
)
return info
[docs]
def log_array_stats(
arr: Any,
name: str = "array",
logger: logging.Logger | None = None,
level: int = logging.DEBUG,
) -> dict[str, Any]:
"""Compute and log full array statistics.
WARNING: This function forces a device-to-host transfer for JAX arrays.
Use only for debugging at DEBUG level.
Args:
arr: JAX or NumPy array.
name: Name for the array.
logger: Logger to use (optional, for immediate logging).
level: Log level (default DEBUG).
Returns:
Dictionary with array metadata and statistics.
Example:
>>> stats = log_array_stats(residuals, "residuals", logger)
DEBUG | rheojax | Array statistics | name=residuals | min=0.001 | max=0.234 | mean=0.045
"""
import numpy as np
# Get basic info first (no device transfer)
info = log_array_info(arr, name, include_device=True)
# Only compute stats (and force device transfer) if the logger will actually emit
should_compute = logger is None or logger.isEnabledFor(level)
if should_compute:
try:
# Convert to numpy (forces transfer)
arr_np = np.asarray(arr)
info.update(
{
f"{name}_min": float(np.min(arr_np)),
f"{name}_max": float(np.max(arr_np)),
f"{name}_mean": float(np.mean(arr_np)),
f"{name}_std": float(np.std(arr_np)),
f"{name}_has_nan": bool(np.any(np.isnan(arr_np))),
f"{name}_has_inf": bool(np.any(np.isinf(arr_np))),
}
)
except Exception as e:
info[f"{name}_stats_error"] = str(e)
# Log immediately if logger provided
if logger is not None:
logger.log(level, f"Array statistics for {name}", extra=info)
return info
[docs]
def jax_safe_log(logger: logging.Logger, level: int, msg: str, **kwargs) -> None:
"""Log only if not inside JAX JIT tracing.
This function checks if we're currently being traced by JAX JIT
and skips logging if so, preventing tracing issues.
Args:
logger: Logger instance.
level: Log level.
msg: Log message.
**kwargs: Extra context to log.
Example:
>>> @jax.jit
... def my_function(x):
... jax_safe_log(logger, logging.DEBUG, "Inside JIT", value=x.shape)
... return x * 2
"""
try:
import jax.core
# Check if we're being traced using the stable trace context API
trace_ctx = getattr(jax.core, "trace_ctx", None)
if trace_ctx is not None:
is_top = getattr(trace_ctx, "is_top_level", None)
if is_top is not None and not is_top():
return # Skip logging during tracing
else:
# Fallback for older JAX versions
cur_sublevel_fn = getattr(jax.core, "cur_sublevel", None)
if cur_sublevel_fn is not None:
try:
sublevel = cur_sublevel_fn()
if hasattr(sublevel, "level") and sublevel.level > 0:
return
except (RuntimeError, AttributeError):
pass
except ImportError:
pass # JAX not available, proceed with standard logging
logger.log(level, msg, **kwargs)
[docs]
def jax_debug_log(
logger: logging.Logger, msg: str, *values: Any, level: int = logging.DEBUG
) -> None:
"""Use jax.debug.callback for logging inside JIT-compiled functions.
This allows logging from within JIT-compiled code using JAX's
debug callback mechanism.
Args:
logger: Logger instance.
msg: Log message (can include {} placeholders for values).
*values: Values to log (will be passed through debug.callback).
level: Log level (default DEBUG).
Example:
>>> @jax.jit
... def my_function(x):
... y = x * 2
... jax_debug_log(logger, "Computed y with shape {}", y.shape)
... return y
"""
try:
import jax
def _log_callback(*args):
formatted_msg = msg.format(*args) if args else msg
logger.log(level, formatted_msg)
jax.debug.callback(_log_callback, *values)
except ImportError:
# JAX not available — format with str() to avoid traced-value repr issues
try:
safe_values = tuple(str(v) for v in values) if values else ()
formatted_msg = msg.format(*safe_values) if safe_values else msg
except (IndexError, KeyError):
formatted_msg = msg
logger.log(level, formatted_msg)
[docs]
def log_jax_config(logger: logging.Logger | None = None) -> dict[str, Any]:
"""Log JAX configuration state.
Logs JAX version, available devices, default backend, and
float64 configuration.
Args:
logger: Logger to use. If provided, logs immediately.
Returns:
Dictionary with JAX configuration.
Example:
>>> log_jax_config(logger)
INFO | rheojax | JAX Configuration | jax_version=0.8.0 | devices=['gpu:0'] | float64_enabled=True
"""
config_info = {}
try:
import jax
# Read the actual x64 state (True/False), not the config descriptor
try:
float64_enabled = jax.config.x64_enabled
except AttributeError:
float64_enabled = getattr(jax.config, "jax_enable_x64", None)
config_info = {
"jax_version": jax.__version__,
"default_backend": jax.default_backend(),
"float64_enabled": float64_enabled,
}
# Get device info (single call, reuse for platform)
try:
devices = jax.devices()
config_info["devices"] = [str(d) for d in devices]
config_info["device_count"] = len(devices)
if devices:
config_info["platform"] = str(devices[0].platform)
except Exception:
config_info["devices"] = ["unavailable"]
except ImportError:
config_info["jax_available"] = False
if logger is not None:
logger.info("JAX Configuration", extra=config_info)
return config_info
[docs]
def log_numerical_issue(
logger: logging.Logger, arr: Any, name: str = "array", context: str = ""
) -> bool:
"""Check for and log numerical issues (NaN, Inf) in arrays.
Args:
logger: Logger instance.
arr: Array to check.
name: Name for the array in log output.
context: Additional context about where the issue occurred.
Returns:
True if numerical issues were found, False otherwise.
Example:
>>> if log_numerical_issue(logger, residuals, "residuals", "during fitting"):
... raise ValueError("Numerical instability detected")
"""
import numpy as np
try:
# Guard: np.asarray raises TracerArrayConversionError on JAX tracers
try:
import jax.core
if isinstance(arr, jax.core.Tracer):
logger.debug(f"Cannot check {name} for numerical issues: JAX tracer")
return False
except ImportError:
pass
arr_np = np.asarray(arr)
has_nan = bool(np.any(np.isnan(arr_np)))
has_inf = bool(np.any(np.isinf(arr_np)))
if has_nan or has_inf:
issues = []
if has_nan:
nan_count = int(np.sum(np.isnan(arr_np)))
issues.append(f"NaN ({nan_count} values)")
if has_inf:
inf_count = int(np.sum(np.isinf(arr_np)))
issues.append(f"Inf ({inf_count} values)")
logger.warning(
f"Numerical issue detected in {name}",
extra={
"array_name": name,
"issues": ", ".join(issues),
"shape": arr_np.shape,
"context": context,
"has_nan": has_nan,
"has_inf": has_inf,
},
)
return True
except Exception as e:
logger.debug(f"Could not check {name} for numerical issues: {e}")
return False
def log_device_transfer(
logger: logging.Logger, arr: Any, name: str = "array", target: str = "host"
) -> None:
"""Log when a device transfer occurs.
Useful for debugging performance issues related to GPU-CPU transfers.
Args:
logger: Logger instance.
arr: Array being transferred.
name: Name for the array.
target: Target of transfer (e.g., "host", "gpu:0").
Example:
>>> log_device_transfer(logger, result, "model_output", "host")
DEBUG | rheojax | Device transfer | array=model_output | target=host | size_mb=45.2
"""
size_bytes = getattr(arr, "nbytes", 0)
size_mb = size_bytes / (1024 * 1024)
source = "unknown"
if hasattr(arr, "devices"):
try:
devices = arr.devices()
if devices:
source = str(list(devices)[0])
except Exception as exc:
logging.getLogger(__name__).debug(
"Could not resolve device source for %s: %s", name, exc
)
logger.debug(
"Device transfer",
extra={
"array_name": name,
"source": source,
"target": target,
"size_mb": round(size_mb, 2),
"shape": getattr(arr, "shape", "unknown"),
},
)