Source code for rheojax.models.ikh.ml_ikh

"""Multi-Lambda Isotropic-Kinematic Hardening (ML-IKH) Model.

Extends MIKH to N modes for capturing distributed thixotropic timescales.
Supports two yield surface formulations:

1. **Per-Mode Yield** (default): Each mode has independent yield surface
   - Total stress = Σ σᵢ (parallel connection)
   - Parameters: 7 per mode + 1 global

2. **Weighted-Sum Yield**: Single global yield surface
   - σ_y = σ_y0 + k3·Σ(wᵢ·λᵢ)
   - All modes share elastic/plastic response
   - Parameters: 5 global + 3 per mode
"""

from typing import Literal

import numpy as np

from rheojax.core.base import ArrayLike
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
from rheojax.logging import get_logger

diffrax = lazy_import("diffrax")
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.ikh._base import IKHBase
from rheojax.models.ikh._kernels import (
    make_ml_ikh_creep_ode_rhs_per_mode,
    make_ml_ikh_creep_ode_rhs_weighted_sum,
    make_ml_ikh_maxwell_ode_rhs_per_mode,
    make_ml_ikh_maxwell_ode_rhs_weighted_sum,
    ml_ikh_flow_curve_steady_state_per_mode,
    ml_ikh_flow_curve_steady_state_weighted_sum,
    ml_ikh_scan_kernel,
    ml_ikh_weighted_sum_kernel,
)

jax, jnp = safe_import_jax()

logger = get_logger(__name__)


# kwargs to filter before passing to nlsq_optimize
_MLIKH_RESERVED = {
    "test_mode",
    "gamma_dot",
    "sigma_applied",
    "sigma_0",
    "deformation_mode",
    "poisson_ratio",
    "smart_init",
    "mikh_warmstart",
}


