Source code for rheojax.models.stz.conventional

"""STZ Conventional Model Implementation.

This module implements the concrete Shear Transformation Zone (STZ) model,
supporting multiple protocols (Flow, Transient, SAOS, LAOS) via JAX and Diffrax.
"""

from __future__ import annotations

from typing import Any, cast

import numpy as np

from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax

diffrax = lazy_import("diffrax")
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger, log_fit
from rheojax.models.stz._base import STZBase, VariantType
from rheojax.models.stz._kernels import (
    stz_creep_ode_rhs,
    stz_ode_rhs,
)

# Safe JAX import
jax, jnp = safe_import_jax()

# Logger
logger = get_logger(__name__)

_MISSING = object()

# kwargs to filter before passing to nlsq_optimize
_STZ_RESERVED = {
    "test_mode",
    "gamma_dot",
    "sigma_applied",
    "sigma_0",
    "gamma_0",
    "omega",
    "use_log_residuals",
    "use_multi_start",
    "n_starts",
    "perturb_factor",
    "deformation_mode",
    "poisson_ratio",
}


[docs] @ModelRegistry.register( "stz_conventional", protocols=[ Protocol.FLOW_CURVE, Protocol.CREEP, Protocol.RELAXATION, Protocol.STARTUP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class STZConventional(STZBase): """Conventional Shear Transformation Zone (STZ) Model. Implements STZ plasticity with Langer (2008) formulation. Supports Minimal, Standard, and Full complexity variants. Protocols: - Steady-State Flow: Algebraic solution for flow curve - Transient: Diffrax ODE integration for creep/relaxation/startup - SAOS/LAOS: Diffrax ODE integration + FFT for harmonic analysis """
[docs] def __init__(self, variant: VariantType = "standard"): """Initialize STZ Conventional Model. Args: variant: Model variant ('minimal', 'standard', 'full') """ super().__init__(variant=variant) self._test_mode: str | None = None self._gamma_0: float | None = None self._omega_laos: float | None = None self._gamma_dot_applied: float | None = None self._sigma_applied: float | None = None self._sigma_0: float | None = None
def _fit( self, X: np.ndarray, y: np.ndarray, **kwargs, ) -> STZConventional: """Fit STZ model to data. Args: X: Independent variable (time, frequency, or shear rate) y: Dependent variable (stress, modulus, viscosity) **kwargs: Optimizer options. Must include 'test_mode'. """ test_mode = kwargs.get("test_mode") if test_mode is None: # Fallback for compatibility or explicit check if hasattr(self, "_test_mode") and self._test_mode is not None: test_mode = self._test_mode else: raise ValueError("test_mode must be specified for STZ fitting") with log_fit(logger, model="STZConventional", data_shape=X.shape) as ctx: self._test_mode = cast(str, test_mode) ctx["test_mode"] = test_mode ctx["variant"] = self.variant if test_mode in ["steady_shear", "rotation", "flow_curve"]: self._fit_steady_shear(X, y, **kwargs) elif test_mode in ["relaxation", "creep", "startup"]: self._fit_transient(X, y, mode=cast(str, test_mode), **kwargs) elif test_mode in ["laos", "oscillation"]: self._fit_oscillation(X, y, **kwargs) else: raise ValueError(f"Unsupported test_mode: {test_mode}") self.fitted_ = True return self # ========================================================================= # Steady State Flow # ========================================================================= def _fit_steady_shear( self, gamma_dot: np.ndarray, stress: np.ndarray, **kwargs ) -> None: """Fit steady-state flow curve (stress vs shear rate). Args: gamma_dot: Shear rate array (1/s). stress: Shear stress array (Pa). **kwargs: Optimizer options: - use_log_residuals (bool): Whether to fit in log space (default: True). - max_iter (int): Maximum optimization iterations. - ftol (float): Function tolerance. - xtol (float): Parameter tolerance. - gtol (float): Gradient tolerance. """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64) stress_jax = jnp.asarray(stress, dtype=jnp.float64) def model_fn(x_data, params): p_map = dict(zip(self.parameters.keys(), params, strict=True)) # Use parameters directly as JAX scalars sigma_y = p_map["sigma_y"] chi_inf = p_map["chi_inf"] tau0 = p_map["tau0"] epsilon0 = p_map["epsilon0"] ez = p_map["ez"] return self._predict_steady_shear_jit( x_data, sigma_y, chi_inf, tau0, epsilon0, ez, ) objective = create_least_squares_objective( model_fn, gamma_dot_jax, stress_jax, use_log_residuals=kwargs.get("use_log_residuals", True), ) filtered = {k: v for k, v in kwargs.items() if k not in _STZ_RESERVED} result = nlsq_optimize(objective, self.parameters, **filtered) if not result.success: raise RuntimeError(f"STZ steady shear fit failed: {result.message}") @staticmethod @jax.jit def _predict_steady_shear_jit(gamma_dot, sigma_y, chi_inf, tau0, epsilon0, ez): """Analytical steady-state flow curve prediction. At steady state (Langer 2008): - chi -> chi_inf - Lambda_ss = exp(-ez / chi_inf) - gamma_dot = (2*epsilon0/tau0) * Lambda_ss * cosh(s/sy) * tanh(s/sy) = (2*epsilon0/tau0) * Lambda_ss * sinh(s/sy) - Inverting: sigma = sigma_y * arcsinh(gamma_dot * tau0 / (2*epsilon0*Lambda_ss)) """ Lambda_ss = jnp.exp(-ez / chi_inf) prefactor = 2.0 * epsilon0 * Lambda_ss + 1e-30 arg = (gamma_dot * tau0) / prefactor sigma = sigma_y * jnp.arcsinh(arg) return sigma # ========================================================================= # Transient (ODE) - Startup, Relaxation, Creep # ========================================================================= def _fit_transient(self, t: np.ndarray, y: np.ndarray, mode: str, **kwargs) -> None: """Fit transient response (Stress Growth / Relaxation / Creep). Args: t: Time array (s). y: Response data (stress for startup/relaxation, strain for creep). mode: 'startup', 'relaxation', or 'creep'. **kwargs: Protocol-specific inputs and optimizer options: - gamma_dot (float): Applied shear rate for startup (required). - sigma_0 (float): Initial stress for relaxation (optional). - sigma_applied (float): Applied stress for creep (required). - use_log_residuals (bool): Log-space fitting (default: False). - max_iter (int): Maximum optimization iterations. """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) t_jax = jnp.asarray(t, dtype=jnp.float64) # Preserve complex dtype for oscillation data (G* = G' + iG'') y_arr = np.asarray(y) if np.iscomplexobj(y_arr): y_jax = jnp.asarray(y_arr, dtype=jnp.complex128) else: y_jax = jnp.asarray(y_arr, dtype=jnp.float64) # Extract protocol-specific inputs (use .get() to avoid mutating caller's dict) gamma_dot = kwargs.get("gamma_dot", None) sigma_applied = kwargs.get("sigma_applied", None) sigma_0 = kwargs.get("sigma_0", None) if mode == "startup" and gamma_dot is None: raise ValueError("startup mode requires gamma_dot in kwargs") if mode == "creep" and sigma_applied is None: raise ValueError("creep mode requires sigma_applied in kwargs") # Store for prediction and NUTS self._gamma_dot_applied = gamma_dot self._sigma_applied = sigma_applied self._sigma_0 = sigma_0 # Build model function that uses ODE integration def model_fn(x_data, params): p_map = dict(zip(self.parameters.keys(), params, strict=True)) # Convert params to dict of floats/arrays p_dict = p_map return self._simulate_transient_jit( x_data, p_dict, mode, gamma_dot, sigma_applied, sigma_0, self.variant ) objective = create_least_squares_objective( model_fn, t_jax, y_jax, use_log_residuals=kwargs.get("use_log_residuals", False), ) filtered = {k: v for k, v in kwargs.items() if k not in _STZ_RESERVED} result = nlsq_optimize(objective, self.parameters, **filtered) if not result.success: logger.warning(f"STZ transient fit warning: {result.message}") def _simulate_transient_jit( self, t: jnp.ndarray, params: dict, mode: str, gamma_dot: float | None, sigma_applied: float | None, sigma_0: float | None, variant: str, ) -> jnp.ndarray: """Simulate transient response using Diffrax ODE integration. Args: t: Time array params: Parameter dictionary mode: 'startup', 'relaxation', or 'creep' gamma_dot: Applied shear rate (for startup) sigma_applied: Applied stress (for creep) sigma_0: Initial stress (for relaxation) variant: Model variant Returns: Stress (for startup/relaxation) or strain (for creep) """ # R11-STZ-001: `variant` must remain a Python-level static dispatch key. # DO NOT move it into the ODE args dict — strings are not valid JAX types # and will crash under jax.checkpoint. # Build args dict for stz_ode_rhs args = { "G0": params["G0"], "sigma_y": params["sigma_y"], "tau0": params["tau0"], "epsilon0": params["epsilon0"], "chi_inf": params["chi_inf"], "c0": params["c0"], "ez": params.get("ez", 1.0), } # Add variant-specific parameters if variant in ["standard", "full"]: args["tau_beta"] = params.get("tau_beta", params["tau0"] * 100) if variant == "full": args["m_inf"] = params.get("m_inf", 0.1) args["rate_m"] = params.get("rate_m", 1.0) # Set up initial conditions based on mode chi_init = 0.05 # Annealed state ez = params.get("ez", 1.0) Lambda_init = jnp.exp(-ez / chi_init) # Define ODE function and initial state based on mode if mode == "creep": # Creep: Constant stress, measure strain # State vector: [strain, chi, Lambda, m] (strain replaces stress) ode_fn = stz_creep_ode_rhs args["sigma_applied"] = sigma_applied if sigma_applied is not None else 0.0 # Strain starts at 0 y0_val = 0.0 # Initial state construction if variant == "minimal": y0 = jnp.array([y0_val, chi_init]) elif variant == "standard": y0 = jnp.array([y0_val, chi_init, Lambda_init]) else: # full y0 = jnp.array([y0_val, chi_init, Lambda_init, 0.0]) else: # Startup/Relaxation: Controlled rate, measure stress # State vector: [stress, chi, Lambda, m] ode_fn = stz_ode_rhs if mode == "startup": # Strain-controlled: apply constant gamma_dot, measure stress args["gamma_dot"] = gamma_dot sigma_init = 0.0 else: # relaxation # Strain-controlled: gamma_dot = 0, initial stress decays args["gamma_dot"] = 0.0 sigma_init = sigma_0 if sigma_0 is not None else params["sigma_y"] chi_init = params["chi_inf"] # Start at steady-state chi Lambda_init = jnp.exp(-ez / chi_init) # Initial state construction if variant == "minimal": y0 = jnp.array([sigma_init, chi_init]) elif variant == "standard": y0 = jnp.array([sigma_init, chi_init, Lambda_init]) else: # full y0 = jnp.array([sigma_init, chi_init, Lambda_init, 0.0]) # Set up Diffrax solver # Wrap with checkpoint to reduce VJP memory during NUTS reverse-mode AD def _rhs(ti, yi, args_i): return ode_fn(cast(float, ti), yi, args_i) term = diffrax.ODETerm(jax.checkpoint(_rhs)) solver = diffrax.Tsit5() stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6) t0 = t[0] t1 = t[-1] dt0 = (t1 - t0) / max(len(t), 1000) saveat = diffrax.SaveAt(ts=t) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, stepsize_controller=stepsize_controller, max_steps=10_000_000, throw=False, ) # Extract primary variable (index 0) # For creep, this is strain. For others, this is stress. result = sol.ys[:, 0] # Handle solver failures result = jnp.where( sol.result == diffrax.RESULTS.successful, result, jnp.nan * jnp.ones_like(result), ) return result def _predict_transient(self, t: np.ndarray, mode: str | None = None) -> np.ndarray: """Predict transient response.""" t_jax = jnp.asarray(t, dtype=jnp.float64) p_values = {k: self.parameters.get_value(k) for k in self.parameters.keys()} mode = mode if mode is not None else self._test_mode if mode is None: raise ValueError("Test mode not specified for prediction") result = self._simulate_transient_jit( t_jax, p_values, mode, self._gamma_dot_applied, self._sigma_applied, self._sigma_0, self.variant, ) return np.array(result) # ========================================================================= # SAOS / LAOS (ODE + FFT) # ========================================================================= def _fit_oscillation(self, X: np.ndarray, y: np.ndarray, **kwargs) -> None: """Fit oscillation data (SAOS or LAOS). Routes to specific fitting method based on strain amplitude `gamma_0`. If `gamma_0 > 0.01` (1%), uses LAOS mode (full ODE). Otherwise uses SAOS mode (linear approximation). Args: X: Frequency array (rad/s) for SAOS, or time array for LAOS. y: Complex modulus [G', G''] for SAOS, or stress for LAOS. **kwargs: Protocol parameters: - gamma_0 (float): Strain amplitude (optional, triggers LAOS if > 0.01). - omega (float): Angular frequency (required if gamma_0 provided). - use_log_residuals (bool): Log-space fitting (default varies). """ gamma_0 = kwargs.get("gamma_0", None) omega = kwargs.get("omega", None) # Store for prediction self._gamma_0 = gamma_0 self._omega_laos = omega if gamma_0 is not None and gamma_0 > 0.01: # LAOS mode - full ODE integration self._fit_laos_mode(X, y, gamma_0, omega, **kwargs) else: # SAOS mode - linear viscoelastic approximation self._fit_saos_mode(X, y, **kwargs) def _fit_saos_mode(self, omega: np.ndarray, G_star: np.ndarray, **kwargs) -> None: """Fit SAOS data using linear viscoelastic approximation. In SAOS limit, STZ behaves like a Maxwell-like viscoelastic solid. G*(omega) approximated from steady-state chi and Lambda. Args: omega: Angular frequency array (rad/s). G_star: Complex modulus data (complex array or [N, 2] array). **kwargs: Optimizer options: - normalize (bool): Normalize residuals (default: True). - max_iter (int): Maximum optimization iterations. """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) omega_jax = jnp.asarray(omega, dtype=jnp.float64) # Handle G_star format G_star_np = np.asarray(G_star) if np.iscomplexobj(G_star_np): G_star_2d = np.column_stack([np.real(G_star_np), np.imag(G_star_np)]) elif G_star_np.ndim == 2 and G_star_np.shape[1] == 2: G_star_2d = G_star_np else: raise ValueError(f"G_star must be complex or (M, 2), got {G_star_np.shape}") G_star_jax = jnp.asarray(G_star_2d, dtype=jnp.float64) def model_fn(x_data, params): p_map = dict(zip(self.parameters.keys(), params, strict=True)) # Extract parameters G0 = p_map["G0"] sigma_y = p_map["sigma_y"] chi_inf = p_map["chi_inf"] tau0 = p_map["tau0"] epsilon0 = p_map["epsilon0"] ez = p_map.get("ez", 1.0) return self._predict_saos_jit( x_data, G0, sigma_y, chi_inf, tau0, epsilon0, ez, ) objective = create_least_squares_objective( model_fn, omega_jax, G_star_jax, normalize=True, ) filtered = {k: v for k, v in kwargs.items() if k not in _STZ_RESERVED} result = nlsq_optimize(objective, self.parameters, **filtered) if not result.success: logger.warning(f"STZ SAOS fit warning: {result.message}") @staticmethod @jax.jit def _predict_saos_jit(omega, G0, sigma_y, chi_inf, tau0, epsilon0, ez): """SAOS prediction using linear viscoelastic approximation. In the linear limit (small strain), the STZ plastic rate linearizes as: gamma_dot_pl ≈ (2*epsilon0/tau0) * Lambda_ss * (sigma / sigma_y) Combined with ds/dt = G0*(gamma_dot - gamma_dot_pl), this gives a Maxwell model with effective relaxation time: tau_M = tau0 * sigma_y / (2 * epsilon0 * Lambda_ss * G0) """ # At steady state chi -> chi_inf Lambda_ss = jnp.exp(-ez / chi_inf) # Effective Maxwell relaxation time (Langer 2008, linearized) tau_eff = (tau0 * sigma_y) / (2.0 * epsilon0 * Lambda_ss * G0 + 1e-30) # Maxwell model: G* = G0 * (i * omega * tau) / (1 + i * omega * tau) omega_tau = omega * tau_eff denom = 1.0 + omega_tau**2 G_prime = G0 * omega_tau**2 / denom G_double_prime = G0 * omega_tau / denom return jnp.stack([G_prime, G_double_prime], axis=1) def _fit_laos_mode( self, t: np.ndarray, sigma: np.ndarray, gamma_0: float, omega: float, **kwargs, ) -> None: """Fit LAOS data using full ODE integration + FFT. Args: t: Time array (s). sigma: Stress response array (Pa). gamma_0: Strain amplitude. omega: Angular frequency (rad/s). **kwargs: Optimizer options: - normalize (bool): Normalize residuals (default: True). - max_iter (int): Maximum optimization iterations. """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) t_jax = jnp.asarray(t, dtype=jnp.float64) sigma_jax = jnp.asarray(sigma, dtype=jnp.float64) def model_fn(x_data, params): p_map = dict(zip(self.parameters.keys(), params, strict=True)) # Convert params to dict p_dict = p_map _, stress = self._simulate_laos_internal( x_data, p_dict, gamma_0, omega, self.variant ) return stress objective = create_least_squares_objective( model_fn, t_jax, sigma_jax, normalize=True, ) filtered = {k: v for k, v in kwargs.items() if k not in _STZ_RESERVED} result = nlsq_optimize(objective, self.parameters, **filtered) if not result.success: logger.warning(f"STZ LAOS fit warning: {result.message}") def _simulate_laos_internal( self, t: jnp.ndarray, params: dict, gamma_0: float, omega: float, variant: str, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Simulate LAOS response using Diffrax. Args: t: Time array params: Parameter dictionary gamma_0: Strain amplitude omega: Angular frequency variant: Model variant Returns: (strain, stress) arrays """ # Strain input: gamma(t) = gamma_0 * sin(omega * t) # Strain rate: gamma_dot(t) = gamma_0 * omega * cos(omega * t) # Build args with time-varying gamma_dot # We need to pass a function for gamma_dot, but stz_ode_rhs expects scalar # Solution: use a wrapper that interpolates base_args = { "G0": params["G0"], "sigma_y": params["sigma_y"], "tau0": params["tau0"], "epsilon0": params["epsilon0"], "chi_inf": params["chi_inf"], "c0": params["c0"], "ez": params.get("ez", 1.0), } if variant in ["standard", "full"]: base_args["tau_beta"] = params.get("tau_beta", params["tau0"] * 100) if variant == "full": base_args["m_inf"] = params.get("m_inf", 0.1) base_args["rate_m"] = params.get("rate_m", 1.0) # Initial conditions chi_init = params["chi_inf"] # Start at steady state for LAOS ez = params.get("ez", 1.0) Lambda_init = jnp.exp(-ez / chi_init) sigma_init = 0.0 if variant == "minimal": y0 = jnp.array([sigma_init, chi_init]) elif variant == "standard": y0 = jnp.array([sigma_init, chi_init, Lambda_init]) else: y0 = jnp.array([sigma_init, chi_init, Lambda_init, 0.0]) # Define ODE term with time-varying gamma_dot def laos_ode(ti, yi, args_i): gamma_dot_t = gamma_0 * omega * jnp.cos(omega * ti) args_with_rate = {**args_i, "gamma_dot": gamma_dot_t} return stz_ode_rhs(ti, yi, args_with_rate) # Wrap with checkpoint to reduce VJP memory during NUTS reverse-mode AD term = diffrax.ODETerm(jax.checkpoint(laos_ode)) solver = diffrax.Tsit5() stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6) t0 = t[0] t1 = t[-1] dt0 = (t1 - t0) / max(len(t), 1000) saveat = diffrax.SaveAt(ts=t) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, args=base_args, saveat=saveat, stepsize_controller=stepsize_controller, max_steps=10_000_000, throw=False, ) # Extract stress stress = sol.ys[:, 0] # Handle solver failures stress = jnp.where( sol.result == diffrax.RESULTS.successful, stress, jnp.nan * jnp.ones_like(stress), ) # Compute strain strain = gamma_0 * jnp.sin(omega * t) return strain, stress
[docs] def simulate_laos( self, gamma_0: float, omega: float, n_cycles: int = 2, n_points_per_cycle: int = 256, ) -> tuple[np.ndarray, np.ndarray]: """Simulate LAOS response. Args: gamma_0: Strain amplitude omega: Angular frequency (rad/s) n_cycles: Number of oscillation cycles n_points_per_cycle: Points per cycle Returns: (strain, stress) arrays """ self._gamma_0 = gamma_0 self._omega_laos = omega period = 2.0 * np.pi / omega t_max = n_cycles * period n_points = n_cycles * n_points_per_cycle t = np.linspace(0, t_max, n_points, endpoint=False) t_jax = jnp.asarray(t, dtype=jnp.float64) p_values = {k: self.parameters.get_value(k) for k in self.parameters.keys()} strain, stress = self._simulate_laos_internal( t_jax, p_values, gamma_0, omega, self.variant ) return np.array(strain), np.array(stress)
[docs] def extract_harmonics( self, stress: np.ndarray, n_points_per_cycle: int = 256, ) -> dict: """Extract Fourier harmonics from LAOS stress response. Args: stress: Stress array from simulate_laos n_points_per_cycle: Points per cycle Returns: Dictionary with I_1, I_3, I_5 amplitudes and ratios """ # Use last complete cycle stress_cycle = stress[-n_points_per_cycle:] # FFT stress_fft = np.fft.fft(stress_cycle) n = len(stress_cycle) harmonics = {} # Fundamental (n=1) I_1 = 2.0 * np.abs(stress_fft[1]) / n harmonics["I_1"] = I_1 # Third harmonic if 3 < n // 2: I_3 = 2.0 * np.abs(stress_fft[3]) / n else: I_3 = 0.0 harmonics["I_3"] = I_3 # Fifth harmonic if 5 < n // 2: I_5 = 2.0 * np.abs(stress_fft[5]) / n else: I_5 = 0.0 harmonics["I_5"] = I_5 # Ratios if I_1 > 0: harmonics["I_3_I_1"] = I_3 / I_1 harmonics["I_5_I_1"] = I_5 / I_1 else: harmonics["I_3_I_1"] = 0.0 harmonics["I_5_I_1"] = 0.0 return harmonics
# ========================================================================= # Bayesian Mixin Interface # =========================================================================
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """NumPyro/BayesianMixin model function. Routes to appropriate prediction based on test_mode. """ p_values = dict(zip(self.parameters.keys(), params, strict=True)) # Ensure we have a valid mode mode = test_mode if test_mode is not None else getattr(self, "_test_mode", None) if mode is None: raise ValueError( "test_mode must be set before calling model_function. " "Call fit() first or pass test_mode explicitly." ) X_jax = jnp.asarray(X, dtype=jnp.float64) if mode in ["steady_shear", "rotation", "flow_curve"]: return self._predict_steady_shear_jit( X_jax, p_values["sigma_y"], p_values["chi_inf"], p_values["tau0"], p_values["epsilon0"], p_values["ez"], ) elif mode == "oscillation": return self._predict_saos_jit( X_jax, p_values["G0"], p_values["sigma_y"], p_values["chi_inf"], p_values["tau0"], p_values["epsilon0"], p_values.get("ez", 1.0), ) elif mode in ["startup", "relaxation", "creep"]: # Use sentinel to avoid swallowing falsy values (e.g. gamma_dot=0.0) _gd = kwargs.get("gamma_dot", _MISSING) gamma_dot = ( _gd if _gd is not _MISSING else getattr(self, "_gamma_dot_applied", None) ) _sig = kwargs.get("sigma", _MISSING) if _sig is _MISSING: _sig = kwargs.get("sigma_applied", _MISSING) sigma = ( _sig if _sig is not _MISSING else getattr(self, "_sigma_applied", None) ) _s0 = kwargs.get("sigma_0", _MISSING) sigma_0 = _s0 if _s0 is not _MISSING else getattr(self, "_sigma_0", None) return self._simulate_transient_jit( X_jax, p_values, mode, gamma_dot, sigma, sigma_0, self.variant, ) elif mode == "laos": _g0 = kwargs.get("gamma_0", _MISSING) gamma_0 = _g0 if _g0 is not _MISSING else getattr(self, "_gamma_0", None) _ol = kwargs.get("omega", _MISSING) if _ol is _MISSING: _ol = kwargs.get("omega_laos", _MISSING) omega_laos = ( _ol if _ol is not _MISSING else getattr(self, "_omega_laos", None) ) if gamma_0 is None or omega_laos is None: raise ValueError("LAOS mode requires gamma_0 and omega") _, stress = self._simulate_laos_internal( X_jax, p_values, gamma_0, omega_laos, self.variant ) return stress raise ValueError(f"Unsupported test_mode for model_function: {mode}")
# ========================================================================= # Prediction Interface # ========================================================================= def _predict(self, X: np.ndarray, **kwargs: Any) -> np.ndarray: """Predict based on fitted state.""" X_jax = jnp.asarray(X, dtype=jnp.float64) p_values = {k: self.parameters.get_value(k) for k in self.parameters.keys()} # Extract transient parameters from kwargs if provided (for direct predict without fit) if self._test_mode in ["startup", "relaxation", "creep"]: if self._gamma_dot_applied is None: self._gamma_dot_applied = kwargs.get("gamma_dot") if self._sigma_applied is None: self._sigma_applied = kwargs.get("sigma_applied") if self._sigma_0 is None: self._sigma_0 = kwargs.get("sigma_0") if self._test_mode in ["steady_shear", "rotation", "flow_curve"]: result = self._predict_steady_shear_jit( X_jax, p_values["sigma_y"], p_values["chi_inf"], p_values["tau0"], p_values["epsilon0"], p_values["ez"], ) return np.array(result) elif self._test_mode == "oscillation": result = self._predict_saos_jit( X_jax, p_values["G0"], p_values["sigma_y"], p_values["chi_inf"], p_values["tau0"], p_values["epsilon0"], p_values.get("ez", 1.0), ) # Convert (N,2) [G', G''] to complex G* for consistent API result = np.array(result) return result[:, 0] + 1j * result[:, 1] elif self._test_mode in ["startup", "relaxation", "creep"]: return self._predict_transient(X) elif self._test_mode == "laos": # Extract LAOS parameters from kwargs if provided if self._gamma_0 is None: self._gamma_0 = kwargs.get("gamma_0") if self._omega_laos is None: self._omega_laos = kwargs.get("omega") if self._gamma_0 is None or self._omega_laos is None: raise ValueError("LAOS prediction requires gamma_0 and omega") _, stress = self._simulate_laos_internal( X_jax, p_values, self._gamma_0, self._omega_laos, self.variant ) return np.array(stress) return np.zeros_like(X)