Source code for rheojax.models.ikh.mikh

"""Maxwell-Isotropic-Kinematic Hardening (MIKH) Model.

A thixotropic elasto-viscoplastic model combining:
1. Maxwell viscoelastic element
2. Armstrong-Frederick kinematic hardening (backstress evolution)
3. Isotropic hardening/softening via structural parameter lambda (thixotropy)
4. Viscous background solvent

Supports all 6 experimental protocols:
- Flow curve (steady state)
- Startup shear
- Stress relaxation
- Creep
- SAOS (small amplitude oscillatory shear)
- LAOS (large amplitude oscillatory shear)
"""

from typing import cast

import numpy as np

from rheojax.core.base import ArrayLike
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax

diffrax = lazy_import("diffrax")
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.ikh._base import IKHBase
from rheojax.models.ikh._kernels import (
    ikh_creep_ode_rhs,
    ikh_flow_curve_steady_state,
    ikh_maxwell_ode_rhs,
    ikh_scan_kernel,
)

jax, jnp = safe_import_jax()


# kwargs to filter before passing to nlsq_optimize
_IKH_RESERVED = {
    "test_mode",
    "gamma_dot",
    "sigma_applied",
    "sigma_0",
    "deformation_mode",
    "poisson_ratio",
}


[docs] @ModelRegistry.register( "mikh", protocols=[ Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class MIKH(IKHBase): r"""Maxwell-Isotropic-Kinematic Hardening (MIKH) Model. A thixotropic elasto-viscoplastic model combining: 1. Armstrong-Frederick kinematic hardening (backstress evolution). 2. Isotropic hardening/softening via structural parameter lambda (thixotropy). 3. Maxwell viscoelastic element for proper relaxation behavior. 4. Viscous background solvent. Two Formulations: - **Maxwell ODE** (via Diffrax): For creep/relaxation protocols - **Return Mapping**: For startup/LAOS protocols (incremental) Governing Equations: σ_total = σ + η_inf * γ̇ Stress Evolution (ODE formulation): dσ/dt = G(γ̇ - γ̇ᵖ) - (G/η)σ Yield Surface: \|σ - α\| ≤ σ_y(λ) σ_y(λ) = σ_y0 + Δσ_y * λ Structure Evolution: dλ/dt = (1-λ)/τ_thix - Γ*λ*\|γ̇ᵖ\| Backstress Evolution (Armstrong-Frederick): dα = C*dγ_p - γ_dyn*\|α\|^(m-1)*α*\|dγ_p\| Parameters: G: Shear modulus [Pa] eta: Maxwell viscosity [Pa·s] (controls relaxation time τ = η/G) C: Kinematic hardening modulus [Pa] gamma_dyn: Dynamic recovery parameter for backstress [-] m: AF recovery exponent [-] (typically 1.0) sigma_y0: Minimal (destructured) yield stress [Pa] delta_sigma_y: Yield stress increment (structured - destructured) [Pa] tau_thix: Thixotropic rebuilding time scale [s] Gamma: Structural breakdown coefficient [-] eta_inf: High-shear viscosity [Pa·s] mu_p: Plastic viscosity for Perzyna regularization [Pa·s] """
[docs] def __init__(self): super().__init__() self._test_mode = None # Store test mode for Bayesian # Elasticity self.parameters.add( "G", value=1e3, bounds=(1e-1, 1e9), units="Pa", description="Shear modulus" ) self.parameters.add( "eta", value=1e6, bounds=(1e-3, 1e12), units="Pa s", description="Maxwell viscosity (relaxation time = eta/G)", ) # Kinematic Hardening (Armstrong-Frederick) self.parameters.add( "C", value=5e2, bounds=(0.0, 1e9), units="Pa", description="Kinematic hardening modulus", ) self.parameters.add( "gamma_dyn", value=1.0, bounds=(0.0, 1e4), units="-", description="Dynamic recovery parameter", ) self.parameters.add( "m", value=1.0, bounds=(0.5, 3.0), units="-", description="AF recovery exponent", ) # Yield Stress & Thixotropy self.parameters.add( "sigma_y0", value=10.0, bounds=(0.0, 1e9), units="Pa", description="Minimal yield stress (destructured)", ) self.parameters.add( "delta_sigma_y", value=50.0, bounds=(0.0, 1e9), units="Pa", description="Structural yield stress contribution", ) self.parameters.add( "tau_thix", value=1.0, bounds=(1e-6, 1e12), units="s", description="Rebuilding time scale", ) self.parameters.add( "Gamma", value=0.5, bounds=(0.0, 1e4), units="-", description="Breakdown coefficient", ) # Viscosity self.parameters.add( "eta_inf", value=0.1, bounds=(0.0, 1e9), units="Pa s", description="High-shear viscosity (solvent)", ) self.parameters.add( "mu_p", value=1e-3, bounds=(1e-6, 1e3), units="Pa s", description="Plastic viscosity (Perzyna regularization)", )
def _fit(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MIKH": """Fit model parameters to data using protocol-aware optimization. Args: X: Input data (time/strain array or RheoData) y: Target data (stress or strain depending on protocol) **kwargs: Options including: - test_mode: Protocol ('flow_curve', 'startup', 'relaxation', 'creep', 'oscillation', 'laos') - gamma_dot: Shear rate (for startup) - sigma_applied: Applied stress (for creep) - sigma_0: Initial stress (for relaxation) """ test_mode = kwargs.get("test_mode", "startup") self._test_mode = test_mode if test_mode == "flow_curve": return self._fit_flow_curve(X, y, **kwargs) elif test_mode in ["creep", "relaxation"]: return self._fit_ode_formulation(X, y, **kwargs) elif test_mode in ["startup", "laos"]: return self._fit_return_mapping(X, y, **kwargs) elif test_mode in ["oscillation", "saos"]: return self._fit_oscillation(X, y, **kwargs) else: # Default to return mapping for strain-driven protocols return self._fit_return_mapping(X, y, **kwargs) def _fit_flow_curve(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MIKH": """Fit to steady-state flow curve data.""" from rheojax.utils.optimization import nlsq_optimize gamma_dot = jnp.asarray(X) sigma_target = jnp.asarray(y) def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=True)) sigma_pred = ikh_flow_curve_steady_state(gamma_dot, **p_dict) return sigma_pred - sigma_target filtered = {k: v for k, v in kwargs.items() if k not in _IKH_RESERVED} nlsq_optimize(objective, self.parameters, **filtered) return self def _fit_ode_formulation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MIKH": """Fit using ODE formulation (for creep/relaxation).""" from rheojax.utils.optimization import nlsq_optimize t = jnp.asarray(X) y_target = jnp.asarray(y) test_mode = kwargs.get("test_mode", "relaxation") gamma_dot = kwargs.get("gamma_dot", 0.0) sigma_applied = kwargs.get("sigma_applied", 100.0) sigma_0 = kwargs.get("sigma_0", 100.0) # Cache protocol kwargs for model_function (NUTS reads these) self._fit_gamma_dot = gamma_dot self._fit_sigma_applied = sigma_applied self._fit_sigma_0 = sigma_0 def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=True)) y_pred = self._simulate_transient( t, p_dict, test_mode, gamma_dot, sigma_applied, sigma_0 ) return y_pred - y_target # Force method="scipy": diffrax ODE solvers use custom_vjp which is # incompatible with NLSQ's forward-mode autodiff (jvp). kwargs["method"] = "scipy" filtered = {k: v for k, v in kwargs.items() if k not in _IKH_RESERVED} nlsq_optimize(objective, self.parameters, **filtered) return self def _fit_return_mapping(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MIKH": """Fit using return mapping formulation (for startup/LAOS).""" from rheojax.utils.optimization import nlsq_optimize times, strains = self._extract_time_strain(X, **kwargs) sigma_target = jnp.asarray(y) def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=True)) sigma_pred = self._predict_from_params(times, strains, p_dict) return sigma_pred - sigma_target filtered = {k: v for k, v in kwargs.items() if k not in _IKH_RESERVED} nlsq_optimize(objective, self.parameters, **filtered) return self def _fit_oscillation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MIKH": """Fit to oscillatory data (SAOS). Supports two modes: 1. Frequency-domain: X=omega, y=|G*| or complex G* (uses Maxwell analytical solution) 2. Time-domain: X=time, y=stress (uses return mapping with sinusoidal strain) """ X_arr = jnp.asarray(X) # Detect if this is frequency-domain (omega array) or time-domain (time series) # Use explicit kwarg if provided; otherwise heuristic based on data shape # and whether values look like frequencies (positive, no repeated values) oscillation_mode = kwargs.get("oscillation_mode", None) if oscillation_mode == "time_domain": is_time_domain = True elif oscillation_mode == "frequency_domain": is_time_domain = False else: # Heuristic: frequency data is short, positive, and monotonic is_short = len(X_arr) <= 200 is_positive = bool(jnp.all(X_arr > 0)) is_monotone = bool(jnp.all(jnp.diff(X_arr) > 0)) is_time_domain = not (is_short and is_positive and is_monotone) if is_time_domain: # Time-domain: use return mapping with sinusoidal strain return self._fit_return_mapping(X, y, **kwargs) else: # Frequency-domain: fit G' and G'' using Maxwell analytical solution return self._fit_saos_frequency_domain(X, y, **kwargs) def _fit_saos_frequency_domain( self, X: ArrayLike, y: ArrayLike, **kwargs ) -> "MIKH": """Fit to frequency-domain SAOS data using Maxwell analytical expressions. Fits G' and G'' independently when complex or (N, 2) input is provided. Falls back to magnitude-only fitting for real 1D input. Args: X: Angular frequency array (omega) y: Complex G* = G' + iG'', (N, 2) array [G', G''], or real |G*| """ from rheojax.utils.optimization import nlsq_optimize omega = jnp.asarray(X) # Handle different y formats — always extract G' and G'' for # component-wise fitting (magnitude-only discards phase angle δ) y_arr = jnp.asarray(y) if jnp.iscomplexobj(y_arr): # Complex G* = G' + iG'' provided target_G_prime = jnp.real(y_arr) target_G_double_prime = jnp.imag(y_arr) fit_components = True elif y_arr.ndim == 2 and y_arr.shape[1] == 2: # (N, 2) array provided - [G', G''] format target_G_prime = y_arr[:, 0] target_G_double_prime = y_arr[:, 1] fit_components = True else: # Real 1D array — assume magnitude |G*| (no phase info available) target_magnitude = y_arr fit_components = False def objective(param_values): """Compute residual using Maxwell analytical SAOS expressions.""" p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=True)) G = p_dict["G"] eta = p_dict["eta"] tau = eta / jnp.maximum(G, 1e-30) # Maxwell relaxation time # Maxwell moduli wt = omega * tau G_prime = G * wt**2 / (1 + wt**2) G_double_prime = G * wt / (1 + wt**2) if fit_components: # Fit G' and G'' independently (preserves phase information) return jnp.concatenate( [G_prime - target_G_prime, G_double_prime - target_G_double_prime] ) else: # Magnitude-only fallback (no phase info available) G_star_magnitude = jnp.sqrt( jnp.maximum(G_prime**2 + G_double_prime**2, 1e-30) ) return G_star_magnitude - target_magnitude filtered = {k: v for k, v in kwargs.items() if k not in _IKH_RESERVED} nlsq_optimize(objective, self.parameters, **filtered) return self def _simulate_transient( self, t: jnp.ndarray, params: dict, mode: str, gamma_dot: float | None = None, sigma_applied: float | None = None, sigma_0: float | None = None, ) -> jnp.ndarray: """Simulate transient response using Diffrax ODE integration. Args: t: Time array params: Parameter dictionary mode: 'startup', 'relaxation', or 'creep' gamma_dot: Applied shear rate (for startup) sigma_applied: Applied stress (for creep) sigma_0: Initial stress (for relaxation) Returns: Stress (for startup/relaxation) or strain (for creep) """ # Build args for ODE RHS args = {k: params[k] for k in params} # Initial state based on mode lambda_init = 1.0 # Fully structured initially if mode == "creep": # Creep: constant stress, track strain ode_fn = ikh_creep_ode_rhs args["sigma_applied"] = ( sigma_applied if sigma_applied is not None else 100.0 ) # State: [strain, alpha, lambda] y0 = jnp.array([0.0, 0.0, lambda_init]) elif mode == "startup": # Startup: constant rate, track stress ode_fn = ikh_maxwell_ode_rhs args["gamma_dot"] = gamma_dot if gamma_dot is not None else 1.0 # State: [sigma, alpha, lambda] y0 = jnp.array([0.0, 0.0, lambda_init]) else: # relaxation # Relaxation: rate = 0, stress decays ode_fn = ikh_maxwell_ode_rhs args["gamma_dot"] = 0.0 sigma_init = ( sigma_0 if sigma_0 is not None else params.get("sigma_y0", 10.0) + params.get("delta_sigma_y", 50.0) ) # Start partially destructured lambda_init_relax = 0.5 y0 = jnp.array([sigma_init, 0.0, lambda_init_relax]) # Diffrax setup term = diffrax.ODETerm( lambda ti, yi, args_i: ode_fn(cast(float, ti), yi, args_i) ) solver = diffrax.Tsit5() stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-7) t0 = t[0] t1 = t[-1] dt0 = (t1 - t0) / max(len(t), 1000) saveat = diffrax.SaveAt(ts=t) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, stepsize_controller=stepsize_controller, max_steps=1_000_000, throw=False, ) # Extract primary variable (index 0) # For creep: strain; for startup/relaxation: stress result = sol.ys[:, 0] # Handle solver failures result = jnp.where( sol.result == diffrax.RESULTS.successful, result, jnp.nan * jnp.ones_like(result), ) # Add viscous contribution for startup if mode == "startup": eta_inf_val = params.get("eta_inf", 0.0) result = result + jnp.where( jnp.greater(eta_inf_val, 0.0), eta_inf_val * args["gamma_dot"], jnp.zeros_like(result), ) return result def _predict_from_params(self, times, strains, params): """Predict using parameter dictionary (for NLSQ/Bayesian).""" return ikh_scan_kernel(times, strains, use_viscosity=True, **params) def _predict(self, X: ArrayLike, **kwargs) -> ArrayLike: """Predict stress from time/strain history or based on test_mode. Args: X: Input data. Shape depends on test_mode: - flow_curve: shear rates - startup/laos: (2, N) array of [time, strain] or RheoData - relaxation: time array (requires sigma_0) - creep: time array (requires sigma_applied) **kwargs: Additional parameters (test_mode, gamma_dot, etc.) """ _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else ( getattr(self, "_test_mode", None) if getattr(self, "_test_mode", None) is not None else "startup" ) ) params = self.parameters.get_values() param_dict = dict(zip(self.parameters.keys(), params, strict=True)) if test_mode == "flow_curve": gamma_dot = jnp.asarray(X) return ikh_flow_curve_steady_state(gamma_dot, **param_dict) elif test_mode in ["creep", "relaxation"]: t = jnp.asarray(X) gamma_dot = kwargs.get("gamma_dot", getattr(self, "_fit_gamma_dot", 0.0)) sigma_applied = kwargs.get( "sigma_applied", getattr(self, "_fit_sigma_applied", 100.0) ) sigma_0 = kwargs.get("sigma_0", getattr(self, "_fit_sigma_0", 100.0)) return self._simulate_transient( t, param_dict, test_mode, gamma_dot, sigma_applied, sigma_0 ) else: # Strain-driven protocols (startup, laos, oscillation) times, strains = self._extract_time_strain(X, **kwargs) return self._predict_from_params(times, strains, param_dict)
[docs] def predict_flow_curve(self, gamma_dot: ArrayLike) -> ArrayLike: """Predict steady-state flow curve.""" return self._predict(gamma_dot, test_mode="flow_curve")
[docs] def predict_startup(self, t: ArrayLike, gamma_dot: float = 1.0) -> ArrayLike: """Predict startup shear response. Args: t: Time array gamma_dot: Constant shear rate Returns: Stress vs time """ params = dict( zip(self.parameters.keys(), self.parameters.get_values(), strict=True) ) return self._simulate_transient( jnp.asarray(t), params, "startup", gamma_dot=gamma_dot )
[docs] def predict_relaxation(self, t: ArrayLike, sigma_0: float = 100.0) -> ArrayLike: """Predict stress relaxation. Args: t: Time array sigma_0: Initial stress Returns: Stress vs time """ params = dict( zip(self.parameters.keys(), self.parameters.get_values(), strict=True) ) return self._simulate_transient( jnp.asarray(t), params, "relaxation", sigma_0=sigma_0 )
[docs] def predict_creep(self, t: ArrayLike, sigma_applied: float = 50.0) -> ArrayLike: """Predict creep response. Args: t: Time array sigma_applied: Applied constant stress Returns: Strain vs time """ params = dict( zip(self.parameters.keys(), self.parameters.get_values(), strict=True) ) return self._simulate_transient( jnp.asarray(t), params, "creep", sigma_applied=sigma_applied )
[docs] def predict_laos( self, t: ArrayLike, gamma_0: float = 1.0, omega: float = 1.0 ) -> ArrayLike: """Predict LAOS response. Args: t: Time array gamma_0: Strain amplitude omega: Angular frequency Returns: Stress vs time """ t_arr = jnp.asarray(t) strain = gamma_0 * jnp.sin(omega * t_arr) return self._predict_from_params( t_arr, strain, dict( zip(self.parameters.keys(), self.parameters.get_values(), strict=True) ), )
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """NumPyro model function for Bayesian inference. Accepts protocol-specific kwargs (gamma_dot, sigma_applied, sigma_0). Falls back to values cached during _fit() if not provided. """ # Use stored test_mode if not provided if test_mode is None: test_mode = getattr(self, "_test_mode", None) if test_mode is None: test_mode = "startup" mode = test_mode # Convert array to dict for kernel if isinstance(params, (np.ndarray, jnp.ndarray)): param_names = list(self.parameters.keys()) param_dict = dict(zip(param_names, params, strict=True)) else: param_dict = params # Extract protocol-specific args from kwargs, falling back to # cached values from _fit_ode_formulation() gamma_dot = kwargs.get("gamma_dot", getattr(self, "_fit_gamma_dot", 0.0)) sigma_applied = kwargs.get( "sigma_applied", getattr(self, "_fit_sigma_applied", 100.0) ) sigma_0 = kwargs.get("sigma_0", getattr(self, "_fit_sigma_0", 100.0)) if mode == "flow_curve": gamma_dot_arr = jnp.asarray(X) return ikh_flow_curve_steady_state(gamma_dot_arr, **param_dict) elif mode in ["creep", "relaxation"]: t = jnp.asarray(X) return self._simulate_transient( t, param_dict, mode, gamma_dot, sigma_applied, sigma_0 ) elif mode == "oscillation": # Frequency-domain SAOS using Maxwell analytical expressions omega = jnp.asarray(X) G = param_dict["G"] eta = param_dict["eta"] tau = eta / jnp.maximum(G, 1e-30) # Maxwell relaxation time # Maxwell moduli wt = omega * tau G_prime = G * wt**2 / (1 + wt**2) G_double_prime = G * wt / (1 + wt**2) return jnp.column_stack([G_prime, G_double_prime]) else: # startup/laos modes need strain computed from kwargs times, strains = self._extract_time_strain(X, **kwargs) return self._predict_from_params(times, strains, param_dict)