"""STZ Conventional Model Implementation.
This module implements the concrete Shear Transformation Zone (STZ) model,
supporting multiple protocols (Flow, Transient, SAOS, LAOS) via JAX and Diffrax.
"""
from __future__ import annotations
from typing import Any, cast
import numpy as np
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger, log_fit
from rheojax.models.stz._base import STZBase, VariantType
from rheojax.models.stz._kernels import (
stz_creep_ode_rhs,
stz_ode_rhs,
)
# Safe JAX import
jax, jnp = safe_import_jax()
# Logger
logger = get_logger(__name__)
_MISSING = object()
# kwargs to filter before passing to nlsq_optimize
_STZ_RESERVED = {
"test_mode",
"gamma_dot",
"sigma_applied",
"sigma_0",
"gamma_0",
"omega",
"use_log_residuals",
"use_multi_start",
"n_starts",
"perturb_factor",
"deformation_mode",
"poisson_ratio",
}
[docs]
@ModelRegistry.register(
"stz_conventional",
protocols=[
Protocol.FLOW_CURVE,
Protocol.CREEP,
Protocol.RELAXATION,
Protocol.STARTUP,
Protocol.OSCILLATION,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class STZConventional(STZBase):
"""Conventional Shear Transformation Zone (STZ) Model.
Implements STZ plasticity with Langer (2008) formulation.
Supports Minimal, Standard, and Full complexity variants.
Protocols:
- Steady-State Flow: Algebraic solution for flow curve
- Transient: Diffrax ODE integration for creep/relaxation/startup
- SAOS/LAOS: Diffrax ODE integration + FFT for harmonic analysis
"""
[docs]
def __init__(self, variant: VariantType = "standard"):
"""Initialize STZ Conventional Model.
Args:
variant: Model variant ('minimal', 'standard', 'full')
"""
super().__init__(variant=variant)
self._test_mode: str | None = None
self._gamma_0: float | None = None
self._omega_laos: float | None = None
self._gamma_dot_applied: float | None = None
self._sigma_applied: float | None = None
self._sigma_0: float | None = None
def _fit(
self,
X: np.ndarray,
y: np.ndarray,
**kwargs,
) -> STZConventional:
"""Fit STZ model to data.
Args:
X: Independent variable (time, frequency, or shear rate)
y: Dependent variable (stress, modulus, viscosity)
**kwargs: Optimizer options. Must include 'test_mode'.
"""
test_mode = kwargs.get("test_mode")
if test_mode is None:
# Fallback for compatibility or explicit check
if hasattr(self, "_test_mode") and self._test_mode is not None:
test_mode = self._test_mode
else:
raise ValueError("test_mode must be specified for STZ fitting")
with log_fit(logger, model="STZConventional", data_shape=X.shape) as ctx:
self._test_mode = cast(str, test_mode)
ctx["test_mode"] = test_mode
ctx["variant"] = self.variant
if test_mode in ["steady_shear", "rotation", "flow_curve"]:
self._fit_steady_shear(X, y, **kwargs)
elif test_mode in ["relaxation", "creep", "startup"]:
self._fit_transient(X, y, mode=cast(str, test_mode), **kwargs)
elif test_mode in ["laos", "oscillation"]:
self._fit_oscillation(X, y, **kwargs)
else:
raise ValueError(f"Unsupported test_mode: {test_mode}")
self.fitted_ = True
return self
# =========================================================================
# Steady State Flow
# =========================================================================
def _fit_steady_shear(
self, gamma_dot: np.ndarray, stress: np.ndarray, **kwargs
) -> None:
"""Fit steady-state flow curve (stress vs shear rate).
Args:
gamma_dot: Shear rate array (1/s).
stress: Shear stress array (Pa).
**kwargs: Optimizer options:
- use_log_residuals (bool): Whether to fit in log space (default: True).
- max_iter (int): Maximum optimization iterations.
- ftol (float): Function tolerance.
- xtol (float): Parameter tolerance.
- gtol (float): Gradient tolerance.
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64)
stress_jax = jnp.asarray(stress, dtype=jnp.float64)
def model_fn(x_data, params):
p_map = dict(zip(self.parameters.keys(), params, strict=True))
# Use parameters directly as JAX scalars
sigma_y = p_map["sigma_y"]
chi_inf = p_map["chi_inf"]
tau0 = p_map["tau0"]
epsilon0 = p_map["epsilon0"]
ez = p_map["ez"]
return self._predict_steady_shear_jit(
x_data,
sigma_y,
chi_inf,
tau0,
epsilon0,
ez,
)
objective = create_least_squares_objective(
model_fn,
gamma_dot_jax,
stress_jax,
use_log_residuals=kwargs.get("use_log_residuals", True),
)
filtered = {k: v for k, v in kwargs.items() if k not in _STZ_RESERVED}
result = nlsq_optimize(objective, self.parameters, **filtered)
if not result.success:
raise RuntimeError(f"STZ steady shear fit failed: {result.message}")
@staticmethod
@jax.jit
def _predict_steady_shear_jit(gamma_dot, sigma_y, chi_inf, tau0, epsilon0, ez):
"""Analytical steady-state flow curve prediction.
At steady state (Langer 2008):
- chi -> chi_inf
- Lambda_ss = exp(-ez / chi_inf)
- gamma_dot = (2*epsilon0/tau0) * Lambda_ss * cosh(s/sy) * tanh(s/sy)
= (2*epsilon0/tau0) * Lambda_ss * sinh(s/sy)
- Inverting: sigma = sigma_y * arcsinh(gamma_dot * tau0 / (2*epsilon0*Lambda_ss))
"""
Lambda_ss = jnp.exp(-ez / chi_inf)
prefactor = 2.0 * epsilon0 * Lambda_ss + 1e-30
arg = (gamma_dot * tau0) / prefactor
sigma = sigma_y * jnp.arcsinh(arg)
return sigma
# =========================================================================
# Transient (ODE) - Startup, Relaxation, Creep
# =========================================================================
def _fit_transient(self, t: np.ndarray, y: np.ndarray, mode: str, **kwargs) -> None:
"""Fit transient response (Stress Growth / Relaxation / Creep).
Args:
t: Time array (s).
y: Response data (stress for startup/relaxation, strain for creep).
mode: 'startup', 'relaxation', or 'creep'.
**kwargs: Protocol-specific inputs and optimizer options:
- gamma_dot (float): Applied shear rate for startup (required).
- sigma_0 (float): Initial stress for relaxation (optional).
- sigma_applied (float): Applied stress for creep (required).
- use_log_residuals (bool): Log-space fitting (default: False).
- max_iter (int): Maximum optimization iterations.
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
t_jax = jnp.asarray(t, dtype=jnp.float64)
# Preserve complex dtype for oscillation data (G* = G' + iG'')
y_arr = np.asarray(y)
if np.iscomplexobj(y_arr):
y_jax = jnp.asarray(y_arr, dtype=jnp.complex128)
else:
y_jax = jnp.asarray(y_arr, dtype=jnp.float64)
# Extract protocol-specific inputs (use .get() to avoid mutating caller's dict)
gamma_dot = kwargs.get("gamma_dot", None)
sigma_applied = kwargs.get("sigma_applied", None)
sigma_0 = kwargs.get("sigma_0", None)
if mode == "startup" and gamma_dot is None:
raise ValueError("startup mode requires gamma_dot in kwargs")
if mode == "creep" and sigma_applied is None:
raise ValueError("creep mode requires sigma_applied in kwargs")
# Store for prediction and NUTS
self._gamma_dot_applied = gamma_dot
self._sigma_applied = sigma_applied
self._sigma_0 = sigma_0
# Build model function that uses ODE integration
def model_fn(x_data, params):
p_map = dict(zip(self.parameters.keys(), params, strict=True))
# Convert params to dict of floats/arrays
p_dict = p_map
return self._simulate_transient_jit(
x_data, p_dict, mode, gamma_dot, sigma_applied, sigma_0, self.variant
)
objective = create_least_squares_objective(
model_fn,
t_jax,
y_jax,
use_log_residuals=kwargs.get("use_log_residuals", False),
)
filtered = {k: v for k, v in kwargs.items() if k not in _STZ_RESERVED}
result = nlsq_optimize(objective, self.parameters, **filtered)
if not result.success:
logger.warning(f"STZ transient fit warning: {result.message}")
def _simulate_transient_jit(
self,
t: jnp.ndarray,
params: dict,
mode: str,
gamma_dot: float | None,
sigma_applied: float | None,
sigma_0: float | None,
variant: str,
) -> jnp.ndarray:
"""Simulate transient response using Diffrax ODE integration.
Args:
t: Time array
params: Parameter dictionary
mode: 'startup', 'relaxation', or 'creep'
gamma_dot: Applied shear rate (for startup)
sigma_applied: Applied stress (for creep)
sigma_0: Initial stress (for relaxation)
variant: Model variant
Returns:
Stress (for startup/relaxation) or strain (for creep)
"""
# R11-STZ-001: `variant` must remain a Python-level static dispatch key.
# DO NOT move it into the ODE args dict — strings are not valid JAX types
# and will crash under jax.checkpoint.
# Build args dict for stz_ode_rhs
args = {
"G0": params["G0"],
"sigma_y": params["sigma_y"],
"tau0": params["tau0"],
"epsilon0": params["epsilon0"],
"chi_inf": params["chi_inf"],
"c0": params["c0"],
"ez": params.get("ez", 1.0),
}
# Add variant-specific parameters
if variant in ["standard", "full"]:
args["tau_beta"] = params.get("tau_beta", params["tau0"] * 100)
if variant == "full":
args["m_inf"] = params.get("m_inf", 0.1)
args["rate_m"] = params.get("rate_m", 1.0)
# Set up initial conditions based on mode
chi_init = 0.05 # Annealed state
ez = params.get("ez", 1.0)
Lambda_init = jnp.exp(-ez / chi_init)
# Define ODE function and initial state based on mode
if mode == "creep":
# Creep: Constant stress, measure strain
# State vector: [strain, chi, Lambda, m] (strain replaces stress)
ode_fn = stz_creep_ode_rhs
sigma_app_safe = sigma_applied if sigma_applied is not None else 0.0
args["sigma_applied"] = sigma_app_safe
# Initial total strain γ(0+) = σ_applied / G0 (instantaneous elastic
# response to the step stress). The ODE evolves total strain because
# dγ_total/dt = γ̇_plastic under constant stress (dσ/dt = 0), so the
# elastic component must enter through the IC. Without this the
# prediction is missing γ_e ~ σ/G, which dominates the early-time
# signal for soft materials (low G0).
y0_val = sigma_app_safe / params["G0"]
# Initial state construction
if variant == "minimal":
y0 = jnp.array([y0_val, chi_init])
elif variant == "standard":
y0 = jnp.array([y0_val, chi_init, Lambda_init])
else: # full
y0 = jnp.array([y0_val, chi_init, Lambda_init, 0.0])
else:
# Startup/Relaxation: Controlled rate, measure stress
# State vector: [stress, chi, Lambda, m]
ode_fn = stz_ode_rhs
if mode == "startup":
# Strain-controlled: apply constant gamma_dot, measure stress
args["gamma_dot"] = gamma_dot
sigma_init = 0.0
else: # relaxation
# Strain-controlled: gamma_dot = 0, initial stress decays
args["gamma_dot"] = 0.0
sigma_init = sigma_0 if sigma_0 is not None else params["sigma_y"]
chi_init = params["chi_inf"] # Start at steady-state chi
Lambda_init = jnp.exp(-ez / chi_init)
# Initial state construction
if variant == "minimal":
y0 = jnp.array([sigma_init, chi_init])
elif variant == "standard":
y0 = jnp.array([sigma_init, chi_init, Lambda_init])
else: # full
y0 = jnp.array([sigma_init, chi_init, Lambda_init, 0.0])
# Set up Diffrax solver
# Wrap with checkpoint to reduce VJP memory during NUTS reverse-mode AD
def _rhs(ti, yi, args_i):
return ode_fn(cast(float, ti), yi, args_i)
term = diffrax.ODETerm(jax.checkpoint(_rhs))
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=10_000_000,
throw=False,
)
# Extract primary variable (index 0)
# For creep, this is strain. For others, this is stress.
result = sol.ys[:, 0]
# Handle solver failures
result = jnp.where(
sol.result == diffrax.RESULTS.successful,
result,
jnp.nan * jnp.ones_like(result),
)
return result
def _predict_transient(self, t: np.ndarray, mode: str | None = None) -> np.ndarray:
"""Predict transient response."""
t_jax = jnp.asarray(t, dtype=jnp.float64)
p_values = {k: self.parameters.get_value(k) for k in self.parameters.keys()}
mode = mode if mode is not None else self._test_mode
if mode is None:
raise ValueError("Test mode not specified for prediction")
result = self._simulate_transient_jit(
t_jax,
p_values,
mode,
self._gamma_dot_applied,
self._sigma_applied,
self._sigma_0,
self.variant,
)
return np.array(result)
# =========================================================================
# SAOS / LAOS (ODE + FFT)
# =========================================================================
def _fit_oscillation(self, X: np.ndarray, y: np.ndarray, **kwargs) -> None:
"""Fit oscillation data (SAOS or LAOS).
Routes to specific fitting method based on strain amplitude `gamma_0`.
If `gamma_0 > 0.01` (1%), uses LAOS mode (full ODE). Otherwise uses SAOS
mode (linear approximation).
Args:
X: Frequency array (rad/s) for SAOS, or time array for LAOS.
y: Complex modulus [G', G''] for SAOS, or stress for LAOS.
**kwargs: Protocol parameters:
- gamma_0 (float): Strain amplitude (optional, triggers LAOS if > 0.01).
- omega (float): Angular frequency (required if gamma_0 provided).
- use_log_residuals (bool): Log-space fitting (default varies).
"""
gamma_0 = kwargs.get("gamma_0", None)
omega = kwargs.get("omega", None)
# Store for prediction
self._gamma_0 = gamma_0
self._omega_laos = omega
if gamma_0 is not None and gamma_0 > 0.01:
# LAOS mode - full ODE integration
self._fit_laos_mode(X, y, gamma_0, omega, **kwargs)
else:
# SAOS mode - linear viscoelastic approximation
self._fit_saos_mode(X, y, **kwargs)
def _fit_saos_mode(self, omega: np.ndarray, G_star: np.ndarray, **kwargs) -> None:
"""Fit SAOS data using linear viscoelastic approximation.
In SAOS limit, STZ behaves like a Maxwell-like viscoelastic solid.
G*(omega) approximated from steady-state chi and Lambda.
Args:
omega: Angular frequency array (rad/s).
G_star: Complex modulus data (complex array or [N, 2] array).
**kwargs: Optimizer options:
- normalize (bool): Normalize residuals (default: True).
- max_iter (int): Maximum optimization iterations.
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
omega_jax = jnp.asarray(omega, dtype=jnp.float64)
# Handle G_star format
G_star_np = np.asarray(G_star)
if np.iscomplexobj(G_star_np):
G_star_2d = np.column_stack([np.real(G_star_np), np.imag(G_star_np)])
elif G_star_np.ndim == 2 and G_star_np.shape[1] == 2:
G_star_2d = G_star_np
else:
raise ValueError(f"G_star must be complex or (M, 2), got {G_star_np.shape}")
G_star_jax = jnp.asarray(G_star_2d, dtype=jnp.float64)
def model_fn(x_data, params):
p_map = dict(zip(self.parameters.keys(), params, strict=True))
# Extract parameters
G0 = p_map["G0"]
sigma_y = p_map["sigma_y"]
chi_inf = p_map["chi_inf"]
tau0 = p_map["tau0"]
epsilon0 = p_map["epsilon0"]
ez = p_map.get("ez", 1.0)
return self._predict_saos_jit(
x_data,
G0,
sigma_y,
chi_inf,
tau0,
epsilon0,
ez,
)
objective = create_least_squares_objective(
model_fn,
omega_jax,
G_star_jax,
normalize=True,
)
filtered = {k: v for k, v in kwargs.items() if k not in _STZ_RESERVED}
result = nlsq_optimize(objective, self.parameters, **filtered)
if not result.success:
logger.warning(f"STZ SAOS fit warning: {result.message}")
@staticmethod
@jax.jit
def _predict_saos_jit(omega, G0, sigma_y, chi_inf, tau0, epsilon0, ez):
"""SAOS prediction using linear viscoelastic approximation.
In the linear limit (small strain), the STZ plastic rate linearizes as:
gamma_dot_pl ≈ (2*epsilon0/tau0) * Lambda_ss * (sigma / sigma_y)
Combined with ds/dt = G0*(gamma_dot - gamma_dot_pl), this gives a
Maxwell model with effective relaxation time:
tau_M = tau0 * sigma_y / (2 * epsilon0 * Lambda_ss * G0)
"""
# At steady state chi -> chi_inf
Lambda_ss = jnp.exp(-ez / chi_inf)
# Effective Maxwell relaxation time (Langer 2008, linearized)
tau_eff = (tau0 * sigma_y) / (2.0 * epsilon0 * Lambda_ss * G0 + 1e-30)
# Maxwell model: G* = G0 * (i * omega * tau) / (1 + i * omega * tau)
omega_tau = omega * tau_eff
denom = 1.0 + omega_tau**2
G_prime = G0 * omega_tau**2 / denom
G_double_prime = G0 * omega_tau / denom
return jnp.stack([G_prime, G_double_prime], axis=1)
def _fit_laos_mode(
self,
t: np.ndarray,
sigma: np.ndarray,
gamma_0: float,
omega: float,
**kwargs,
) -> None:
"""Fit LAOS data using full ODE integration + FFT.
Args:
t: Time array (s).
sigma: Stress response array (Pa).
gamma_0: Strain amplitude.
omega: Angular frequency (rad/s).
**kwargs: Optimizer options:
- normalize (bool): Normalize residuals (default: True).
- max_iter (int): Maximum optimization iterations.
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
t_jax = jnp.asarray(t, dtype=jnp.float64)
sigma_jax = jnp.asarray(sigma, dtype=jnp.float64)
def model_fn(x_data, params):
p_map = dict(zip(self.parameters.keys(), params, strict=True))
# Convert params to dict
p_dict = p_map
_, stress = self._simulate_laos_internal(
x_data, p_dict, gamma_0, omega, self.variant
)
return stress
objective = create_least_squares_objective(
model_fn,
t_jax,
sigma_jax,
normalize=True,
)
filtered = {k: v for k, v in kwargs.items() if k not in _STZ_RESERVED}
result = nlsq_optimize(objective, self.parameters, **filtered)
if not result.success:
logger.warning(f"STZ LAOS fit warning: {result.message}")
def _simulate_laos_internal(
self,
t: jnp.ndarray,
params: dict,
gamma_0: float,
omega: float,
variant: str,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Simulate LAOS response using Diffrax.
Args:
t: Time array
params: Parameter dictionary
gamma_0: Strain amplitude
omega: Angular frequency
variant: Model variant
Returns:
(strain, stress) arrays
"""
# Strain input: gamma(t) = gamma_0 * sin(omega * t)
# Strain rate: gamma_dot(t) = gamma_0 * omega * cos(omega * t)
# Build args with time-varying gamma_dot
# We need to pass a function for gamma_dot, but stz_ode_rhs expects scalar
# Solution: use a wrapper that interpolates
base_args = {
"G0": params["G0"],
"sigma_y": params["sigma_y"],
"tau0": params["tau0"],
"epsilon0": params["epsilon0"],
"chi_inf": params["chi_inf"],
"c0": params["c0"],
"ez": params.get("ez", 1.0),
}
if variant in ["standard", "full"]:
base_args["tau_beta"] = params.get("tau_beta", params["tau0"] * 100)
if variant == "full":
base_args["m_inf"] = params.get("m_inf", 0.1)
base_args["rate_m"] = params.get("rate_m", 1.0)
# Initial conditions
chi_init = params["chi_inf"] # Start at steady state for LAOS
ez = params.get("ez", 1.0)
Lambda_init = jnp.exp(-ez / chi_init)
sigma_init = 0.0
if variant == "minimal":
y0 = jnp.array([sigma_init, chi_init])
elif variant == "standard":
y0 = jnp.array([sigma_init, chi_init, Lambda_init])
else:
y0 = jnp.array([sigma_init, chi_init, Lambda_init, 0.0])
# Define ODE term with time-varying gamma_dot
def laos_ode(ti, yi, args_i):
gamma_dot_t = gamma_0 * omega * jnp.cos(omega * ti)
args_with_rate = {**args_i, "gamma_dot": gamma_dot_t}
return stz_ode_rhs(ti, yi, args_with_rate)
# Wrap with checkpoint to reduce VJP memory during NUTS reverse-mode AD
term = diffrax.ODETerm(jax.checkpoint(laos_ode))
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=base_args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=10_000_000,
throw=False,
)
# Extract stress
stress = sol.ys[:, 0]
# Handle solver failures
stress = jnp.where(
sol.result == diffrax.RESULTS.successful,
stress,
jnp.nan * jnp.ones_like(stress),
)
# Compute strain
strain = gamma_0 * jnp.sin(omega * t)
return strain, stress
[docs]
def simulate_laos(
self,
gamma_0: float,
omega: float,
n_cycles: int = 2,
n_points_per_cycle: int = 256,
) -> tuple[np.ndarray, np.ndarray]:
"""Simulate LAOS response.
Args:
gamma_0: Strain amplitude
omega: Angular frequency (rad/s)
n_cycles: Number of oscillation cycles
n_points_per_cycle: Points per cycle
Returns:
(strain, stress) arrays
"""
self._gamma_0 = gamma_0
self._omega_laos = omega
period = 2.0 * np.pi / omega
t_max = n_cycles * period
n_points = n_cycles * n_points_per_cycle
t = np.linspace(0, t_max, n_points, endpoint=False)
t_jax = jnp.asarray(t, dtype=jnp.float64)
p_values = {k: self.parameters.get_value(k) for k in self.parameters.keys()}
strain, stress = self._simulate_laos_internal(
t_jax, p_values, gamma_0, omega, self.variant
)
return np.array(strain), np.array(stress)
# =========================================================================
# Bayesian Mixin Interface
# =========================================================================
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""NumPyro/BayesianMixin model function.
Routes to appropriate prediction based on test_mode.
"""
p_values = dict(zip(self.parameters.keys(), params, strict=True))
# Ensure we have a valid mode
mode = test_mode if test_mode is not None else getattr(self, "_test_mode", None)
if mode is None:
raise ValueError(
"test_mode must be set before calling model_function. "
"Call fit() first or pass test_mode explicitly."
)
X_jax = jnp.asarray(X, dtype=jnp.float64)
if mode in ["steady_shear", "rotation", "flow_curve"]:
return self._predict_steady_shear_jit(
X_jax,
p_values["sigma_y"],
p_values["chi_inf"],
p_values["tau0"],
p_values["epsilon0"],
p_values["ez"],
)
elif mode == "oscillation":
return self._predict_saos_jit(
X_jax,
p_values["G0"],
p_values["sigma_y"],
p_values["chi_inf"],
p_values["tau0"],
p_values["epsilon0"],
p_values.get("ez", 1.0),
)
elif mode in ["startup", "relaxation", "creep"]:
# Use sentinel to avoid swallowing falsy values (e.g. gamma_dot=0.0)
_gd = kwargs.get("gamma_dot", _MISSING)
gamma_dot = (
_gd
if _gd is not _MISSING
else getattr(self, "_gamma_dot_applied", None)
)
_sig = kwargs.get("sigma", _MISSING)
if _sig is _MISSING:
_sig = kwargs.get("sigma_applied", _MISSING)
sigma = (
_sig if _sig is not _MISSING else getattr(self, "_sigma_applied", None)
)
_s0 = kwargs.get("sigma_0", _MISSING)
sigma_0 = _s0 if _s0 is not _MISSING else getattr(self, "_sigma_0", None)
return self._simulate_transient_jit(
X_jax,
p_values,
mode,
gamma_dot,
sigma,
sigma_0,
self.variant,
)
elif mode == "laos":
_g0 = kwargs.get("gamma_0", _MISSING)
gamma_0 = _g0 if _g0 is not _MISSING else getattr(self, "_gamma_0", None)
_ol = kwargs.get("omega", _MISSING)
if _ol is _MISSING:
_ol = kwargs.get("omega_laos", _MISSING)
omega_laos = (
_ol if _ol is not _MISSING else getattr(self, "_omega_laos", None)
)
if gamma_0 is None or omega_laos is None:
raise ValueError("LAOS mode requires gamma_0 and omega")
_, stress = self._simulate_laos_internal(
X_jax, p_values, gamma_0, omega_laos, self.variant
)
return stress
raise ValueError(f"Unsupported test_mode for model_function: {mode}")
# =========================================================================
# Prediction Interface
# =========================================================================
def _predict(self, X: np.ndarray, **kwargs: Any) -> np.ndarray:
"""Predict based on fitted state."""
X_jax = jnp.asarray(X, dtype=jnp.float64)
p_values = {k: self.parameters.get_value(k) for k in self.parameters.keys()}
# Extract transient parameters from kwargs if provided (for direct predict without fit)
if self._test_mode in ["startup", "relaxation", "creep"]:
if self._gamma_dot_applied is None:
self._gamma_dot_applied = kwargs.get("gamma_dot")
if self._sigma_applied is None:
self._sigma_applied = kwargs.get("sigma_applied")
if self._sigma_0 is None:
self._sigma_0 = kwargs.get("sigma_0")
if self._test_mode in ["steady_shear", "rotation", "flow_curve"]:
result = self._predict_steady_shear_jit(
X_jax,
p_values["sigma_y"],
p_values["chi_inf"],
p_values["tau0"],
p_values["epsilon0"],
p_values["ez"],
)
return np.array(result)
elif self._test_mode == "oscillation":
result = self._predict_saos_jit(
X_jax,
p_values["G0"],
p_values["sigma_y"],
p_values["chi_inf"],
p_values["tau0"],
p_values["epsilon0"],
p_values.get("ez", 1.0),
)
# Convert (N,2) [G', G''] to complex G* for consistent API
result = np.array(result)
return result[:, 0] + 1j * result[:, 1]
elif self._test_mode in ["startup", "relaxation", "creep"]:
return self._predict_transient(X)
elif self._test_mode == "laos":
# Extract LAOS parameters from kwargs if provided
if self._gamma_0 is None:
self._gamma_0 = kwargs.get("gamma_0")
if self._omega_laos is None:
self._omega_laos = kwargs.get("omega")
if self._gamma_0 is None or self._omega_laos is None:
raise ValueError("LAOS prediction requires gamma_0 and omega")
_, stress = self._simulate_laos_internal(
X_jax, p_values, self._gamma_0, self._omega_laos, self.variant
)
return np.array(stress)
return np.zeros_like(X)