Source code for rheojax.models.hl.hebraud_lequeux

"""Hébraud–Lequeux (HL) Model implementation.

This module implements the Hébraud–Lequeux mean-field elastoplastic model
for yield-stress fluids and soft glassy materials. It integrates JAX-accelerated
kernels for high-performance simulation of flow curves, creep, relaxation,
and LAOS protocols.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger, log_fit
from rheojax.utils.hl_kernels import (
    creep_kernel,
    laos_kernel,
    relaxation_kernel,
    run_creep,
    run_flow_curve,
    run_laos,
    run_relaxation,
    run_saos,
    run_startup,
    startup_kernel,
)

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()

if TYPE_CHECKING:
    import jax.numpy as jnp_typing
else:
    jnp_typing = Any

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "hebraud_lequeux", protocols=[ Protocol.FLOW_CURVE, Protocol.CREEP, Protocol.RELAXATION, Protocol.STARTUP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class HebraudLequeux(BaseModel): """Hébraud–Lequeux (HL) Model for Soft Glassy Materials. The HL model (1998) is a mean-field description of yield-stress fluids where mesoscopic blocks of stress evolve via elastic loading, plastic yielding, and stress diffusion (mechanical noise) generated by yielding events elsewhere. It predicts: - Finite yield stress for coupling parameter alpha < 0.5 - Herschel-Bulkley flow curves (stress ~ gdot^0.5) near yield - Creep and delayed yielding - Stress overshoots in startup flow - Non-linear LAOS response Parameters: alpha: Coupling parameter (dimensionless). Controls phase state. alpha < 0.5: Glassy (yield stress) alpha >= 0.5: Fluid (no yield stress) tau: Microscopic yield timescale (s). sigma_c: Critical yield stress threshold (Pa). Attributes: parameters: ParameterSet containing alpha, tau, sigma_c. """
[docs] def __init__(self): """Initialize Hébraud–Lequeux Model.""" super().__init__() # Create parameter set self.parameters = ParameterSet() # alpha: Coupling parameter # Range: 0 to 1. alpha=0.5 is the critical point. self.parameters.add( name="alpha", value=0.3, bounds=(1e-4, 0.9999), units="dimensionless", description="Coupling parameter (alpha < 0.5 -> yield stress)", ) # tau: Yield timescale self.parameters.add( name="tau", value=1.0, bounds=(1e-6, 1e4), units="s", description="Microscopic yield timescale", ) # sigma_c: Yield stress threshold self.parameters.add( name="sigma_c", value=1.0, bounds=(1e-3, 1e6), units="Pa", description="Critical yield stress threshold", ) # Internal state for protocol settings self._test_mode: str | None = None self._last_fit_kwargs: dict[str, Any] = {} # Store metadata for reconstructing time axes in Bayesian mode self._fit_data_metadata: dict[str, Any] = {} # Grid settings (can be adjusted by user) self.grid_n_bins = 501 self.grid_sigma_factor = 5.0 # grid extends to sigma_c * factor # Adaptive time-stepping: cap lax.scan length to avoid OOM/slow JIT. # CFL sub-stepping inside step_hl() ensures physics accuracy # regardless of outer dt, so larger dt is safe. self._max_scan_steps = 20000 self._min_dt = 0.005 # Creep kernel has servo controller feedback loop whose XLA compilation # scales super-linearly with n_steps (~O(n^1.5)): # 500 steps → 0.6s compile, 1000 → 1.9s, 2000 → 5.8s # Cap at 500 for tractable fitting. self._max_scan_steps_creep = 500 # Bayesian (forward-mode AD through scan) is much more expensive self._max_scan_steps_bayesian = 2000 self._max_scan_steps_bayesian_creep = 500 # HL kernels use lax.fori_loop with dynamic bounds for numerical stability, # which requires forward-mode autodiff for NUTS sampling. self._use_forward_mode_ad = True
def _get_grid_params(self, sigma_c_val: float | None = None) -> tuple[float, int]: """Get grid parameters based on current or provided sigma_c.""" if sigma_c_val is None: sigma_c_val = self.parameters.get_value("sigma_c") # Ensure grid covers relevant stress range # Minimum sigma_max of 5.0 to handle standard normalized cases # Otherwise scale with sigma_c sigma_max = max(5.0, sigma_c_val * self.grid_sigma_factor) return sigma_max, self.grid_n_bins def _adaptive_dt(self, t_max: float) -> tuple[float, int]: """Compute adaptive dt and n_steps to cap scan length. The CFL sub-stepping inside step_hl() ensures physics accuracy regardless of outer dt, so we can safely increase dt for long experiments to keep n_steps bounded. """ dt = max(self._min_dt, t_max / self._max_scan_steps) n_steps = int(t_max / dt) + 1 return dt, n_steps def _fit( self, X: np.ndarray, y: np.ndarray, **kwargs: Any, ) -> HebraudLequeux: """Fit HL model parameters to data. Args: X: Independent variable (shear rate, time, etc.) y: Dependent variable (stress, viscosity, compliance, etc.) **kwargs: Must include 'test_mode'. Options: test_mode: Protocol ('steady_shear', 'creep', 'relaxation', 'startup', 'laos') Other optimizer/protocol-specific parameters (e.g. gamma0) """ test_mode: str | None = kwargs.pop("test_mode", None) if test_mode is None: raise ValueError("test_mode must be specified for HL fitting") self._test_mode = test_mode # Strip optimization meta-kwargs injected by BaseModel.fit() — # these are consumed by _fit() and should not leak to model_function. _optimization_keys = { "use_log_residuals", "use_multi_start", "n_starts", "perturb_factor", "_optimization_meta", "method", } self._last_fit_kwargs = { k: v for k, v in kwargs.items() if k not in _optimization_keys } # Store metadata for Bayesian reconstruction if len(X) > 0: self._fit_data_metadata = { "t_max": float(X[-1]) if test_mode != "steady_shear" else None, "len_X": len(X), "X_start": float(X[0]), "X_end": float(X[-1]), } with log_fit(logger, model="HebraudLequeux", data_shape=X.shape) as ctx: logger.info(f"Fitting HL model in mode: {test_mode}") ctx["test_mode"] = test_mode if test_mode == "steady_shear" or test_mode == "flow_curve": self._fit_steady_shear(X, y, **kwargs) elif test_mode == "creep": self._fit_creep(X, y, **kwargs) elif test_mode == "relaxation": self._fit_relaxation(X, y, **kwargs) elif test_mode == "startup": self._fit_startup(X, y, **kwargs) elif test_mode == "laos": self._fit_laos(X, y, **kwargs) elif test_mode == "saos" or test_mode == "oscillation": self._fit_oscillation(X, y, **kwargs) else: raise ValueError(f"Unsupported test mode: {test_mode}") self.fitted_ = True return self def _fit_steady_shear(self, gdot: np.ndarray, stress: np.ndarray, **kwargs): """Fit flow curve using derivative-free optimization. The HL PDE solver has sharp gradients near the glass transition (alpha ~ 0.5) that cause overflow in finite-difference Jacobians. We use Nelder-Mead (derivative-free) with multi-start and log-space MSE to robustly fit flow curves spanning multiple decades. Stress is normalized by the low-rate plateau so sigma_c ~ 2 (HL yield stress ≈ sigma_c/2), and tau is parameterized in log10-space for better scaling. """ import time as _time from scipy.optimize import minimize from rheojax.utils.hl_kernels import _compute_dt_and_steps_for_rate # --- Normalize by low-rate stress --- idx_low = int(np.argmin(np.abs(gdot))) stress_scale = max(float(stress[idx_low]), 1e-12) stress_norm = stress / stress_scale gdot_jax = jnp.asarray(gdot, dtype=jnp.float64) target = np.asarray(stress_norm, dtype=np.float64) target_safe = np.maximum(target, 1e-10) max_stress_norm = float(np.max(stress_norm)) # Grid: n_bins=201 for fitting speed (501 for final predictions) n_bins_fit = 201 # Minimum sigma_max covers the data stress range sigma_max_min = 1.5 * max_stress_norm # --- Cost: log-space MSE with dynamic grid per sigma_c --- def cost_fn(x): alpha_v = x[0] tau_v = 10.0 ** x[1] sigma_c_v = x[2] if not (0.01 <= alpha_v <= 0.99): return 1e6 if not (0.1 <= sigma_c_v <= 15.0): return 1e6 try: # Dynamic sigma_max: always covers yield boundary AND data sigma_max_v = max(5.0 * sigma_c_v, sigma_max_min) ds_v = 2.0 * sigma_max_v / (n_bins_fit - 1) schedule = [ _compute_dt_and_steps_for_rate( abs(float(g)), tau_v, sigma_c_v, ds=ds_v, max_steps=5_000, bucket_size=5_000, ) for g in gdot ] pred = np.array( run_flow_curve( gdot_jax, alpha_v, tau_v, sigma_c_v, 0.005, sigma_max_v, n_bins_fit, per_rate_schedule=schedule, ) ) pred_safe = np.maximum(np.abs(pred), 1e-10) log_resid = np.log10(pred_safe) - np.log10(target_safe) return float(np.mean(log_resid**2)) except Exception as exc: logger.debug("Flow curve cost_fn exception: %s", exc) return 1e6 # --- Multi-start Nelder-Mead --- # x = [alpha, log10(tau), sigma_c_norm] # HL yield stress ≈ sigma_c/2, so sigma_c ~ 2 for normalized data starts = [ [0.10, -1.3, 2.5], # alpha=0.10, tau=0.05, sigma_c=2.5 [0.30, -2.0, 3.0], # alpha=0.30, tau=0.01, sigma_c=3.0 ] best_x = starts[0] best_cost = np.inf t0 = _time.time() # Extract callback for cancellation support (F-HL-009) _callback = kwargs.get("callback") for i, x0 in enumerate(starts): try: res = minimize( cost_fn, x0, method="Nelder-Mead", callback=_callback, options={ "maxfev": 40, "xatol": 0.02, "fatol": 0.005, "adaptive": True, }, ) if res.fun < best_cost: best_cost = res.fun best_x = res.x.copy() logger.info( f"Start {i+1}/{len(starts)}: cost={res.fun:.5f}, " f"alpha={res.x[0]:.3f}, tau={10**res.x[1]:.3e}, " f"sigma_c={res.x[2]:.3f} ({res.nfev} evals)" ) except Exception as e: logger.warning(f"Start {i+1} failed: {e}") elapsed = _time.time() - t0 logger.info(f"HL fit: {elapsed:.1f}s, best cost={best_cost:.5f}") # --- Set fitted parameters --- alpha_fit = float(np.clip(best_x[0], 0.01, 0.99)) tau_fit = float(10.0 ** np.clip(best_x[1], -6, 4)) sigma_c_fit = float(np.clip(best_x[2], 0.1, 15.0)) self.parameters.set_value("alpha", alpha_fit) self.parameters.set_value("tau", tau_fit) self.parameters.set_value("sigma_c", sigma_c_fit * stress_scale) # Store for predict/Bayesian _tau_val = self.parameters.get_value("tau") self._last_fit_kwargs["_tau_est"] = float(1.0 if _tau_val is None else _tau_val) _sc_val = self.parameters.get_value("sigma_c") self._last_fit_kwargs["_sigma_c_est"] = float( 1.0 if _sc_val is None else _sc_val ) self._last_fit_kwargs["_stress_scale"] = stress_scale self._last_fit_kwargs["_sigma_max_min_norm"] = sigma_max_min self._last_fit_kwargs["_n_bins_fit"] = n_bins_fit _sc_phys_val = self.parameters.get_value("sigma_c") sc_phys = float(1.0 if _sc_phys_val is None else _sc_phys_val) self._last_fit_kwargs["_sigma_max"] = max(5.0, self.grid_sigma_factor * sc_phys) # Precompute per-rate schedule for Bayesian model_function. # This avoids np.asarray(X_jax) inside JIT which can fail if # jit_model_args=True or NumPyro traces model arguments. sc_norm_est = (sc_phys / stress_scale) if stress_scale > 0 else 1.0 sigma_max_norm = max(5.0 * sc_norm_est, sigma_max_min) ds = 2.0 * sigma_max_norm / (n_bins_fit - 1) _tau_est_val = self.parameters.get_value("tau") tau_est = float(1.0 if _tau_est_val is None else _tau_est_val) self._last_fit_kwargs["_precomputed_schedule"] = [ _compute_dt_and_steps_for_rate( abs(float(gdot[i])), tau_est, sc_norm_est, ds=ds ) for i in range(len(gdot)) ] def _fit_creep(self, t: np.ndarray, compliance: np.ndarray, **kwargs): """Fit creep compliance.""" from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) stress_target = kwargs.get("stress_target") if stress_target is None: raise ValueError( "stress_target must be provided in kwargs for creep fitting" ) t_jax = jnp.asarray(t, dtype=jnp.float64) J_jax = jnp.asarray(compliance, dtype=jnp.float64) # Calculate n_steps statically for the objective function. # Use creep-specific cap — servo controller causes super-linear # XLA compilation cost with n_steps. t_max = float(t[-1]) dt = max(self._min_dt, t_max / self._max_scan_steps_creep) n_steps = int(t_max / dt) + 1 # Grid sizing — use coarser grid for fitting speed. # Coarser grid → larger ds → larger dt_stable → fewer CFL sub-steps # per outer step → dramatically faster (O(n_bins * n_sub) per step). sigma_max, _ = self._get_grid_params() n_bins = 51 def model_fn(x_data, params): alpha, tau, sigma_c = params time_hist, gamma_hist = creep_kernel( n_steps, stress_target, alpha, tau, sigma_c, 1.0, dt, sigma_max, n_bins ) # Add t=0 time_full = jnp.concatenate([jnp.array([0.0]), time_hist]) gamma_full = jnp.concatenate([jnp.array([0.0]), gamma_hist]) gamma_interp = jnp.interp(x_data, time_full, gamma_full) return gamma_interp / stress_target objective = create_least_squares_objective( model_fn, t_jax, J_jax, normalize=True, use_log_residuals=True ) result = nlsq_optimize( objective, self.parameters, method="scipy", use_jax=True, max_iter=kwargs.get("max_iter", 200), ) if not result.success: logger.warning(f"Optimization warning: {result.message}") # Store protocol kwargs for model_function (Bayesian inference) self._last_fit_kwargs["stress_target"] = float(stress_target) self._last_fit_kwargs["_sigma_max"] = float(sigma_max) self._last_fit_kwargs["_n_bins"] = int(n_bins) def _fit_relaxation(self, t: np.ndarray, modulus: np.ndarray, **kwargs): """Fit stress relaxation modulus.""" from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) gamma0 = kwargs.get("gamma0") if gamma0 is None: raise ValueError( "gamma0 (step strain) must be provided in kwargs for relaxation fitting" ) t_jax = jnp.asarray(t, dtype=jnp.float64) G_jax = jnp.asarray(modulus, dtype=jnp.float64) t_max = float(t[-1]) dt, n_steps = self._adaptive_dt(t_max) # Grid sizing sigma_max, n_bins = self._get_grid_params() def model_fn(x_data, params): alpha, tau, sigma_c = params time_hist, stress_hist = relaxation_kernel( n_steps, gamma0, alpha, tau, sigma_c, dt, sigma_max, n_bins ) # Initial stress approximation init_stress = gamma0 time_full = jnp.concatenate([jnp.array([0.0]), time_hist]) stress_full = jnp.concatenate([jnp.array([init_stress]), stress_hist]) sigma_interp = jnp.interp(x_data, time_full, stress_full) return sigma_interp / gamma0 objective = create_least_squares_objective( model_fn, t_jax, G_jax, normalize=True, use_log_residuals=True ) result = nlsq_optimize( objective, self.parameters, method="scipy", use_jax=True, max_iter=kwargs.get("max_iter", 500), ) if not result.success: logger.warning(f"Optimization warning: {result.message}") # Store protocol kwargs for model_function (Bayesian inference) self._last_fit_kwargs["gamma0"] = float(gamma0) self._last_fit_kwargs["_sigma_max"] = float(sigma_max) self._last_fit_kwargs["_n_bins"] = int(n_bins) def _fit_startup(self, t: np.ndarray, stress: np.ndarray, **kwargs): """Fit startup stress transient.""" from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) gdot = kwargs.get("gdot") if gdot is None: raise ValueError( "gdot (shear rate) must be provided in kwargs for startup fitting" ) t_jax = jnp.asarray(t, dtype=jnp.float64) stress_jax = jnp.asarray(stress, dtype=jnp.float64) t_max = float(t[-1]) dt, n_steps = self._adaptive_dt(t_max) # Grid sizing sigma_max, n_bins = self._get_grid_params() def model_fn(x_data, params): alpha, tau, sigma_c = params time_hist, stress_hist = startup_kernel( n_steps, gdot, alpha, tau, sigma_c, dt, sigma_max, n_bins ) time_full = jnp.concatenate([jnp.array([0.0]), time_hist]) stress_full = jnp.concatenate([jnp.array([0.0]), stress_hist]) return jnp.interp(x_data, time_full, stress_full) objective = create_least_squares_objective( model_fn, t_jax, stress_jax, normalize=True ) result = nlsq_optimize( objective, self.parameters, method="scipy", use_jax=True, max_iter=kwargs.get("max_iter", 500), ) if not result.success: logger.warning(f"Optimization warning: {result.message}") # Store protocol kwargs for model_function (Bayesian inference) self._last_fit_kwargs["gdot"] = float(gdot) self._last_fit_kwargs["_sigma_max"] = float(sigma_max) self._last_fit_kwargs["_n_bins"] = int(n_bins) def _fit_laos(self, t: np.ndarray, stress: np.ndarray, **kwargs): """Fit LAOS stress response.""" from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) gamma0 = kwargs.get("gamma0") omega = kwargs.get("omega") if gamma0 is None or omega is None: raise ValueError("gamma0 and omega must be provided for LAOS fitting") t_jax = jnp.asarray(t, dtype=jnp.float64) stress_jax = jnp.asarray(stress, dtype=jnp.float64) t_max = float(t[-1]) dt, n_steps = self._adaptive_dt(t_max) # Grid sizing sigma_max, n_bins = self._get_grid_params() def model_fn(x_data, params): alpha, tau, sigma_c = params time_hist, stress_hist = laos_kernel( n_steps, gamma0, omega, alpha, tau, sigma_c, dt, sigma_max, n_bins ) time_full = jnp.concatenate([jnp.array([0.0]), time_hist]) stress_full = jnp.concatenate([jnp.array([0.0]), stress_hist]) return jnp.interp(x_data, time_full, stress_full) objective = create_least_squares_objective( model_fn, t_jax, stress_jax, normalize=True ) result = nlsq_optimize( objective, self.parameters, method="scipy", use_jax=True, max_iter=kwargs.get("max_iter", 500), ) if not result.success: logger.warning(f"Optimization warning: {result.message}") # Store protocol kwargs for model_function (Bayesian inference) self._last_fit_kwargs["gamma0"] = float(gamma0) self._last_fit_kwargs["omega"] = float(omega) self._last_fit_kwargs["_sigma_max"] = float(sigma_max) self._last_fit_kwargs["_n_bins"] = int(n_bins) def _fit_oscillation(self, omega: np.ndarray, G_star: np.ndarray, **kwargs): """Fit SAOS oscillatory data (G', G''). Uses scipy L-BFGS-B with log-space MSE cost, consistent with the HL fitting pattern for derivative-free optimization. Args: omega: Angular frequency array (rad/s) G_star: Complex modulus — either complex array or (M, 2) [G', G''] **kwargs: Additional fitting arguments """ from scipy.optimize import minimize # Parse complex or (M, 2) format if np.iscomplexobj(G_star): G_prime_data = np.real(G_star) G_double_prime_data = np.imag(G_star) elif G_star.ndim == 2 and G_star.shape[1] == 2: G_prime_data = G_star[:, 0] G_double_prime_data = G_star[:, 1] else: raise ValueError( "G_star must be complex array or (M, 2) array of [G', G'']" ) # Grid sizing sigma_max, n_bins = self._get_grid_params() n_cycles = kwargs.get("n_cycles", 10) gamma0_saos = kwargs.get("gamma0", 0.01) # Safe log targets Gp_safe = np.maximum(np.abs(G_prime_data), 1e-10) Gpp_safe = np.maximum(np.abs(G_double_prime_data), 1e-10) def cost_fn(x): alpha_v = x[0] tau_v = 10.0 ** x[1] sigma_c_v = x[2] if not (0.01 <= alpha_v <= 0.99): return 1e6 if not (0.01 <= sigma_c_v <= 100.0): return 1e6 try: result = run_saos( jnp.asarray(omega), alpha_v, tau_v, sigma_c_v, gamma0=gamma0_saos, n_cycles=n_cycles, sigma_max=sigma_max, n_bins=n_bins, ) pred = np.array(result) pred_Gp = np.maximum(np.abs(pred[:, 0]), 1e-10) pred_Gpp = np.maximum(np.abs(pred[:, 1]), 1e-10) log_resid_Gp = np.log10(pred_Gp) - np.log10(Gp_safe) log_resid_Gpp = np.log10(pred_Gpp) - np.log10(Gpp_safe) return float(np.mean(log_resid_Gp**2 + log_resid_Gpp**2)) except Exception as exc: logger.debug("SAOS cost_fn exception: %s", exc) return 1e6 # Multi-start optimization starts = [ [0.30, -1.0, 2.0], [0.10, -2.0, 3.0], [0.50, 0.0, 1.0], ] best_x = starts[0] best_cost = np.inf # Extract callback for cancellation support (F-HL-009) _callback = kwargs.get("callback") for i, x0 in enumerate(starts): try: res = minimize( cost_fn, x0, method="Nelder-Mead", callback=_callback, options={ "maxfev": 60, "xatol": 0.02, "fatol": 0.005, "adaptive": True, }, ) if res.fun < best_cost: best_cost = res.fun best_x = res.x.copy() logger.info( f"SAOS start {i+1}/{len(starts)}: cost={res.fun:.5f}, " f"alpha={res.x[0]:.3f}, tau={10**res.x[1]:.3e}, " f"sigma_c={res.x[2]:.3f}" ) except Exception as e: logger.warning(f"SAOS start {i+1} failed: {e}") # Set fitted parameters alpha_fit = float(np.clip(best_x[0], 0.01, 0.99)) tau_fit = float(10.0 ** np.clip(best_x[1], -6, 4)) sigma_c_fit = float(np.clip(best_x[2], 0.01, 100.0)) self.parameters.set_value("alpha", alpha_fit) self.parameters.set_value("tau", tau_fit) self.parameters.set_value("sigma_c", sigma_c_fit) logger.info( f"HL SAOS fit: alpha={alpha_fit:.3f}, tau={tau_fit:.3e}, " f"sigma_c={sigma_c_fit:.3f}, cost={best_cost:.5f}" ) # Store protocol kwargs for model_function (Bayesian inference) self._last_fit_kwargs["n_cycles"] = int(n_cycles) self._last_fit_kwargs["gamma0"] = float(gamma0_saos) self._last_fit_kwargs["_sigma_max"] = float(sigma_max) self._last_fit_kwargs["_n_bins"] = int(n_bins) def _predict(self, X: np.ndarray, **kwargs: Any) -> np.ndarray: """Predict response using fitted parameters and stored test_mode.""" if self._test_mode is None: raise ValueError("Model not fitted or test_mode not set.") X_jax = jnp.asarray(X, dtype=jnp.float64) alpha = self.parameters.get_value("alpha") tau = self.parameters.get_value("tau") sigma_c = self.parameters.get_value("sigma_c") sigma_max, n_bins = self._get_grid_params(sigma_c) # The run_xxx functions in hl_kernels.py are now wrappers that # handle t_max/n_steps correctly (computes it from X array) # So we can use them directly here as X is provided at runtime. if self._test_mode in ("steady_shear", "flow_curve"): from rheojax.utils.hl_kernels import _compute_dt_and_steps_for_rate # Predict in normalized units (same as fitting) then scale back stress_scale = self._last_fit_kwargs.get("_stress_scale", 1.0) sigma_max_min_norm = self._last_fit_kwargs.get("_sigma_max_min_norm", 5.0) n_bins_pred = self._last_fit_kwargs.get("_n_bins_fit", 501) tau_val = float(1.0 if tau is None else tau) sigma_c_norm = float(1.0 if sigma_c is None else sigma_c) / stress_scale sigma_max_norm = max(5.0 * sigma_c_norm, sigma_max_min_norm) ds = 2.0 * sigma_max_norm / (n_bins_pred - 1) X_np = np.asarray(X_jax) # Single vectorized transfer per_rate_schedule = [ _compute_dt_and_steps_for_rate( abs(float(X_np[i])), tau_val, sigma_c_norm, ds=ds, ) for i in range(len(X_np)) ] pred_norm = run_flow_curve( X_jax, float(0.5 if alpha is None else alpha), tau_val, sigma_c_norm, 0.005, float(sigma_max_norm), int(n_bins_pred), per_rate_schedule=per_rate_schedule, ) return np.array(pred_norm) * stress_scale elif self._test_mode == "creep": stress_target = self._last_fit_kwargs.get("stress_target", 1.0) t_max = float(X_jax[-1]) dt_pred = max(self._min_dt, t_max / self._max_scan_steps_creep) return np.array( run_creep( X_jax, float(stress_target), float(0.5 if alpha is None else alpha), float(1.0 if tau is None else tau), float(1.0 if sigma_c is None else sigma_c), 1.0, dt_pred, float(sigma_max), int(n_bins), ) ) elif self._test_mode == "relaxation": gamma0 = self._last_fit_kwargs.get("gamma0", 1.0) t_max = float(X_jax[-1]) dt_pred, _ = self._adaptive_dt(t_max) return np.array( run_relaxation( X_jax, float(gamma0), float(0.5 if alpha is None else alpha), float(1.0 if tau is None else tau), float(1.0 if sigma_c is None else sigma_c), dt_pred, float(sigma_max), int(n_bins), ) ) elif self._test_mode == "startup": gdot = self._last_fit_kwargs.get("gdot", 1.0) t_max = float(X_jax[-1]) dt_pred, _ = self._adaptive_dt(t_max) return np.array( run_startup( X_jax, float(gdot), float(0.5 if alpha is None else alpha), float(1.0 if tau is None else tau), float(1.0 if sigma_c is None else sigma_c), dt_pred, float(sigma_max), int(n_bins), ) ) elif self._test_mode == "laos": gamma0 = self._last_fit_kwargs.get("gamma0", 1.0) omega = self._last_fit_kwargs.get("omega", 1.0) t_max = float(X_jax[-1]) dt_pred, _ = self._adaptive_dt(t_max) return np.array( run_laos( X_jax, float(gamma0), float(omega), float(0.5 if alpha is None else alpha), float(1.0 if tau is None else tau), float(1.0 if sigma_c is None else sigma_c), dt_pred, float(sigma_max), int(n_bins), ) ) elif self._test_mode in ("oscillation", "saos"): result = np.array( run_saos( X_jax, float(0.5 if alpha is None else alpha), float(1.0 if tau is None else tau), float(1.0 if sigma_c is None else sigma_c), sigma_max=float(sigma_max), n_bins=int(n_bins), ) ) # Convert (N,2) [G', G''] to complex G* for consistent API return result[:, 0] + 1j * result[:, 1] else: raise ValueError(f"Unknown test mode: {self._test_mode}")
[docs] def model_function( self, X: np.ndarray, params: np.ndarray, test_mode: str | None = None, **kwargs ): """Model function for Bayesian inference (NumPyro NUTS). Args: X: Input array params: Parameter values [alpha, tau, sigma_c] test_mode: Override test mode **kwargs: Protocol kwargs forwarded by BayesianMixin. Falls back to _last_fit_kwargs for values not provided. """ mode = test_mode if test_mode is not None else self._test_mode if mode is None: raise ValueError("test_mode required for Bayesian inference") alpha, tau, sigma_c = params X_jax = jnp.asarray(X, dtype=jnp.float64) # Helper: read protocol kwarg from explicit kwargs (forwarded by # BayesianMixin), falling back to _last_fit_kwargs, then default. def _kw(key, default): val = kwargs.get(key) if val is not None: return val val = self._last_fit_kwargs.get(key) if val is not None: return val return default # Use fixed grid for Bayesian (sigma_c is dynamic tracer, can't resize # in JIT). Use sigma_max from NLSQ fit if available, else conservative. sigma_max = _kw("_sigma_max", 50.0) n_bins = _kw("_n_bins", 201) # Helper to get adaptive dt and n_steps safely. # Bayesian uses a coarser cap because forward-mode AD through # lax.scan is much more expensive than plain evaluation. def get_dt_and_n_steps(x_arr, creep=False): try: t_max = float(x_arr[-1]) except Exception as e: if self._fit_data_metadata and "t_max" in self._fit_data_metadata: t_max = self._fit_data_metadata["t_max"] else: raise RuntimeError( "Cannot determine n_steps for Bayesian inference." ) from e cap = ( self._max_scan_steps_bayesian_creep if creep else self._max_scan_steps_bayesian ) dt = max(self._min_dt, t_max / cap) n_steps = int(t_max / dt) + 1 return dt, n_steps # Dispatch to kernels if mode in ("steady_shear", "flow_curve"): from rheojax.utils.hl_kernels import _compute_dt_and_steps_for_rate dt = ( self._min_dt ) # default for flow curve (overridden by per_rate_schedule) # Run in normalized units (same as fitting) for consistency stress_scale = self._last_fit_kwargs.get("_stress_scale", 1.0) sigma_max_min_norm = self._last_fit_kwargs.get("_sigma_max_min_norm", 5.0) n_bins_bayes = self._last_fit_kwargs.get("_n_bins_fit", 501) # Use stored NLSQ estimates for schedule (tau/sigma_c are tracers) tau_est = self._last_fit_kwargs.get("_tau_est", 1.0) sc_est = self._last_fit_kwargs.get("_sigma_c_est", 1.0) sc_norm_est = sc_est / stress_scale sigma_max_norm = max(5.0 * sc_norm_est, sigma_max_min_norm) ds = 2.0 * sigma_max_norm / (n_bins_bayes - 1) try: # Use precomputed schedule from _fit() if available — # avoids np.asarray(X_jax) which can fail during JIT tracing. schedule = self._last_fit_kwargs.get("_precomputed_schedule") if schedule is None: X_np = np.asarray(X_jax) schedule = [ _compute_dt_and_steps_for_rate( abs(float(X_np[i])), tau_est, sc_norm_est, ds=ds, ) for i in range(len(X_np)) ] # sigma_c tracer divided by stress_scale to get normalized sigma_c_norm = sigma_c / stress_scale pred_norm = run_flow_curve( X_jax, alpha, tau, sigma_c_norm, dt, sigma_max_norm, n_bins_bayes, per_rate_schedule=schedule, ) return pred_norm * stress_scale except Exception as exc: # F-HL-017 fix: fallback also uses normalized units for # consistency. Previously used raw sigma_c without rescaling. logger.warning( "Flow curve normalized path failed, using fallback", error=str(exc), ) sigma_c_norm_fb = sigma_c / stress_scale pred_fb = run_flow_curve( X_jax, alpha, tau, sigma_c_norm_fb, dt, sigma_max_norm, n_bins_bayes, ) return pred_fb * stress_scale elif mode == "creep": # Creep uses tighter cap due to super-linear compilation cost dt, n_steps = get_dt_and_n_steps(X_jax, creep=True) stress_target = _kw("stress_target", 1.0) time_hist, gamma_hist = creep_kernel( n_steps, stress_target, alpha, tau, sigma_c, 1.0, dt, sigma_max, n_bins ) time_full = jnp.concatenate([jnp.array([0.0]), time_hist]) gamma_full = jnp.concatenate([jnp.array([0.0]), gamma_hist]) return jnp.interp(X_jax, time_full, gamma_full) / stress_target elif mode == "relaxation": dt, n_steps = get_dt_and_n_steps(X_jax) gamma0 = _kw("gamma0", 1.0) time_hist, stress_hist = relaxation_kernel( n_steps, gamma0, alpha, tau, sigma_c, dt, sigma_max, n_bins ) # Init stress approx init_stress = gamma0 time_full = jnp.concatenate([jnp.array([0.0]), time_hist]) stress_full = jnp.concatenate([jnp.array([init_stress]), stress_hist]) return jnp.interp(X_jax, time_full, stress_full) / gamma0 elif mode == "startup": dt, n_steps = get_dt_and_n_steps(X_jax) gdot = _kw("gdot", 1.0) time_hist, stress_hist = startup_kernel( n_steps, gdot, alpha, tau, sigma_c, dt, sigma_max, n_bins ) time_full = jnp.concatenate([jnp.array([0.0]), time_hist]) stress_full = jnp.concatenate([jnp.array([0.0]), stress_hist]) return jnp.interp(X_jax, time_full, stress_full) elif mode == "laos": dt, n_steps = get_dt_and_n_steps(X_jax) gamma0 = _kw("gamma0", 1.0) omega = _kw("omega", 1.0) time_hist, stress_hist = laos_kernel( n_steps, gamma0, omega, alpha, tau, sigma_c, dt, sigma_max, n_bins ) time_full = jnp.concatenate([jnp.array([0.0]), time_hist]) stress_full = jnp.concatenate([jnp.array([0.0]), stress_hist]) return jnp.interp(X_jax, time_full, stress_full) elif mode in ("oscillation", "saos"): # SAOS for Bayesian: use stored n_cycles/gamma0 from _fit_oscillation() n_cycles_bayes = int(_kw("n_cycles", 5)) gamma0_bayes = float(_kw("gamma0", 0.01)) return run_saos( X_jax, alpha, tau, sigma_c, gamma0=gamma0_bayes, n_cycles=n_cycles_bayes, sigma_max=sigma_max, n_bins=n_bins, ) else: raise ValueError(f"Unknown test mode for Bayesian: {mode}")
[docs] def get_phase_state(self) -> str: """Return the phase state based on alpha.""" _alpha_val = self.parameters.get_value("alpha") alpha = 0.3 if _alpha_val is None else _alpha_val if alpha < 0.5: return "glass" else: return "fluid"