"""Visualization tools for Lattice Elasto-Plastic Models."""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import jax
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger
jax, jnp = safe_import_jax()
logger = get_logger(__name__)
def _plot_scalar_lattice(
stress: np.ndarray,
thresholds: np.ndarray,
title: str = "Lattice EPM State (Scalar)",
figsize: tuple[int, int] = (12, 5),
cmap_stress: str = "coolwarm",
cmap_thresh: str = "viridis",
) -> plt.Figure:
"""Plot scalar stress and threshold fields side-by-side.
Args:
stress: 2D array of local stress values (L, L).
thresholds: 2D array of local yield thresholds (L, L).
title: Overall figure title.
figsize: Figure size (width, height).
cmap_stress: Colormap for stress field (diverging).
cmap_thresh: Colormap for threshold field (sequential).
Returns:
Matplotlib Figure object.
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
fig.suptitle(title)
# Stress Plot
max_stress = np.max(np.abs(stress))
# VIZ-009: prevent degenerate (zero-range) colormap
if max_stress == 0.0:
max_stress = 1.0
im1 = ax1.imshow(
stress, cmap=cmap_stress, vmin=-max_stress, vmax=max_stress, origin="lower"
)
ax1.set_title(r"Stress Field $\sigma_{ij}$")
fig.colorbar(im1, ax=ax1)
ax1.set_xlabel("x")
ax1.set_ylabel("y")
# Threshold Plot
im2 = ax2.imshow(thresholds, cmap=cmap_thresh, origin="lower")
ax2.set_title(r"Yield Thresholds $\sigma_c$")
fig.colorbar(im2, ax=ax2)
ax2.set_xlabel("x")
ax2.set_ylabel("y")
fig.tight_layout()
return fig
def _plot_tensorial_lattice(
stress: np.ndarray,
thresholds: np.ndarray,
title: str = "Lattice EPM State (Tensorial)",
figsize: tuple[int, int] = (16, 4),
cmap_stress: str = "coolwarm",
cmap_thresh: str = "viridis",
) -> plt.Figure:
"""Plot tensorial stress components and threshold fields.
Args:
stress: 3D array of stress tensor (3, L, L) with [σ_xx, σ_yy, σ_xy].
thresholds: 2D array of local yield thresholds (L, L).
title: Overall figure title.
figsize: Figure size (width, height).
cmap_stress: Colormap for stress fields (diverging).
cmap_thresh: Colormap for threshold field (sequential).
Returns:
Matplotlib Figure object.
"""
fig, axes = plt.subplots(1, 4, figsize=figsize)
fig.suptitle(title)
# Find global stress scale for consistent colormaps
# VIZ-R6-007: Guard against all-zero stress (degenerate vmin=vmax=0)
max_stress = np.max(np.abs(stress))
if max_stress == 0.0:
max_stress = 1.0
# Component labels
labels = [r"$\sigma_{xx}$", r"$\sigma_{yy}$", r"$\sigma_{xy}$"]
# Plot each stress component
for i in range(3):
im = axes[i].imshow(
stress[i],
cmap=cmap_stress,
vmin=-max_stress,
vmax=max_stress,
origin="lower",
)
axes[i].set_title(labels[i])
fig.colorbar(im, ax=axes[i])
axes[i].set_xlabel("x")
axes[i].set_ylabel("y")
# Threshold plot
im_thresh = axes[3].imshow(thresholds, cmap=cmap_thresh, origin="lower")
axes[3].set_title(r"Yield Thresholds $\sigma_c$")
fig.colorbar(im_thresh, ax=axes[3])
axes[3].set_xlabel("x")
axes[3].set_ylabel("y")
fig.tight_layout()
return fig
[docs]
def plot_lattice_fields(
stress: np.ndarray | jax.Array,
thresholds: np.ndarray | jax.Array,
title: str | None = None,
figsize: tuple[int, int] | None = None,
cmap_stress: str = "coolwarm",
cmap_thresh: str = "viridis",
) -> plt.Figure:
"""Plot EPM lattice fields with auto-detection of scalar vs tensorial stress.
Automatically detects whether stress is scalar (L, L) or tensorial (3, L, L)
and dispatches to the appropriate plotting function.
Args:
stress: Either (L, L) scalar or (3, L, L) tensorial stress field.
thresholds: 2D array of local yield thresholds (L, L).
title: Overall figure title (auto-generated if None).
figsize: Figure size (width, height) (auto-selected if None).
cmap_stress: Colormap for stress field (diverging).
cmap_thresh: Colormap for threshold field (sequential).
Returns:
Matplotlib Figure object.
Raises:
ValueError: If stress shape is invalid.
"""
stress = np.array(stress)
thresholds = np.array(thresholds)
if stress.ndim == 2:
# Scalar stress field
default_title = "Lattice EPM State (Scalar)"
default_figsize = (12, 5)
return _plot_scalar_lattice(
stress,
thresholds,
title=title or default_title,
figsize=figsize or default_figsize,
cmap_stress=cmap_stress,
cmap_thresh=cmap_thresh,
)
elif stress.ndim == 3 and stress.shape[0] == 3:
# Tensorial stress field
default_title = "Lattice EPM State (Tensorial)"
default_figsize = (16, 4)
return _plot_tensorial_lattice(
stress,
thresholds,
title=title or default_title,
figsize=figsize or default_figsize,
cmap_stress=cmap_stress,
cmap_thresh=cmap_thresh,
)
else:
raise ValueError(
f"Invalid stress shape: {stress.shape}. "
"Expected (L, L) for scalar or (3, L, L) for tensorial."
)
[docs]
def animate_stress_evolution(
stress_history: np.ndarray | jax.Array,
interval: int = 50,
cmap: str = "coolwarm",
save_path: str | None = None,
) -> animation.FuncAnimation:
"""Create an animation of the stress field evolution.
Args:
stress_history: 3D array of stress history (Time, L, L).
interval: Delay between frames in milliseconds.
cmap: Colormap for stress.
save_path: If provided, save the animation to this path (e.g. 'movie.mp4').
Returns:
Matplotlib FuncAnimation object.
"""
history = np.array(stress_history)
# VIS-P1-007: guard against empty history before unpacking shape
if history.shape[0] == 0:
raise ValueError("stress_history has no frames (shape[0] == 0)")
n_frames, L, _ = history.shape
fig, ax = plt.subplots(figsize=(6, 5))
# Determine global limits for stable coloring
max_val = np.max(np.abs(history))
# VIZ-009: prevent degenerate (zero-range) colormap
if max_val == 0.0:
max_val = 1.0
im = ax.imshow(
history[0],
cmap=cmap,
vmin=-max_val,
vmax=max_val,
origin="lower",
animated=True,
)
ax.set_title("Time Step: 0")
fig.colorbar(im, ax=ax, label=r"Stress $\sigma$")
def update(frame):
im.set_array(history[frame])
ax.set_title(f"Time Step: {frame}")
return (im,)
anim = animation.FuncAnimation(
fig, update, frames=n_frames, interval=interval, blit=True
)
if save_path:
# VIZ-015: wrap save() so missing writers produce a clear error, not a cryptic traceback
try:
anim.save(save_path)
except Exception as e:
logger.error(
"Failed to save animation — check that the required writer "
"(ffmpeg for .mp4, Pillow for .gif) is installed",
save_path=str(save_path),
error=str(e),
)
raise
return anim
[docs]
def plot_tensorial_fields(
stress: np.ndarray | jax.Array,
figsize: tuple[int, int] = (15, 4),
cmap: str = "coolwarm",
ax: plt.Axes | list[plt.Axes] | None = None,
**kwargs,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot all three stress tensor components in a 3-panel layout.
Args:
stress: Stress tensor of shape (3, L, L) with [σ_xx, σ_yy, σ_xy].
figsize: Figure size (width, height).
cmap: Colormap for stress fields (diverging, centered at 0).
ax: Optional pre-existing axes (3 axes required).
**kwargs: Additional arguments passed to imshow.
Returns:
Tuple of (Figure, list of 3 Axes).
"""
stress = np.array(stress)
if stress.shape[0] != 3:
raise ValueError(f"Expected stress shape (3, L, L), got {stress.shape}")
# Create figure if axes not provided
if ax is None:
fig, axes = plt.subplots(1, 3, figsize=figsize)
else:
# VIZ-010: accept tuple in addition to list/ndarray
if not isinstance(ax, (list, tuple, np.ndarray)) or len(ax) != 3:
raise ValueError("If ax provided, must be a sequence of 3 axes")
axes = list(ax)
fig = axes[0].get_figure()
# Find global stress scale for consistent colormaps
max_stress = np.max(np.abs(stress))
# VIZ-009: prevent degenerate (zero-range) colormap
if max_stress == 0.0:
max_stress = 1.0
# Component labels with LaTeX
labels = [r"$\sigma_{xx}$", r"$\sigma_{yy}$", r"$\sigma_{xy}$"]
# Plot each component
for i in range(3):
im = axes[i].imshow(
stress[i],
cmap=cmap,
vmin=-max_stress,
vmax=max_stress,
origin="lower",
**kwargs,
)
axes[i].set_title(labels[i])
fig.colorbar(im, ax=axes[i])
axes[i].set_xlabel("x")
axes[i].set_ylabel("y")
# VIZ-004: use fig.tight_layout() instead of plt.tight_layout() (uses gcf())
fig.tight_layout()
return fig, list(axes)
[docs]
def plot_normal_stress_field(
stress: np.ndarray | jax.Array,
nu: float = 0.5,
figsize: tuple[int, int] = (6, 5),
cmap: str = "coolwarm",
ax: plt.Axes | None = None,
**kwargs,
) -> tuple[plt.Figure, plt.Axes]:
"""Plot first normal stress difference field N₁ = σ_xx - σ_yy.
Args:
stress: Stress tensor of shape (3, L, L) with [σ_xx, σ_yy, σ_xy].
nu: Poisson's ratio (not used for N₁, but kept for consistency).
figsize: Figure size (width, height).
cmap: Colormap (diverging, centered at 0).
ax: Optional pre-existing axis.
**kwargs: Additional arguments passed to imshow.
Returns:
Tuple of (Figure, Axes).
"""
stress = np.array(stress)
if stress.shape[0] != 3:
raise ValueError(f"Expected stress shape (3, L, L), got {stress.shape}")
# Compute N₁
N1 = stress[0] - stress[1]
# Create figure if not provided
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.get_figure()
# Plot with symmetric colormap centered at 0
# VIZ-R6-008: Guard against all-zero N1 (degenerate vmin=vmax=0)
max_N1 = np.max(np.abs(N1))
if max_N1 == 0.0:
max_N1 = 1.0
im = ax.imshow(N1, cmap=cmap, vmin=-max_N1, vmax=max_N1, origin="lower", **kwargs)
ax.set_title(r"$N_1 = \sigma_{xx} - \sigma_{yy}$")
fig.colorbar(im, ax=ax, label=r"$N_1$")
ax.set_xlabel("x")
ax.set_ylabel("y")
# VIZ-004: use fig.tight_layout() instead of plt.tight_layout() (uses gcf())
fig.tight_layout()
return fig, ax
[docs]
def plot_von_mises_field(
stress: np.ndarray | jax.Array,
thresholds: np.ndarray | jax.Array,
nu: float = 0.5,
figsize: tuple[int, int] = (12, 5),
ax: plt.Axes | list[plt.Axes] | None = None,
**kwargs,
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Plot von Mises effective stress and normalized yield map.
Creates a 2-panel figure:
- Left: σ_eff with viridis (sequential)
- Right: σ_eff/σ_c with RdYlGn_r centered at 1
(Green <1: elastic, Yellow ≈1: near yield, Red >1: plastic)
Args:
stress: Stress tensor of shape (3, L, L) with [σ_xx, σ_yy, σ_xy].
thresholds: Yield thresholds of shape (L, L).
nu: Poisson's ratio for plane strain constraint.
figsize: Figure size (width, height).
ax: Optional pre-existing axes (2 axes required).
**kwargs: Additional arguments passed to imshow.
Returns:
Tuple of (Figure, list of 2 Axes).
"""
stress = np.array(stress)
thresholds = np.array(thresholds)
if stress.shape[0] != 3:
raise ValueError(f"Expected stress shape (3, L, L), got {stress.shape}")
# Import von Mises function
from rheojax.utils.epm_kernels_tensorial import compute_von_mises_stress
# Reshape stress for von Mises computation: (3, L, L) -> (L, L, 3)
stress_reshaped = np.moveaxis(stress, 0, -1)
# Convert to JAX for computation
stress_jax = jnp.array(stress_reshaped)
sigma_eff = compute_von_mises_stress(stress_jax, nu)
sigma_eff = np.array(sigma_eff)
# Compute normalized stress
sigma_normalized = sigma_eff / (thresholds + 1e-12)
# Create figure if not provided
if ax is None:
fig, axes = plt.subplots(1, 2, figsize=figsize)
else:
# VIS-EPM-002: Accept tuple in addition to list/ndarray — consistent
# with plot_tensorial_fields. `fig, (ax1, ax2) = plt.subplots(1, 2)`
# produces a tuple, which was previously rejected.
if not isinstance(ax, (list, tuple, np.ndarray)) or len(ax) != 2:
raise ValueError("If ax provided, must be list/tuple/array of 2 axes")
axes = ax
fig = axes[0].get_figure()
# Left panel: σ_eff with viridis (sequential)
im1 = axes[0].imshow(sigma_eff, cmap="viridis", origin="lower", **kwargs)
axes[0].set_title(r"von Mises $\sigma_{\mathrm{eff}}$")
fig.colorbar(im1, ax=axes[0], label=r"$\sigma_{\mathrm{eff}}$")
axes[0].set_xlabel("x")
axes[0].set_ylabel("y")
# Right panel: σ_eff/σ_c with RdYlGn_r centered at 1
# F-024: vmax is data-driven — preserves default of 2.0 for typical data but
# extends the colormap range for high-stress regions (99th percentile).
vmin = 0.0
vmax = max(2.0, float(np.nanpercentile(sigma_normalized, 99)))
im2 = axes[1].imshow(
sigma_normalized,
cmap="RdYlGn_r",
vmin=vmin,
vmax=vmax,
origin="lower",
**kwargs,
)
axes[1].set_title(r"Normalized Stress $\sigma_{\mathrm{eff}} / \sigma_c$")
fig.colorbar(im2, ax=axes[1], label=r"$\sigma_{\mathrm{eff}} / \sigma_c$")
axes[1].set_xlabel("x")
axes[1].set_ylabel("y")
# VIZ-004: use fig.tight_layout() instead of plt.tight_layout() (uses gcf())
fig.tight_layout()
return fig, list(axes)
[docs]
def plot_normal_stress_ratio(
shear_rates: np.ndarray | jax.Array,
N1: np.ndarray | jax.Array,
sigma_xy: np.ndarray | jax.Array,
figsize: tuple[int, int] = (8, 6),
ax: plt.Axes | None = None,
**kwargs,
) -> tuple[plt.Figure, plt.Axes]:
"""Plot log-log of N₁/σ_xy vs shear rate.
Args:
shear_rates: Array of shear rates.
N1: First normal stress difference values.
sigma_xy: Shear stress values.
figsize: Figure size (width, height).
ax: Optional pre-existing axis.
**kwargs: Additional arguments passed to plot.
Returns:
Tuple of (Figure, Axes).
"""
shear_rates = np.array(shear_rates)
N1 = np.array(N1)
sigma_xy = np.array(sigma_xy)
# Compute ratio (avoid division by zero)
ratio = N1 / (np.abs(sigma_xy) + 1e-12)
# Create figure if not provided
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.get_figure()
# VIZ-017: warn about negative ratio values that loglog will silently drop
negative_count = np.sum(ratio < 0)
if negative_count > 0:
logger.warning(
"N1/sigma_xy ratio has %d negative values; these will be absent from loglog plot. "
"Consider using a linear or semilogy scale for signed data.",
negative_count,
)
# Log-log plot
ax.loglog(shear_rates, ratio, marker="o", **kwargs)
ax.set_xlabel(r"Shear Rate $\dot{\gamma}$ (1/s)")
ax.set_ylabel(r"$N_1 / \sigma_{xy}$")
ax.set_title("Normal Stress Ratio")
ax.grid(True, which="both", alpha=0.3)
# VIZ-004: use fig.tight_layout() instead of plt.tight_layout() (uses gcf())
fig.tight_layout()
return fig, ax
[docs]
def animate_tensorial_evolution(
history: dict[str, np.ndarray | jax.Array],
component: str = "all",
interval: int = 50,
save_path: str | None = None,
**kwargs,
) -> animation.FuncAnimation:
"""Create animation of tensorial stress field evolution.
Args:
history: Dictionary with keys:
- 'stress': Stress history of shape (T, 3, L, L)
- 'time': Time array of shape (T,)
component: Component to animate:
- 'all': All 3 components (3-panel animation)
- 'xx', 'yy', 'xy': Individual components
- 'N1': First normal stress difference
- 'vm': von Mises effective stress
interval: Delay between frames in milliseconds.
save_path: If provided, save animation to this path.
**kwargs: Additional arguments (e.g., nu for von Mises).
Returns:
Matplotlib FuncAnimation object.
"""
stress_history = np.array(history["stress"])
time = np.array(history["time"])
T, n_comp, L, _ = stress_history.shape
if n_comp != 3:
raise ValueError(
f"Expected stress shape (T, 3, L, L), got {stress_history.shape}"
)
nu = kwargs.get("nu", 0.5)
# Determine component mapping
component_map = {
"xx": 0,
"yy": 1,
"xy": 2,
}
if component == "all":
# 3-panel animation
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
# Find global limits
# VIS-EPM-001: Guard against zero-stress initial states — same fix as
# animate_stress_evolution (line 219). Without this, vmin=vmax=0
# produces a degenerate colormap and matplotlib normalization warning.
max_stress = np.max(np.abs(stress_history))
if max_stress == 0.0:
max_stress = 1.0
labels = [r"$\sigma_{xx}$", r"$\sigma_{yy}$", r"$\sigma_{xy}$"]
images = []
for i in range(3):
im = axes[i].imshow(
stress_history[0, i],
cmap="coolwarm",
vmin=-max_stress,
vmax=max_stress,
origin="lower",
animated=True,
)
axes[i].set_title(f"{labels[i]} - t={time[0]:.3f}")
fig.colorbar(im, ax=axes[i])
images.append(im)
def update(frame):
for i in range(3):
images[i].set_array(stress_history[frame, i])
axes[i].set_title(f"{labels[i]} - t={time[frame]:.3f}")
return tuple(images)
anim = animation.FuncAnimation(
fig, update, frames=T, interval=interval, blit=True
)
elif component in component_map:
# Single component animation
idx = component_map[component]
fig, ax = plt.subplots(figsize=(6, 5))
max_val = np.max(np.abs(stress_history[:, idx]))
# VIZ-ANM-001: guard degenerate vmin=vmax=0 colormap (same fix as "all" branch)
if max_val == 0.0:
max_val = 1.0
im = ax.imshow(
stress_history[0, idx],
cmap="coolwarm",
vmin=-max_val,
vmax=max_val,
origin="lower",
animated=True,
)
ax.set_title(f"$\\sigma_{{{component}}}$ - t={time[0]:.3f}")
fig.colorbar(im, ax=ax)
def update(frame):
im.set_array(stress_history[frame, idx])
ax.set_title(f"$\\sigma_{{{component}}}$ - t={time[frame]:.3f}")
return (im,)
anim = animation.FuncAnimation(
fig, update, frames=T, interval=interval, blit=True
)
elif component == "N1":
# Normal stress difference animation
fig, ax = plt.subplots(figsize=(6, 5))
# Compute N₁ for all frames
N1_history = stress_history[:, 0] - stress_history[:, 1]
max_N1 = np.max(np.abs(N1_history))
# VIZ-ANM-002: guard degenerate vmin=vmax=0 colormap
if max_N1 == 0.0:
max_N1 = 1.0
im = ax.imshow(
N1_history[0],
cmap="coolwarm",
vmin=-max_N1,
vmax=max_N1,
origin="lower",
animated=True,
)
ax.set_title(f"$N_1$ - t={time[0]:.3f}")
fig.colorbar(im, ax=ax, label=r"$N_1$")
def update(frame):
im.set_array(N1_history[frame])
ax.set_title(f"$N_1$ - t={time[frame]:.3f}")
return (im,)
anim = animation.FuncAnimation(
fig, update, frames=T, interval=interval, blit=True
)
elif component == "vm":
# von Mises animation
fig, ax = plt.subplots(figsize=(6, 5))
# Import von Mises function
from rheojax.utils.epm_kernels_tensorial import compute_von_mises_stress
# VIZ-021: vectorize von Mises computation with vmap instead of O(T) Python loop
try:
stress_all = np.moveaxis(stress_history, 1, -1) # (T, L, L, 3)
stress_jax_all = jnp.array(stress_all)
compute_vm_batch = jax.vmap(compute_von_mises_stress, in_axes=(0, None))
vm_history_arr = np.array(compute_vm_batch(stress_jax_all, nu))
except (
jax.errors.TracerBoolConversionError,
jax.errors.ConcretizationTypeError,
NotImplementedError,
ValueError,
RuntimeError, # R8-VIZ-001: XLA OOM: "RESOURCE_EXHAUSTED"
MemoryError, # host-side allocation failure
):
# Fallback to sequential loop if vmap fails (e.g., tracing issues)
vm_history = []
for t_idx in range(T):
stress_reshaped = np.moveaxis(stress_history[t_idx], 0, -1)
stress_jax = jnp.array(stress_reshaped)
sigma_eff = compute_von_mises_stress(stress_jax, nu)
vm_history.append(np.array(sigma_eff))
vm_history_arr = np.array(vm_history)
max_vm = np.max(vm_history_arr)
# VIZ-ANM-003: guard degenerate vmin=vmax=0 colormap for zero-stress inputs
if max_vm == 0.0:
max_vm = 1.0
im = ax.imshow(
vm_history_arr[0],
cmap="viridis",
vmin=0,
vmax=max_vm,
origin="lower",
animated=True,
)
ax.set_title(f"$\\sigma_{{\\mathrm{{eff}}}}$ - t={time[0]:.3f}")
fig.colorbar(im, ax=ax, label=r"$\sigma_{\mathrm{eff}}$")
def update(frame):
im.set_array(vm_history_arr[frame])
ax.set_title(f"$\\sigma_{{\\mathrm{{eff}}}}$ - t={time[frame]:.3f}")
return (im,)
anim = animation.FuncAnimation(
fig, update, frames=T, interval=interval, blit=True
)
else:
raise ValueError(
f"Unknown component: {component}. "
"Expected 'all', 'xx', 'yy', 'xy', 'N1', or 'vm'."
)
if save_path:
# VIZ-015: wrap save() so missing writers produce a clear error, not a cryptic traceback
try:
anim.save(save_path)
except Exception as e:
logger.error(
"Failed to save animation — check that the required writer "
"(ffmpeg for .mp4, Pillow for .gif) is installed",
save_path=str(save_path),
error=str(e),
)
raise
return anim