"""Multi-Lambda Isotropic-Kinematic Hardening (ML-IKH) Model.
Extends MIKH to N modes for capturing distributed thixotropic timescales.
Supports two yield surface formulations:
1. **Per-Mode Yield** (default): Each mode has independent yield surface
- Total stress = Σ σᵢ (parallel connection)
- Parameters: 7 per mode + 1 global
2. **Weighted-Sum Yield**: Single global yield surface
- σ_y = σ_y0 + k3·Σ(wᵢ·λᵢ)
- All modes share elastic/plastic response
- Parameters: 5 global + 3 per mode
"""
from typing import Literal
import numpy as np
from rheojax.core.base import ArrayLike
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.ikh._base import IKHBase
from rheojax.models.ikh._kernels import (
make_ml_ikh_creep_ode_rhs_per_mode,
make_ml_ikh_creep_ode_rhs_weighted_sum,
make_ml_ikh_maxwell_ode_rhs_per_mode,
make_ml_ikh_maxwell_ode_rhs_weighted_sum,
ml_ikh_flow_curve_steady_state_per_mode,
ml_ikh_flow_curve_steady_state_weighted_sum,
ml_ikh_scan_kernel,
ml_ikh_weighted_sum_kernel,
)
jax, jnp = safe_import_jax()
# kwargs to filter before passing to nlsq_optimize
_MLIKH_RESERVED = {
"test_mode",
"gamma_dot",
"sigma_applied",
"sigma_0",
"deformation_mode",
"poisson_ratio",
}
[docs]
@ModelRegistry.register(
"ml_ikh",
protocols=[
Protocol.FLOW_CURVE,
Protocol.STARTUP,
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.OSCILLATION,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class MLIKH(IKHBase):
"""Multi-Lambda Isotropic-Kinematic Hardening (ML-IKH) Model.
Extends MIKH to N modes connected in parallel. Each mode evolves its own
internal variables (stress, backstress, structural lambda) with distinct
timescales (tau_thix_i) and properties.
Two Yield Mode Options:
- **per_mode** (default): Each mode has independent yield surface.
Total stress is sum of mode stresses.
- **weighted_sum**: Single global yield surface with structure
contribution from all modes: σ_y = σ_y0 + k3·Σ(wᵢ·λᵢ)
Per-Mode Parameters (for each mode i=1..N):
G_i: Shear modulus
C_i: Backstress modulus
gamma_dyn_i: Dynamic recovery
sigma_y0_i: Minimal yield stress
delta_sigma_y_i: Structural yield stress
tau_thix_i: Rebuilding timescale
Gamma_i: Breakdown coefficient
Weighted-Sum Parameters:
G: Global shear modulus
C: Global hardening modulus
gamma_dyn: Global dynamic recovery
sigma_y0: Base yield stress
k3: Structure-yield coupling
tau_thix_i: Per-mode rebuilding timescales
Gamma_i: Per-mode breakdown coefficients
w_i: Per-mode structure weights
Global Parameters:
eta_inf: High-shear viscosity (both modes)
Supported Protocols:
- FLOW_CURVE: Steady-state stress vs shear rate (analytical solution)
- STARTUP: Transient stress growth at constant shear rate (return mapping)
- RELAXATION: Stress decay at constant strain (ODE formulation via Diffrax)
- CREEP: Strain evolution at constant stress (ODE formulation via Diffrax)
- OSCILLATION: Small amplitude oscillatory shear response
- LAOS: Large amplitude oscillatory shear (return mapping with sinusoidal strain)
Note:
Both yield modes (per_mode, weighted_sum) support all protocols.
ODE protocols (creep, relaxation) use Diffrax for numerical integration.
Return mapping protocols (startup, LAOS) use JAX scan for time stepping.
Args:
n_modes: Number of structural modes (default: 2)
yield_mode: Yield formulation ('per_mode' or 'weighted_sum')
"""
[docs]
def __init__(
self,
n_modes: int = 2,
yield_mode: Literal["per_mode", "weighted_sum"] = "per_mode",
):
super().__init__()
if n_modes < 1:
raise ValueError(f"n_modes must be >= 1, got {n_modes}")
if yield_mode not in ("per_mode", "weighted_sum"):
raise ValueError(
f"yield_mode must be 'per_mode' or 'weighted_sum', got {yield_mode}"
)
self._n_modes = n_modes
self._yield_mode = yield_mode
self._test_mode = None
self._create_parameters()
def _create_parameters(self):
"""Initialize parameters based on yield_mode."""
self.parameters = ParameterSet()
if self._yield_mode == "per_mode":
self._create_per_mode_parameters()
else:
self._create_weighted_sum_parameters()
def _create_per_mode_parameters(self):
"""Create parameters for per-mode yield formulation."""
for i in range(1, self._n_modes + 1):
# Elasticity & Hardening
self.parameters.add(
f"G_{i}",
value=1e3 / self._n_modes,
bounds=(0.0, 1e9),
units="Pa",
description=f"Mode {i} Shear modulus",
)
self.parameters.add(
f"C_{i}",
value=5e2 / self._n_modes,
bounds=(0.0, 1e9),
units="Pa",
description=f"Mode {i} Kinematic hardening modulus",
)
self.parameters.add(
f"gamma_dyn_{i}",
value=1.0,
bounds=(0.0, 1e4),
units="-",
description=f"Mode {i} Dynamic recovery",
)
# Yield Stress & Thixotropy
self.parameters.add(
f"sigma_y0_{i}",
value=10.0 / self._n_modes,
bounds=(0.0, 1e9),
units="Pa",
description=f"Mode {i} Minimal yield stress",
)
self.parameters.add(
f"delta_sigma_y_{i}",
value=50.0 / self._n_modes,
bounds=(0.0, 1e9),
units="Pa",
description=f"Mode {i} Structural yield stress",
)
# Timescales distributed logarithmically
tau_val = 10.0 ** (i - 1 - self._n_modes / 2)
self.parameters.add(
f"tau_thix_{i}",
value=tau_val,
bounds=(1e-6, 1e12),
units="s",
description=f"Mode {i} Rebuilding time scale",
)
self.parameters.add(
f"Gamma_{i}",
value=0.5,
bounds=(0.0, 1e4),
units="-",
description=f"Mode {i} Breakdown coefficient",
)
# Global Viscosity
self.parameters.add(
"eta_inf",
value=0.1,
bounds=(0.0, 1e9),
units="Pa s",
description="High-shear viscosity",
)
def _create_weighted_sum_parameters(self):
"""Create parameters for weighted-sum yield formulation."""
# Global mechanical parameters
self.parameters.add(
"G",
value=1e3,
bounds=(1e-1, 1e9),
units="Pa",
description="Global shear modulus",
)
self.parameters.add(
"C",
value=5e2,
bounds=(0.0, 1e9),
units="Pa",
description="Global kinematic hardening modulus",
)
self.parameters.add(
"gamma_dyn",
value=1.0,
bounds=(0.0, 1e4),
units="-",
description="Global dynamic recovery",
)
self.parameters.add(
"m",
value=1.0,
bounds=(0.5, 3.0),
units="-",
description="AF recovery exponent",
)
# Yield stress
self.parameters.add(
"sigma_y0",
value=10.0,
bounds=(0.0, 1e9),
units="Pa",
description="Base yield stress",
)
self.parameters.add(
"k3",
value=50.0,
bounds=(0.0, 1e9),
units="Pa",
description="Structure-yield coupling",
)
# Per-mode structure parameters
for i in range(1, self._n_modes + 1):
tau_val = 10.0 ** (i - 1 - self._n_modes / 2)
self.parameters.add(
f"tau_thix_{i}",
value=tau_val,
bounds=(1e-6, 1e12),
units="s",
description=f"Mode {i} Rebuilding time scale",
)
self.parameters.add(
f"Gamma_{i}",
value=0.5,
bounds=(0.0, 1e4),
units="-",
description=f"Mode {i} Breakdown coefficient",
)
self.parameters.add(
f"w_{i}",
value=1.0 / self._n_modes,
bounds=(0.0, 1.0),
units="-",
description=f"Mode {i} structure weight",
)
# Global Viscosity
self.parameters.add(
"eta_inf",
value=0.1,
bounds=(0.0, 1e9),
units="Pa s",
description="High-shear viscosity",
)
def _stack_mode_params(self, params, names=None):
"""Stack per-mode parameters in a single pass.
Reduces repeated dict lookups + jnp.stack calls from O(N*names)
to O(1) per prediction, which matters during Bayesian inference
(4000-8000 evaluations).
"""
n = self._n_modes
if names is None:
names = [
"G",
"C",
"gamma_dyn",
"sigma_y0",
"delta_sigma_y",
"tau_thix",
"Gamma",
]
return {
name: jnp.stack([params[f"{name}_{i}"] for i in range(1, n + 1)])
for name in names
}
def _predict_from_params(self, times, strains, params):
"""Predict using parameter dictionary (for NLSQ/Bayesian)."""
if self._yield_mode == "per_mode":
return self._predict_per_mode(times, strains, params)
else:
return self._predict_weighted_sum(times, strains, params)
def _predict_per_mode(self, times, strains, params):
"""Predict with per-mode yield surfaces."""
# Stack all per-mode parameters in a single pass
kernel_params = self._stack_mode_params(params)
eta_inf = params["eta_inf"]
return ml_ikh_scan_kernel(
times,
strains,
num_modes=self._n_modes,
use_viscosity=True,
eta_inf=eta_inf,
**kernel_params,
)
def _predict_weighted_sum(self, times, strains, params):
"""Predict with weighted-sum yield surface."""
stacked = self._stack_mode_params(params, names=["tau_thix", "Gamma", "w"])
kernel_params = {
"G": params["G"],
"C": params["C"],
"gamma_dyn": params["gamma_dyn"],
"m": params.get("m", 1.0),
"sigma_y0": params["sigma_y0"],
"k3": params["k3"],
"eta_inf": params["eta_inf"],
**stacked,
}
return ml_ikh_weighted_sum_kernel(
times, strains, num_modes=self._n_modes, use_viscosity=True, **kernel_params
)
def _predict_flow_curve_from_params(self, gamma_dot, params):
"""Predict steady-state flow curve from parameter dictionary."""
if self._yield_mode == "per_mode":
return ml_ikh_flow_curve_steady_state_per_mode(
gamma_dot, self._n_modes, **params
)
else:
return ml_ikh_flow_curve_steady_state_weighted_sum(
gamma_dot, self._n_modes, **params
)
def _build_ode_args(self, params, **kwargs):
"""Build args dictionary for ODE integration."""
args = {"n_modes": self._n_modes}
if self._yield_mode == "per_mode":
# Stack all per-mode parameters in a single pass
args.update(self._stack_mode_params(params))
# Default arrays for optional parameters
args["eta"] = jnp.full(self._n_modes, 1e12)
args["mu_p"] = jnp.full(self._n_modes, 1e-6)
args["m"] = jnp.ones(self._n_modes)
else:
# Global parameters
args["G"] = params["G"]
args["C"] = params["C"]
args["gamma_dyn"] = params["gamma_dyn"]
args["m"] = params.get("m", 1.0)
args["sigma_y0"] = params["sigma_y0"]
args["k3"] = params.get("k3", 0.0)
# Per-mode structure parameters
stacked = self._stack_mode_params(params, names=["tau_thix", "Gamma", "w"])
args.update(stacked)
args["eta"] = 1e12
args["mu_p"] = 1e-6
args["eta_inf"] = params.get("eta_inf", 0.0)
# Add protocol-specific args
for key in ["gamma_dot", "sigma_applied"]:
if key in kwargs:
args[key] = kwargs[key]
return args
def _simulate_transient(
self,
t: jnp.ndarray,
params: dict,
mode: str,
gamma_dot: float | None = None,
sigma_applied: float | None = None,
sigma_0: float | None = None,
) -> jnp.ndarray:
"""Simulate transient response using Diffrax ODE integration.
Args:
t: Time array
params: Parameter dictionary
mode: 'startup', 'relaxation', or 'creep'
gamma_dot: Applied shear rate (for startup)
sigma_applied: Applied stress (for creep)
sigma_0: Initial stress (for relaxation)
Returns:
Stress (for startup/relaxation) or strain (for creep)
"""
n = self._n_modes
args = self._build_ode_args(params)
# Initial lambda (fully structured)
lambda_init = 1.0
if self._yield_mode == "per_mode":
if mode == "creep":
# State: [γ, α_1..α_N, λ_1..λ_N] (1+2N)
ode_fn = make_ml_ikh_creep_ode_rhs_per_mode(n)
args["sigma_applied"] = (
sigma_applied if sigma_applied is not None else 100.0
)
y0 = jnp.concatenate(
[
jnp.array([0.0]), # gamma
jnp.zeros(n), # alphas
jnp.full(n, lambda_init), # lambdas
]
)
elif mode == "startup":
# State: [σ_1..σ_N, α_1..α_N, λ_1..λ_N] (3N)
ode_fn = make_ml_ikh_maxwell_ode_rhs_per_mode(n)
args["gamma_dot"] = gamma_dot if gamma_dot is not None else 1.0
y0 = jnp.concatenate(
[
jnp.zeros(n), # sigmas
jnp.zeros(n), # alphas
jnp.full(n, lambda_init), # lambdas
]
)
else: # relaxation
ode_fn = make_ml_ikh_maxwell_ode_rhs_per_mode(n)
args["gamma_dot"] = 0.0
# Initial stress distributed across modes
sigma_init = (
sigma_0
if sigma_0 is not None
else (jnp.sum(args["sigma_y0"]) + jnp.sum(args["delta_sigma_y"]))
)
lambda_init_relax = 0.5
y0 = jnp.concatenate(
[
jnp.full(n, sigma_init / n), # sigmas (distributed)
jnp.zeros(n), # alphas
jnp.full(n, lambda_init_relax), # lambdas
]
)
else: # weighted_sum
if mode == "creep":
# State: [γ, α, λ_1..λ_N] (2+N)
ode_fn = make_ml_ikh_creep_ode_rhs_weighted_sum(n)
args["sigma_applied"] = (
sigma_applied if sigma_applied is not None else 100.0
)
y0 = jnp.concatenate(
[
jnp.array([0.0, 0.0]), # gamma, alpha
jnp.full(n, lambda_init), # lambdas
]
)
elif mode == "startup":
# State: [σ, α, λ_1..λ_N] (2+N)
ode_fn = make_ml_ikh_maxwell_ode_rhs_weighted_sum(n)
args["gamma_dot"] = gamma_dot if gamma_dot is not None else 1.0
y0 = jnp.concatenate(
[
jnp.array([0.0, 0.0]), # sigma, alpha
jnp.full(n, lambda_init), # lambdas
]
)
else: # relaxation
ode_fn = make_ml_ikh_maxwell_ode_rhs_weighted_sum(n)
args["gamma_dot"] = 0.0
sigma_init = (
sigma_0 if sigma_0 is not None else (args["sigma_y0"] + args["k3"])
)
lambda_init_relax = 0.5
y0 = jnp.concatenate(
[
jnp.array([sigma_init, 0.0]),
jnp.full(n, lambda_init_relax),
]
)
# Diffrax setup
term = diffrax.ODETerm(lambda ti, yi, args_i: ode_fn(ti, yi, args_i))
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-7)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=1_000_000,
throw=False,
)
# Extract primary variable
if mode == "creep":
# Return strain (first component)
result = sol.ys[:, 0]
else:
# Return stress
if self._yield_mode == "per_mode":
# Sum mode stresses (first n components)
result = jnp.sum(sol.ys[:, :n], axis=1)
else:
# Single global stress (first component)
result = sol.ys[:, 0]
# Handle solver failures
result = jnp.where(
sol.result == diffrax.RESULTS.successful,
result,
jnp.nan * jnp.ones_like(result),
)
# Add viscous contribution for startup
if mode == "startup":
eta_inf_val = params.get("eta_inf", 0.0)
result = result + jnp.where(
jnp.greater(eta_inf_val, 0.0),
eta_inf_val * args["gamma_dot"],
jnp.zeros_like(result),
)
return result
def _predict(self, X: ArrayLike, **kwargs) -> ArrayLike:
"""Predict response with protocol-aware dispatch.
Args:
X: Input data (shear rates for flow_curve, time for others)
**kwargs: Options including test_mode, gamma_dot, sigma_applied, etc.
Returns:
Predicted stress or strain depending on protocol
"""
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "startup"
)
)
# Get parameters as dict
params = self.parameters.get_values()
param_dict = dict(zip(self.parameters.keys(), params, strict=True))
if test_mode == "flow_curve":
return self._predict_flow_curve_from_params(jnp.asarray(X), param_dict)
elif test_mode in ["creep", "relaxation"]:
return self._simulate_transient(
jnp.asarray(X),
param_dict,
test_mode,
gamma_dot=kwargs.get(
"gamma_dot", getattr(self, "_fit_gamma_dot", float("nan"))
),
sigma_applied=kwargs.get(
"sigma_applied", getattr(self, "_fit_sigma_applied", 100.0)
),
sigma_0=kwargs.get("sigma_0", getattr(self, "_fit_sigma_0", 100.0)),
)
else: # startup, laos, oscillation
times, strains = self._extract_time_strain(X, **kwargs)
return self._predict_from_params(times, strains, param_dict)
def _fit(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH":
"""Fit model parameters using protocol-aware optimization.
Args:
X: Input data (shear rates, time array, or time/strain)
y: Target data (stress or strain depending on protocol)
**kwargs: Options including:
- test_mode: Protocol ('flow_curve', 'startup', 'relaxation',
'creep', 'oscillation', 'laos')
- gamma_dot: Shear rate (for startup)
- sigma_applied: Applied stress (for creep)
- sigma_0: Initial stress (for relaxation)
"""
test_mode = kwargs.get("test_mode", "startup")
self._test_mode = test_mode
if test_mode == "flow_curve":
return self._fit_flow_curve(X, y, **kwargs)
elif test_mode in ["creep", "relaxation"]:
return self._fit_ode_formulation(X, y, **kwargs)
elif test_mode in ["startup", "laos"]:
return self._fit_return_mapping(X, y, **kwargs)
elif test_mode in ["oscillation", "saos"]:
return self._fit_oscillation(X, y, **kwargs)
else:
# Default to return mapping for strain-driven protocols
return self._fit_return_mapping(X, y, **kwargs)
def _fit_flow_curve(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH":
"""Fit to steady-state flow curve data."""
from rheojax.utils.optimization import nlsq_optimize
gamma_dot = jnp.asarray(X)
sigma_target = jnp.asarray(y)
def objective(param_values):
p_names = list(self.parameters.keys())
p_dict = dict(zip(p_names, param_values, strict=True))
sigma_pred = self._predict_flow_curve_from_params(gamma_dot, p_dict)
return sigma_pred - sigma_target
filtered = {k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED}
nlsq_optimize(objective, self.parameters, **filtered)
return self
def _fit_ode_formulation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH":
"""Fit using ODE formulation (for creep/relaxation)."""
from rheojax.utils.optimization import nlsq_optimize
t = jnp.asarray(X)
y_target = jnp.asarray(y)
test_mode = kwargs.get("test_mode", "relaxation")
gamma_dot = kwargs.get("gamma_dot", 0.0)
sigma_applied = kwargs.get("sigma_applied", 100.0)
sigma_0 = kwargs.get("sigma_0", 100.0)
# Cache protocol kwargs for model_function (NUTS reads these)
self._fit_gamma_dot = gamma_dot
self._fit_sigma_applied = sigma_applied
self._fit_sigma_0 = sigma_0
def objective(param_values):
p_names = list(self.parameters.keys())
p_dict = dict(zip(p_names, param_values, strict=True))
y_pred = self._simulate_transient(
t, p_dict, test_mode, gamma_dot, sigma_applied, sigma_0
)
return y_pred - y_target
# Force method="scipy": diffrax ODE solvers use custom_vjp which is
# incompatible with NLSQ's forward-mode autodiff (jvp).
kwargs["method"] = "scipy"
filtered = {k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED}
nlsq_optimize(objective, self.parameters, **filtered)
return self
def _fit_return_mapping(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH":
"""Fit using return-mapping algorithm (for startup/LAOS)."""
from rheojax.utils.optimization import nlsq_optimize
times, strains = self._extract_time_strain(X, **kwargs)
sigma_target = jnp.asarray(y)
def objective(param_values):
p_names = list(self.parameters.keys())
p_dict = dict(zip(p_names, param_values, strict=True))
sigma_pred = self._predict_from_params(times, strains, p_dict)
return sigma_pred - sigma_target
filtered = {k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED}
nlsq_optimize(objective, self.parameters, **filtered)
return self
def _fit_oscillation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> "MLIKH":
"""Fit to oscillation data (SAOS/MAOS).
Supports two modes:
1. Frequency-domain: X=omega, y=|G*| or complex G* (uses Maxwell analytical solution)
2. Time-domain: X=time, y=stress (uses return mapping with sinusoidal strain)
"""
X_arr = jnp.asarray(X)
# Detect if this is frequency-domain or time-domain
is_time_domain = len(X_arr) > 100
if is_time_domain:
return self._fit_return_mapping(X, y, **kwargs)
else:
return self._fit_saos_frequency_domain(X, y, **kwargs)
def _fit_saos_frequency_domain(
self, X: ArrayLike, y: ArrayLike, **kwargs
) -> "MLIKH":
"""Fit to frequency-domain SAOS data using Maxwell analytical expressions.
Fits G' and G'' independently when complex or (N, 2) input is provided.
Falls back to magnitude-only fitting for real 1D input.
Args:
X: Angular frequency array (omega)
y: Complex G* = G' + iG'', (N, 2) array [G', G''], or real |G*|
"""
from rheojax.utils.optimization import nlsq_optimize
omega = jnp.asarray(X)
# Handle different y formats — always extract G' and G'' for
# component-wise fitting (magnitude-only discards phase angle δ)
y_arr = jnp.asarray(y)
if jnp.iscomplexobj(y_arr):
# Complex G* = G' + iG'' provided
target_G_prime = jnp.real(y_arr)
target_G_double_prime = jnp.imag(y_arr)
fit_components = True
elif y_arr.ndim == 2 and y_arr.shape[1] == 2:
# (N, 2) array provided - [G', G''] format
target_G_prime = y_arr[:, 0]
target_G_double_prime = y_arr[:, 1]
fit_components = True
else:
# Real 1D array — assume magnitude |G*| (no phase info available)
target_magnitude = y_arr
fit_components = False
def objective(param_values):
"""Compute residual using Maxwell analytical SAOS expressions."""
p_names = list(self.parameters.keys())
p_dict = dict(zip(p_names, param_values, strict=True))
# Extract G and eta based on yield_mode
if self._yield_mode == "per_mode":
# Sum contributions from all modes (parallel Maxwell elements)
G_prime_total = jnp.zeros_like(omega)
G_double_prime_total = jnp.zeros_like(omega)
for i in range(1, self._n_modes + 1):
G_i = p_dict[f"G_{i}"]
# Estimate eta from bounds or use large value for elastic behavior
eta_i = 1e12 # Effectively infinite for SAOS
tau_i = eta_i / G_i
wt_i = omega * tau_i
G_prime_i = G_i * wt_i**2 / (1 + wt_i**2)
G_double_prime_i = G_i * wt_i / (1 + wt_i**2)
G_prime_total += G_prime_i
G_double_prime_total += G_double_prime_i
else:
# Weighted-sum mode: use global G
G = p_dict["G"]
eta = 1e12 # Effectively infinite
tau = eta / G
wt = omega * tau
G_prime_total = G * wt**2 / (1 + wt**2)
G_double_prime_total = G * wt / (1 + wt**2)
if fit_components:
# Fit G' and G'' independently (preserves phase information)
return jnp.concatenate(
[
G_prime_total - target_G_prime,
G_double_prime_total - target_G_double_prime,
]
)
else:
# Magnitude-only fallback (no phase info available)
G_star_magnitude = jnp.sqrt(
G_prime_total**2 + G_double_prime_total**2 + 1e-30
)
return G_star_magnitude - target_magnitude
filtered = {k: v for k, v in kwargs.items() if k not in _MLIKH_RESERVED}
nlsq_optimize(objective, self.parameters, **filtered)
return self
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""NumPyro model function with protocol-aware dispatch.
Accepts protocol-specific kwargs (gamma_dot, sigma_applied, sigma_0).
Falls back to values cached during _fit() if not provided.
Args:
X: Input data
params: Parameter array or dict from NumPyro
test_mode: Optional protocol override
**kwargs: Protocol-specific arguments
Returns:
Predicted response
"""
# Convert params to dict if array
if isinstance(params, (np.ndarray, jnp.ndarray)):
param_names = list(self.parameters.keys())
param_dict = dict(zip(param_names, params, strict=True))
else:
param_dict = params
if test_mode is None:
test_mode = getattr(self, "_test_mode", None)
if test_mode is None:
test_mode = "startup"
mode = test_mode
# Extract protocol-specific args from kwargs, falling back to
# cached values from _fit_ode_formulation()
gamma_dot = kwargs.get("gamma_dot", getattr(self, "_fit_gamma_dot", 1.0))
sigma_applied = kwargs.get(
"sigma_applied", getattr(self, "_fit_sigma_applied", 100.0)
)
sigma_0 = kwargs.get("sigma_0", getattr(self, "_fit_sigma_0", 100.0))
if mode == "flow_curve":
return self._predict_flow_curve_from_params(jnp.asarray(X), param_dict)
elif mode in ["creep", "relaxation"]:
return self._simulate_transient(
jnp.asarray(X),
param_dict,
mode,
gamma_dot=gamma_dot,
sigma_applied=sigma_applied,
sigma_0=sigma_0,
)
elif mode == "oscillation":
# Frequency-domain SAOS using multi-Maxwell analytical expressions
omega = jnp.asarray(X)
if self._yield_mode == "per_mode":
# Sum contributions from all modes (parallel Maxwell elements)
G_prime_total = jnp.zeros_like(omega)
G_double_prime_total = jnp.zeros_like(omega)
for i in range(1, self._n_modes + 1):
G_i = param_dict[f"G_{i}"]
eta_i = param_dict.get(
f"eta_{i}", 1e12
) # High viscosity if not specified
tau_i = eta_i / G_i
wt_i = omega * tau_i
G_prime_total += G_i * wt_i**2 / (1 + wt_i**2)
G_double_prime_total += G_i * wt_i / (1 + wt_i**2)
else:
# Weighted-sum mode: use global G
G = param_dict["G"]
eta = param_dict.get("eta", 1e12) # High viscosity if not specified
tau = eta / G
wt = omega * tau
G_prime_total = G * wt**2 / (1 + wt**2)
G_double_prime_total = G * wt / (1 + wt**2)
return jnp.column_stack([G_prime_total, G_double_prime_total])
else: # startup, laos
# startup/laos modes need strain computed from kwargs
times, strains = self._extract_time_strain(X, **kwargs)
return self._predict_from_params(times, strains, param_dict)
@property
def n_modes(self) -> int:
"""Number of structural modes."""
return self._n_modes
@property
def yield_mode(self) -> str:
"""Yield formulation mode ('per_mode' or 'weighted_sum')."""
return self._yield_mode
# -------------------------------------------------------------------------
# Convenience Methods for Protocol-Specific Predictions
# -------------------------------------------------------------------------
[docs]
def predict_flow_curve(self, gamma_dot: ArrayLike) -> ArrayLike:
"""Predict steady-state flow curve.
Args:
gamma_dot: Array of shear rates
Returns:
Array of steady-state stresses
"""
return self._predict(gamma_dot, test_mode="flow_curve")
[docs]
def predict_startup(
self, t: ArrayLike, gamma_dot: float = 1.0, strain: ArrayLike | None = None
) -> ArrayLike:
"""Predict startup shear response.
Args:
t: Time array
gamma_dot: Applied shear rate (default: 1.0)
strain: Optional strain array (if None, uses gamma_dot * t)
Returns:
Array of stresses
"""
t_arr = jnp.asarray(t)
if strain is None:
strain = gamma_dot * t_arr
return self._predict(jnp.stack([t_arr, strain]), test_mode="startup")
[docs]
def predict_relaxation(self, t: ArrayLike, sigma_0: float = 100.0) -> ArrayLike:
"""Predict stress relaxation after step strain.
Args:
t: Time array
sigma_0: Initial stress (default: 100.0)
Returns:
Array of decaying stresses
"""
return self._predict(t, test_mode="relaxation", sigma_0=sigma_0)
[docs]
def predict_creep(self, t: ArrayLike, sigma_applied: float = 50.0) -> ArrayLike:
"""Predict creep response under constant stress.
Args:
t: Time array
sigma_applied: Applied constant stress (default: 50.0)
Returns:
Array of strains
"""
return self._predict(t, test_mode="creep", sigma_applied=sigma_applied)
[docs]
def predict_laos(
self, t: ArrayLike, gamma_0: float = 1.0, omega: float = 1.0
) -> ArrayLike:
"""Predict large amplitude oscillatory shear response.
Args:
t: Time array
gamma_0: Strain amplitude (default: 1.0)
omega: Angular frequency in rad/s (default: 1.0)
Returns:
Array of stresses
"""
t_arr = jnp.asarray(t)
strain = gamma_0 * jnp.sin(omega * t_arr)
return self._predict(jnp.stack([t_arr, strain]), test_mode="laos")