[docs] @ModelRegistry.register( "ml_ikh", protocols=[ Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class MLIKH(IKHBase): """Multi-Lambda Isotropic-Kinematic Hardening (ML-IKH) Model. Extends MIKH to N modes connected in parallel. Each mode evolves its own internal variables (stress, backstress, structural lambda) with distinct timescales (tau_thix_i) and properties. Two Yield Mode Options: - **per_mode** (default): Each mode has independent yield surface. Total stress is sum of mode stresses. - **weighted_sum**: Single global yield surface with structure contribution from all modes: σ_y = σ_y0 + k3·Σ(wᵢ·λᵢ) Per-Mode Parameters (for each mode i=1..N): G_i: Shear modulus C_i: Backstress modulus gamma_dyn_i: Dynamic recovery sigma_y0_i: Minimal yield stress delta_sigma_y_i: Structural yield stress tau_thix_i: Rebuilding timescale Gamma_i: Breakdown coefficient Weighted-Sum Parameters: G: Global shear modulus C: Global hardening modulus gamma_dyn: Global dynamic recovery sigma_y0: Base yield stress k3: Structure-yield coupling tau_thix_i: Per-mode rebuilding timescales Gamma_i: Per-mode breakdown coefficients w_i: Per-mode structure weights Global Parameters (both yield modes): eta_inf: High-shear viscosity mu_p: Plastic viscosity (Perzyna regularization) — controls creep/flow rate Supported Protocols: - FLOW_CURVE: Steady-state stress vs shear rate (analytical solution) - STARTUP: Transient stress growth at constant shear rate (return mapping) - RELAXATION: Stress decay at constant strain (ODE formulation via Diffrax) - CREEP: Strain evolution at constant stress (ODE formulation via Diffrax) - OSCILLATION: Small amplitude oscillatory shear response - LAOS: Large amplitude oscillatory shear (return mapping with sinusoidal strain) Note: Both yield modes (per_mode, weighted_sum) support all protocols. ODE protocols (creep, relaxation) use Diffrax for numerical integration. Return mapping protocols (startup, LAOS) use JAX scan for time stepping. Args: n_modes: Number of structural modes (default: 2) yield_mode: Yield formulation ('per_mode' or 'weighted_sum') """
[docs] def __init__( self, n_modes: int = 2, yield_mode: Literal["per_mode", "weighted_sum"] = "per_mode", ): super().__init__() if n_modes < 1: raise ValueError(f"n_modes must be >= 1, got {n_modes}") if yield_mode not in ("per_mode", "weighted_sum"): raise ValueError( f"yield_mode must be 'per_mode' or 'weighted_sum', got {yield_mode}" ) self._n_modes = n_modes self._yield_mode = yield_mode self._test_mode = None self._create_parameters()
def _create_parameters(self): """Initialize parameters based on yield_mode.""" self.parameters = ParameterSet() if self._yield_mode == "per_mode": self._create_per_mode_parameters() else: self._create_weighted_sum_parameters() def _create_per_mode_parameters(self): """Create parameters for per-mode yield formulation.""" for i in range(1, self._n_modes + 1): # Elasticity & Hardening self.parameters.add( f"G_{i}", value=1e3 / self._n_modes, bounds=(0.0, 1e9), units="Pa", description=f"Mode {i} Shear modulus", ) self.parameters.add( f"C_{i}", value=5e2 / self._n_modes, bounds=(0.0, 1e9), units="Pa", description=f"Mode {i} Kinematic hardening modulus", ) self.parameters.add( f"gamma_dyn_{i}", value=1.0, bounds=(0.0, 1e4), units="-", description=f"Mode {i} Dynamic recovery", ) # Yield Stress & Thixotropy self.parameters.add( f"sigma_y0_{i}", value=10.0 / self._n_modes, bounds=(0.0, 1e9), units="Pa", description=f"Mode {i} Minimal yield stress", ) self.parameters.add( f"delta_sigma_y_{i}", value=50.0 / self._n_modes, bounds=(0.0, 1e9), units="Pa", description=f"Mode {i} Structural yield stress", ) # Timescales distributed logarithmically tau_val = 10.0 ** (i - 1 - self._n_modes / 2) self.parameters.add( f"tau_thix_{i}", value=tau_val, bounds=(1e-6, 1e12), units="s", description=f"Mode {i} Rebuilding time scale", ) self.parameters.add( f"Gamma_{i}", value=0.5, bounds=(0.0, 1e4), units="-", description=f"Mode {i} Breakdown coefficient", ) # Global Viscosity self.parameters.add( "eta_inf", value=0.1, bounds=(0.0, 1e9), units="Pa s", description="High-shear viscosity", ) self.parameters.add( "mu_p", value=1e-3, bounds=(1e-5, 1e5), units="Pa s", description="Plastic viscosity (Perzyna regularization)", ) def _create_weighted_sum_parameters(self): """Create parameters for weighted-sum yield formulation.""" # Global mechanical parameters self.parameters.add( "G", value=1e3, bounds=(1e-1, 1e9), units="Pa", description="Global shear modulus", ) self.parameters.add( "C", value=5e2, bounds=(0.0, 1e9), units="Pa", description="Global kinematic hardening modulus", ) self.parameters.add( "gamma_dyn", value=1.0, bounds=(0.0, 1e4), units="-", description="Global dynamic recovery", ) self.parameters.add( "m", value=1.0, bounds=(0.5, 3.0), units="-", description="AF recovery exponent", ) # Yield stress self.parameters.add( "sigma_y0", value=10.0, bounds=(0.0, 1e9), units="Pa", description="Base yield stress", ) self.parameters.add( "k3", value=50.0, bounds=(0.0, 1e9), units="Pa", description="Structure-yield coupling", ) # Per-mode structure parameters for i in range(1, self._n_modes + 1): tau_val = 10.0 ** (i - 1 - self._n_modes / 2) self.parameters.add( f"tau_thix_{i}", value=tau_val, bounds=(1e-6, 1e12), units="s", description=f"Mode {i} Rebuilding time scale", ) self.parameters.add( f"Gamma_{i}", value=0.5, bounds=(0.0, 1e4), units="-", description=f"Mode {i} Breakdown coefficient", ) self.parameters.add( f"w_{i}", value=1.0 / self._n_modes, bounds=(0.0, 1.0), units="-", description=f"Mode {i} structure weight", ) # Global Viscosity self.parameters.add( "eta_inf", value=0.1, bounds=(0.0, 1e9), units="Pa s", description="High-shear viscosity", ) self.parameters.add( "mu_p", value=1e-3, bounds=(1e-5, 1e5), units="Pa s", description="Plastic viscosity (Perzyna regularization)", ) def _stack_mode_params(self, params, names=None): """Stack per-mode parameters in a single pass. Reduces repeated dict lookups + jnp.stack calls from O(N*names) to O(1) per prediction, which matters during Bayesian inference (4000-8000 evaluations). """ n = self._n_modes if names is None: names = [ "G", "C", "gamma_dyn", "sigma_y0", "delta_sigma_y", "tau_thix", "Gamma", ] return { name: jnp.stack([params[f"{name}_{i}"] for i in range(1, n + 1)]) for name in names } def _predict_from_params(self, times, strains, params): """Predict using parameter dictionary (for NLSQ/Bayesian).""" if self._yield_mode == "per_mode": return self._predict_per_mode(times, strains, params) else: return self._predict_weighted_sum(times, strains, params) def _predict_per_mode(self, times, strains, params): """Predict with per-mode yield surfaces.""" # Stack all per-mode parameters in a single pass kernel_params = self._stack_mode_params(params) eta_inf = params["eta_inf"] return ml_ikh_scan_kernel( times, strains, num_modes=self._n_modes, use_viscosity=True, eta_inf=eta_inf, **kernel_params, ) def _predict_weighted_sum(self, times, strains, params): """Predict with weighted-sum yield surface.""" stacked = self._stack_mode_params(params, names=["tau_thix", "Gamma", "w"]) kernel_params = { "G": params["G"], "C": params["C"], "gamma_dyn": params["gamma_dyn"], "m": params.get("m", 1.0), "sigma_y0": params["sigma_y0"], "k3": params["k3"], "eta_inf": params["eta_inf"], **stacked, } return ml_ikh_weighted_sum_kernel( times, strains, num_modes=self._n_modes, use_viscosity=True, **kernel_params ) def _predict_flow_curve_from_params(self, gamma_dot, params): """Predict steady-state flow curve from parameter dictionary.""" if self._yield_mode == "per_mode": return ml_ikh_flow_curve_steady_state_per_mode( gamma_dot, self._n_modes, **params ) else: return ml_ikh_flow_curve_steady_state_weighted_sum( gamma_dot, self._n_modes, **params ) def _build_ode_args(self, params, **kwargs): """Build args dictionary for ODE integration.""" args = {"n_modes": self._n_modes} mu_p_val = params.get("mu_p", 1e-3) if self._yield_mode == "per_mode": # Stack all per-mode parameters in a single pass args.update(self._stack_mode_params(params)) # Default arrays for optional parameters args["eta"] = jnp.full(self._n_modes, 1e12) args["mu_p"] = jnp.full(self._n_modes, mu_p_val) args["m"] = jnp.ones(self._n_modes) else: # Global parameters args["G"] = params["G"] args["C"] = params["C"] args["gamma_dyn"] = params["gamma_dyn"] args["m"] = params.get("m", 1.0) args["sigma_y0"] = params["sigma_y0"] args["k3"] = params.get("k3", 0.0) # Per-mode structure parameters stacked = self._stack_mode_params(params, names=["tau_thix", "Gamma", "w"]) args.update(stacked) args["eta"] = 1e12 args["mu_p"] = mu_p_val args["eta_inf"] = params.get("eta_inf", 0.0) # Add protocol-specific args for key in ["gamma_dot", "sigma_applied"]: if key in kwargs: args[key] = kwargs[key] return args def _simulate_transient( self, t: jnp.ndarray, params: dict, mode: str, gamma_dot: float | None = None, sigma_applied: float | None = None, sigma_0: float | None = None, ) -> jnp.ndarray: """Simulate transient response using Diffrax ODE integration. Args: t: Time array params: Parameter dictionary mode: 'startup', 'relaxation', or 'creep' gamma_dot: Applied shear rate (for startup) sigma_applied: Applied stress (for creep) sigma_0: Initial stress (for relaxation) Returns: Stress (for startup/relaxation) or strain (for creep) """ n = self._n_modes args = self._build_ode_args(params) # Initial lambda (fully structured) lambda_init = 1.0 if self._yield_mode == "per_mode": if mode == "creep": # State: [γ, α_1..α_N, λ_1..λ_N] (1+2N) ode_fn = make_ml_ikh_creep_ode_rhs_per_mode(n) args["sigma_applied"] = ( sigma_applied if sigma_applied is not None else 100.0 ) y0 = jnp.concatenate( [ jnp.array([0.0]), # gamma jnp.zeros(n), # alphas jnp.full(n, lambda_init), # lambdas ] ) elif mode == "startup": # State: [σ_1..σ_N, α_1..α_N, λ_1..λ_N] (3N) ode_fn = make_ml_ikh_maxwell_ode_rhs_per_mode(n) args["gamma_dot"] = gamma_dot if gamma_dot is not None else 1.0 y0 = jnp.concatenate( [ jnp.zeros(n), # sigmas jnp.zeros(n), # alphas jnp.full(n, lambda_init), # lambdas ] ) else: # relaxation ode_fn = make_ml_ikh_maxwell_ode_rhs_per_mode(n) args["gamma_dot"] = 0.0 # Initial stress distributed across modes sigma_init = ( sigma_0 if sigma_0 is not None else (jnp.sum(args["sigma_y0"]) + jnp.sum(args["delta_sigma_y"])) ) lambda_init_relax = 0.5 y0 = jnp.concatenate( [ jnp.full(n, sigma_init / n), # sigmas (distributed) jnp.zeros(n), # alphas jnp.full(n, lambda_init_relax), # lambdas ] ) else: # weighted_sum if mode == "creep": # State: [γ, α, λ_1..λ_N] (2+N) ode_fn = make_ml_ikh_creep_ode_rhs_weighted_sum(n) args["sigma_applied"] = ( sigma_applied if sigma_applied is not None else 100.0 ) y0 = jnp.concatenate( [ jnp.array([0.0, 0.0]), # gamma, alpha jnp.full(n, lambda_init), # lambdas ] ) elif mode == "startup": # State: [σ, α, λ_1..λ_N] (2+N) ode_fn = make_ml_ikh_maxwell_ode_rhs_weighted_sum(n) args["gamma_dot"] = gamma_dot if gamma_dot is not None else 1.0 y0 = jnp.concatenate( [ jnp.array([0.0, 0.0]), # sigma, alpha jnp.full(n, lambda_init), # lambdas ] ) else: # relaxation ode_fn = make_ml_ikh_maxwell_ode_rhs_weighted_sum(n) args["gamma_dot"] = 0.0 sigma_init = ( sigma_0 if sigma_0 is not None else (args["sigma_y0"] + args["k3"]) ) lambda_init_relax = 0.5 y0 = jnp.concatenate( [ jnp.array([sigma_init, 0.0]), jnp.full(n, lambda_init_relax), ] ) # Diffrax setup term = diffrax.ODETerm(lambda ti, yi, args_i: ode_fn(ti, yi, args_i)) solver = diffrax.Tsit5() stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-7) t0 = t[0] t1 = t[-1] dt0 = (t1 - t0) / max(len(t), 1000) saveat = diffrax.SaveAt(ts=t) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, stepsize_controller=stepsize_controller, max_steps=1_000_000, throw=False, ) # Extract primary variable if mode == "creep": # Return strain (first component) result = sol.ys[:, 0] else: # Return stress if self._yield_mode == "per_mode": # Sum mode stresses (first n components) result = jnp.sum(sol.ys[:, :n], axis=1) else: # Single global stress (first component) result = sol.ys[:, 0] # Handle solver failures result = jnp.where( sol.result == diffrax.RESULTS.successful, result, jnp.nan * jnp.ones_like(result), ) # Add viscous contribution for startup if mode == "startup": eta_inf_val = params.get("eta_inf", 0.0) result = result + jnp.where( jnp.greater(eta_inf_val, 0.0), eta_inf_val * args["gamma_dot"], jnp.zeros_like(result), ) return result def _predict(self, X: ArrayLike, **kwargs) -> ArrayLike: """Predict response with protocol-aware dispatch. Args: X: Input data (shear rates for flow_curve, time for others) **kwargs: Options including test_mode, gamma_dot, sigma_applied, etc. Returns: Predicted stress or strain depending on protocol """ _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else ( getattr(self, "_test_mode", None) if getattr(self, "_test_mode", None) is not None else "startup" ) ) # Get parameters as dict params = self.parameters.get_values() param_dict = dict(zip(self.parameters.keys(), params, strict=True)) if test_mode == "flow_curve": return self._predict_flow_curve_from_params(jnp.asarray(X), param_dict) elif test_mode in ["creep", "relaxation"]: return self._simulate_transient( jnp.asarray(X), param_dict, test_mode, gamma_dot=kwargs.get( "gamma_dot", getattr(self, "_fit_gamma_dot", float("nan")) ), sigma_applied=kwargs.get( "sigma_applied", getattr(self, "_fit_sigma_applied", 100.0) ), sigma_0=kwargs.get("sigma_0", getattr(self, "_fit_sigma_0", 100.0)), ) else: # startup, laos, oscillation times, strains = self._extract_time_strain(X, **kwargs) return self._predict_from_params(times, strains, param_dict) def _fit(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH": """Fit model parameters using protocol-aware optimization. Args: X: Input data (shear rates, time array, or time/strain) y: Target data (stress or strain depending on protocol) **kwargs: Options including: - test_mode: Protocol ('flow_curve', 'startup', 'relaxation', 'creep', 'oscillation', 'laos') - gamma_dot: Shear rate (for startup) - sigma_applied: Applied stress (for creep) - sigma_0: Initial stress (for relaxation) """ test_mode = kwargs.get("test_mode", "startup") self._test_mode = test_mode if test_mode == "flow_curve": return self._fit_flow_curve(X, y, **kwargs) elif test_mode in ["creep", "relaxation"]: return self._fit_ode_formulation(X, y, **kwargs) elif test_mode in ["startup", "laos"]: return self._fit_return_mapping(X, y, **kwargs) elif test_mode in ["oscillation", "saos"]: return self._fit_oscillation(X, y, **kwargs) else: # Default to return mapping for strain-driven protocols return self._fit_return_mapping(X, y, **kwargs) def _fit_flow_curve(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH": """Fit to steady-state flow curve data.""" from rheojax.utils.optimization import nlsq_optimize gamma_dot = jnp.asarray(X) sigma_target = jnp.asarray(y) def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=True)) sigma_pred = self._predict_flow_curve_from_params(gamma_dot, p_dict) return sigma_pred - sigma_target filtered = {k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED} nlsq_optimize(objective, self.parameters, **filtered) return self def _fit_ode_formulation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH": """Fit using ODE formulation (for creep/relaxation).""" from rheojax.utils.optimization import nlsq_optimize t = jnp.asarray(X) y_target = jnp.asarray(y) test_mode = kwargs.get("test_mode", "relaxation") gamma_dot = kwargs.get("gamma_dot", 0.0) sigma_applied = kwargs.get("sigma_applied", 100.0) sigma_0 = kwargs.get("sigma_0", 100.0) # Cache protocol kwargs for model_function (NUTS reads these) self._fit_gamma_dot = gamma_dot self._fit_sigma_applied = sigma_applied self._fit_sigma_0 = sigma_0 # Data-informed initialization for creep. # Default sigma_y (30 Pa) often exceeds sigma_applied (e.g. 7 Pa), making # the objective flat and the Jacobian zero. Set sigma_y just below # sigma_applied AND pick mu_p large enough to keep the ODE non-stiff: # mu_p = sigma_applied × t_span → γ̇_p ≈ 0.01 1/s (slow creep, non-stiff) if test_mode == "creep" and kwargs.get("smart_init", True) and sigma_applied: sigma_a = float(sigma_applied) t_list = t.tolist() if hasattr(t, "tolist") else list(t) t_span = max(float(t_list[-1]) - float(t_list[0]), 1.0) if len(t_list) > 1 else 1.0 mu_p_init = max(sigma_a * t_span, 0.1) # large enough for non-stiff ODE sy_total = sigma_a * 0.9 # sigma_y = 90% of applied → small overstress if self._yield_mode == "per_mode": n = self._n_modes for i in range(1, n + 1): self.parameters.set_value(f"G_{i}", sigma_a * 5.0 / n) self.parameters.set_value(f"sigma_y0_{i}", sy_total * 0.4 / n) self.parameters.set_value(f"delta_sigma_y_{i}", sy_total * 0.6 / n) self.parameters.set_value(f"C_{i}", sy_total * 0.2 / n) if "mu_p" in list(self.parameters.keys()): self.parameters.set_value("mu_p", mu_p_init) def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=True)) y_pred = self._simulate_transient( t, p_dict, test_mode, gamma_dot, sigma_applied, sigma_0 ) return y_pred - y_target # Force method="scipy": diffrax ODE solvers use custom_vjp which is # incompatible with NLSQ's forward-mode autodiff (jvp). kwargs["method"] = "scipy" filtered = {k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED} nlsq_optimize(objective, self.parameters, **filtered) return self def _mikh_warmstart( self, times: jnp.ndarray, strains: jnp.ndarray, sigma_target: jnp.ndarray, **kwargs, ) -> None: """Fit single-mode MIKH and distribute its params as MLIKH starting point. Solves the simpler 11-parameter single-mode problem first, then distributes the result across MLIKH modes with logarithmically-spaced tau_thix values. Silently returns without modifying parameters on any failure so the caller can fall back to amplitude-based init. """ from rheojax.models.ikh.mikh import MIKH mikh = MIKH() # Apply data-informed scaling to MIKH before fitting stress_amp = float(jnp.max(jnp.abs(sigma_target))) stress_ss = float(jnp.mean(sigma_target[-min(10, len(sigma_target)) :])) gamma_dot_val = float(kwargs.get("gamma_dot", 1.0) or 1.0) if stress_amp > 0: mikh.parameters.set_value("G", stress_amp) mikh.parameters.set_value("sigma_y0", max(stress_ss * 0.3, 1e-3)) mikh.parameters.set_value("delta_sigma_y", max(stress_ss * 0.5, 1e-3)) mikh.parameters.set_value("C", max(stress_ss * 0.3, 1e-3)) if gamma_dot_val > 0: mikh.parameters.set_value("eta_inf", stress_ss / gamma_dot_val) # Fit single-mode MIKH (fast: 11 params, return-mapping scan) fit_kwargs = { k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED } fit_kwargs.setdefault("max_iter", 500) try: mikh._fit_return_mapping( jnp.stack([times, strains]), sigma_target, **fit_kwargs, ) except Exception: return def _gv(name: str, default: float) -> float: v = mikh.parameters.get_value(name) return float(v) if v is not None else default G_m = _gv("G", 1e3) C_m = _gv("C", 500.0) gd_m = _gv("gamma_dyn", 1.0) sy0_m = _gv("sigma_y0", 10.0) dsy_m = _gv("delta_sigma_y", 50.0) tau_m = max(_gv("tau_thix", 1.0), 1e-6) Gam_m = _gv("Gamma", 0.5) etainf_m = _gv("eta_inf", 0.1) mup_m = _gv("mu_p", 1e-3) # Distribute across modes: G/n, σ_y/n, τ log-spread around τ_mikh n = self._n_modes for i in range(1, n + 1): self.parameters.set_value(f"G_{i}", G_m / n) self.parameters.set_value(f"C_{i}", C_m / n) self.parameters.set_value(f"gamma_dyn_{i}", gd_m) self.parameters.set_value(f"sigma_y0_{i}", sy0_m / n) self.parameters.set_value(f"delta_sigma_y_{i}", dsy_m / n) # tau_thix spread: mode i offset by 10^(i-1 - (n-1)/2) relative to tau_m tau_i = float( jnp.clip(tau_m * (10.0 ** (i - 1 - (n - 1) / 2.0)), 1e-6, 1e12) ) self.parameters.set_value(f"tau_thix_{i}", tau_i) self.parameters.set_value(f"Gamma_{i}", Gam_m) self.parameters.set_value("eta_inf", etainf_m) if "mu_p" in list(self.parameters.keys()): self.parameters.set_value("mu_p", mup_m) def _fit_return_mapping(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH": """Fit using return-mapping algorithm (for startup/LAOS). For per_mode with n_modes > 1 and startup protocol: first fits a single-mode MIKH model (fast, 11 params) and distributes its result across modes as a warm-start. This avoids the local-minima problem that plagues cold-start multi-mode NLSQ. """ from rheojax.utils.optimization import nlsq_optimize times, strains = self._extract_time_strain(X, **kwargs) sigma_target = jnp.asarray(y) test_mode = kwargs.get("test_mode", "startup") # Stage 1 init: MIKH warm-start (per_mode startup only, opt-in). # Fits single-mode MIKH and distributes its params as starting point. # Helps when MIKH itself converges cleanly on the target data. # Set mikh_warmstart=True to enable: model.fit(..., mikh_warmstart=True) use_warmstart = ( self._yield_mode == "per_mode" and self._n_modes > 1 and test_mode != "laos" and kwargs.get("mikh_warmstart", False) and kwargs.get("smart_init", True) ) if use_warmstart: self._mikh_warmstart(times, strains, sigma_target, **kwargs) elif kwargs.get("smart_init", True): # Fallback: amplitude-based scaling when warm-start disabled/not applicable stress_amp = float(jnp.max(jnp.abs(sigma_target))) if stress_amp > 0 and self._yield_mode == "per_mode": for i in range(1, self._n_modes + 1): cur_G = self.parameters.get_value(f"G_{i}") if cur_G is not None and cur_G < stress_amp * 0.1: self.parameters.set_value(f"G_{i}", stress_amp / self._n_modes) cur_sy = self.parameters.get_value(f"sigma_y0_{i}") if cur_sy is not None and cur_sy < stress_amp * 0.01: self.parameters.set_value( f"sigma_y0_{i}", stress_amp * 0.1 / self._n_modes ) elif stress_amp > 0: cur_G = self.parameters.get_value("G") if cur_G is not None and cur_G < stress_amp * 0.1: self.parameters.set_value("G", stress_amp) cur_sy = self.parameters.get_value("sigma_y0") if cur_sy is not None and cur_sy < stress_amp * 0.01: self.parameters.set_value("sigma_y0", stress_amp * 0.1) def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=True)) sigma_pred = self._predict_from_params(times, strains, p_dict) return sigma_pred - sigma_target filtered = {k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED} nlsq_optimize(objective, self.parameters, **filtered) return self def _fit_oscillation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH": """Fit to oscillation data (SAOS/MAOS). Supports two modes: 1. Frequency-domain: X=omega, y=|G*| or complex G* (uses Maxwell analytical solution) 2. Time-domain: X=time, y=stress (uses return mapping with sinusoidal strain) """ X_arr = jnp.asarray(X) # Detect if this is frequency-domain or time-domain is_time_domain = len(X_arr) > 100 if is_time_domain: return self._fit_return_mapping(X, y, **kwargs) else: return self._fit_saos_frequency_domain(X, y, **kwargs) def _fit_saos_frequency_domain( self, X: ArrayLike, y: ArrayLike, **kwargs ) -> "MLIKH": """Fit to frequency-domain SAOS data using Maxwell analytical expressions. Fits G' and G'' independently when complex or (N, 2) input is provided. Falls back to magnitude-only fitting for real 1D input. .. warning:: The IKH constitutive model has no viscosity parameter, so the Maxwell relaxation time τ = η/G is approximated with η = 1e12 (effectively infinite). This gives ωτ >> 1 at all accessible frequencies, producing G' ≈ G (constant) and G'' ≈ 0. As a result, **SAOS fitting can only recover the elastic modulus G; it cannot reproduce frequency-dependent loss behaviour**. If your data shows significant G'' variation, consider a model family with an explicit viscosity parameter (e.g. Giesekus, fluidity, or Maxwell). Args: X: Angular frequency array (omega) y: Complex G* = G' + iG'', (N, 2) array [G', G''], or real |G*| """ logger.warning( "IKH SAOS: model has no viscosity parameter; τ approximated as " "η/G with η=1e12, giving G'≈G (constant) and G''≈0. " "SAOS fitting can only recover elastic modulus, not loss behaviour." ) from rheojax.utils.optimization import nlsq_optimize omega = jnp.asarray(X) # Handle different y formats — always extract G' and G'' for # component-wise fitting (magnitude-only discards phase angle δ) y_arr = jnp.asarray(y) if jnp.iscomplexobj(y_arr): # Complex G* = G' + iG'' provided target_G_prime = jnp.real(y_arr) target_G_double_prime = jnp.imag(y_arr) fit_components = True elif y_arr.ndim == 2 and y_arr.shape[1] == 2: # (N, 2) array provided - [G', G''] format target_G_prime = y_arr[:, 0] target_G_double_prime = y_arr[:, 1] fit_components = True else: # Real 1D array — assume magnitude |G*| (no phase info available) target_magnitude = y_arr fit_components = False def objective(param_values): """Compute residual using Maxwell analytical SAOS expressions.""" p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=True)) # Extract G and eta based on yield_mode if self._yield_mode == "per_mode": # Sum contributions from all modes (parallel Maxwell elements) G_prime_total = jnp.zeros_like(omega) G_double_prime_total = jnp.zeros_like(omega) for i in range(1, self._n_modes + 1): G_i = p_dict[f"G_{i}"] # Estimate eta from bounds or use large value for elastic behavior eta_i = 1e12 # Effectively infinite for SAOS tau_i = eta_i / G_i wt_i = omega * tau_i G_prime_i = G_i * wt_i**2 / (1 + wt_i**2) G_double_prime_i = G_i * wt_i / (1 + wt_i**2) G_prime_total += G_prime_i G_double_prime_total += G_double_prime_i else: # Weighted-sum mode: use global G G = p_dict["G"] eta = 1e12 # Effectively infinite tau = eta / G wt = omega * tau G_prime_total = G * wt**2 / (1 + wt**2) G_double_prime_total = G * wt / (1 + wt**2) if fit_components: # Fit G' and G'' independently (preserves phase information) return jnp.concatenate( [ G_prime_total - target_G_prime, G_double_prime_total - target_G_double_prime, ] ) else: # Magnitude-only fallback (no phase info available) G_star_magnitude = jnp.sqrt( G_prime_total**2 + G_double_prime_total**2 + 1e-30 ) return G_star_magnitude - target_magnitude filtered = {k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED} nlsq_optimize(objective, self.parameters, **filtered) return self
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """NumPyro model function with protocol-aware dispatch. Accepts protocol-specific kwargs (gamma_dot, sigma_applied, sigma_0). Falls back to values cached during _fit() if not provided. Args: X: Input data params: Parameter array or dict from NumPyro test_mode: Optional protocol override **kwargs: Protocol-specific arguments Returns: Predicted response """ # Convert params to dict if array if isinstance(params, (np.ndarray, jnp.ndarray)): param_names = list(self.parameters.keys()) param_dict = dict(zip(param_names, params, strict=True)) else: param_dict = params if test_mode is None: test_mode = getattr(self, "_test_mode", None) if test_mode is None: test_mode = "startup" mode = test_mode # Extract protocol-specific args from kwargs, falling back to # cached values from _fit_ode_formulation() gamma_dot = kwargs.get("gamma_dot", getattr(self, "_fit_gamma_dot", 1.0)) sigma_applied = kwargs.get( "sigma_applied", getattr(self, "_fit_sigma_applied", 100.0) ) sigma_0 = kwargs.get("sigma_0", getattr(self, "_fit_sigma_0", 100.0)) if mode == "flow_curve": return self._predict_flow_curve_from_params(jnp.asarray(X), param_dict) elif mode in ["creep", "relaxation"]: return self._simulate_transient( jnp.asarray(X), param_dict, mode, gamma_dot=gamma_dot, sigma_applied=sigma_applied, sigma_0=sigma_0, ) elif mode == "oscillation": # Frequency-domain SAOS using multi-Maxwell analytical expressions omega = jnp.asarray(X) if self._yield_mode == "per_mode": # Sum contributions from all modes (parallel Maxwell elements) G_prime_total = jnp.zeros_like(omega) G_double_prime_total = jnp.zeros_like(omega) for i in range(1, self._n_modes + 1): G_i = param_dict[f"G_{i}"] eta_i = param_dict.get( f"eta_{i}", 1e12 ) # High viscosity if not specified tau_i = eta_i / G_i wt_i = omega * tau_i G_prime_total += G_i * wt_i**2 / (1 + wt_i**2) G_double_prime_total += G_i * wt_i / (1 + wt_i**2) else: # Weighted-sum mode: use global G G = param_dict["G"] eta = param_dict.get("eta", 1e12) # High viscosity if not specified tau = eta / G wt = omega * tau G_prime_total = G * wt**2 / (1 + wt**2) G_double_prime_total = G * wt / (1 + wt**2) return jnp.column_stack([G_prime_total, G_double_prime_total]) else: # startup, laos # startup/laos modes need strain computed from kwargs times, strains = self._extract_time_strain(X, **kwargs) return self._predict_from_params(times, strains, param_dict)
@property def n_modes(self) -> int: """Number of structural modes.""" return self._n_modes @property def yield_mode(self) -> str: """Yield formulation mode ('per_mode' or 'weighted_sum').""" return self._yield_mode # ------------------------------------------------------------------------- # Convenience Methods for Protocol-Specific Predictions # -------------------------------------------------------------------------
[docs] def predict_flow_curve(self, gamma_dot: ArrayLike) -> ArrayLike: """Predict steady-state flow curve. Args: gamma_dot: Array of shear rates Returns: Array of steady-state stresses """ return self._predict(gamma_dot, test_mode="flow_curve")
[docs] def predict_startup( self, t: ArrayLike, gamma_dot: float = 1.0, strain: ArrayLike | None = None ) -> ArrayLike: """Predict startup shear response. Args: t: Time array gamma_dot: Applied shear rate (default: 1.0) strain: Optional strain array (if None, uses gamma_dot * t) Returns: Array of stresses """ t_arr = jnp.asarray(t) if strain is None: strain = gamma_dot * t_arr return self._predict(jnp.stack([t_arr, strain]), test_mode="startup")
[docs] def predict_relaxation(self, t: ArrayLike, sigma_0: float = 100.0) -> ArrayLike: """Predict stress relaxation after step strain. Args: t: Time array sigma_0: Initial stress (default: 100.0) Returns: Array of decaying stresses """ return self._predict(t, test_mode="relaxation", sigma_0=sigma_0)
[docs] def predict_creep(self, t: ArrayLike, sigma_applied: float = 50.0) -> ArrayLike: """Predict creep response under constant stress. Args: t: Time array sigma_applied: Applied constant stress (default: 50.0) Returns: Array of strains """ return self._predict(t, test_mode="creep", sigma_applied=sigma_applied)
[docs] def predict_laos( self, t: ArrayLike, gamma_0: float = 1.0, omega: float = 1.0 ) -> ArrayLike: """Predict large amplitude oscillatory shear response. Args: t: Time array gamma_0: Strain amplitude (default: 1.0) omega: Angular frequency in rad/s (default: 1.0) Returns: Array of stresses """ t_arr = jnp.asarray(t) strain = gamma_0 * jnp.sin(omega * t_arr) return self._predict(jnp.stack([t_arr, strain]), test_mode="laos")