Source code for rheojax.models.dmt.local

"""Local (0D) de Souza Mendes-Thompson (DMT) model.

Implements the homogeneous DMT model for thixotropic yield-stress fluids
with optional Maxwell viscoelastic backbone.

Supports all standard rheological protocols:
- Flow curve (steady shear)
- Start-up shear (stress overshoot)
- Stress relaxation (Maxwell only)
- Creep (delayed yielding)
- SAOS (Maxwell only)
- LAOS (nonlinear oscillatory)
"""

from __future__ import annotations

from typing import Literal

import numpy as np

from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger
from rheojax.models.dmt._base import DMTBase
from rheojax.models.dmt._kernels import (
    elastic_modulus,
    invert_stress_for_gamma_dot_exponential,
    invert_stress_for_gamma_dot_hb,
    maxwell_stress_evolution,
    saos_moduli_maxwell,
    steady_stress_exponential,
    steady_stress_herschel_bulkley,
    structure_evolution,
    viscosity_exponential,
    viscosity_herschel_bulkley_regularized,
)

# Safe JAX import
jax, jnp = safe_import_jax()

# Module logger
logger = get_logger(__name__)

_MISSING = object()


[docs] @ModelRegistry.register( "dmt_local", protocols=[ Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class DMTLocal(DMTBase): """Local (0D) DMT model for homogeneous thixotropic flow. This model assumes spatially homogeneous flow (no shear banding). For shear banding analysis, use DMTNonlocal. The model captures: - **Yielding**: Stress plateau at low shear rates (HB closure) - **Thixotropy**: Time-dependent viscosity via structure kinetics - **Viscoelasticity**: Optional Maxwell backbone for overshoot/SAOS Two viscosity closures: - "exponential": η(λ) = η_∞·(η_0/η_∞)^λ (smooth, original DMT) - "herschel_bulkley": Explicit yield stress τ_y(λ) + K(λ)|γ̇|^n Parameters ---------- closure : {"exponential", "herschel_bulkley"}, default "exponential" Viscosity closure type. include_elasticity : bool, default True Include Maxwell viscoelastic backbone for stress overshoot and SAOS. Examples -------- >>> from rheojax.models.dmt import DMTLocal >>> >>> # Create model with Herschel-Bulkley closure >>> model = DMTLocal(closure="herschel_bulkley", include_elasticity=True) >>> >>> # Fit to flow curve data >>> model.fit(gamma_dot, stress, test_mode="flow_curve") >>> >>> # Simulate startup shear >>> t, stress, lam = model.simulate_startup(gamma_dot=10.0, t_end=100.0) See Also -------- DMTNonlocal : Nonlocal (1D) variant with shear banding FluidityLocal : Simpler fluidity-based thixotropic model References ---------- de Souza Mendes, P.R. & Thompson, R.L. (2013). "A unified approach to model elasto-viscoplastic thixotropic yield-stress materials and apparent yield-stress fluids." Rheol. Acta 52, 673-694. """
[docs] def __init__( self, closure: Literal["exponential", "herschel_bulkley"] = "exponential", include_elasticity: bool = True, ): """Initialize DMTLocal model.""" super().__init__(closure=closure, include_elasticity=include_elasticity) logger.info( "DMTLocal initialized", closure=closure, include_elasticity=include_elasticity, )
# ========================================================================= # Required Abstract Methods # ========================================================================= def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> DMTLocal: """Fit model to data. Dispatches to protocol-specific fitting method based on test_mode. Parameters ---------- X : array Independent variable (γ̇ for flow_curve, t for transients) y : array Dependent variable (σ for flow_curve/startup, γ for creep) **kwargs Additional arguments including test_mode Returns ------- self Fitted model instance """ test_mode = kwargs.get("test_mode", "flow_curve") # P2-DMT-001: Cache test_mode so model_function has a safe fallback # when called outside the NUTS closure (e.g. direct model_function call). self._test_mode = test_mode if test_mode in ("flow_curve", "rotation"): return self._fit_flow_curve(X, y, **kwargs) elif test_mode == "startup": return self._fit_transient(X, y, **kwargs) elif test_mode == "relaxation": if not self.include_elasticity: raise ValueError( "Relaxation requires include_elasticity=True (DMT-Maxwell)" ) return self._fit_relaxation(X, y, **kwargs) elif test_mode == "creep": return self._fit_creep(X, y, **kwargs) elif test_mode == "oscillation": if not self.include_elasticity: raise ValueError("SAOS requires include_elasticity=True (DMT-Maxwell)") return self._fit_oscillation(X, y, **kwargs) elif test_mode == "laos": return self._fit_laos(X, y, **kwargs) else: raise ValueError(f"Unknown test_mode: {test_mode}") def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """Predict model response. Dispatches to protocol-specific prediction method based on test_mode. Parameters ---------- X : array Independent variable **kwargs Additional arguments including test_mode Returns ------- array Predicted response """ test_mode = kwargs.get("test_mode", "flow_curve") if test_mode in ("flow_curve", "rotation"): return self._predict_flow_curve(X) elif test_mode == "startup": return self._predict_startup(X, **kwargs) elif test_mode == "relaxation": return self._predict_relaxation(X, **kwargs) elif test_mode == "creep": return self._predict_creep(X, **kwargs) elif test_mode == "oscillation": return self._predict_oscillation(X, **kwargs) elif test_mode == "laos": return self._predict_laos(X, **kwargs) else: raise ValueError(f"Unknown test_mode for prediction: {test_mode}") # ========================================================================= # Flow Curve (Steady Shear) # ========================================================================= def _fit_flow_curve( self, gamma_dot: np.ndarray, stress: np.ndarray, **kwargs ) -> DMTLocal: """Fit to steady-state flow curve σ(γ̇). Uses NLSQ to optimize parameters to match equilibrium stress-rate curve. Parameters ---------- gamma_dot : array Shear rate array [1/s] stress : array Stress array [Pa] **kwargs Fitting options Returns ------- self Fitted model """ from rheojax.core.parameters import ParameterSet from rheojax.utils.optimization import nlsq_curve_fit # Convert to numpy gamma_dot_np = np.asarray(gamma_dot, dtype=np.float64) stress_np = np.asarray(stress, dtype=np.float64) # Create a ParameterSet with only flow curve parameters if self.closure == "exponential": param_names = ["eta_0", "eta_inf", "a", "c"] else: param_names = ["tau_y0", "K0", "n_flow", "eta_inf", "a", "c", "m1", "m2"] fit_params = ParameterSet() for name in param_names: param = self.parameters[name] fit_params.add( name=name, value=param.value, bounds=param.bounds, units=param.units, description=param.description, ) # Define model function f(x, params_array) -> y_pred def model_fn(x, params_array): if self.closure == "exponential": eta_0, eta_inf, a, c = params_array[:4] return steady_stress_exponential(jnp.array(x), eta_0, eta_inf, a, c) else: tau_y0, K0, n_flow, eta_inf, a, c, m1, m2 = params_array[:8] return steady_stress_herschel_bulkley( jnp.array(x), tau_y0, K0, n_flow, eta_inf, a, c, m1, m2 ) # Filter protocol kwargs before forwarding to NLSQ _dmt_reserved = { "test_mode", "gamma_dot", "lam_init", "sigma_init", "sigma_0", "lam_0", "gamma_0", "omega_laos", "n_cycles", "points_per_cycle", "deformation_mode", "poisson_ratio", } nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _dmt_reserved} result = nlsq_curve_fit( model_fn, gamma_dot_np, stress_np, fit_params, **nlsq_kwargs ) # Update main parameters with fitted values for name in param_names: self.parameters[name].value = fit_params[name].value self._fitted = True self._fit_result = result logger.info( "DMTLocal flow curve fit complete", r_squared=result.r_squared, rmse=result.rmse, ) return self def _predict_flow_curve(self, gamma_dot: np.ndarray) -> np.ndarray: """Predict steady-state stress from flow curve. Parameters ---------- gamma_dot : array Shear rate [1/s] Returns ------- array Predicted stress [Pa] """ gamma_dot_jax = jnp.array(gamma_dot) params = self.get_parameter_dict() if self.closure == "exponential": stress = steady_stress_exponential( gamma_dot_jax, params["eta_0"], params["eta_inf"], params["a"], params["c"], ) else: stress = steady_stress_herschel_bulkley( gamma_dot_jax, params["tau_y0"], params["K0"], params["n_flow"], params["eta_inf"], params["a"], params["c"], params["m1"], params["m2"], ) return np.array(stress) # ========================================================================= # Startup Shear # =========================================================================
[docs] def simulate_startup( self, gamma_dot: float, t_end: float, dt: float = 0.01, lam_init: float = 1.0, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Simulate startup of steady shear from rest. Parameters ---------- gamma_dot : float Applied constant shear rate [1/s] t_end : float Simulation end time [s] dt : float Time step [s] lam_init : float Initial structure parameter (default: 1.0, fully structured) Returns ------- t : array Time array [s] stress : array Stress response [Pa] lam : array Structure parameter evolution """ n_steps = int(t_end / dt) t = jnp.linspace(0, t_end, n_steps) params = self.get_parameter_dict() if self.include_elasticity: t_out, stress, lam = self._simulate_startup_maxwell( t, dt, gamma_dot, lam_init, params ) else: t_out, stress, lam = self._simulate_startup_viscous( t, dt, gamma_dot, lam_init, params ) # Convert to numpy for public API return np.array(t_out), np.array(stress), np.array(lam)
def _simulate_startup_viscous( self, t: jnp.ndarray, dt: float, gamma_dot: float, lam_init: float, params: dict, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Simulate startup for DMT-Viscous (no elasticity).""" def step(lam, _): # Structure evolution dlam = structure_evolution( lam, gamma_dot, params["t_eq"], params["a"], params["c"] ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) # Viscosity if self.closure == "exponential": eta = viscosity_exponential(lam_new, params["eta_0"], params["eta_inf"]) else: eta = viscosity_herschel_bulkley_regularized( lam_new, gamma_dot, params["tau_y0"], params["K0"], params["n_flow"], params["eta_inf"], params["m1"], params["m2"], ) stress = eta * gamma_dot return lam_new, (stress, lam_new) step = jax.checkpoint(step) _, (stress, lam) = jax.lax.scan(step, lam_init, None, length=len(t)) # Return JAX arrays (conversion to numpy happens in public methods) return t, stress, lam def _simulate_startup_maxwell( self, t: jnp.ndarray, dt: float, gamma_dot: float, lam_init: float, params: dict, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Simulate startup for DMT-Maxwell (with elasticity). Uses semi-implicit (backward Euler) integration for unconditional stability when dt > theta_1 (relaxation time). """ def step(state, _): sigma, lam = state # Structure evolution dlam = structure_evolution( lam, gamma_dot, params["t_eq"], params["a"], params["c"] ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) # Elastic modulus G = elastic_modulus(lam_new, params["G0"], params["m_G"]) # Viscosity if self.closure == "exponential": eta = viscosity_exponential(lam_new, params["eta_0"], params["eta_inf"]) else: eta = viscosity_herschel_bulkley_regularized( lam_new, gamma_dot, params["tau_y0"], params["K0"], params["n_flow"], params["eta_inf"], params["m1"], params["m2"], ) # Relaxation time theta_1 = eta / jnp.maximum(G, 1e-10) # Semi-implicit stress evolution (unconditionally stable) # dsigma/dt = G*gamma_dot - sigma/theta_1 # Using backward Euler: sigma_new = (sigma + dt*G*gamma_dot) / (1 + dt/theta_1) sigma_new = (sigma + dt * G * gamma_dot) / (1.0 + dt / theta_1) return (sigma_new, lam_new), (sigma_new, lam_new) step = jax.checkpoint(step) init_state = (0.0, lam_init) # Zero initial stress _, (stress, lam) = jax.lax.scan(step, init_state, None, length=len(t)) # Return JAX arrays (conversion to numpy happens in public methods) return t, stress, lam def _fit_transient(self, t: np.ndarray, stress: np.ndarray, **kwargs) -> DMTLocal: """Fit to transient startup data.""" # Extract gamma_dot from kwargs gamma_dot = kwargs.get("gamma_dot", 1.0) lam_init = kwargs.get("lam_init", 1.0) # Cache for model_function (Bayesian inference bridge) self._gamma_dot_applied = gamma_dot self._startup_lam_init = lam_init from rheojax.utils.optimization import fit_with_nlsq t_jax = jnp.array(t) stress_jax = jnp.array(stress) dt = float(t[1] - t[0]) # Estimate stress scale for normalization stress_scale = jnp.maximum(jnp.std(stress_jax), 1.0) def residual_fn(params_array): # Reconstruct parameter dict param_dict = self._params_array_to_dict(params_array) _, stress_pred, _ = self._simulate_with_params( t_jax, dt, gamma_dot, lam_init, param_dict ) # Clip extreme predictions to avoid NaN gradients stress_pred = jnp.clip(stress_pred, -1e12, 1e12) # Normalize residuals for numerical stability return (stress_pred - stress_jax) / stress_scale params_array, bounds = self._get_params_for_optimization() # Filter protocol kwargs before forwarding to NLSQ _dmt_reserved = { "test_mode", "gamma_dot", "lam_init", "sigma_init", "sigma_0", "lam_0", "gamma_0", "omega_laos", "n_cycles", "points_per_cycle", "deformation_mode", "poisson_ratio", } nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _dmt_reserved} result = fit_with_nlsq(residual_fn, params_array, bounds=bounds, **nlsq_kwargs) self._set_params_from_array(result.x) self._fitted = True return self def _predict_startup(self, t: np.ndarray, **kwargs) -> np.ndarray: """Predict startup stress.""" gamma_dot = kwargs.get("gamma_dot", 1.0) lam_init = kwargs.get("lam_init", 1.0) t_jax = jnp.array(t) dt = float(t[1] - t[0]) if len(t) > 1 else 0.01 params = self.get_parameter_dict() # Use internal method directly to ensure consistent output length if self.include_elasticity: _, stress, _ = self._simulate_startup_maxwell( t_jax, dt, gamma_dot, lam_init, params ) else: _, stress, _ = self._simulate_startup_viscous( t_jax, dt, gamma_dot, lam_init, params ) return np.array(stress) # ========================================================================= # Stress Relaxation (Maxwell only) # =========================================================================
[docs] def simulate_relaxation( self, t_end: float, dt: float = 0.01, sigma_init: float = 100.0, lam_init: float = 0.5, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Simulate stress relaxation after cessation of shear. Requires include_elasticity=True. Parameters ---------- t_end : float Simulation end time [s] dt : float Time step [s] sigma_init : float Initial stress at cessation [Pa] lam_init : float Initial structure at cessation Returns ------- t : array Time array [s] stress : array Relaxing stress [Pa] lam : array Recovering structure """ if not self.include_elasticity: raise ValueError("Stress relaxation requires include_elasticity=True") n_steps = int(t_end / dt) t = jnp.linspace(0, t_end, n_steps) params = self.get_parameter_dict() def step(state, _): sigma, lam = state # Structure recovery (no breakdown, γ̇ = 0) dlam = (1.0 - lam) / params["t_eq"] lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) # Elastic modulus G = elastic_modulus(lam_new, params["G0"], params["m_G"]) # Viscosity at zero shear rate if self.closure == "exponential": eta = viscosity_exponential(lam_new, params["eta_0"], params["eta_inf"]) else: eta = params["eta_inf"] # HB at zero shear rate # Relaxation time theta_1 = eta / jnp.maximum(G, 1e-10) # Stress relaxation dsigma = -sigma / jnp.maximum(theta_1, 1e-12) sigma_new = sigma + dt * dsigma return (sigma_new, lam_new), (sigma_new, lam_new) step = jax.checkpoint(step) init_state = (sigma_init, lam_init) _, (stress, lam) = jax.lax.scan(step, init_state, None, length=n_steps) return np.array(t), np.array(stress), np.array(lam)
def _fit_relaxation(self, t: np.ndarray, stress: np.ndarray, **kwargs) -> DMTLocal: """Fit to stress relaxation data σ(t) after cessation of shear. Requires include_elasticity=True. Parameters ---------- t : array Time array [s] stress : array Relaxing stress [Pa] **kwargs sigma_init : float Initial stress at cessation [Pa] (default: stress[0]) lam_init : float Initial structure at cessation (default: 0.5) """ from rheojax.utils.optimization import fit_with_nlsq sigma_init = kwargs.get("sigma_init", float(stress[0])) lam_init = kwargs.get("lam_init", 0.5) # Cache for model_function (Bayesian inference bridge) self._relax_sigma_init = sigma_init self._relax_lam_init = lam_init stress_jax = jnp.array(stress) dt = float(t[1] - t[0]) n_steps = len(t) stress_scale = jnp.maximum(jnp.std(stress_jax), 1.0) def residual_fn(params_array): param_dict = self._params_array_to_dict(params_array) def step(state, _): sigma, lam = state dlam = (1.0 - lam) / param_dict["t_eq"] lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) G = elastic_modulus(lam_new, param_dict["G0"], param_dict["m_G"]) if self.closure == "exponential": eta = viscosity_exponential( lam_new, param_dict["eta_0"], param_dict["eta_inf"] ) else: eta = param_dict["eta_inf"] theta_1 = eta / jnp.maximum(G, 1e-10) dsigma = -sigma / jnp.maximum(theta_1, 1e-12) sigma_new = sigma + dt * dsigma return (sigma_new, lam_new), sigma_new step = jax.checkpoint(step) init_state = (jnp.float64(sigma_init), jnp.float64(lam_init)) _, stress_pred = jax.lax.scan(step, init_state, None, length=n_steps) stress_pred = jnp.clip(stress_pred, -1e12, 1e12) return (stress_pred - stress_jax) / stress_scale params_array, bounds = self._get_params_for_optimization() _dmt_reserved = { "test_mode", "gamma_dot", "lam_init", "sigma_init", "sigma_0", "lam_0", "gamma_0", "omega_laos", "n_cycles", "points_per_cycle", "deformation_mode", "poisson_ratio", } nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _dmt_reserved} result = fit_with_nlsq(residual_fn, params_array, bounds=bounds, **nlsq_kwargs) self._set_params_from_array(result.x) self._fitted = True return self def _predict_relaxation(self, t: np.ndarray, **kwargs) -> np.ndarray: """Predict relaxation stress.""" sigma_init = kwargs.get("sigma_init", 100.0) lam_init = kwargs.get("lam_init", 0.5) _, stress, _ = self.simulate_relaxation( float(t[-1]), float(t[1] - t[0]), sigma_init, lam_init ) return stress # ========================================================================= # Creep # =========================================================================
[docs] def simulate_creep( self, sigma_0: float, t_end: float, dt: float = 0.01, lam_init: float = 1.0, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Simulate creep under constant applied stress. For the Maxwell variant (include_elasticity=True), the total strain includes both elastic and viscous contributions: γ(t) = γ_e(t) + γ_v(t) where: - γ_e(t) = σ₀/G(λ(t)) is the elastic strain (changes with structure) - γ_v(t) = ∫₀ᵗ σ₀/η(λ(s)) ds is the viscous strain This correctly captures: - Initial elastic jump: γ(0⁺) = σ₀/G(λ_init) - Elastic strain recovery/growth as structure evolves - Viscous flow accumulation Parameters ---------- sigma_0 : float Applied constant stress [Pa] t_end : float Simulation end time [s] dt : float Time step [s] lam_init : float Initial structure parameter Returns ------- t : array Time array [s] gamma : array Total accumulated strain (elastic + viscous for Maxwell variant) gamma_dot : array Total shear rate evolution [1/s] lam : array Structure parameter evolution """ n_steps = int(t_end / dt) t = jnp.linspace(0, t_end, n_steps) params = self.get_parameter_dict() if self.include_elasticity: # Maxwell variant: track elastic and viscous strain separately return self._simulate_creep_maxwell( t, dt, n_steps, sigma_0, lam_init, params ) else: # Viscous variant: purely viscous flow return self._simulate_creep_viscous( t, dt, n_steps, sigma_0, lam_init, params )
def _simulate_creep_viscous( self, t: jnp.ndarray, dt: float, n_steps: int, sigma_0: float, lam_init: float, params: dict, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Simulate creep for DMT-Viscous (no elasticity). Pure viscous flow: γ̇ = σ₀/η(λ) """ def step(state, _): lam, gamma = state # Viscous flow rate: γ̇ = σ₀/η(λ) if self.closure == "exponential": gamma_dot = invert_stress_for_gamma_dot_exponential( sigma_0, lam, params["eta_0"], params["eta_inf"] ) else: gamma_dot = invert_stress_for_gamma_dot_hb( sigma_0, lam, params["tau_y0"], params["K0"], params["n_flow"], params["eta_inf"], params["m1"], params["m2"], ) # Structure evolution (driven by viscous flow rate) dlam = structure_evolution( lam, gamma_dot, params["t_eq"], params["a"], params["c"] ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) # Strain accumulation gamma_new = gamma + dt * gamma_dot return (lam_new, gamma_new), (gamma_new, gamma_dot, lam_new) step = jax.checkpoint(step) init_state = (lam_init, 0.0) _, (gamma, gamma_dot, lam) = jax.lax.scan( step, init_state, None, length=n_steps ) return np.array(t), np.array(gamma), np.array(gamma_dot), np.array(lam) def _simulate_creep_maxwell( self, t: jnp.ndarray, dt: float, n_steps: int, sigma_0: float, lam_init: float, params: dict, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Simulate creep for DMT-Maxwell (with elasticity). Total strain: γ = γ_e + γ_v - Elastic: γ_e = σ₀/G(λ) - Viscous: dγ_v/dt = σ₀/η(λ) The total shear rate includes both viscous flow and elastic strain change: γ̇ = dγ_v/dt + dγ_e/dt = σ₀/η + d(σ₀/G)/dt Structure evolution uses the viscous flow rate as the driving deformation. """ def step(state, _): lam, gamma_v, lam_prev = state # Compute elastic modulus and strain G = elastic_modulus(lam, params["G0"], params["m_G"]) gamma_e = sigma_0 / jnp.maximum(G, 1e-10) # Previous elastic strain (for rate calculation) G_prev = elastic_modulus(lam_prev, params["G0"], params["m_G"]) gamma_e_prev = sigma_0 / jnp.maximum(G_prev, 1e-10) # Viscous flow rate: γ̇_v = σ₀/η(λ) if self.closure == "exponential": gamma_dot_v = invert_stress_for_gamma_dot_exponential( sigma_0, lam, params["eta_0"], params["eta_inf"] ) else: gamma_dot_v = invert_stress_for_gamma_dot_hb( sigma_0, lam, params["tau_y0"], params["K0"], params["n_flow"], params["eta_inf"], params["m1"], params["m2"], ) # Elastic strain rate (from structure change) gamma_dot_e = (gamma_e - gamma_e_prev) / dt # Total shear rate gamma_dot_total = gamma_dot_v + gamma_dot_e # Structure evolution (driven by viscous flow rate) # Use viscous rate since that represents actual material deformation dlam = structure_evolution( lam, gamma_dot_v, params["t_eq"], params["a"], params["c"] ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) # Viscous strain accumulation gamma_v_new = gamma_v + dt * gamma_dot_v # Total strain = elastic + viscous gamma_total = gamma_e + gamma_v_new return (lam_new, gamma_v_new, lam), (gamma_total, gamma_dot_total, lam_new) # State: (λ, γ_v, λ_prev) step = jax.checkpoint(step) init_state = (lam_init, 0.0, lam_init) _, (gamma, gamma_dot, lam) = jax.lax.scan( step, init_state, None, length=n_steps ) return np.array(t), np.array(gamma), np.array(gamma_dot), np.array(lam) def _fit_creep(self, t: np.ndarray, gamma: np.ndarray, **kwargs) -> DMTLocal: """Fit to creep data γ(t) under constant applied stress. Parameters ---------- t : array Time array [s] gamma : array Total strain array **kwargs sigma_0 : float Applied constant stress [Pa] (default: 10.0) lam_init : float Initial structure parameter (default: 1.0) """ from rheojax.utils.optimization import fit_with_nlsq sigma_0 = kwargs.get("sigma_0", 10.0) lam_init = kwargs.get("lam_init", 1.0) # Cache for model_function (Bayesian inference bridge) self._sigma_applied = sigma_0 self._creep_lam_init = lam_init gamma_jax = jnp.array(gamma) dt = float(t[1] - t[0]) n_steps = len(t) gamma_scale = jnp.maximum(jnp.std(gamma_jax), 1e-6) def residual_fn(params_array): param_dict = self._params_array_to_dict(params_array) if self.include_elasticity: def step(state, _): lam, gamma_v, lam_prev = state G = elastic_modulus(lam, param_dict["G0"], param_dict["m_G"]) gamma_e = sigma_0 / jnp.maximum(G, 1e-10) if self.closure == "exponential": gamma_dot_v = invert_stress_for_gamma_dot_exponential( sigma_0, lam, param_dict["eta_0"], param_dict["eta_inf"] ) else: gamma_dot_v = invert_stress_for_gamma_dot_hb( sigma_0, lam, param_dict["tau_y0"], param_dict["K0"], param_dict["n_flow"], param_dict["eta_inf"], param_dict["m1"], param_dict["m2"], ) dlam = structure_evolution( lam, gamma_dot_v, param_dict["t_eq"], param_dict["a"], param_dict["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) gamma_v_new = gamma_v + dt * gamma_dot_v gamma_total = gamma_e + gamma_v_new return (lam_new, gamma_v_new, lam), gamma_total step = jax.checkpoint(step) init_state = ( jnp.float64(lam_init), jnp.float64(0.0), jnp.float64(lam_init), ) _, gamma_pred = jax.lax.scan(step, init_state, None, length=n_steps) else: def step(state, _): lam, gamma_acc = state if self.closure == "exponential": gamma_dot = invert_stress_for_gamma_dot_exponential( sigma_0, lam, param_dict["eta_0"], param_dict["eta_inf"] ) else: gamma_dot = invert_stress_for_gamma_dot_hb( sigma_0, lam, param_dict["tau_y0"], param_dict["K0"], param_dict["n_flow"], param_dict["eta_inf"], param_dict["m1"], param_dict["m2"], ) dlam = structure_evolution( lam, gamma_dot, param_dict["t_eq"], param_dict["a"], param_dict["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) gamma_new = gamma_acc + dt * gamma_dot return (lam_new, gamma_new), gamma_new step = jax.checkpoint(step) init_state = (jnp.float64(lam_init), jnp.float64(0.0)) _, gamma_pred = jax.lax.scan(step, init_state, None, length=n_steps) gamma_pred = jnp.clip(gamma_pred, -1e12, 1e12) return (gamma_pred - gamma_jax) / gamma_scale params_array, bounds = self._get_params_for_optimization() _dmt_reserved = { "test_mode", "gamma_dot", "lam_init", "sigma_init", "sigma_0", "lam_0", "gamma_0", "omega_laos", "n_cycles", "points_per_cycle", "deformation_mode", "poisson_ratio", } nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _dmt_reserved} result = fit_with_nlsq(residual_fn, params_array, bounds=bounds, **nlsq_kwargs) self._set_params_from_array(result.x) self._fitted = True return self def _predict_creep(self, t: np.ndarray, **kwargs) -> np.ndarray: """Predict creep strain.""" sigma_0 = kwargs.get("sigma_0", 10.0) lam_init = kwargs.get("lam_init", 1.0) _, gamma, _, _ = self.simulate_creep( sigma_0, float(t[-1]), float(t[1] - t[0]), lam_init ) return gamma # ========================================================================= # SAOS (Maxwell only) # =========================================================================
[docs] def predict_saos( self, omega: np.ndarray, lam_0: float = 1.0 ) -> tuple[np.ndarray, np.ndarray]: """Predict SAOS moduli G'(ω) and G''(ω). Requires include_elasticity=True. Assumes small amplitude so structure remains at λ₀. Parameters ---------- omega : array Angular frequency [rad/s] lam_0 : float Reference structure level (default: 1.0, fully structured) Returns ------- G_prime : array Storage modulus G'(ω) [Pa] G_double_prime : array Loss modulus G''(ω) [Pa] """ if not self.include_elasticity: raise ValueError("SAOS requires include_elasticity=True") omega_jax = jnp.array(omega) params = self.get_parameter_dict() # Get G and η at reference structure G = elastic_modulus(lam_0, params["G0"], params["m_G"]) if self.closure == "exponential": eta = viscosity_exponential(lam_0, params["eta_0"], params["eta_inf"]) else: # HB at low shear rate eta = viscosity_herschel_bulkley_regularized( lam_0, 1e-6, params["tau_y0"], params["K0"], params["n_flow"], params["eta_inf"], params["m1"], params["m2"], ) theta_1 = eta / jnp.maximum(G, 1e-10) G_prime, G_double_prime = saos_moduli_maxwell( omega_jax, float(G), float(theta_1), params["eta_inf"] ) return np.array(G_prime), np.array(G_double_prime)
def _fit_oscillation( self, omega: np.ndarray, G_star: np.ndarray, **kwargs ) -> DMTLocal: """Fit to SAOS data G*(ω) = G'(ω) + jG''(ω). Requires include_elasticity=True. Assumes small amplitude so structure remains at the reference level λ₀. Parameters ---------- omega : array Angular frequency [rad/s] G_star : array Complex modulus data (real+imag or (N,2) array of [G', G'']) **kwargs lam_0 : float Reference structure level (default: 1.0) """ from rheojax.utils.optimization import fit_with_nlsq lam_0 = kwargs.get("lam_0", 1.0) # Cache for model_function (Bayesian inference bridge) self._saos_lam_0 = lam_0 omega_jax = jnp.array(omega) # Handle complex or (N,2) input G_star_np = np.asarray(G_star) if np.iscomplexobj(G_star_np): G_prime_data = jnp.array(np.real(G_star_np)) G_double_prime_data = jnp.array(np.imag(G_star_np)) elif G_star_np.ndim == 2 and G_star_np.shape[1] == 2: G_prime_data = jnp.array(G_star_np[:, 0]) G_double_prime_data = jnp.array(G_star_np[:, 1]) else: raise ValueError( "G_star must be complex (G'+jG'') or shape (N,2) array [G', G'']" ) # Scale for normalization modulus_scale = jnp.maximum( jnp.std(jnp.concatenate([G_prime_data, G_double_prime_data])), 1.0 ) def residual_fn(params_array): param_dict = self._params_array_to_dict(params_array) G = elastic_modulus(lam_0, param_dict["G0"], param_dict["m_G"]) if self.closure == "exponential": eta = viscosity_exponential( lam_0, param_dict["eta_0"], param_dict["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_0, 1e-6, param_dict["tau_y0"], param_dict["K0"], param_dict["n_flow"], param_dict["eta_inf"], param_dict["m1"], param_dict["m2"], ) theta_1 = eta / jnp.maximum(G, 1e-10) G_prime_pred, G_double_prime_pred = saos_moduli_maxwell( omega_jax, G, theta_1, param_dict["eta_inf"] ) res_prime = (G_prime_pred - G_prime_data) / modulus_scale res_double_prime = ( G_double_prime_pred - G_double_prime_data ) / modulus_scale return jnp.concatenate([res_prime, res_double_prime]) params_array, bounds = self._get_params_for_optimization() _dmt_reserved = { "test_mode", "gamma_dot", "lam_init", "sigma_init", "sigma_0", "lam_0", "gamma_0", "omega_laos", "n_cycles", "points_per_cycle", "deformation_mode", "poisson_ratio", } nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _dmt_reserved} result = fit_with_nlsq(residual_fn, params_array, bounds=bounds, **nlsq_kwargs) self._set_params_from_array(result.x) self._fitted = True return self def _predict_oscillation(self, omega: np.ndarray, **kwargs) -> np.ndarray: """Predict complex modulus.""" lam_0 = kwargs.get("lam_0", 1.0) G_prime, G_double_prime = self.predict_saos(omega, lam_0) return G_prime + 1j * G_double_prime # ========================================================================= # LAOS # =========================================================================
[docs] def simulate_laos( self, gamma_0: float, omega: float, n_cycles: int = 10, points_per_cycle: int = 128, lam_init: float = 1.0, ) -> dict[str, np.ndarray]: """Simulate LAOS (Large Amplitude Oscillatory Shear). Parameters ---------- gamma_0 : float Strain amplitude omega : float Angular frequency [rad/s] n_cycles : int Number of cycles to simulate points_per_cycle : int Points per cycle for discretization lam_init : float Initial structure parameter Returns ------- dict 't': time array 'strain': strain γ(t) 'strain_rate': strain rate γ̇(t) 'stress': stress σ(t) 'lam': structure λ(t) """ period = 2 * np.pi / omega t_total = n_cycles * period n_points = n_cycles * points_per_cycle dt = t_total / n_points t = jnp.linspace(0, t_total, n_points) strain = gamma_0 * jnp.sin(omega * t) strain_rate = gamma_0 * omega * jnp.cos(omega * t) params = self.get_parameter_dict() if self.include_elasticity: # Maxwell LAOS def step(state, sr): sigma, lam = state # Structure evolution dlam = structure_evolution( lam, sr, params["t_eq"], params["a"], params["c"] ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) # Elastic modulus G = elastic_modulus(lam_new, params["G0"], params["m_G"]) # Viscosity if self.closure == "exponential": eta = viscosity_exponential( lam_new, params["eta_0"], params["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_new, sr, params["tau_y0"], params["K0"], params["n_flow"], params["eta_inf"], params["m1"], params["m2"], ) # Relaxation time theta_1 = eta / jnp.maximum(G, 1e-10) # Stress evolution dsigma = maxwell_stress_evolution(sigma, sr, G, theta_1) sigma_new = sigma + dt * dsigma return (sigma_new, lam_new), (sigma_new, lam_new) step = jax.checkpoint(step) init_state = (0.0, lam_init) _, (stress, lam) = jax.lax.scan(step, init_state, strain_rate) else: # Viscous LAOS def step(lam, sr): dlam = structure_evolution( lam, sr, params["t_eq"], params["a"], params["c"] ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) if self.closure == "exponential": eta = viscosity_exponential( lam_new, params["eta_0"], params["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_new, sr, params["tau_y0"], params["K0"], params["n_flow"], params["eta_inf"], params["m1"], params["m2"], ) stress = eta * sr return lam_new, (stress, lam_new) step = jax.checkpoint(step) _, (stress, lam) = jax.lax.scan(step, lam_init, strain_rate) return { "t": np.array(t), "strain": np.array(strain), "strain_rate": np.array(strain_rate), "stress": np.array(stress), "lam": np.array(lam), }
[docs] def extract_harmonics( self, laos_result: dict, n_harmonics: int = 5, ) -> dict[str, np.ndarray]: """Extract Fourier harmonics from LAOS stress response. Parameters ---------- laos_result : dict Result from simulate_laos() n_harmonics : int Number of harmonics to extract Returns ------- dict 'sigma_prime': in-phase coefficients (odd harmonics) 'sigma_double_prime': out-of-phase coefficients 'I_n_1': normalized harmonic intensities """ from scipy.integrate import trapezoid t = laos_result["t"] stress = laos_result["stress"] omega = laos_result["strain_rate"].max() / laos_result["strain"].max() # Use last cycle for steady-state analysis period = 2 * np.pi / omega dt = t[1] - t[0] points_per_cycle = int(period / dt) t_cycle = t[-points_per_cycle:] t_cycle = t_cycle - t_cycle[0] # Reset to 0 stress_cycle = stress[-points_per_cycle:] sigma_prime = [] sigma_double_prime = [] for n in range(1, 2 * n_harmonics, 2): # Odd harmonics # In-phase (sin) sp = ( 2 * trapezoid(stress_cycle * np.sin(n * omega * t_cycle), t_cycle) / period ) # Out-of-phase (cos) spp = ( 2 * trapezoid(stress_cycle * np.cos(n * omega * t_cycle), t_cycle) / period ) sigma_prime.append(sp) sigma_double_prime.append(spp) sigma_prime = np.array(sigma_prime) sigma_double_prime = np.array(sigma_double_prime) # Normalized intensities I_n/I_1 I_1 = np.sqrt(sigma_prime[0] ** 2 + sigma_double_prime[0] ** 2) I_n_1 = np.array( [ np.sqrt(sp**2 + spp**2) / I_1 for sp, spp in zip(sigma_prime, sigma_double_prime, strict=True) ] ) return { "sigma_prime": sigma_prime, "sigma_double_prime": sigma_double_prime, "I_n_1": I_n_1, }
def _fit_laos(self, t: np.ndarray, stress: np.ndarray, **kwargs) -> DMTLocal: """Fit to LAOS stress waveform σ(t) under oscillatory strain. Parameters ---------- t : array Time array [s] stress : array Measured stress waveform [Pa] **kwargs gamma_0 : float Strain amplitude (default: 0.1) omega_laos : float Angular frequency [rad/s] (default: 1.0) lam_init : float Initial structure parameter (default: 1.0) """ from rheojax.utils.optimization import fit_with_nlsq gamma_0 = kwargs.get("gamma_0", 0.1) omega = kwargs.get("omega_laos", kwargs.get("omega", 1.0)) lam_init = kwargs.get("lam_init", 1.0) # Cache for model_function (Bayesian inference bridge) self._gamma_0 = gamma_0 self._omega_laos = omega self._laos_lam_init = lam_init t_jax = jnp.array(t) stress_jax = jnp.array(stress) dt = float(t[1] - t[0]) # Compute driving strain rate strain_rate = gamma_0 * omega * jnp.cos(omega * t_jax) stress_scale = jnp.maximum(jnp.std(stress_jax), 1.0) if self.include_elasticity: def residual_fn(params_array): param_dict = self._params_array_to_dict(params_array) def step(state, sr): sigma, lam = state dlam = structure_evolution( lam, sr, param_dict["t_eq"], param_dict["a"], param_dict["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) G = elastic_modulus(lam_new, param_dict["G0"], param_dict["m_G"]) if self.closure == "exponential": eta = viscosity_exponential( lam_new, param_dict["eta_0"], param_dict["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_new, sr, param_dict["tau_y0"], param_dict["K0"], param_dict["n_flow"], param_dict["eta_inf"], param_dict["m1"], param_dict["m2"], ) theta_1 = eta / jnp.maximum(G, 1e-10) dsigma = maxwell_stress_evolution(sigma, sr, G, theta_1) sigma_new = sigma + dt * dsigma return (sigma_new, lam_new), sigma_new step = jax.checkpoint(step) init_state = (jnp.float64(0.0), jnp.float64(lam_init)) _, stress_pred = jax.lax.scan(step, init_state, strain_rate) stress_pred = jnp.clip(stress_pred, -1e12, 1e12) return (stress_pred - stress_jax) / stress_scale else: def residual_fn(params_array): param_dict = self._params_array_to_dict(params_array) def step(lam, sr): dlam = structure_evolution( lam, sr, param_dict["t_eq"], param_dict["a"], param_dict["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) if self.closure == "exponential": eta = viscosity_exponential( lam_new, param_dict["eta_0"], param_dict["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_new, sr, param_dict["tau_y0"], param_dict["K0"], param_dict["n_flow"], param_dict["eta_inf"], param_dict["m1"], param_dict["m2"], ) return lam_new, eta * sr step = jax.checkpoint(step) _, stress_pred = jax.lax.scan(step, jnp.float64(lam_init), strain_rate) stress_pred = jnp.clip(stress_pred, -1e12, 1e12) return (stress_pred - stress_jax) / stress_scale params_array, bounds = self._get_params_for_optimization() _dmt_reserved = { "test_mode", "gamma_dot", "lam_init", "sigma_init", "sigma_0", "lam_0", "gamma_0", "omega_laos", "omega", "n_cycles", "points_per_cycle", "deformation_mode", "poisson_ratio", } nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _dmt_reserved} result = fit_with_nlsq(residual_fn, params_array, bounds=bounds, **nlsq_kwargs) self._set_params_from_array(result.x) self._fitted = True return self def _predict_laos(self, t: np.ndarray, **kwargs) -> np.ndarray: """Predict LAOS stress waveform.""" gamma_0 = kwargs.get("gamma_0", 0.1) omega = kwargs.get("omega_laos", kwargs.get("omega", 1.0)) lam_init = kwargs.get("lam_init", 1.0) n_cycles = kwargs.get("n_cycles", 10) points_per_cycle = max(1, len(t) // n_cycles) if len(t) > n_cycles else 128 result = self.simulate_laos( gamma_0, omega, n_cycles, points_per_cycle, lam_init ) return result["stress"] # ========================================================================= # Helper Methods # ========================================================================= def _compute_r2(self, y_true: jnp.ndarray, y_pred: np.ndarray) -> float: """Compute R² coefficient of determination.""" ss_res = jnp.sum((y_true - y_pred) ** 2) ss_tot = jnp.sum((y_true - jnp.mean(y_true)) ** 2) return float(1 - ss_res / ss_tot) def _get_params_for_optimization(self) -> tuple[jnp.ndarray, tuple]: """Get parameter array and bounds for optimization.""" param_names = list(self.parameters.keys()) params = jnp.array([self.parameters.get_value(n) for n in param_names]) bounds_lower = jnp.array([self.parameters[n].bounds[0] for n in param_names]) bounds_upper = jnp.array([self.parameters[n].bounds[1] for n in param_names]) return params, (bounds_lower, bounds_upper) def _params_array_to_dict(self, params_array: jnp.ndarray) -> dict: """Convert parameter array to dictionary. Note: Values are kept as JAX arrays to maintain compatibility with JAX tracing during optimization. Use float() only after optimization. """ param_names = list(self.parameters.keys()) return {name: params_array[i] for i, name in enumerate(param_names)} def _set_params_from_array(self, params_array: jnp.ndarray) -> None: """Set parameters from array.""" param_names = list(self.parameters.keys()) for i, name in enumerate(param_names): self.parameters.set_value(name, float(params_array[i])) def _simulate_with_params( self, t: jnp.ndarray, dt: float, gamma_dot: float, lam_init: float, params: dict, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Simulate with given parameters (for fitting). This method directly calls internal simulation methods with the params dict to maintain JAX traceability during optimization. """ # Directly call internal methods with params dict (JAX-traceable) if self.include_elasticity: t_out, stress, lam = self._simulate_startup_maxwell( t, dt, gamma_dot, lam_init, params ) else: t_out, stress, lam = self._simulate_startup_viscous( t, dt, gamma_dot, lam_init, params ) return t_out, jnp.array(stress), jnp.array(lam)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """NumPyro/BayesianMixin model function for DMT. Routes to appropriate JAX-traceable prediction based on test_mode. Required by BayesianMixin for NumPyro NUTS sampling. Parameters ---------- X : array-like Independent variable (shear rate, time, or frequency) params : array-like Parameter values in ParameterSet order test_mode : str, optional Override stored test mode Returns ------- jnp.ndarray Predicted response (stress, strain, or complex modulus) """ p_values = dict(zip(self.parameters.keys(), params, strict=True)) # P2-DMT-001: Use getattr with default so this is safe even if called # before _fit() (e.g. during model_function probe or unit tests). mode = ( test_mode if test_mode is not None else getattr(self, "_test_mode", "flow_curve") ) if mode is None: mode = "flow_curve" X_jax = jnp.asarray(X, dtype=jnp.float64) if mode in ["steady_shear", "rotation", "flow_curve"]: return self._model_function_flow_curve(X_jax, p_values) elif mode == "oscillation": return self._model_function_oscillation(X_jax, p_values) elif mode == "startup": return self._model_function_startup(X_jax, p_values, **kwargs) elif mode == "relaxation": return self._model_function_relaxation(X_jax, p_values, **kwargs) elif mode == "creep": return self._model_function_creep(X_jax, p_values, **kwargs) elif mode == "laos": return self._model_function_laos(X_jax, p_values) else: raise ValueError(f"Unsupported test mode for DMT: {mode}")
def _model_function_flow_curve(self, X_jax, p_values): """Flow curve prediction: σ(γ̇) at steady state.""" if self.closure == "exponential": return steady_stress_exponential( X_jax, p_values["eta_0"], p_values["eta_inf"], p_values["a"], p_values["c"], ) else: return steady_stress_herschel_bulkley( X_jax, p_values["tau_y0"], p_values["K0"], p_values["n_flow"], p_values["eta_inf"], p_values["a"], p_values["c"], p_values["m1"], p_values["m2"], ) def _model_function_oscillation(self, X_jax, p_values): """SAOS prediction: G*(ω) = G'(ω) + jG''(ω).""" lam_0 = getattr(self, "_saos_lam_0", 1.0) G = elastic_modulus(lam_0, p_values["G0"], p_values["m_G"]) if self.closure == "exponential": eta = viscosity_exponential(lam_0, p_values["eta_0"], p_values["eta_inf"]) else: eta = viscosity_herschel_bulkley_regularized( lam_0, 1e-6, p_values["tau_y0"], p_values["K0"], p_values["n_flow"], p_values["eta_inf"], p_values["m1"], p_values["m2"], ) theta_1 = eta / jnp.maximum(G, 1e-10) G_prime, G_double_prime = saos_moduli_maxwell( X_jax, G, theta_1, p_values["eta_inf"] ) return G_prime + 1j * G_double_prime def _model_function_startup(self, X_jax, p_values, **kwargs): """Startup shear prediction: σ(t) at constant γ̇.""" _gd = kwargs.get("gamma_dot", _MISSING) gamma_dot = ( _gd if _gd is not _MISSING else getattr(self, "_gamma_dot_applied", 1.0) ) lam_init = getattr(self, "_startup_lam_init", 1.0) dt = X_jax[1] - X_jax[0] n_steps = X_jax.shape[0] if self.include_elasticity: def step(state, _): sigma, lam = state dlam = structure_evolution( lam, gamma_dot, p_values["t_eq"], p_values["a"], p_values["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) G = elastic_modulus(lam_new, p_values["G0"], p_values["m_G"]) if self.closure == "exponential": eta = viscosity_exponential( lam_new, p_values["eta_0"], p_values["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_new, gamma_dot, p_values["tau_y0"], p_values["K0"], p_values["n_flow"], p_values["eta_inf"], p_values["m1"], p_values["m2"], ) theta_1 = eta / jnp.maximum(G, 1e-10) dsigma = maxwell_stress_evolution(sigma, gamma_dot, G, theta_1) sigma_new = sigma + dt * dsigma return (sigma_new, lam_new), sigma_new step = jax.checkpoint(step) init_state = (jnp.float64(0.0), jnp.float64(lam_init)) _, stress = jax.lax.scan(step, init_state, None, length=n_steps) else: def step(lam, _): dlam = structure_evolution( lam, gamma_dot, p_values["t_eq"], p_values["a"], p_values["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) if self.closure == "exponential": eta = viscosity_exponential( lam_new, p_values["eta_0"], p_values["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_new, gamma_dot, p_values["tau_y0"], p_values["K0"], p_values["n_flow"], p_values["eta_inf"], p_values["m1"], p_values["m2"], ) return lam_new, eta * gamma_dot step = jax.checkpoint(step) _, stress = jax.lax.scan(step, jnp.float64(lam_init), None, length=n_steps) return stress def _model_function_relaxation(self, X_jax, p_values, **kwargs): """Stress relaxation prediction: σ(t) after cessation of shear.""" _si = kwargs.get("sigma_init", _MISSING) sigma_init = ( _si if _si is not _MISSING else getattr(self, "_relax_sigma_init", 100.0) ) lam_init = getattr(self, "_relax_lam_init", 0.5) dt = X_jax[1] - X_jax[0] n_steps = X_jax.shape[0] def step(state, _): sigma, lam = state # Structure recovery only (γ̇ = 0) dlam = (1.0 - lam) / p_values["t_eq"] lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) G = elastic_modulus(lam_new, p_values["G0"], p_values["m_G"]) if self.closure == "exponential": eta = viscosity_exponential( lam_new, p_values["eta_0"], p_values["eta_inf"] ) else: # HB at zero shear rate eta = p_values["eta_inf"] theta_1 = eta / jnp.maximum(G, 1e-10) dsigma = -sigma / jnp.maximum(theta_1, 1e-12) sigma_new = sigma + dt * dsigma return (sigma_new, lam_new), sigma_new step = jax.checkpoint(step) init_state = (jnp.float64(sigma_init), jnp.float64(lam_init)) _, stress = jax.lax.scan(step, init_state, None, length=n_steps) return stress def _model_function_creep(self, X_jax, p_values, **kwargs): """Creep prediction: γ(t) at constant σ₀.""" _s0 = kwargs.get("sigma_0", _MISSING) if _s0 is _MISSING: _s0 = kwargs.get("sigma_applied", _MISSING) sigma_0 = _s0 if _s0 is not _MISSING else getattr(self, "_sigma_applied", 50.0) lam_init = getattr(self, "_creep_lam_init", 1.0) dt = X_jax[1] - X_jax[0] n_steps = X_jax.shape[0] if self.include_elasticity: def step(state, _): lam, gamma_v, lam_prev = state G = elastic_modulus(lam, p_values["G0"], p_values["m_G"]) gamma_e = sigma_0 / jnp.maximum(G, 1e-10) if self.closure == "exponential": gamma_dot_v = invert_stress_for_gamma_dot_exponential( sigma_0, lam, p_values["eta_0"], p_values["eta_inf"] ) else: gamma_dot_v = invert_stress_for_gamma_dot_hb( sigma_0, lam, p_values["tau_y0"], p_values["K0"], p_values["n_flow"], p_values["eta_inf"], p_values["m1"], p_values["m2"], ) dlam = structure_evolution( lam, gamma_dot_v, p_values["t_eq"], p_values["a"], p_values["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) gamma_v_new = gamma_v + dt * gamma_dot_v gamma_total = gamma_e + gamma_v_new return (lam_new, gamma_v_new, lam), gamma_total step = jax.checkpoint(step) init_state = ( jnp.float64(lam_init), jnp.float64(0.0), jnp.float64(lam_init), ) _, gamma = jax.lax.scan(step, init_state, None, length=n_steps) else: def step(state, _): lam, gamma = state if self.closure == "exponential": gamma_dot = invert_stress_for_gamma_dot_exponential( sigma_0, lam, p_values["eta_0"], p_values["eta_inf"] ) else: gamma_dot = invert_stress_for_gamma_dot_hb( sigma_0, lam, p_values["tau_y0"], p_values["K0"], p_values["n_flow"], p_values["eta_inf"], p_values["m1"], p_values["m2"], ) dlam = structure_evolution( lam, gamma_dot, p_values["t_eq"], p_values["a"], p_values["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) gamma_new = gamma + dt * gamma_dot return (lam_new, gamma_new), gamma_new step = jax.checkpoint(step) init_state = (jnp.float64(lam_init), jnp.float64(0.0)) _, gamma = jax.lax.scan(step, init_state, None, length=n_steps) return gamma def _model_function_laos(self, X_jax, p_values): """LAOS prediction: σ(t) under oscillatory strain.""" gamma_0 = getattr(self, "_gamma_0", 0.1) omega = getattr(self, "_omega_laos", 1.0) lam_init = getattr(self, "_laos_lam_init", 1.0) # Compute strain rate from time array strain_rate = gamma_0 * omega * jnp.cos(omega * X_jax) dt = X_jax[1] - X_jax[0] if self.include_elasticity: def step(state, sr): sigma, lam = state dlam = structure_evolution( lam, sr, p_values["t_eq"], p_values["a"], p_values["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) G = elastic_modulus(lam_new, p_values["G0"], p_values["m_G"]) if self.closure == "exponential": eta = viscosity_exponential( lam_new, p_values["eta_0"], p_values["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_new, sr, p_values["tau_y0"], p_values["K0"], p_values["n_flow"], p_values["eta_inf"], p_values["m1"], p_values["m2"], ) theta_1 = eta / jnp.maximum(G, 1e-10) dsigma = maxwell_stress_evolution(sigma, sr, G, theta_1) sigma_new = sigma + dt * dsigma return (sigma_new, lam_new), sigma_new step = jax.checkpoint(step) init_state = (jnp.float64(0.0), jnp.float64(lam_init)) _, stress = jax.lax.scan(step, init_state, strain_rate) else: def step(lam, sr): dlam = structure_evolution( lam, sr, p_values["t_eq"], p_values["a"], p_values["c"], ) lam_new = jnp.clip(lam + dt * dlam, 0.0, 1.0) if self.closure == "exponential": eta = viscosity_exponential( lam_new, p_values["eta_0"], p_values["eta_inf"] ) else: eta = viscosity_herschel_bulkley_regularized( lam_new, sr, p_values["tau_y0"], p_values["K0"], p_values["n_flow"], p_values["eta_inf"], p_values["m1"], p_values["m2"], ) return lam_new, eta * sr step = jax.checkpoint(step) _, stress = jax.lax.scan(step, jnp.float64(lam_init), strain_rate) return stress