"""Hébraud–Lequeux (HL) Model implementation.
This module implements the Hébraud–Lequeux mean-field elastoplastic model
for yield-stress fluids and soft glassy materials. It integrates JAX-accelerated
kernels for high-performance simulation of flow curves, creep, relaxation,
and LAOS protocols.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import numpy as np
from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger, log_fit
from rheojax.utils.hl_kernels import (
creep_kernel,
laos_kernel,
relaxation_kernel,
run_creep,
run_flow_curve,
run_laos,
run_relaxation,
run_saos,
run_startup,
startup_kernel,
)
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
if TYPE_CHECKING:
import jax.numpy as jnp_typing
else:
jnp_typing = Any
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"hebraud_lequeux",
protocols=[
Protocol.FLOW_CURVE,
Protocol.CREEP,
Protocol.RELAXATION,
Protocol.STARTUP,
Protocol.OSCILLATION,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class HebraudLequeux(BaseModel):
"""Hébraud–Lequeux (HL) Model for Soft Glassy Materials.
The HL model (1998) is a mean-field description of yield-stress fluids where
mesoscopic blocks of stress evolve via elastic loading, plastic yielding, and
stress diffusion (mechanical noise) generated by yielding events elsewhere.
It predicts:
- Finite yield stress for coupling parameter alpha < 0.5
- Herschel-Bulkley flow curves (stress ~ gdot^0.5) near yield
- Creep and delayed yielding
- Stress overshoots in startup flow
- Non-linear LAOS response
Parameters:
alpha: Coupling parameter (dimensionless). Controls phase state.
alpha < 0.5: Glassy (yield stress)
alpha >= 0.5: Fluid (no yield stress)
tau: Microscopic yield timescale (s).
sigma_c: Critical yield stress threshold (Pa).
Attributes:
parameters: ParameterSet containing alpha, tau, sigma_c.
"""
[docs]
def __init__(self):
"""Initialize Hébraud–Lequeux Model."""
super().__init__()
# Create parameter set
self.parameters = ParameterSet()
# alpha: Coupling parameter
# Range: 0 to 1. alpha=0.5 is the critical point.
self.parameters.add(
name="alpha",
value=0.3,
bounds=(1e-4, 0.9999),
units="dimensionless",
description="Coupling parameter (alpha < 0.5 -> yield stress)",
)
# tau: Yield timescale
self.parameters.add(
name="tau",
value=1.0,
bounds=(1e-6, 1e4),
units="s",
description="Microscopic yield timescale",
)
# sigma_c: Yield stress threshold
self.parameters.add(
name="sigma_c",
value=1.0,
bounds=(1e-3, 1e6),
units="Pa",
description="Critical yield stress threshold",
)
# Internal state for protocol settings
self._test_mode: str | None = None
self._last_fit_kwargs: dict[str, Any] = {}
# Store metadata for reconstructing time axes in Bayesian mode
self._fit_data_metadata: dict[str, Any] = {}
# Grid settings (can be adjusted by user)
self.grid_n_bins = 501
self.grid_sigma_factor = 5.0 # grid extends to sigma_c * factor
# Adaptive time-stepping: cap lax.scan length to avoid OOM/slow JIT.
# CFL sub-stepping inside step_hl() ensures physics accuracy
# regardless of outer dt, so larger dt is safe.
self._max_scan_steps = 20000
self._min_dt = 0.005
# Creep kernel has servo controller feedback loop whose XLA compilation
# scales super-linearly with n_steps (~O(n^1.5)):
# 500 steps → 0.6s compile, 1000 → 1.9s, 2000 → 5.8s
# Cap at 500 for tractable fitting.
self._max_scan_steps_creep = 500
# Bayesian (forward-mode AD through scan) is much more expensive
self._max_scan_steps_bayesian = 2000
self._max_scan_steps_bayesian_creep = 500
# HL kernels use lax.fori_loop with dynamic bounds for numerical stability,
# which requires forward-mode autodiff for NUTS sampling.
self._use_forward_mode_ad = True
def _get_grid_params(self, sigma_c_val: float | None = None) -> tuple[float, int]:
"""Get grid parameters based on current or provided sigma_c."""
if sigma_c_val is None:
sigma_c_val = self.parameters.get_value("sigma_c")
# Ensure grid covers relevant stress range
# Minimum sigma_max of 5.0 to handle standard normalized cases
# Otherwise scale with sigma_c
sigma_max = max(5.0, sigma_c_val * self.grid_sigma_factor)
return sigma_max, self.grid_n_bins
def _adaptive_dt(self, t_max: float) -> tuple[float, int]:
"""Compute adaptive dt and n_steps to cap scan length.
The CFL sub-stepping inside step_hl() ensures physics accuracy
regardless of outer dt, so we can safely increase dt for long
experiments to keep n_steps bounded.
"""
dt = max(self._min_dt, t_max / self._max_scan_steps)
n_steps = int(t_max / dt) + 1
return dt, n_steps
def _fit(
self,
X: np.ndarray,
y: np.ndarray,
**kwargs: Any,
) -> HebraudLequeux:
"""Fit HL model parameters to data.
Args:
X: Independent variable (shear rate, time, etc.)
y: Dependent variable (stress, viscosity, compliance, etc.)
**kwargs: Must include 'test_mode'. Options:
test_mode: Protocol ('steady_shear', 'creep', 'relaxation',
'startup', 'laos')
Other optimizer/protocol-specific parameters (e.g. gamma0)
"""
test_mode: str | None = kwargs.pop("test_mode", None)
if test_mode is None:
raise ValueError("test_mode must be specified for HL fitting")
self._test_mode = test_mode
# Strip optimization meta-kwargs injected by BaseModel.fit() —
# these are consumed by _fit() and should not leak to model_function.
_optimization_keys = {
"use_log_residuals",
"use_multi_start",
"n_starts",
"perturb_factor",
"_optimization_meta",
"method",
}
self._last_fit_kwargs = {
k: v for k, v in kwargs.items() if k not in _optimization_keys
}
# Store metadata for Bayesian reconstruction
if len(X) > 0:
self._fit_data_metadata = {
"t_max": float(X[-1]) if test_mode != "steady_shear" else None,
"len_X": len(X),
"X_start": float(X[0]),
"X_end": float(X[-1]),
}
with log_fit(logger, model="HebraudLequeux", data_shape=X.shape) as ctx:
logger.info(f"Fitting HL model in mode: {test_mode}")
ctx["test_mode"] = test_mode
if test_mode == "steady_shear" or test_mode == "flow_curve":
self._fit_steady_shear(X, y, **kwargs)
elif test_mode == "creep":
self._fit_creep(X, y, **kwargs)
elif test_mode == "relaxation":
self._fit_relaxation(X, y, **kwargs)
elif test_mode == "startup":
self._fit_startup(X, y, **kwargs)
elif test_mode == "laos":
self._fit_laos(X, y, **kwargs)
elif test_mode == "saos" or test_mode == "oscillation":
self._fit_oscillation(X, y, **kwargs)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
self.fitted_ = True
return self
def _fit_steady_shear(self, gdot: np.ndarray, stress: np.ndarray, **kwargs):
"""Fit flow curve using derivative-free optimization.
The HL PDE solver has sharp gradients near the glass transition
(alpha ~ 0.5) that cause overflow in finite-difference Jacobians.
We use Nelder-Mead (derivative-free) with multi-start and log-space
MSE to robustly fit flow curves spanning multiple decades.
Stress is normalized by the low-rate plateau so sigma_c ~ 2
(HL yield stress ≈ sigma_c/2), and tau is parameterized in
log10-space for better scaling.
"""
import time as _time
from scipy.optimize import minimize
from rheojax.utils.hl_kernels import _compute_dt_and_steps_for_rate
# --- Normalize by low-rate stress ---
idx_low = int(np.argmin(np.abs(gdot)))
stress_scale = max(float(stress[idx_low]), 1e-12)
stress_norm = stress / stress_scale
gdot_jax = jnp.asarray(gdot, dtype=jnp.float64)
target = np.asarray(stress_norm, dtype=np.float64)
target_safe = np.maximum(target, 1e-10)
max_stress_norm = float(np.max(stress_norm))
# Grid: n_bins=201 for fitting speed (501 for final predictions)
n_bins_fit = 201
# Minimum sigma_max covers the data stress range
sigma_max_min = 1.5 * max_stress_norm
# --- Cost: log-space MSE with dynamic grid per sigma_c ---
def cost_fn(x):
alpha_v = x[0]
tau_v = 10.0 ** x[1]
sigma_c_v = x[2]
if not (0.01 <= alpha_v <= 0.99):
return 1e6
if not (0.1 <= sigma_c_v <= 15.0):
return 1e6
try:
# Dynamic sigma_max: always covers yield boundary AND data
sigma_max_v = max(5.0 * sigma_c_v, sigma_max_min)
ds_v = 2.0 * sigma_max_v / (n_bins_fit - 1)
schedule = [
_compute_dt_and_steps_for_rate(
abs(float(g)),
tau_v,
sigma_c_v,
ds=ds_v,
max_steps=5_000,
bucket_size=5_000,
)
for g in gdot
]
pred = np.array(
run_flow_curve(
gdot_jax,
alpha_v,
tau_v,
sigma_c_v,
0.005,
sigma_max_v,
n_bins_fit,
per_rate_schedule=schedule,
)
)
pred_safe = np.maximum(np.abs(pred), 1e-10)
log_resid = np.log10(pred_safe) - np.log10(target_safe)
return float(np.mean(log_resid**2))
except Exception as exc:
logger.debug("Flow curve cost_fn exception: %s", exc)
return 1e6
# --- Multi-start Nelder-Mead ---
# x = [alpha, log10(tau), sigma_c_norm]
# HL yield stress ≈ sigma_c/2, so sigma_c ~ 2 for normalized data
starts = [
[0.10, -1.3, 2.5], # alpha=0.10, tau=0.05, sigma_c=2.5
[0.30, -2.0, 3.0], # alpha=0.30, tau=0.01, sigma_c=3.0
]
best_x = starts[0]
best_cost = np.inf
t0 = _time.time()
# Extract callback for cancellation support (F-HL-009)
_callback = kwargs.get("callback")
for i, x0 in enumerate(starts):
try:
res = minimize(
cost_fn,
x0,
method="Nelder-Mead",
callback=_callback,
options={
"maxfev": 40,
"xatol": 0.02,
"fatol": 0.005,
"adaptive": True,
},
)
if res.fun < best_cost:
best_cost = res.fun
best_x = res.x.copy()
logger.info(
f"Start {i+1}/{len(starts)}: cost={res.fun:.5f}, "
f"alpha={res.x[0]:.3f}, tau={10**res.x[1]:.3e}, "
f"sigma_c={res.x[2]:.3f} ({res.nfev} evals)"
)
except Exception as e:
logger.warning(f"Start {i+1} failed: {e}")
elapsed = _time.time() - t0
logger.info(f"HL fit: {elapsed:.1f}s, best cost={best_cost:.5f}")
# --- Set fitted parameters ---
alpha_fit = float(np.clip(best_x[0], 0.01, 0.99))
tau_fit = float(10.0 ** np.clip(best_x[1], -6, 4))
sigma_c_fit = float(np.clip(best_x[2], 0.1, 15.0))
self.parameters.set_value("alpha", alpha_fit)
self.parameters.set_value("tau", tau_fit)
self.parameters.set_value("sigma_c", sigma_c_fit * stress_scale)
# Store for predict/Bayesian
_tau_val = self.parameters.get_value("tau")
self._last_fit_kwargs["_tau_est"] = float(1.0 if _tau_val is None else _tau_val)
_sc_val = self.parameters.get_value("sigma_c")
self._last_fit_kwargs["_sigma_c_est"] = float(
1.0 if _sc_val is None else _sc_val
)
self._last_fit_kwargs["_stress_scale"] = stress_scale
self._last_fit_kwargs["_sigma_max_min_norm"] = sigma_max_min
self._last_fit_kwargs["_n_bins_fit"] = n_bins_fit
_sc_phys_val = self.parameters.get_value("sigma_c")
sc_phys = float(1.0 if _sc_phys_val is None else _sc_phys_val)
self._last_fit_kwargs["_sigma_max"] = max(5.0, self.grid_sigma_factor * sc_phys)
# Precompute per-rate schedule for Bayesian model_function.
# This avoids np.asarray(X_jax) inside JIT which can fail if
# jit_model_args=True or NumPyro traces model arguments.
sc_norm_est = (sc_phys / stress_scale) if stress_scale > 0 else 1.0
sigma_max_norm = max(5.0 * sc_norm_est, sigma_max_min)
ds = 2.0 * sigma_max_norm / (n_bins_fit - 1)
_tau_est_val = self.parameters.get_value("tau")
tau_est = float(1.0 if _tau_est_val is None else _tau_est_val)
self._last_fit_kwargs["_precomputed_schedule"] = [
_compute_dt_and_steps_for_rate(
abs(float(gdot[i])), tau_est, sc_norm_est, ds=ds
)
for i in range(len(gdot))
]
def _fit_creep(self, t: np.ndarray, compliance: np.ndarray, **kwargs):
"""Fit creep compliance."""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
stress_target = kwargs.get("stress_target")
if stress_target is None:
raise ValueError(
"stress_target must be provided in kwargs for creep fitting"
)
t_jax = jnp.asarray(t, dtype=jnp.float64)
J_jax = jnp.asarray(compliance, dtype=jnp.float64)
# Calculate n_steps statically for the objective function.
# Use creep-specific cap — servo controller causes super-linear
# XLA compilation cost with n_steps.
t_max = float(t[-1])
dt = max(self._min_dt, t_max / self._max_scan_steps_creep)
n_steps = int(t_max / dt) + 1
# Grid sizing — use coarser grid for fitting speed.
# Coarser grid → larger ds → larger dt_stable → fewer CFL sub-steps
# per outer step → dramatically faster (O(n_bins * n_sub) per step).
sigma_max, _ = self._get_grid_params()
n_bins = 51
def model_fn(x_data, params):
alpha, tau, sigma_c = params
time_hist, gamma_hist = creep_kernel(
n_steps, stress_target, alpha, tau, sigma_c, 1.0, dt, sigma_max, n_bins
)
# Add t=0
time_full = jnp.concatenate([jnp.array([0.0]), time_hist])
gamma_full = jnp.concatenate([jnp.array([0.0]), gamma_hist])
gamma_interp = jnp.interp(x_data, time_full, gamma_full)
return gamma_interp / stress_target
objective = create_least_squares_objective(
model_fn, t_jax, J_jax, normalize=True, use_log_residuals=True
)
result = nlsq_optimize(
objective,
self.parameters,
method="scipy",
use_jax=True,
max_iter=kwargs.get("max_iter", 200),
)
if not result.success:
logger.warning(f"Optimization warning: {result.message}")
# Store protocol kwargs for model_function (Bayesian inference)
self._last_fit_kwargs["stress_target"] = float(stress_target)
self._last_fit_kwargs["_sigma_max"] = float(sigma_max)
self._last_fit_kwargs["_n_bins"] = int(n_bins)
def _fit_relaxation(self, t: np.ndarray, modulus: np.ndarray, **kwargs):
"""Fit stress relaxation modulus."""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
gamma0 = kwargs.get("gamma0")
if gamma0 is None:
raise ValueError(
"gamma0 (step strain) must be provided in kwargs for relaxation fitting"
)
t_jax = jnp.asarray(t, dtype=jnp.float64)
G_jax = jnp.asarray(modulus, dtype=jnp.float64)
t_max = float(t[-1])
dt, n_steps = self._adaptive_dt(t_max)
# Grid sizing
sigma_max, n_bins = self._get_grid_params()
def model_fn(x_data, params):
alpha, tau, sigma_c = params
time_hist, stress_hist = relaxation_kernel(
n_steps, gamma0, alpha, tau, sigma_c, dt, sigma_max, n_bins
)
# Initial stress approximation
init_stress = gamma0
time_full = jnp.concatenate([jnp.array([0.0]), time_hist])
stress_full = jnp.concatenate([jnp.array([init_stress]), stress_hist])
sigma_interp = jnp.interp(x_data, time_full, stress_full)
return sigma_interp / gamma0
objective = create_least_squares_objective(
model_fn, t_jax, G_jax, normalize=True, use_log_residuals=True
)
result = nlsq_optimize(
objective,
self.parameters,
method="scipy",
use_jax=True,
max_iter=kwargs.get("max_iter", 500),
)
if not result.success:
logger.warning(f"Optimization warning: {result.message}")
# Store protocol kwargs for model_function (Bayesian inference)
self._last_fit_kwargs["gamma0"] = float(gamma0)
self._last_fit_kwargs["_sigma_max"] = float(sigma_max)
self._last_fit_kwargs["_n_bins"] = int(n_bins)
def _fit_startup(self, t: np.ndarray, stress: np.ndarray, **kwargs):
"""Fit startup stress transient."""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
gdot = kwargs.get("gdot")
if gdot is None:
raise ValueError(
"gdot (shear rate) must be provided in kwargs for startup fitting"
)
t_jax = jnp.asarray(t, dtype=jnp.float64)
stress_jax = jnp.asarray(stress, dtype=jnp.float64)
t_max = float(t[-1])
dt, n_steps = self._adaptive_dt(t_max)
# Grid sizing
sigma_max, n_bins = self._get_grid_params()
def model_fn(x_data, params):
alpha, tau, sigma_c = params
time_hist, stress_hist = startup_kernel(
n_steps, gdot, alpha, tau, sigma_c, dt, sigma_max, n_bins
)
time_full = jnp.concatenate([jnp.array([0.0]), time_hist])
stress_full = jnp.concatenate([jnp.array([0.0]), stress_hist])
return jnp.interp(x_data, time_full, stress_full)
objective = create_least_squares_objective(
model_fn, t_jax, stress_jax, normalize=True
)
result = nlsq_optimize(
objective,
self.parameters,
method="scipy",
use_jax=True,
max_iter=kwargs.get("max_iter", 500),
)
if not result.success:
logger.warning(f"Optimization warning: {result.message}")
# Store protocol kwargs for model_function (Bayesian inference)
self._last_fit_kwargs["gdot"] = float(gdot)
self._last_fit_kwargs["_sigma_max"] = float(sigma_max)
self._last_fit_kwargs["_n_bins"] = int(n_bins)
def _fit_laos(self, t: np.ndarray, stress: np.ndarray, **kwargs):
"""Fit LAOS stress response."""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
gamma0 = kwargs.get("gamma0")
omega = kwargs.get("omega")
if gamma0 is None or omega is None:
raise ValueError("gamma0 and omega must be provided for LAOS fitting")
t_jax = jnp.asarray(t, dtype=jnp.float64)
stress_jax = jnp.asarray(stress, dtype=jnp.float64)
t_max = float(t[-1])
dt, n_steps = self._adaptive_dt(t_max)
# Grid sizing
sigma_max, n_bins = self._get_grid_params()
def model_fn(x_data, params):
alpha, tau, sigma_c = params
time_hist, stress_hist = laos_kernel(
n_steps, gamma0, omega, alpha, tau, sigma_c, dt, sigma_max, n_bins
)
time_full = jnp.concatenate([jnp.array([0.0]), time_hist])
stress_full = jnp.concatenate([jnp.array([0.0]), stress_hist])
return jnp.interp(x_data, time_full, stress_full)
objective = create_least_squares_objective(
model_fn, t_jax, stress_jax, normalize=True
)
result = nlsq_optimize(
objective,
self.parameters,
method="scipy",
use_jax=True,
max_iter=kwargs.get("max_iter", 500),
)
if not result.success:
logger.warning(f"Optimization warning: {result.message}")
# Store protocol kwargs for model_function (Bayesian inference)
self._last_fit_kwargs["gamma0"] = float(gamma0)
self._last_fit_kwargs["omega"] = float(omega)
self._last_fit_kwargs["_sigma_max"] = float(sigma_max)
self._last_fit_kwargs["_n_bins"] = int(n_bins)
def _fit_oscillation(self, omega: np.ndarray, G_star: np.ndarray, **kwargs):
"""Fit SAOS oscillatory data (G', G'').
Uses scipy L-BFGS-B with log-space MSE cost, consistent with
the HL fitting pattern for derivative-free optimization.
Args:
omega: Angular frequency array (rad/s)
G_star: Complex modulus — either complex array or (M, 2) [G', G'']
**kwargs: Additional fitting arguments
"""
from scipy.optimize import minimize
# Parse complex or (M, 2) format
if np.iscomplexobj(G_star):
G_prime_data = np.real(G_star)
G_double_prime_data = np.imag(G_star)
elif G_star.ndim == 2 and G_star.shape[1] == 2:
G_prime_data = G_star[:, 0]
G_double_prime_data = G_star[:, 1]
else:
raise ValueError(
"G_star must be complex array or (M, 2) array of [G', G'']"
)
# Grid sizing
sigma_max, n_bins = self._get_grid_params()
n_cycles = kwargs.get("n_cycles", 10)
gamma0_saos = kwargs.get("gamma0", 0.01)
# Safe log targets
Gp_safe = np.maximum(np.abs(G_prime_data), 1e-10)
Gpp_safe = np.maximum(np.abs(G_double_prime_data), 1e-10)
def cost_fn(x):
alpha_v = x[0]
tau_v = 10.0 ** x[1]
sigma_c_v = x[2]
if not (0.01 <= alpha_v <= 0.99):
return 1e6
if not (0.01 <= sigma_c_v <= 100.0):
return 1e6
try:
result = run_saos(
jnp.asarray(omega),
alpha_v,
tau_v,
sigma_c_v,
gamma0=gamma0_saos,
n_cycles=n_cycles,
sigma_max=sigma_max,
n_bins=n_bins,
)
pred = np.array(result)
pred_Gp = np.maximum(np.abs(pred[:, 0]), 1e-10)
pred_Gpp = np.maximum(np.abs(pred[:, 1]), 1e-10)
log_resid_Gp = np.log10(pred_Gp) - np.log10(Gp_safe)
log_resid_Gpp = np.log10(pred_Gpp) - np.log10(Gpp_safe)
return float(np.mean(log_resid_Gp**2 + log_resid_Gpp**2))
except Exception as exc:
logger.debug("SAOS cost_fn exception: %s", exc)
return 1e6
# Multi-start optimization
starts = [
[0.30, -1.0, 2.0],
[0.10, -2.0, 3.0],
[0.50, 0.0, 1.0],
]
best_x = starts[0]
best_cost = np.inf
# Extract callback for cancellation support (F-HL-009)
_callback = kwargs.get("callback")
for i, x0 in enumerate(starts):
try:
res = minimize(
cost_fn,
x0,
method="Nelder-Mead",
callback=_callback,
options={
"maxfev": 60,
"xatol": 0.02,
"fatol": 0.005,
"adaptive": True,
},
)
if res.fun < best_cost:
best_cost = res.fun
best_x = res.x.copy()
logger.info(
f"SAOS start {i+1}/{len(starts)}: cost={res.fun:.5f}, "
f"alpha={res.x[0]:.3f}, tau={10**res.x[1]:.3e}, "
f"sigma_c={res.x[2]:.3f}"
)
except Exception as e:
logger.warning(f"SAOS start {i+1} failed: {e}")
# Set fitted parameters
alpha_fit = float(np.clip(best_x[0], 0.01, 0.99))
tau_fit = float(10.0 ** np.clip(best_x[1], -6, 4))
sigma_c_fit = float(np.clip(best_x[2], 0.01, 100.0))
self.parameters.set_value("alpha", alpha_fit)
self.parameters.set_value("tau", tau_fit)
self.parameters.set_value("sigma_c", sigma_c_fit)
logger.info(
f"HL SAOS fit: alpha={alpha_fit:.3f}, tau={tau_fit:.3e}, "
f"sigma_c={sigma_c_fit:.3f}, cost={best_cost:.5f}"
)
# Store protocol kwargs for model_function (Bayesian inference)
self._last_fit_kwargs["n_cycles"] = int(n_cycles)
self._last_fit_kwargs["gamma0"] = float(gamma0_saos)
self._last_fit_kwargs["_sigma_max"] = float(sigma_max)
self._last_fit_kwargs["_n_bins"] = int(n_bins)
def _predict(self, X: np.ndarray, **kwargs: Any) -> np.ndarray:
"""Predict response using fitted parameters and stored test_mode."""
if self._test_mode is None:
raise ValueError("Model not fitted or test_mode not set.")
X_jax = jnp.asarray(X, dtype=jnp.float64)
alpha = self.parameters.get_value("alpha")
tau = self.parameters.get_value("tau")
sigma_c = self.parameters.get_value("sigma_c")
sigma_max, n_bins = self._get_grid_params(sigma_c)
# The run_xxx functions in hl_kernels.py are now wrappers that
# handle t_max/n_steps correctly (computes it from X array)
# So we can use them directly here as X is provided at runtime.
if self._test_mode in ("steady_shear", "flow_curve"):
from rheojax.utils.hl_kernels import _compute_dt_and_steps_for_rate
# Predict in normalized units (same as fitting) then scale back
stress_scale = self._last_fit_kwargs.get("_stress_scale", 1.0)
sigma_max_min_norm = self._last_fit_kwargs.get("_sigma_max_min_norm", 5.0)
n_bins_pred = self._last_fit_kwargs.get("_n_bins_fit", 501)
tau_val = float(1.0 if tau is None else tau)
sigma_c_norm = float(1.0 if sigma_c is None else sigma_c) / stress_scale
sigma_max_norm = max(5.0 * sigma_c_norm, sigma_max_min_norm)
ds = 2.0 * sigma_max_norm / (n_bins_pred - 1)
X_np = np.asarray(X_jax) # Single vectorized transfer
per_rate_schedule = [
_compute_dt_and_steps_for_rate(
abs(float(X_np[i])),
tau_val,
sigma_c_norm,
ds=ds,
)
for i in range(len(X_np))
]
pred_norm = run_flow_curve(
X_jax,
float(0.5 if alpha is None else alpha),
tau_val,
sigma_c_norm,
0.005,
float(sigma_max_norm),
int(n_bins_pred),
per_rate_schedule=per_rate_schedule,
)
return np.array(pred_norm) * stress_scale
elif self._test_mode == "creep":
stress_target = self._last_fit_kwargs.get("stress_target", 1.0)
t_max = float(X_jax[-1])
dt_pred = max(self._min_dt, t_max / self._max_scan_steps_creep)
return np.array(
run_creep(
X_jax,
float(stress_target),
float(0.5 if alpha is None else alpha),
float(1.0 if tau is None else tau),
float(1.0 if sigma_c is None else sigma_c),
1.0,
dt_pred,
float(sigma_max),
int(n_bins),
)
)
elif self._test_mode == "relaxation":
gamma0 = self._last_fit_kwargs.get("gamma0", 1.0)
t_max = float(X_jax[-1])
dt_pred, _ = self._adaptive_dt(t_max)
return np.array(
run_relaxation(
X_jax,
float(gamma0),
float(0.5 if alpha is None else alpha),
float(1.0 if tau is None else tau),
float(1.0 if sigma_c is None else sigma_c),
dt_pred,
float(sigma_max),
int(n_bins),
)
)
elif self._test_mode == "startup":
gdot = self._last_fit_kwargs.get("gdot", 1.0)
t_max = float(X_jax[-1])
dt_pred, _ = self._adaptive_dt(t_max)
return np.array(
run_startup(
X_jax,
float(gdot),
float(0.5 if alpha is None else alpha),
float(1.0 if tau is None else tau),
float(1.0 if sigma_c is None else sigma_c),
dt_pred,
float(sigma_max),
int(n_bins),
)
)
elif self._test_mode == "laos":
gamma0 = self._last_fit_kwargs.get("gamma0", 1.0)
omega = self._last_fit_kwargs.get("omega", 1.0)
t_max = float(X_jax[-1])
dt_pred, _ = self._adaptive_dt(t_max)
return np.array(
run_laos(
X_jax,
float(gamma0),
float(omega),
float(0.5 if alpha is None else alpha),
float(1.0 if tau is None else tau),
float(1.0 if sigma_c is None else sigma_c),
dt_pred,
float(sigma_max),
int(n_bins),
)
)
elif self._test_mode in ("oscillation", "saos"):
result = np.array(
run_saos(
X_jax,
float(0.5 if alpha is None else alpha),
float(1.0 if tau is None else tau),
float(1.0 if sigma_c is None else sigma_c),
sigma_max=float(sigma_max),
n_bins=int(n_bins),
)
)
# Convert (N,2) [G', G''] to complex G* for consistent API
return result[:, 0] + 1j * result[:, 1]
else:
raise ValueError(f"Unknown test mode: {self._test_mode}")
[docs]
def model_function(
self, X: np.ndarray, params: np.ndarray, test_mode: str | None = None, **kwargs
):
"""Model function for Bayesian inference (NumPyro NUTS).
Args:
X: Input array
params: Parameter values [alpha, tau, sigma_c]
test_mode: Override test mode
**kwargs: Protocol kwargs forwarded by BayesianMixin. Falls back
to _last_fit_kwargs for values not provided.
"""
mode = test_mode if test_mode is not None else self._test_mode
if mode is None:
raise ValueError("test_mode required for Bayesian inference")
alpha, tau, sigma_c = params
X_jax = jnp.asarray(X, dtype=jnp.float64)
# Helper: read protocol kwarg from explicit kwargs (forwarded by
# BayesianMixin), falling back to _last_fit_kwargs, then default.
def _kw(key, default):
val = kwargs.get(key)
if val is not None:
return val
val = self._last_fit_kwargs.get(key)
if val is not None:
return val
return default
# Use fixed grid for Bayesian (sigma_c is dynamic tracer, can't resize
# in JIT). Use sigma_max from NLSQ fit if available, else conservative.
sigma_max = _kw("_sigma_max", 50.0)
n_bins = _kw("_n_bins", 201)
# Helper to get adaptive dt and n_steps safely.
# Bayesian uses a coarser cap because forward-mode AD through
# lax.scan is much more expensive than plain evaluation.
def get_dt_and_n_steps(x_arr, creep=False):
try:
t_max = float(x_arr[-1])
except Exception as e:
if self._fit_data_metadata and "t_max" in self._fit_data_metadata:
t_max = self._fit_data_metadata["t_max"]
else:
raise RuntimeError(
"Cannot determine n_steps for Bayesian inference."
) from e
cap = (
self._max_scan_steps_bayesian_creep
if creep
else self._max_scan_steps_bayesian
)
dt = max(self._min_dt, t_max / cap)
n_steps = int(t_max / dt) + 1
return dt, n_steps
# Dispatch to kernels
if mode in ("steady_shear", "flow_curve"):
from rheojax.utils.hl_kernels import _compute_dt_and_steps_for_rate
dt = (
self._min_dt
) # default for flow curve (overridden by per_rate_schedule)
# Run in normalized units (same as fitting) for consistency
stress_scale = self._last_fit_kwargs.get("_stress_scale", 1.0)
sigma_max_min_norm = self._last_fit_kwargs.get("_sigma_max_min_norm", 5.0)
n_bins_bayes = self._last_fit_kwargs.get("_n_bins_fit", 501)
# Use stored NLSQ estimates for schedule (tau/sigma_c are tracers)
tau_est = self._last_fit_kwargs.get("_tau_est", 1.0)
sc_est = self._last_fit_kwargs.get("_sigma_c_est", 1.0)
sc_norm_est = sc_est / stress_scale
sigma_max_norm = max(5.0 * sc_norm_est, sigma_max_min_norm)
ds = 2.0 * sigma_max_norm / (n_bins_bayes - 1)
try:
# Use precomputed schedule from _fit() if available —
# avoids np.asarray(X_jax) which can fail during JIT tracing.
schedule = self._last_fit_kwargs.get("_precomputed_schedule")
if schedule is None:
X_np = np.asarray(X_jax)
schedule = [
_compute_dt_and_steps_for_rate(
abs(float(X_np[i])),
tau_est,
sc_norm_est,
ds=ds,
)
for i in range(len(X_np))
]
# sigma_c tracer divided by stress_scale to get normalized
sigma_c_norm = sigma_c / stress_scale
pred_norm = run_flow_curve(
X_jax,
alpha,
tau,
sigma_c_norm,
dt,
sigma_max_norm,
n_bins_bayes,
per_rate_schedule=schedule,
)
return pred_norm * stress_scale
except Exception as exc:
# F-HL-017 fix: fallback also uses normalized units for
# consistency. Previously used raw sigma_c without rescaling.
logger.warning(
"Flow curve normalized path failed, using fallback",
error=str(exc),
)
sigma_c_norm_fb = sigma_c / stress_scale
pred_fb = run_flow_curve(
X_jax,
alpha,
tau,
sigma_c_norm_fb,
dt,
sigma_max_norm,
n_bins_bayes,
)
return pred_fb * stress_scale
elif mode == "creep":
# Creep uses tighter cap due to super-linear compilation cost
dt, n_steps = get_dt_and_n_steps(X_jax, creep=True)
stress_target = _kw("stress_target", 1.0)
time_hist, gamma_hist = creep_kernel(
n_steps, stress_target, alpha, tau, sigma_c, 1.0, dt, sigma_max, n_bins
)
time_full = jnp.concatenate([jnp.array([0.0]), time_hist])
gamma_full = jnp.concatenate([jnp.array([0.0]), gamma_hist])
return jnp.interp(X_jax, time_full, gamma_full) / stress_target
elif mode == "relaxation":
dt, n_steps = get_dt_and_n_steps(X_jax)
gamma0 = _kw("gamma0", 1.0)
time_hist, stress_hist = relaxation_kernel(
n_steps, gamma0, alpha, tau, sigma_c, dt, sigma_max, n_bins
)
# Init stress approx
init_stress = gamma0
time_full = jnp.concatenate([jnp.array([0.0]), time_hist])
stress_full = jnp.concatenate([jnp.array([init_stress]), stress_hist])
return jnp.interp(X_jax, time_full, stress_full) / gamma0
elif mode == "startup":
dt, n_steps = get_dt_and_n_steps(X_jax)
gdot = _kw("gdot", 1.0)
time_hist, stress_hist = startup_kernel(
n_steps, gdot, alpha, tau, sigma_c, dt, sigma_max, n_bins
)
time_full = jnp.concatenate([jnp.array([0.0]), time_hist])
stress_full = jnp.concatenate([jnp.array([0.0]), stress_hist])
return jnp.interp(X_jax, time_full, stress_full)
elif mode == "laos":
dt, n_steps = get_dt_and_n_steps(X_jax)
gamma0 = _kw("gamma0", 1.0)
omega = _kw("omega", 1.0)
time_hist, stress_hist = laos_kernel(
n_steps, gamma0, omega, alpha, tau, sigma_c, dt, sigma_max, n_bins
)
time_full = jnp.concatenate([jnp.array([0.0]), time_hist])
stress_full = jnp.concatenate([jnp.array([0.0]), stress_hist])
return jnp.interp(X_jax, time_full, stress_full)
elif mode in ("oscillation", "saos"):
# SAOS for Bayesian: use stored n_cycles/gamma0 from _fit_oscillation()
n_cycles_bayes = int(_kw("n_cycles", 5))
gamma0_bayes = float(_kw("gamma0", 0.01))
return run_saos(
X_jax,
alpha,
tau,
sigma_c,
gamma0=gamma0_bayes,
n_cycles=n_cycles_bayes,
sigma_max=sigma_max,
n_bins=n_bins,
)
else:
raise ValueError(f"Unknown test mode for Bayesian: {mode}")
[docs]
def get_phase_state(self) -> str:
"""Return the phase state based on alpha."""
_alpha_val = self.parameters.get_value("alpha")
alpha = 0.3 if _alpha_val is None else _alpha_val
if alpha < 0.5:
return "glass"
else:
return "fluid"