"""Core plotting functions for rheological data visualization.
This module provides publication-quality plotting utilities for rheological data,
with automatic plot type selection based on data characteristics.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from rheojax.core.data import RheoData
from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
# Module logger
logger = get_logger(__name__)
def _modulus_labels(
data: RheoData | None = None,
y_units: str | None = None,
) -> tuple[str, str, str]:
"""Return (storage_label, loss_label, generic_label) based on deformation mode.
For tension/bending/compression → E'/E''/Modulus; otherwise G'/G''/Modulus.
"""
deformation = None
if data is not None:
deformation = getattr(data, "deformation_mode", None) or (
data.metadata.get("deformation_mode") if data.metadata else None
)
units = y_units or (data.y_units if data else None) or "Pa"
if deformation in ("tension", "bending", "compression"):
return f"E' ({units})", f'E" ({units})', f"Modulus ({units})"
return f"G' ({units})", f'G" ({units})', f"Modulus ({units})"
# Default plotting style parameters
# Module constants for consistent plot styling
# Note: _GRID_ALPHA and _GRID_LINESTYLE are also used (with inline literals) in
# templates.py, spp_plots.py, and epm_plots.py.
_GRID_ALPHA = 0.3
_GRID_LINESTYLE = "--"
# Default figure sizes (can be overridden per-call via figsize parameter)
_DEFAULT_FIGSIZE = (8, 6)
_PUBLICATION_FIGSIZE = (6, 4.5)
_PRESENTATION_FIGSIZE = (10, 7)
# Explicit viscosity unit set for y-label heuristic in plot_time_domain()
_VISCOSITY_UNITS: frozenset[str] = frozenset(
{"Pa·s", "Pa.s", "Pa*s", "Pa s", "mPa·s", "mPa.s"}
)
DEFAULT_STYLE = {
"figure.figsize": _DEFAULT_FIGSIZE,
"font.size": 11,
"axes.labelsize": 12,
"axes.titlesize": 13,
"xtick.labelsize": 10,
"ytick.labelsize": 10,
"legend.fontsize": 10,
"lines.linewidth": 1.5,
"lines.markersize": 6,
}
PUBLICATION_STYLE = {
"figure.figsize": _PUBLICATION_FIGSIZE,
"font.size": 10,
"axes.labelsize": 11,
"axes.titlesize": 12,
"xtick.labelsize": 9,
"ytick.labelsize": 9,
"legend.fontsize": 9,
"lines.linewidth": 1.2,
"lines.markersize": 5,
}
PRESENTATION_STYLE = {
"figure.figsize": _PRESENTATION_FIGSIZE,
"font.size": 14,
"axes.labelsize": 16,
"axes.titlesize": 18,
"xtick.labelsize": 13,
"ytick.labelsize": 13,
"legend.fontsize": 13,
"lines.linewidth": 2.0,
"lines.markersize": 8,
}
def _apply_style(style: str = "default") -> dict[str, Any]:
"""Apply plotting style and return style parameters.
Args:
style: Style name ('default', 'publication', 'presentation')
Returns:
Dictionary of style parameters
"""
if style == "publication":
return PUBLICATION_STYLE.copy()
elif style == "presentation":
return PRESENTATION_STYLE.copy()
else:
return DEFAULT_STYLE.copy()
def _ensure_numpy(data: Any) -> np.ndarray:
"""Ensure data is a NumPy array for plotting.
Args:
data: Input data array
Returns:
NumPy array
"""
# VIS-011: Use hasattr check instead of isinstance(data, jnp.ndarray)
# which is unreliable on JAX >= 0.4.7
if hasattr(data, "devices"):
return np.array(data)
return np.asarray(data)
def _filter_positive(
x: np.ndarray, y: np.ndarray, warn: bool = True
) -> tuple[np.ndarray, np.ndarray]:
"""Filter out non-positive y values for log-scale plotting.
Args:
x: X-axis data
y: Y-axis data
warn: If True, warn when filtering occurs
Returns:
Tuple of (filtered_x, filtered_y) with only positive y values
"""
# VIS-P2-007: Guard against complex arrays — comparison y > 0 is undefined for complex
if np.iscomplexobj(y):
raise TypeError(
"_filter_positive received complex y array; extract real/imag first"
)
positive_mask = np.isfinite(y) & (y > 0)
n_removed = len(y) - np.sum(positive_mask)
if n_removed > 0 and warn:
if np.sum(positive_mask) == 0:
# F-022: All values non-positive — return original data so the caller
# gets something to plot rather than empty arrays that create blank plots.
# Matplotlib will mask non-positive values on log scale by itself.
import warnings
logger.warning(
"All y-values are non-positive; log-scale plot will be empty "
"(matplotlib drops non-positive values). "
"Check data scaling or use a linear-scale plot.",
)
warnings.warn(
"All y-values are non-positive; log-scale plot will be empty. "
"Check data scaling or use a linear-scale plot.",
UserWarning,
stacklevel=3,
)
return x, y
else:
import warnings
warnings.warn(
f"Removed {n_removed} non-positive values from log-scale plot. "
f"This is common for very small G'' values or measurement noise.",
UserWarning,
stacklevel=3,
)
return x[positive_mask], y[positive_mask]
[docs]
def plot_rheo_data(
data: RheoData, style: str = "default", **kwargs: Any
) -> tuple[Figure, Axes | np.ndarray]:
"""Plot RheoData with automatic plot type selection.
This function automatically selects the appropriate plot type based on
the data domain, test mode, and data characteristics.
Args:
data: RheoData object to plot
style: Plotting style ('default', 'publication', 'presentation')
**kwargs: Additional keyword arguments passed to matplotlib
Returns:
Tuple of (Figure, Axes) or (Figure, array of Axes)
Examples:
>>> time = np.linspace(0, 10, 100)
>>> stress = 1000 * np.exp(-time / 2)
>>> data = RheoData(x=time, y=stress, domain="time")
>>> fig, ax = plot_rheo_data(data)
"""
logger.debug("Generating plot", plot_type="rheo_data", style=style)
try:
_meta = data.metadata or {}
test_mode = _meta.get("test_mode", "")
# Select plot type based on domain and test mode
# VIS-P1-004: Forward deformation mode for E'/G' label selection
kwargs = dict(kwargs)
deformation_mode = kwargs.pop("deformation_mode", None)
if deformation_mode is None:
deformation_mode = getattr(data, "deformation_mode", None) or _meta.get(
"deformation_mode"
)
# VIS-P2-003: Detect frequency-domain data even when y is real (e.g., only G' stored)
is_freq_domain = getattr(data, "domain", None) == "frequency" or _meta.get(
"test_mode"
) in (
"oscillation",
"frequency_sweep",
)
if is_freq_domain or np.iscomplexobj(data.y):
# Frequency-domain data — pass deformation_mode for label selection
freq_kwargs = dict(kwargs)
if deformation_mode is not None:
freq_kwargs["deformation_mode"] = deformation_mode
result = plot_frequency_domain(
_ensure_numpy(data.x),
_ensure_numpy(data.y),
x_units=data.x_units,
y_units=data.y_units,
style=style,
**freq_kwargs,
)
elif test_mode in ("startup", "laos"):
# Startup and LAOS: time/strain vs stress (linear axes)
result = plot_time_domain(
_ensure_numpy(data.x),
_ensure_numpy(data.y),
x_units=data.x_units,
y_units=data.y_units,
style=style,
**kwargs,
)
elif test_mode in ("rotation", "flow_curve") or data.x_units in [
"1/s",
"s^-1",
]:
# Flow curve data
result = plot_flow_curve(
_ensure_numpy(data.x),
_ensure_numpy(data.y),
x_units=data.x_units,
y_units=data.y_units,
style=style,
**kwargs,
)
else:
# Time-domain data (relaxation, creep, etc.)
result = plot_time_domain(
_ensure_numpy(data.x),
_ensure_numpy(data.y),
x_units=data.x_units,
y_units=data.y_units,
style=style,
**kwargs,
)
logger.debug("Figure created", plot_type="rheo_data")
return result
except Exception as e:
logger.error(
"Failed to generate rheo_data plot",
plot_type="rheo_data",
error=str(e),
exc_info=True,
)
raise
[docs]
def plot_time_domain(
x: np.ndarray,
y: np.ndarray,
x_units: str | None = None,
y_units: str | None = None,
log_x: bool = False,
log_y: bool = False,
style: str = "default",
**kwargs: Any,
) -> tuple[Figure, Axes]:
"""Plot time-domain rheological data.
Args:
x: Independent variable (typically time)
y: Dependent variable (stress, strain, etc.)
x_units: Units for x-axis
y_units: Units for y-axis
log_x: Use logarithmic scale for x-axis
log_y: Use logarithmic scale for y-axis
style: Plotting style
**kwargs: Additional keyword arguments for matplotlib plot
Returns:
Tuple of (Figure, Axes)
"""
logger.debug("Generating plot", plot_type="time_domain", style=style)
try:
style_params = _apply_style(style)
fig, ax = plt.subplots(figsize=style_params["figure.figsize"])
# VIS-010: Set font sizes on axes directly instead of mutating
# global plt.rcParams (which permanently pollutes process state)
# Plot data
plot_kwargs = {
"linewidth": style_params["lines.linewidth"],
"marker": "o",
"markersize": style_params["lines.markersize"],
"markerfacecolor": "none",
"markeredgewidth": 1.0,
}
plot_kwargs.update(kwargs)
# VIZ-R6-003: Filter non-positive x before plotting when log_x is set,
# otherwise t=0 anchors the x-axis at 0 and breaks the log scale display.
x_plot, y_plot = x, y
if log_x:
pos_mask = np.isfinite(x) & (x > 0)
if not np.all(pos_mask) and np.any(pos_mask):
x_plot, y_plot = x[pos_mask], y[pos_mask]
ax.plot(x_plot, y_plot, **plot_kwargs)
# Set labels
x_label = f"Time ({x_units})" if x_units else "Time"
# VIS-P2-008 / F-035: Explicit unit sets prevent misclassification of edge cases
if y_units:
y_units_stripped = y_units.strip()
if y_units_stripped in ("Pa", "kPa", "MPa", "GPa"):
y_label = f"Stress ({y_units})"
elif y_units_stripped in _VISCOSITY_UNITS:
y_label = f"Viscosity ({y_units})"
elif "Pa" in y_units_stripped and (
"\u00b7s" in y_units_stripped
or ".s" in y_units_stripped
or "*s" in y_units_stripped
):
y_label = f"Viscosity ({y_units})"
else:
y_label = f"y ({y_units})"
else:
y_label = "y"
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
# Set scales
if log_x:
ax.set_xscale("log")
if log_y:
ax.set_yscale("log")
# Grid
ax.grid(True, which="both", alpha=_GRID_ALPHA, linestyle=_GRID_LINESTYLE)
fig.tight_layout()
logger.debug("Figure created", plot_type="time_domain")
return fig, ax
except Exception as e:
# VIZ-R6-002: Guard against fig not yet assigned (e.g. _apply_style raises)
_fig = locals().get("fig")
if _fig is not None:
plt.close(_fig)
logger.error(
"Failed to generate time_domain plot",
plot_type="time_domain",
error=str(e),
exc_info=True,
)
raise
[docs]
def plot_frequency_domain(
x: np.ndarray,
y: np.ndarray,
x_units: str | None = None,
y_units: str | None = None,
style: str = "default",
**kwargs: Any,
) -> tuple[Figure, np.ndarray]:
"""Plot frequency-domain rheological data (complex modulus).
For complex data, creates two subplots for G' (storage modulus) and
G'' (loss modulus). For real data, creates a single plot.
Args:
x: Frequency data
y: Complex modulus data (G* = G' + iG'')
x_units: Units for frequency axis
y_units: Units for modulus
style: Plotting style
**kwargs: Additional keyword arguments for matplotlib plot.
deformation_mode: str, optional — 'tension'/'bending'/'compression'
causes labels to show E'/E'' instead of G'/G''.
Returns:
Tuple of (Figure, np.ndarray of Axes). For complex data the array has
shape (2,); for real data it has shape (1,) so callers can always
index the result uniformly (e.g. ``axes[0]``).
"""
logger.debug("Generating plot", plot_type="frequency_domain", style=style)
try:
style_params = _apply_style(style)
# VIS-010: Set font sizes on axes directly instead of mutating
# global plt.rcParams (which permanently pollutes process state)
# VIS-P1-004: Deformation-mode aware labels (E' vs G')
kwargs = dict(kwargs)
deformation_mode = kwargs.pop("deformation_mode", None)
is_tensile = deformation_mode in ("tension", "bending", "compression")
units = y_units or "Pa"
storage_sym = "E'" if is_tensile else "G'"
loss_sym = 'E"' if is_tensile else 'G"'
if np.iscomplexobj(y):
# Complex data - plot storage and loss on separate subplots
fig, axes = plt.subplots(
2,
1,
figsize=(
style_params["figure.figsize"][0],
style_params["figure.figsize"][1] * 1.5,
),
)
# Plot kwargs
plot_kwargs = {
"linewidth": style_params["lines.linewidth"],
"marker": "o",
"markersize": style_params["lines.markersize"],
"markerfacecolor": "none",
"markeredgewidth": 1.0,
}
plot_kwargs.update(kwargs)
# VIS-P0-001 / VIS-P0-002: Strip keys that are passed explicitly to avoid
# "multiple values for keyword argument" TypeError.
# VIZ-016: Prefer user-supplied label/color over the automatic defaults.
user_storage_label = kwargs.pop("label", None)
user_color = kwargs.pop("color", None)
plot_kwargs_safe = {
k: v for k, v in plot_kwargs.items() if k not in ("label", "color")
}
# Storage modulus
x_gp, gp = _filter_positive(x, np.real(y), warn=True)
axes[0].loglog(
x_gp, gp, **plot_kwargs_safe, label=user_storage_label or storage_sym
)
axes[0].set_ylabel(f"{storage_sym} ({units})")
axes[0].grid(
True, which="both", alpha=_GRID_ALPHA, linestyle=_GRID_LINESTYLE
)
axes[0].legend()
# Loss modulus
# VIZ-R6-001: Use distinct label for loss modulus subplot
x_gpp, gpp = _filter_positive(x, np.imag(y), warn=True)
loss_label = (
f"{user_storage_label} (loss)" if user_storage_label else loss_sym
)
axes[1].loglog(
x_gpp,
gpp,
**plot_kwargs_safe,
label=loss_label,
color=user_color or "C1",
)
axes[1].set_xlabel(
f"Frequency ({x_units})" if x_units else "Frequency (rad/s)"
)
axes[1].set_ylabel(f"{loss_sym} ({units})")
axes[1].grid(
True, which="both", alpha=_GRID_ALPHA, linestyle=_GRID_LINESTYLE
)
axes[1].legend()
fig.tight_layout()
logger.debug("Figure created", plot_type="frequency_domain")
return fig, axes
else:
# Real data - single plot
fig, ax = plt.subplots(figsize=style_params["figure.figsize"])
plot_kwargs = {
"linewidth": style_params["lines.linewidth"],
"marker": "o",
"markersize": style_params["lines.markersize"],
"markerfacecolor": "none",
"markeredgewidth": 1.0,
}
plot_kwargs.update(kwargs)
x_filtered, y_filtered = _filter_positive(x, y, warn=True)
ax.loglog(x_filtered, y_filtered, **plot_kwargs)
ax.set_xlabel(f"Frequency ({x_units})" if x_units else "Frequency (rad/s)")
ax.set_ylabel(f"Modulus ({y_units})" if y_units else "Modulus (Pa)")
ax.grid(True, which="both", alpha=_GRID_ALPHA, linestyle=_GRID_LINESTYLE)
fig.tight_layout()
logger.debug("Figure created", plot_type="frequency_domain")
# F-013: Wrap single ax in an array for consistent return type with complex branch
return fig, np.array([ax])
except Exception as e:
# VIZ-R6-002: Guard against fig not yet assigned
_fig = locals().get("fig")
if _fig is not None:
plt.close(_fig)
logger.error(
"Failed to generate frequency_domain plot",
plot_type="frequency_domain",
error=str(e),
exc_info=True,
)
raise
[docs]
def plot_flow_curve(
x: np.ndarray,
y: np.ndarray,
x_units: str | None = None,
y_units: str | None = None,
x_label: str | None = None,
y_label: str | None = None,
style: str = "default",
**kwargs: Any,
) -> tuple[Figure, Axes]:
"""Plot flow curve (viscosity or stress vs shear rate).
Args:
x: Shear rate data
y: Viscosity or stress data
x_units: Units for shear rate
y_units: Units for y-axis
x_label: Custom x-axis label
y_label: Custom y-axis label
style: Plotting style
**kwargs: Additional keyword arguments for matplotlib plot
Returns:
Tuple of (Figure, Axes)
"""
logger.debug("Generating plot", plot_type="flow_curve", style=style)
try:
style_params = _apply_style(style)
fig, ax = plt.subplots(figsize=style_params["figure.figsize"])
# VIS-010: Set font sizes on axes directly instead of mutating
# global plt.rcParams (which permanently pollutes process state)
# Plot data on log-log scale
plot_kwargs = {
"linewidth": style_params["lines.linewidth"],
"marker": "o",
"markersize": style_params["lines.markersize"],
"markerfacecolor": "none",
"markeredgewidth": 1.0,
}
plot_kwargs.update(kwargs)
# Filter positive values for log-log plot
x_filtered, y_filtered = _filter_positive(x, y, warn=True)
ax.loglog(x_filtered, y_filtered, **plot_kwargs)
# Set labels
if x_label is None:
x_label = f"Shear Rate ({x_units})" if x_units else "Shear Rate (1/s)"
if y_label is None:
# VIZ-R6-005: Infer stress vs viscosity label from units
if y_units:
y_stripped = y_units.strip()
if y_stripped in ("Pa", "kPa", "MPa", "GPa"):
y_label = f"Stress ({y_units})"
else:
y_label = f"Viscosity ({y_units})"
else:
y_label = "Viscosity (Pa.s)"
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
# Grid
ax.grid(True, which="both", alpha=_GRID_ALPHA, linestyle=_GRID_LINESTYLE)
fig.tight_layout()
logger.debug("Figure created", plot_type="flow_curve")
return fig, ax
except Exception as e:
# VIZ-R6-002: Guard against fig not yet assigned
_fig = locals().get("fig")
if _fig is not None:
plt.close(_fig)
logger.error(
"Failed to generate flow_curve plot",
plot_type="flow_curve",
error=str(e),
exc_info=True,
)
raise
[docs]
def plot_residuals(
x: np.ndarray,
residuals: np.ndarray,
y_true: np.ndarray | None = None,
y_pred: np.ndarray | None = None,
x_units: str | None = None,
style: str = "default",
**kwargs: Any,
) -> tuple[Figure, Axes | np.ndarray]:
"""Plot residuals from model fitting.
If y_true and y_pred are provided, creates two subplots: one with data
and predictions, and one with residuals. Otherwise, plots only residuals.
Args:
x: Independent variable
residuals: Residual values (y_true - y_pred)
y_true: True y values (optional)
y_pred: Predicted y values (optional)
x_units: Units for x-axis
style: Plotting style
**kwargs: Additional keyword arguments for matplotlib plot
Returns:
Tuple of (Figure, Axes) or (Figure, array of Axes)
"""
logger.debug("Generating plot", plot_type="residuals", style=style)
try:
style_params = _apply_style(style)
# VIS-010: Set font sizes on axes directly instead of mutating
# global plt.rcParams (which permanently pollutes process state)
if y_true is not None and y_pred is not None:
# Two subplots: data+predictions and residuals
fig, axes = plt.subplots(
2,
1,
figsize=(
style_params["figure.figsize"][0],
style_params["figure.figsize"][1] * 1.5,
),
)
# Data and predictions
axes[0].plot(
x,
y_true,
"o",
label="Data",
markersize=style_params["lines.markersize"],
markerfacecolor="none",
markeredgewidth=1.0,
)
axes[0].plot(
x, y_pred, "-", label="Model", linewidth=style_params["lines.linewidth"]
)
axes[0].set_ylabel("y")
axes[0].legend()
axes[0].grid(True, alpha=0.3, linestyle="--")
# Residuals
axes[1].plot(
x,
residuals,
"o",
markersize=style_params["lines.markersize"],
markerfacecolor="none",
markeredgewidth=1.0,
)
axes[1].axhline(y=0, color="k", linestyle="--", linewidth=1.0)
axes[1].set_xlabel(f"x ({x_units})" if x_units else "x")
axes[1].set_ylabel("Residuals")
axes[1].grid(True, alpha=0.3, linestyle="--")
fig.tight_layout()
logger.debug("Figure created", plot_type="residuals")
return fig, axes
else:
# Single plot: residuals only
fig, ax = plt.subplots(figsize=style_params["figure.figsize"])
ax.plot(
x,
residuals,
"o",
markersize=style_params["lines.markersize"],
markerfacecolor="none",
markeredgewidth=1.0,
**kwargs,
)
ax.axhline(y=0, color="k", linestyle="--", linewidth=1.0)
ax.set_xlabel(f"x ({x_units})" if x_units else "x")
ax.set_ylabel("Residuals")
ax.grid(True, alpha=0.3, linestyle="--")
fig.tight_layout()
logger.debug("Figure created", plot_type="residuals")
return fig, ax
except Exception as e:
_fig = locals().get("fig")
if _fig is not None:
plt.close(_fig) # prevent memory leak
logger.error(
"Failed to generate residuals plot",
plot_type="residuals",
error=str(e),
exc_info=True,
)
raise
[docs]
def compute_uncertainty_band(
model_fn: Any, # callable
x_pred: np.ndarray,
popt: np.ndarray,
pcov: np.ndarray,
confidence: float = 0.95,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Compute prediction uncertainty band via error propagation.
Uses the formula: sigma_y(x) = sqrt(diag(J @ pcov @ J.T))
where J is the Jacobian of the model with respect to parameters.
For 95% confidence interval, the band is +/-1.96 * sigma_y(x).
Args:
model_fn: Model function that takes (x, params) and returns predictions.
Must be compatible with JAX autodiff for Jacobian computation.
x_pred: X values for prediction (n_points,)
popt: Optimal parameter values (n_params,)
pcov: Parameter covariance matrix (n_params x n_params)
confidence: Confidence level (default: 0.95 for 95% CI)
Returns:
Tuple of (y_fit, y_lower, y_upper) where:
- y_fit: Fitted values at x_pred
- y_lower: Lower bound of confidence interval
- y_upper: Upper bound of confidence interval
Example:
>>> def model(x, params):
... a, b = params
... return a * x + b
>>> x = np.linspace(0, 10, 50)
>>> popt = np.array([2.0, 1.0])
>>> pcov = np.array([[0.01, 0.0], [0.0, 0.05]])
>>> y_fit, y_lo, y_hi = compute_uncertainty_band(model, x, popt, pcov)
Note
----
Complex-valued ``y_fit`` is not supported for uncertainty bands.
If ``y_fit`` is complex, the function returns ``(y_fit, None, None)``
and logs a debug message. Callers should check for ``None`` before
using ``y_lower`` and ``y_upper``.
"""
logger.debug(
"Computing uncertainty band",
n_points=len(x_pred),
n_params=len(popt),
confidence=confidence,
)
try:
from scipy.stats import norm
x_pred = np.asarray(x_pred, dtype=np.float64)
popt = np.asarray(popt, dtype=np.float64)
pcov = np.asarray(pcov, dtype=np.float64)
# Compute y_fit — check complex BEFORE forcing float64 dtype,
# since np.asarray(..., dtype=float64) silently truncates complex values.
y_fit_raw = model_fn(x_pred, popt)
if np.iscomplexobj(np.asarray(y_fit_raw)):
logger.debug("Complex output detected, skipping uncertainty band")
return np.asarray(y_fit_raw), None, None
y_fit = np.asarray(y_fit_raw, dtype=np.float64)
# Compute Jacobian: J[i, j] = dy[i] / dparam[j]
# Use finite differences for robustness
n_points = len(x_pred)
n_params = len(popt)
jac = np.zeros((n_points, n_params), dtype=np.float64)
eps = np.sqrt(np.finfo(np.float64).eps)
for j in range(n_params):
params_plus = popt.copy()
params_plus[j] += eps * max(abs(popt[j]), 1.0)
y_plus = np.asarray(model_fn(x_pred, params_plus), dtype=np.float64)
params_minus = popt.copy()
params_minus[j] -= eps * max(abs(popt[j]), 1.0)
y_minus = np.asarray(model_fn(x_pred, params_minus), dtype=np.float64)
jac[:, j] = (y_plus - y_minus) / (2 * eps * max(abs(popt[j]), 1.0))
# Compute variance: var_y = diag(J @ pcov @ J.T)
try:
var_y = np.einsum("ij,jk,ik->i", jac, pcov, jac)
sigma_y = np.sqrt(np.maximum(var_y, 0.0))
except Exception:
# Fallback if einsum fails
logger.debug("Einsum failed for variance computation, returning None")
return y_fit, None, None
# VIS-P2-001: Guard against Inf/NaN from ill-conditioned covariance matrix
if not np.all(np.isfinite(sigma_y)):
logger.warning(
"Non-finite uncertainty from ill-conditioned covariance matrix"
)
return y_fit, None, None
# Compute z-score for confidence interval
z = norm.ppf(1 - (1 - confidence) / 2)
y_lower = y_fit - z * sigma_y
y_upper = y_fit + z * sigma_y
logger.debug("Uncertainty band computed successfully")
return y_fit, y_lower, y_upper
except Exception as e:
logger.error(
"Failed to compute uncertainty band",
error=str(e),
exc_info=True,
)
raise
[docs]
def plot_fit_with_uncertainty(
x_data: np.ndarray,
y_data: np.ndarray,
x_fit: np.ndarray,
y_fit: np.ndarray,
y_lower: np.ndarray | None = None,
y_upper: np.ndarray | None = None,
log_x: bool = True,
log_y: bool = True,
data_label: str = "Data",
fit_label: str = "Fit",
band_label: str = "95% CI",
x_label: str | None = None,
y_label: str | None = None,
style: str = "default",
ax: Axes | None = None,
**kwargs: Any,
) -> tuple[Figure | None, Axes]:
"""Plot data with fitted curve and optional uncertainty band.
Creates publication-quality fit plots with:
- Scatter data points
- Solid fitted curve
- Shaded uncertainty band (if y_lower/y_upper provided)
- Legend with customizable labels
Args:
x_data: Experimental x values
y_data: Experimental y values
x_fit: X values for fitted curve (can be denser than data)
y_fit: Fitted y values
y_lower: Lower bound of uncertainty band (optional)
y_upper: Upper bound of uncertainty band (optional)
log_x: Use log scale for x-axis (default: True)
log_y: Use log scale for y-axis (default: True)
data_label: Legend label for data points
fit_label: Legend label for fitted curve
band_label: Legend label for uncertainty band
x_label: X-axis label
y_label: Y-axis label
style: Plot style ('default', 'publication', 'presentation')
ax: Optional existing axes to plot on
**kwargs: Additional arguments passed to scatter/plot
Returns:
Tuple of (Figure, Axes). Figure is None if ax was provided.
Example:
>>> x = np.logspace(-1, 2, 20)
>>> y = 100 * x ** -0.5 + np.random.randn(20) * 5
>>> x_fit = np.logspace(-1, 2, 100)
>>> y_fit = 100 * x_fit ** -0.5
>>> fig, ax = plot_fit_with_uncertainty(x, y, x_fit, y_fit)
"""
logger.debug("Generating plot", plot_type="fit_with_uncertainty", style=style)
try:
# VIZ-002: Reject complex inputs early — _filter_positive raises TypeError on complex
# and log-scale operations are undefined. Callers should plot G'/G'' separately.
if np.iscomplexobj(np.asarray(y_data)) or np.iscomplexobj(np.asarray(y_fit)):
raise ValueError(
"plot_fit_with_uncertainty does not support complex data. "
"Plot G' and G'' components separately."
)
style_params = _apply_style(style)
# Create figure if no axes provided
fig = None
if ax is None:
fig, ax = plt.subplots(figsize=style_params["figure.figsize"])
x_data = _ensure_numpy(x_data)
y_data = _ensure_numpy(y_data)
x_fit = _ensure_numpy(x_fit)
y_fit = _ensure_numpy(y_fit)
# F-030: Validate uncertainty bound lengths before attempting fill_between.
# Log at error level so callers are clearly notified of the data problem.
if y_lower is not None and y_upper is not None:
if len(y_lower) != len(x_fit) or len(y_upper) != len(x_fit):
logger.error(
"Uncertainty bounds length mismatch — bounds will not be plotted",
x_fit_len=len(x_fit),
y_lower_len=len(y_lower),
y_upper_len=len(y_upper),
)
y_lower = None
y_upper = None
def scatter_fn(x, y, **kw):
return ax.scatter(x, y, **kw)
# Determine plot functions based on log scales
if log_x and log_y:
plot_fn = ax.loglog
# Filter positive values for log scale
x_data_plot, y_data_plot = _filter_positive(x_data, y_data)
x_fit_plot, y_fit_plot = _filter_positive(x_fit, y_fit, warn=False)
elif log_x:
plot_fn = ax.semilogx
# VIS-002: Filter positive x values for log-x scale
x_mask = x_data > 0
x_data_plot, y_data_plot = x_data[x_mask], y_data[x_mask]
x_mask_fit = x_fit > 0
x_fit_plot, y_fit_plot = x_fit[x_mask_fit], y_fit[x_mask_fit]
elif log_y:
plot_fn = ax.semilogy
x_data_plot, y_data_plot = _filter_positive(x_data, y_data)
x_fit_plot, y_fit_plot = _filter_positive(x_fit, y_fit, warn=False)
else:
plot_fn = ax.plot
x_data_plot, y_data_plot = x_data, y_data
x_fit_plot, y_fit_plot = x_fit, y_fit
# Plot uncertainty band first (so it's behind other elements)
if y_lower is not None and y_upper is not None:
y_lower = _ensure_numpy(y_lower)
y_upper = _ensure_numpy(y_upper)
if log_y:
# Filter positive for log scale
mask = (y_lower > 0) & (y_upper > 0) & (x_fit > 0)
ax.fill_between(
x_fit[mask],
y_lower[mask],
y_upper[mask],
alpha=0.3,
color="C0",
label=band_label,
zorder=1,
)
else:
ax.fill_between(
x_fit,
y_lower,
y_upper,
alpha=0.3,
color="C0",
label=band_label,
zorder=1,
)
# Plot fitted curve
plot_fn(
x_fit_plot,
y_fit_plot,
"-",
linewidth=style_params["lines.linewidth"],
color="C0",
label=fit_label,
zorder=2,
)
# Plot data points
scatter_fn(
x_data_plot,
y_data_plot,
s=style_params["lines.markersize"] ** 2,
facecolors="none",
edgecolors="C1",
linewidths=1.5,
label=data_label,
zorder=3,
)
# Labels and legend
if x_label:
ax.set_xlabel(x_label)
if y_label:
ax.set_ylabel(y_label)
ax.legend(loc="best")
ax.grid(True, alpha=0.3, linestyle="--")
if fig is not None:
fig.tight_layout()
logger.debug("Figure created", plot_type="fit_with_uncertainty")
return fig, ax
except Exception as e:
_fig = locals().get("fig")
if _fig is not None:
plt.close(_fig) # prevent memory leak (only when we created the figure)
logger.error(
"Failed to generate fit_with_uncertainty plot",
plot_type="fit_with_uncertainty",
error=str(e),
exc_info=True,
)
raise