"""Generalized Maxwell Model (Prony series) for multi-mode viscoelastic relaxation.
The Generalized Maxwell Model (GMM) extends the single Maxwell element to N modes,
providing a flexible framework for capturing complex relaxation spectra:
E(t) = E_∞ + Σᵢ₌₁ᴺ Eᵢ exp(-t/τᵢ)
Key features:
- Tri-mode equality: relaxation, oscillation, and creep predictions
- Two-step NLSQ fitting with softmax penalty for physical constraints
- Transparent element minimization (auto-optimize N)
- Bayesian inference via NumPyro NUTS with warm-start
- Tiered Bayesian prior safety mechanism (fail-fast on bad NLSQ convergence)
- JIT-compiled predictions for GPU acceleration
References:
- Park, S. W., & Schapery, R. A. (1999). Methods of interconversion between
linear viscoelastic material functions. Part I—A numerical method based on
Prony series. International Journal of Solids and Structures, 36(11), 1653-1675.
- pyvisco: https://github.com/saintsfan342000/pyvisco
"""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, cast
import nlsq
import numpy as np
from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
# Lazy import diffrax for transient simulations (deferred to avoid ~250ms startup cost)
from rheojax.core.jax_config import lazy_import as _lazy_import
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode
from rheojax.logging import get_logger, log_fit
from rheojax.utils.optimization import OptimizationResult
from rheojax.utils.prony import (
compute_r_squared,
create_prony_parameter_set,
select_optimal_n,
softmax_penalty,
)
diffrax = _lazy_import("diffrax")
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
if TYPE_CHECKING: # pragma: no cover
import jax.numpy as jnp_typing
else:
jnp_typing = np
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"generalized_maxwell",
protocols=[
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.OSCILLATION,
Protocol.FLOW_CURVE,
Protocol.STARTUP,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class GeneralizedMaxwell(BaseModel):
"""Generalized Maxwell Model with N exponential relaxation modes.
The GMM uses Prony series representation for tri-mode viscoelastic behavior:
**Relaxation mode:**
E(t) = E_∞ + Σᵢ₌₁ᴺ Eᵢ exp(-t/τᵢ)
**Oscillation mode (closed-form Fourier transform):**
E'(ω) = E_∞ + Σᵢ Eᵢ (ωτᵢ)²/(1+(ωτᵢ)²)
E"(ω) = Σᵢ Eᵢ (ωτᵢ)/(1+(ωτᵢ)²)
**Creep mode (numerical simulation):**
J(t) = ε(t)/σ₀ via backward-Euler integration
**Performance Optimization (v0.4.0+):**
Element minimization workflows use warm-start optimization for 2-5x speedup:
- Successive fits initialized from optimal N+1 parameters
- Compilation reuse across n_modes iterations
- Early termination when R² degrades below threshold
- Transparent optimization (no API changes required)
- Typical speedup: 20-50s → 4-25s for N=10 element search
Parameters:
n_modes: Number of relaxation modes (N)
modulus_type: 'shear' (G) or 'tensile' (E)
Attributes:
parameters: ParameterSet containing E_inf, E_i, tau_i (or G equivalents)
Example:
>>> from rheojax.models.generalized_maxwell import GeneralizedMaxwell
>>> import numpy as np
>>> model = GeneralizedMaxwell(n_modes=3, modulus_type='shear')
>>> t = np.logspace(-3, 2, 50)
>>> G_data = ... # Relaxation modulus data
>>> model.fit(t, G_data, test_mode='relaxation', optimization_factor=1.5)
>>> G_pred = model.predict(t)
>>> # Element minimization automatically uses warm-start for 2-5x speedup
>>> print(f"Optimal modes: {model._n_modes}") # Auto-reduced from 3
"""
[docs]
def __init__(self, n_modes: int = 3, modulus_type: str = "shear"):
"""Initialize Generalized Maxwell Model.
Args:
n_modes: Number of exponential relaxation modes (N ≥ 1)
modulus_type: 'shear' for G (default) or 'tensile' for E
Raises:
ValueError: If n_modes < 1 or modulus_type invalid
"""
super().__init__()
if n_modes < 1:
raise ValueError(f"n_modes must be ≥ 1, got {n_modes}")
if modulus_type not in ["shear", "tensile"]:
raise ValueError(
f"modulus_type must be 'shear' or 'tensile', got '{modulus_type}'"
)
self._n_modes = n_modes
self._modulus_type = modulus_type
self._test_mode: TestMode | str | None = None
# Create Prony parameter set
self.parameters = create_prony_parameter_set(n_modes, modulus_type)
# Store NLSQ result for warm-start and diagnostics
self._nlsq_result: OptimizationResult | None = None
# Store element minimization diagnostics
self._element_minimization_diagnostics: dict[str, object] | None = None
def _fit(
self,
X: np.ndarray,
y: np.ndarray,
test_mode: str | None = None,
optimization_factor: float | None = 1.5,
**kwargs,
) -> None:
"""Fit GMM to data using NLSQ optimization.
Args:
X: Independent variable (time or frequency)
y: Dependent variable (modulus or compliance)
test_mode: Test mode ('relaxation', 'oscillation', 'creep')
optimization_factor: R² threshold multiplier for element minimization (None to disable)
**kwargs: NLSQ optimizer arguments (max_iter, ftol, xtol, gtol)
Raises:
ValueError: If test_mode not provided or invalid
"""
# Detect test mode
if test_mode is None:
logger.error("test_mode must be specified for GMM fitting")
raise ValueError("test_mode must be specified for GMM fitting")
self._test_mode = test_mode
with log_fit(
logger,
self.__class__.__name__,
data_shape=X.shape,
test_mode=test_mode,
n_modes=self._n_modes,
modulus_type=self._modulus_type,
) as ctx:
logger.debug(
"Processing GMM input data",
x_range=(float(X.min()), float(X.max())),
y_range=(float(np.real(y).min()), float(np.real(y).max())),
optimization_factor=optimization_factor,
)
# Route to appropriate fitting method
try:
if test_mode == "relaxation":
self._fit_relaxation_mode(
X, y, optimization_factor=optimization_factor, **kwargs
)
elif test_mode == "oscillation":
self._fit_oscillation_mode(
X, y, optimization_factor=optimization_factor, **kwargs
)
elif test_mode == "creep":
self._fit_creep_mode(
X, y, optimization_factor=optimization_factor, **kwargs
)
elif test_mode == "steady_shear":
self._fit_steady_shear_mode(
X, y, optimization_factor=optimization_factor, **kwargs
)
elif test_mode == "startup":
self._fit_startup_mode(
X, y, optimization_factor=optimization_factor, **kwargs
)
elif test_mode == "laos":
self._fit_laos_mode(
X, y, optimization_factor=optimization_factor, **kwargs
)
else:
logger.error("Unknown test_mode", test_mode=test_mode)
raise ValueError(f"Unknown test_mode: {test_mode}")
except Exception as e:
logger.error(
"GMM fitting failed",
error_type=type(e).__name__,
error_message=str(e),
exc_info=True,
)
raise
# Log fitted parameters
symbol = "E" if self._modulus_type == "tensile" else "G"
ctx["n_modes_final"] = self._n_modes
ctx[f"{symbol}_inf"] = self.parameters.get_value(f"{symbol}_inf")
logger.debug(
"GMM fitting completed",
n_modes_final=self._n_modes,
modulus_inf=self.parameters.get_value(f"{symbol}_inf"),
)
def _nlsq_fit(
self,
objective,
x0,
bounds,
max_nfev=1000,
ftol=1e-6,
xtol=1e-6,
gtol=1e-6,
y_data=None,
) -> OptimizationResult:
"""NLSQ wrapper for consistent fitting across modes.
Args:
objective: Residual function
x0: Initial parameter guess
bounds: (lower, upper) parameter bounds
max_nfev: Maximum function evaluations
ftol: Function tolerance
xtol: Parameter tolerance
gtol: Gradient tolerance
y_data: Optional raw dependent-variable array. When provided,
gets attached to the result so ``r_squared`` computes correctly.
If absent, also tries ``getattr(objective, "_y_data", None)``.
Returns:
OptimizationResult with fitted parameters and diagnostics
"""
logger.debug(
"Starting NLSQ optimization",
n_params=len(x0),
max_nfev=max_nfev,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
ls = nlsq.LeastSquares()
try:
nlsq_result = ls.least_squares(
objective,
x0=np.asarray(x0),
bounds=bounds,
method="trf",
ftol=ftol,
xtol=xtol,
gtol=gtol,
max_nfev=max_nfev,
verbose=0,
)
except ValueError as e:
# Handle infeasible initial guess
logger.error(
"NLSQ optimization failed with ValueError",
error_message=str(e),
exc_info=True,
)
raise RuntimeError(
f"NLSQ optimization failed with error: {e}\n"
"This may indicate:\n"
" 1. Data is unsuitable for GMM fitting (e.g., constant values)\n"
" 2. Initial parameter guess is outside bounds\n"
" 3. Too many modes for the available data"
) from e
# OPT-YDATA-001: compute residuals at optimum and attach y_data so
# ``r_squared`` works. Without this the GMM custom path leaves
# residuals=None and y_data=None, masking fit success as r_squared=None.
from rheojax.utils.optimization import attach_y_data_to_result
try:
_final_res = np.asarray(objective(nlsq_result.x))
if np.iscomplexobj(_final_res):
_final_res = np.concatenate(
[np.real(_final_res), np.imag(_final_res)]
)
_final_res = _final_res.astype(np.float64)
except Exception: # pragma: no cover - defensive
_final_res = None
# Convert to OptimizationResult
result = OptimizationResult(
x=np.asarray(nlsq_result.x),
fun=nlsq_result.cost,
jac=np.asarray(nlsq_result.jac) if nlsq_result.jac is not None else None,
success=nlsq_result.success,
message=nlsq_result.message,
nit=nlsq_result.nfev,
nfev=nlsq_result.nfev,
njev=nlsq_result.njev if hasattr(nlsq_result, "njev") else 0,
optimality=(
nlsq_result.optimality if hasattr(nlsq_result, "optimality") else None
),
active_mask=(
nlsq_result.active_mask if hasattr(nlsq_result, "active_mask") else None
),
cost=nlsq_result.cost,
grad=(
np.asarray(nlsq_result.grad)
if hasattr(nlsq_result, "grad") and nlsq_result.grad is not None
else None
),
nlsq_result=nlsq_result,
residuals=_final_res,
)
# Prefer the explicit y_data argument; fall back to one stashed on
# the objective itself by the caller, then on self (set by the
# _fit_*_mode method that originated the call).
_y = y_data if y_data is not None else getattr(objective, "_y_data", None)
if _y is None:
_y = getattr(self, "_current_y_data", None)
attach_y_data_to_result(result, _y)
logger.debug(
"NLSQ optimization completed",
success=result.success,
cost=result.cost,
nfev=result.nfev,
message=result.message,
)
return result
def _fit_relaxation_mode(
self,
t: np.ndarray,
E_t: np.ndarray,
optimization_factor: float | None = 1.5,
initial_params: np.ndarray | None = None,
**kwargs,
) -> None:
"""Fit GMM to relaxation modulus data.
Args:
t: Time array
E_t: Relaxation modulus array
optimization_factor: R² threshold multiplier for element minimization
initial_params: Optional initial parameter guess for warm-start
Shape: (2*n_modes + 1,) [E_inf, E_1...E_N, tau_1...tau_N]
If None, uses default heuristic initialization
**kwargs: NLSQ optimizer arguments
"""
# OPT-YDATA-001: stash y_data so _nlsq_fit can attach it to the
# OptimizationResult and ``r_squared`` works on _nlsq_result.
self._current_y_data = np.asarray(E_t)
# Extract kwargs
max_iter = kwargs.get("max_iter", 1000)
ftol = kwargs.get("ftol", 1e-6)
xtol = kwargs.get("xtol", 1e-6)
gtol = kwargs.get("gtol", 1e-6)
use_log_residuals = kwargs.get("use_log_residuals", False)
symbol = "E" if self._modulus_type == "tensile" else "G"
# Precompute log-space observation once when using log residuals.
_log_E_t = jnp.log10(jnp.maximum(jnp.asarray(E_t), 1e-30))
# Define objective function
def objective(params):
"""Residual for relaxation modulus.
Uses log-space residuals when ``use_log_residuals`` is set, so that
master curves spanning many decades in E(t) weight every decade
equally instead of being dominated by the glassy plateau.
"""
E_inf = params[0]
E_i = params[1 : 1 + self._n_modes]
tau_i = params[1 + self._n_modes :]
# Predict relaxation modulus
E_pred = self._predict_relaxation_jit(jnp.asarray(t), E_inf, E_i, tau_i)
if use_log_residuals:
return jnp.log10(jnp.maximum(E_pred, 1e-30)) - _log_E_t
return E_pred - E_t
# Derive tau range from the actual time data so master curves spanning
# many decades are fittable. Previously hardcoded to logspace(-2, 2),
# which silently truncated any t-range outside [0.01, 100] s.
t_np = np.asarray(t)
t_pos = t_np[t_np > 0]
if t_pos.size > 0:
log_t_lo = float(np.log10(t_pos.min()))
log_t_hi = float(np.log10(t_pos.max()))
else:
log_t_lo, log_t_hi = -2.0, 2.0
# Pad by one decade on each side and clamp to a safe numerical floor.
tau_lo_bound = max(10.0 ** (log_t_lo - 2.0), 1e-30)
tau_hi_bound = 10.0 ** (log_t_hi + 2.0)
if self._n_modes == 1:
tau_guess_arr = jnp.array([10.0 ** (0.5 * (log_t_lo + log_t_hi))])
else:
tau_guess_arr = jnp.logspace(log_t_lo, log_t_hi, self._n_modes)
# Always compute derivative-based heuristic guesses so that the
# multi-start retry block below can use them even when the caller
# supplied ``initial_params``.
E_inf_guess = jnp.min(E_t)
E_sum_guess = jnp.max(E_t) - E_inf_guess
# Derivative-based initial E_i: estimate the contribution from each
# tau bin using the local drop in E(t) around that tau. This breaks
# the uniform-guess Jacobian degeneracy at high n_modes.
t_arr = np.asarray(t)
E_arr = np.asarray(E_t)
tau_arr = np.asarray(tau_guess_arr)
if t_arr.size >= 2 and self._n_modes > 1:
order = np.argsort(t_arr)
t_sorted = t_arr[order]
E_sorted = E_arr[order]
log_t = np.log(np.maximum(t_sorted, 1e-30))
dEdlogt = np.gradient(E_sorted, log_t)
contrib = np.interp(
np.log(np.clip(tau_arr, t_sorted[0], t_sorted[-1])),
log_t,
-dEdlogt,
)
contrib = np.clip(contrib, 1e-6, None)
contrib_sum = float(contrib.sum())
total = float(E_sum_guess)
if contrib_sum > 0 and total > 0:
E_i_guess = jnp.asarray(contrib * (total / contrib_sum))
else:
E_i_guess = jnp.full(self._n_modes, total / max(self._n_modes, 1))
else:
E_i_guess = jnp.full(
self._n_modes,
E_sum_guess / max(self._n_modes, 1),
)
tau_i_guess = tau_guess_arr
if initial_params is not None:
x0 = jnp.asarray(initial_params)
else:
x0 = jnp.concatenate([jnp.array([E_inf_guess]), E_i_guess, tau_i_guess])
# Parameter bounds — use data-derived tau range (with wide padding) so
# the optimizer can actually reach the relaxation times in the data.
bounds_lower = jnp.concatenate(
[
jnp.array([0.0]),
jnp.full(self._n_modes, 1e-12),
jnp.full(self._n_modes, tau_lo_bound),
]
)
bounds_upper = jnp.concatenate(
[
jnp.array([jnp.max(E_t) * 10]),
jnp.full(self._n_modes, jnp.max(E_t) * 10),
jnp.full(self._n_modes, tau_hi_bound),
]
)
# Step 1: Fit with softmax penalty
def objective_step1(params):
"""Objective with softmax penalty."""
E_i = params[1 : 1 + self._n_modes]
residual = objective(params)
penalty = softmax_penalty(E_i, scale=1e-3)
return jnp.concatenate([residual, jnp.array([penalty])])
def _run_fit_relax(x_init):
return self._nlsq_fit(
objective_step1,
x_init,
bounds=(bounds_lower, bounds_upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
result_step1 = _run_fit_relax(x0)
# --- Multi-start: Prony fitting has many local minima because
# adjacent modes overlap in their contributions to E(t). Always
# perturb the initial guess a few times and keep the lowest-cost
# result. This is ~4x the cost of a single fit but eliminates
# seed-specific bad minima and Jacobian-ridge stalls at once.
best_result = result_step1
if initial_params is None and self._n_modes >= 2:
rng_retry = np.random.default_rng(0)
n_p = self._n_modes
total_E = float(jnp.max(E_t) - jnp.min(E_t))
base_E = np.asarray(E_i_guess)
base_tau = np.asarray(tau_i_guess)
for _attempt in range(4):
pert_E = rng_retry.uniform(0.3, 3.0, size=n_p)
pert_tau = 10.0 ** rng_retry.uniform(-0.5, 0.5, size=n_p)
E_init = jnp.asarray(
np.clip(
base_E * pert_E,
1e-6 * max(total_E, 1.0),
10.0 * max(total_E, 1.0),
)
)
tau_init = jnp.asarray(
np.clip(base_tau * pert_tau, tau_lo_bound, tau_hi_bound)
)
x_retry = jnp.concatenate([jnp.array([E_inf_guess]), E_init, tau_init])
try:
result_retry = _run_fit_relax(x_retry)
except Exception:
continue
if float(result_retry.cost) < float(best_result.cost):
best_result = result_retry
result_step1 = best_result
# Check for negative Eᵢ
params_opt = result_step1.x
E_i_opt = params_opt[1 : 1 + self._n_modes]
if jnp.any(E_i_opt < 0):
logger.warning(
"Negative Eᵢ detected in relaxation fit. Refitting with hard bounds."
)
# Step 2: Refit with hard bounds
result_step2 = self._nlsq_fit(
objective,
params_opt,
bounds=(bounds_lower, bounds_upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
result_final = result_step2
params_opt = result_final.x
else:
result_final = result_step1
# Store NLSQ result
self._nlsq_result = result_final
# Set fitted parameters (batch update for 5-10% speedup)
E_inf_opt = params_opt[0]
E_i_opt = params_opt[1 : 1 + self._n_modes]
tau_i_opt = params_opt[1 + self._n_modes :]
param_values = {f"{symbol}_inf": float(E_inf_opt)}
param_values.update(
{f"{symbol}_{i+1}": float(E_i_opt[i]) for i in range(self._n_modes)}
)
param_values.update(
{f"tau_{i+1}": float(tau_i_opt[i]) for i in range(self._n_modes)}
)
self.parameters.set_values(param_values)
# Element minimization
if optimization_factor is not None and self._n_modes > 1:
self._apply_element_minimization(t, E_t, optimization_factor, **kwargs)
def _apply_element_minimization(
self, X: np.ndarray, y: np.ndarray, optimization_factor: float, **kwargs
) -> None:
"""Apply element minimization with padded arrays to avoid JIT recompilation.
Performance optimization: eliminates JAX recompilation by keeping parameter
arrays at fixed N_max shape throughout the N-reduction loop. Inactive modes
are frozen via bounds (lower == upper) so they don't affect optimization.
Key insight: Setting E_i=0 for inactive modes naturally zeroes their
contribution in the additive Prony sum (0 * exp(-t/tau) = 0), so no
explicit masking is needed.
Args:
X: Independent variable (time or frequency)
y: Dependent variable (modulus or compliance)
- For relaxation/creep: 1D array of shape (M,)
- For oscillation: 1D concatenated array [G', G"] of shape (2*M,)
optimization_factor: R² threshold multiplier (e.g., 1.5 means N_opt where R²_N >= 1.5 * R²_min)
**kwargs: NLSQ optimizer arguments
"""
# OPT-YDATA-001: ensure y_data is stashed (may already be set by the
# caller, but be robust if called directly).
self._current_y_data = np.asarray(y)
# Store initial n_modes for diagnostics
n_max = self._n_modes
n_initial = n_max
# Extract NLSQ kwargs
max_iter = kwargs.get("max_iter", 1000)
ftol = kwargs.get("ftol", 1e-6)
xtol = kwargs.get("xtol", 1e-6)
gtol = kwargs.get("gtol", 1e-6)
symbol = "E" if self._modulus_type == "tensile" else "G"
# Convert data to JAX arrays (once)
X_jax = jnp.asarray(X)
y_jax = jnp.asarray(y)
# Compute data-based upper bound for moduli
E_max = float(jnp.max(jnp.abs(y_jax)) * 10)
# Select JIT prediction function based on test mode
# All prediction functions use E_i[:, None] broadcasting or jnp.sum(E_i * ...),
# so E_i=0 for inactive modes naturally contributes zero.
test_mode = self._test_mode
# Define padded objective function (always uses N_max-shaped arrays)
# This is JIT-compiled ONCE and reused for all n_active values.
if test_mode in ("relaxation",):
def objective(params):
E_inf = params[0]
E_i = params[1 : 1 + n_max]
tau_i = params[1 + n_max :]
pred = self._predict_relaxation_jit(X_jax, E_inf, E_i, tau_i)
return pred - y_jax
elif test_mode in ("oscillation", "laos"):
def objective(params):
E_inf = params[0]
E_i = params[1 : 1 + n_max]
tau_i = params[1 + n_max :]
pred = self._predict_oscillation_jit(X_jax, E_inf, E_i, tau_i)
return jnp.concatenate([pred[0], pred[1]]) - y_jax
elif test_mode == "creep":
def objective(params):
E_inf = params[0]
E_i = params[1 : 1 + n_max]
tau_i = params[1 + n_max :]
pred = self._predict_creep_jit(X_jax, E_inf, E_i, tau_i)
return pred - y_jax
elif test_mode == "startup":
gamma_dot = getattr(self, "_startup_gamma_dot", 1.0)
def objective(params):
E_inf = params[0]
E_i = params[1 : 1 + n_max]
tau_i = params[1 + n_max :]
pred = self._predict_startup_jit(X_jax, E_inf, E_i, tau_i, gamma_dot)
return pred - y_jax
else:
raise ValueError(
f"Element minimization not supported for test_mode: {test_mode}"
)
# Softmax penalty wrapper (also fixed shape)
def objective_step1(params):
E_i = params[1 : 1 + n_max]
residual = objective(params)
penalty = softmax_penalty(E_i, scale=1e-3)
return jnp.concatenate([residual, jnp.array([penalty])])
# Get current best params from the initial N_max fit
if self._nlsq_result is not None:
current_params = np.asarray(self._nlsq_result.x)
else:
E_inf = self.parameters.get_value(f"{symbol}_inf")
E_i = [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(n_max)]
tau_i = [self.parameters.get_value(f"tau_{i+1}") for i in range(n_max)]
current_params = np.array([E_inf] + E_i + tau_i)
# Iterative N reduction with padded arrays
fit_results: dict = {}
best_params = current_params.copy()
r2_max = None
r2_threshold = None
# Pre-compute base bounds arrays (all-active case) and inactive values.
# Only the active/inactive boundary changes per iteration, so we update
# slices in-place instead of rebuilding from scratch each time.
lower = np.zeros(2 * n_max + 1)
upper = np.zeros(2 * n_max + 1)
lower[0] = 0.0
upper[0] = E_max
# Start with all modes active
lower[1 : 1 + n_max] = 1e-12
upper[1 : 1 + n_max] = E_max
lower[1 + n_max :] = 1e-6
upper[1 + n_max :] = 1e6
for n_active in range(n_max, 0, -1):
try:
# Freeze modes beyond n_active.
# E_i bounds: inactive nearly frozen (NLSQ TRF requires lower < upper).
# E_i < 1e-30 Pa is effectively zero.
lower[1 + n_active : 1 + n_max] = 0.0
upper[1 + n_active : 1 + n_max] = 1e-30
# tau_i bounds: inactive nearly frozen around 1.0.
lower[1 + n_max + n_active :] = 1.0 - 1e-12
upper[1 + n_max + n_active :] = 1.0 + 1e-12
# Warm-start: zero out inactive modes from previous best
x0 = best_params.copy()
x0[1 + n_active : 1 + n_max] = 0.0 # Inactive E_i
x0[1 + n_max + n_active :] = 1.0 # Inactive tau_i
# Clamp active params to bounds
x0 = np.clip(x0, lower, upper)
# Step 1: Fit with softmax penalty
result = self._nlsq_fit(
objective_step1,
x0,
bounds=(lower, upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
# Check for negative E_i in active modes and refit if needed
params_opt = result.x
E_i_active = params_opt[1 : 1 + n_active]
if jnp.any(E_i_active < 0):
result = self._nlsq_fit(
objective,
params_opt,
bounds=(lower, upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
params_opt = result.x
# Compute prediction for R²
residual = np.asarray(objective(params_opt))
y_pred = np.asarray(y) + residual
r2_n = compute_r_squared(y, y_pred)
fit_results[n_active] = {
"r2": r2_n,
"params": params_opt.copy(),
"result": result,
}
best_params = params_opt.copy()
# Set R² threshold after first fit (highest N)
if r2_max is None:
r2_max = r2_n
degradation_room = 1.0 - r2_max
allowed_degradation = degradation_room * (optimization_factor - 1.0)
r2_threshold = r2_max - allowed_degradation
# Early termination: stop if R² falls below threshold
if r2_threshold is not None and r2_n < r2_threshold:
logger.info(
f"Element minimization: early termination at n_modes={n_active} "
f"(R²={r2_n:.6f} < threshold={r2_threshold:.6f})"
)
break
except (RuntimeError, ValueError) as e:
logger.warning(
f"Element minimization: fitting failed for n_modes={n_active}: {e}"
)
break
# Select optimal N
r2_values = {n: cast(float, result["r2"]) for n, result in fit_results.items()}
n_optimal = select_optimal_n(r2_values, optimization_factor=optimization_factor)
# Store diagnostics with all required keys
n_modes_list = sorted(r2_values.keys())
r2_list = [r2_values[n] for n in n_modes_list]
self._element_minimization_diagnostics = {
"n_initial": n_initial,
"r2": r2_list,
"n_modes": n_modes_list,
"n_optimal": n_optimal,
"optimization_factor": optimization_factor,
}
# Update model if optimal N is different
if n_optimal < self._n_modes:
logger.info(
f"Element minimization: reducing from {self._n_modes} to {n_optimal} modes"
)
# Extract active parameters from padded result
optimal_params = fit_results[n_optimal]["params"]
E_inf_opt = optimal_params[0]
E_i_opt = optimal_params[1 : 1 + n_optimal]
tau_i_opt = optimal_params[1 + n_max : 1 + n_max + n_optimal]
# Rebuild ParameterSet with n_optimal modes
self._n_modes = n_optimal
self.parameters = create_prony_parameter_set(
n_optimal, modulus_type=self._modulus_type
)
# Set fitted parameter values
param_values = {f"{symbol}_inf": float(E_inf_opt)}
param_values.update(
{f"{symbol}_{i+1}": float(E_i_opt[i]) for i in range(n_optimal)}
)
param_values.update(
{f"tau_{i+1}": float(tau_i_opt[i]) for i in range(n_optimal)}
)
self.parameters.set_values(param_values)
# Build slimmed-down NLSQ result for the optimal model
slim_x = np.concatenate([[E_inf_opt], E_i_opt, tau_i_opt])
optimal_result = fit_results[n_optimal]["result"]
self._nlsq_result = OptimizationResult(
x=slim_x,
fun=optimal_result.fun,
jac=None,
success=optimal_result.success,
message=optimal_result.message,
nit=optimal_result.nit,
nfev=optimal_result.nfev,
njev=optimal_result.njev,
optimality=optimal_result.optimality,
active_mask=None,
cost=optimal_result.cost,
grad=None,
nlsq_result=optimal_result.nlsq_result,
# OPT-YDATA-001: forward y_data so r_squared is computable on
# the slimmed (post-element-minimization) result too.
residuals=getattr(optimal_result, "residuals", None),
y_data=getattr(optimal_result, "y_data", None),
n_data=getattr(optimal_result, "n_data", None),
)
def _fit_oscillation_mode(
self,
omega: np.ndarray,
E_star: np.ndarray,
optimization_factor: float | None = 1.5,
initial_params: np.ndarray | None = None,
**kwargs,
) -> None:
"""Fit GMM to complex modulus data.
Args:
omega: Angular frequency array
E_star: Complex modulus [E', E"] - can be (2, M) or (M, 2)
optimization_factor: R² threshold multiplier for element minimization
initial_params: Optional initial parameter guess for warm-start
Shape: (2*n_modes + 1,) [E_inf, E_1...E_N, tau_1...tau_N]
If None, uses default heuristic initialization
**kwargs: NLSQ optimizer arguments
"""
# OPT-YDATA-001: stash y_data so _nlsq_fit attaches it for r_squared.
# E_star may be complex or (M,2); attach as-is, r_squared handles both.
self._current_y_data = np.asarray(E_star)
# Extract kwargs
max_iter = kwargs.get("max_iter", 1000)
ftol = kwargs.get("ftol", 1e-6)
xtol = kwargs.get("xtol", 1e-6)
gtol = kwargs.get("gtol", 1e-6)
use_log_residuals = kwargs.get("use_log_residuals", False)
symbol = "E" if self._modulus_type == "tensile" else "G"
# Standardize input shape to (2, M)
E_star = np.asarray(E_star)
if E_star.ndim == 1:
if np.iscomplexobj(E_star):
# Handle complex 1D array [G*_1, G*_2, ..., G*_M]
E_prime = np.real(E_star)
E_double_prime = np.imag(E_star)
else:
# Handle 1D concatenated [G', G"] from element minimization
M = len(E_star) // 2
E_prime = np.real(E_star[:M])
E_double_prime = np.real(E_star[M:])
elif E_star.shape[0] == 2:
# Input is (2, M), extract directly
E_prime = np.real(E_star[0])
E_double_prime = np.real(E_star[1]) # FIX: Added missing assignment
elif E_star.shape[1] == 2:
# Input is (M, 2), transpose to (2, M)
E_prime = np.real(E_star[:, 0])
E_double_prime = np.real(E_star[:, 1])
else:
raise ValueError(
f"E_star must have shape (2, M), (M, 2), or be 1D concatenated [G', G\"], got {E_star.shape}"
)
# Precompute log observations for log-residual mode so we avoid a
# jnp.log10 call on every optimizer iteration.
_log_Ep = jnp.log10(jnp.maximum(jnp.asarray(E_prime), 1e-30))
_log_Epp = jnp.log10(jnp.maximum(jnp.asarray(E_double_prime), 1e-30))
# Per-component scalar normalization for the linear-residual mode.
# We divide residuals by the RMS of each component so that E' and E''
# contribute with balanced weight regardless of their absolute
# magnitudes. RMS is preferred over max(|obs|) because it is robust
# to outliers, and over per-point |obs| because per-point division
# amplifies noise near the low-magnitude tails of E''.
_Ep_rms = jnp.sqrt(jnp.mean(jnp.asarray(E_prime) ** 2))
_Epp_rms = jnp.sqrt(jnp.mean(jnp.asarray(E_double_prime) ** 2))
_Ep_scale = jnp.maximum(_Ep_rms, jnp.float64(1e-12))
_Epp_scale = jnp.maximum(_Epp_rms, jnp.float64(1e-12))
# Define objective function
def objective(params):
"""Residual for complex modulus.
Uses log-space residuals when ``use_log_residuals`` is set, which
is essential when E'(ω) and E''(ω) span many decades. Otherwise
uses *relative* residuals (pred−obs)/|obs| so that E' and E''
contribute with balanced weight regardless of their absolute
magnitudes. Absolute residuals (the old default) let the glassy
plateau of E' dominate the sum-of-squares and ignore E''.
"""
E_inf = params[0]
E_i = params[1 : 1 + self._n_modes]
tau_i = params[1 + self._n_modes :]
# Predict complex modulus (returns (2, M))
E_star_pred = self._predict_oscillation_jit(
jnp.asarray(omega), E_inf, E_i, tau_i
)
E_prime_pred = E_star_pred[0] # Extract G' from (2, M)
E_double_prime_pred = E_star_pred[1] # Extract G" from (2, M)
if use_log_residuals:
resid_p = jnp.log10(jnp.maximum(E_prime_pred, 1e-30)) - _log_Ep
resid_pp = jnp.log10(jnp.maximum(E_double_prime_pred, 1e-30)) - _log_Epp
else:
resid_p = (E_prime_pred - E_prime) / _Ep_scale
resid_pp = (E_double_prime_pred - E_double_prime) / _Epp_scale
return jnp.concatenate([resid_p, resid_pp])
# Derive tau range from the observed frequency window (τ ≈ 1/ω).
# Pad ±2 decades beyond the data so optimizer can reach boundary modes.
omega_np = np.asarray(omega)
omega_pos = omega_np[omega_np > 0]
if omega_pos.size > 0:
log_tau_lo_data = float(-np.log10(omega_pos.max()))
log_tau_hi_data = float(-np.log10(omega_pos.min()))
else:
log_tau_lo_data, log_tau_hi_data = -2.0, 2.0
tau_lo_bound = max(10.0 ** (log_tau_lo_data - 2.0), 1e-30)
tau_hi_bound = 10.0 ** (log_tau_hi_data + 2.0)
if self._n_modes == 1:
tau_i_guess = jnp.array(
[10.0 ** (0.5 * (log_tau_lo_data + log_tau_hi_data))]
)
else:
tau_i_guess = jnp.logspace(log_tau_lo_data, log_tau_hi_data, self._n_modes)
# Always compute derivative-based heuristic guesses so the multi-start
# retry block below can use them even when the caller supplied
# ``initial_params``.
E_inf_guess = jnp.min(E_prime) # Low-frequency plateau
E_sum_guess = jnp.max(E_prime) - E_inf_guess
# Seed each E_i from the local storage-modulus derivative
# −dE'/d(ln ω) evaluated at each τ_k (since 1/τ_k ≈ ω_k).
omega_sorted_idx = np.argsort(omega_np)
omega_sorted = omega_np[omega_sorted_idx]
Ep_sorted = np.asarray(E_prime)[omega_sorted_idx]
if omega_sorted.size >= 2 and self._n_modes > 1:
log_omega = np.log(np.maximum(omega_sorted, 1e-30))
dEp_dlogw = np.gradient(Ep_sorted, log_omega)
tau_np = np.asarray(tau_i_guess)
omega_at_tau = 1.0 / np.clip(tau_np, 1e-30, None)
contrib = np.interp(
np.log(np.clip(omega_at_tau, omega_sorted[0], omega_sorted[-1])),
log_omega,
dEp_dlogw,
)
contrib = np.clip(contrib, 1e-6, None)
contrib_sum = float(contrib.sum())
total = float(E_sum_guess)
if contrib_sum > 0 and total > 0:
E_i_guess = jnp.asarray(contrib * (total / contrib_sum))
else:
E_i_guess = jnp.full(self._n_modes, E_sum_guess / max(self._n_modes, 1))
else:
E_i_guess = jnp.full(self._n_modes, E_sum_guess / max(self._n_modes, 1))
if initial_params is not None:
x0 = jnp.asarray(initial_params)
else:
x0 = jnp.concatenate([jnp.array([E_inf_guess]), E_i_guess, tau_i_guess])
# Parameter bounds — data-derived tau range so master curves spanning
# many decades stay inside the box.
bounds_lower = jnp.concatenate(
[
jnp.array([0.0]),
jnp.full(self._n_modes, 1e-12),
jnp.full(self._n_modes, tau_lo_bound),
]
)
bounds_upper = jnp.concatenate(
[
jnp.array([jnp.max(E_prime) * 10]),
jnp.full(self._n_modes, jnp.max(E_prime) * 10),
jnp.full(self._n_modes, tau_hi_bound),
]
)
# Step 1: Fit with softmax penalty
def objective_step1(params):
"""Objective with softmax penalty."""
E_i = params[1 : 1 + self._n_modes]
residual = objective(params)
penalty = softmax_penalty(E_i, scale=1e-3)
return jnp.concatenate([residual, jnp.array([penalty])])
def _run_fit(x_init):
return self._nlsq_fit(
objective_step1,
x_init,
bounds=(bounds_lower, bounds_upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
result_step1 = _run_fit(x0)
# --- Multi-start: Prony fitting has many local minima because
# adjacent modes overlap in their contributions to E*(ω). Always
# perturb the initial guess a few times and keep the lowest-cost
# result. Eliminates both seed-specific bad minima and
# Jacobian-ridge stalls at once.
best_result = result_step1
if initial_params is None and self._n_modes >= 2:
rng_retry = np.random.default_rng(0)
n_p = self._n_modes
total_E = float(jnp.max(E_prime) - jnp.min(E_prime))
base_E = np.asarray(E_i_guess)
base_tau = np.asarray(tau_i_guess)
for _attempt in range(4):
pert_E = rng_retry.uniform(0.3, 3.0, size=n_p)
pert_tau = 10.0 ** rng_retry.uniform(-0.5, 0.5, size=n_p)
E_init = jnp.asarray(
np.clip(base_E * pert_E, 1e-6 * total_E, 10.0 * total_E)
)
tau_init = jnp.asarray(
np.clip(base_tau * pert_tau, tau_lo_bound, tau_hi_bound)
)
x_retry = jnp.concatenate([jnp.array([E_inf_guess]), E_init, tau_init])
try:
result_retry = _run_fit(x_retry)
except Exception:
continue
if float(result_retry.cost) < float(best_result.cost):
best_result = result_retry
result_step1 = best_result
# Check for negative Eᵢ
params_opt = result_step1.x
E_i_opt = params_opt[1 : 1 + self._n_modes]
if jnp.any(E_i_opt < 0):
logger.warning(
"Negative Eᵢ detected in oscillation fit. Refitting with hard bounds."
)
# Step 2: Refit with hard bounds
result_step2 = self._nlsq_fit(
objective,
params_opt,
bounds=(bounds_lower, bounds_upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
result_final = result_step2
params_opt = result_final.x
else:
result_final = result_step1
# Store NLSQ result
self._nlsq_result = result_final
# Set fitted parameters (batch update for 5-10% speedup)
E_inf_opt = params_opt[0]
E_i_opt = params_opt[1 : 1 + self._n_modes]
tau_i_opt = params_opt[1 + self._n_modes :]
param_values = {f"{symbol}_inf": float(E_inf_opt)}
param_values.update(
{f"{symbol}_{i+1}": float(E_i_opt[i]) for i in range(self._n_modes)}
)
param_values.update(
{f"tau_{i+1}": float(tau_i_opt[i]) for i in range(self._n_modes)}
)
self.parameters.set_values(param_values)
# Element minimization
if optimization_factor is not None and self._n_modes > 1:
# Reconstruct combined data for minimization (flatten to 1D)
combined_data = np.concatenate([E_prime, E_double_prime])
self._apply_element_minimization(
omega, combined_data, optimization_factor, **kwargs
)
def _fit_creep_mode(
self,
t: np.ndarray,
J_t: np.ndarray,
optimization_factor: float | None = 1.5,
initial_params: np.ndarray | None = None,
**kwargs,
) -> None:
"""Fit GMM to creep compliance data.
Args:
t: Time array
J_t: Creep compliance array
optimization_factor: R² threshold multiplier for element minimization
initial_params: Optional initial parameter guess for warm-start
Shape: (2*n_modes + 1,) [J_0, J_1...J_N, tau_1...tau_N]
If None, uses default heuristic initialization
**kwargs: NLSQ optimizer arguments
"""
# OPT-YDATA-001: stash y_data so _nlsq_fit attaches it for r_squared.
self._current_y_data = np.asarray(J_t)
# Extract kwargs
max_iter = kwargs.get("max_iter", 1000)
ftol = kwargs.get("ftol", 1e-6)
xtol = kwargs.get("xtol", 1e-6)
gtol = kwargs.get("gtol", 1e-6)
symbol = "E" if self._modulus_type == "tensile" else "G"
# Define objective function (predict creep from GMM simulation)
def objective(params):
"""Residual for creep compliance."""
E_inf = params[0]
E_i = params[1 : 1 + self._n_modes]
tau_i = params[1 + self._n_modes :]
# Predict creep compliance via GMM simulation
# Apply step stress σ₀ = 1, solve for strain ε(t), compute J(t) = ε(t)/σ₀
J_pred = self._predict_creep_internal(t, E_inf, E_i, tau_i)
return J_pred - J_t
# Compute data-based bounds (needed regardless of warm-start)
J_0 = jnp.min(J_t) # Initial compliance (instant response)
J_inf = jnp.max(J_t) # Final compliance (long-time)
# Initial parameter guess (warm-start if provided, else default heuristic)
if initial_params is not None:
x0 = jnp.asarray(initial_params)
else:
# For creep: J_0 = 1/(E_∞ + ΣEᵢ), J_∞ = 1/E_∞
# E_∞ corresponds to long-time equilibrium: J_∞ = 1/E_∞
E_inf_guess = 1.0 / J_inf
# Total instant modulus: J_0 = 1/(E_∞ + ΣEᵢ)
E_total_guess = 1.0 / J_0
E_sum_guess = max(E_total_guess - E_inf_guess, 1e-12)
E_i_guess = jnp.full(self._n_modes, E_sum_guess / self._n_modes)
tau_i_guess = jnp.logspace(-2, 2, self._n_modes)
x0 = jnp.concatenate(
[jnp.array([max(E_inf_guess, 1e-12)]), E_i_guess, tau_i_guess]
)
# Parameter bounds
bounds_lower = jnp.concatenate(
[
jnp.array([0.0]),
jnp.full(self._n_modes, 1e-12),
jnp.full(self._n_modes, 1e-6),
]
)
bounds_upper = jnp.concatenate(
[
jnp.array([1.0 / J_0 * 10]),
jnp.full(self._n_modes, 1.0 / J_0 * 10),
jnp.full(self._n_modes, 1e6),
]
)
# Step 1: Fit with softmax penalty
def objective_step1(params):
"""Objective with softmax penalty."""
E_i = params[1 : 1 + self._n_modes]
residual = objective(params)
penalty = softmax_penalty(E_i, scale=1e-3)
return jnp.concatenate([residual, jnp.array([penalty])])
result_step1 = self._nlsq_fit(
objective_step1,
x0,
bounds=(bounds_lower, bounds_upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
# Check for negative Eᵢ
params_opt = result_step1.x
E_i_opt = params_opt[1 : 1 + self._n_modes]
if jnp.any(E_i_opt < 0):
logger.warning(
"Negative Eᵢ detected in creep fit. Refitting with hard bounds."
)
# Step 2: Refit with hard bounds
result_step2 = self._nlsq_fit(
objective,
params_opt,
bounds=(bounds_lower, bounds_upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
result_final = result_step2
params_opt = result_final.x
else:
result_final = result_step1
# Store NLSQ result
self._nlsq_result = result_final
# Set fitted parameters (batch update for 5-10% speedup)
E_inf_opt = params_opt[0]
E_i_opt = params_opt[1 : 1 + self._n_modes]
tau_i_opt = params_opt[1 + self._n_modes :]
param_values = {f"{symbol}_inf": float(E_inf_opt)}
param_values.update(
{f"{symbol}_{i+1}": float(E_i_opt[i]) for i in range(self._n_modes)}
)
param_values.update(
{f"tau_{i+1}": float(tau_i_opt[i]) for i in range(self._n_modes)}
)
self.parameters.set_values(param_values)
# Element minimization
if optimization_factor is not None and self._n_modes > 1:
self._apply_element_minimization(t, J_t, optimization_factor, **kwargs)
def _predict_creep_internal(
self,
t: np.ndarray | jnp_typing.ndarray,
E_inf: float,
E_i: jnp_typing.ndarray,
tau_i: jnp_typing.ndarray,
sigma_0: float = 1.0,
) -> jnp_typing.ndarray:
"""Internal creep prediction for optimization.
Args:
t: Time array
E_inf: Equilibrium modulus
E_i: Prony coefficients (N,)
tau_i: Relaxation times (N,)
sigma_0: Applied stress (default 1.0)
Returns:
Creep compliance J(t)
"""
# Call JIT-compiled creep prediction
J_t = self._predict_creep_jit(jnp.asarray(t), E_inf, E_i, tau_i, sigma_0)
return J_t
def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
"""Predict based on fitted test mode.
Args:
X: Independent variable (time or frequency)
**kwargs: Additional arguments (test_mode handled via self._test_mode)
Returns:
Predicted values (modulus or compliance)
Raises:
ValueError: If test_mode not set (model not fitted)
"""
_kw_mode = kwargs.get("test_mode")
test_mode = _kw_mode if _kw_mode is not None else self._test_mode
if test_mode is None:
raise ValueError("Model not fitted. Call fit() first.")
# Normalize test_mode to string
if hasattr(test_mode, "value"):
test_mode = test_mode.value
# Route to appropriate prediction method
if test_mode == "relaxation":
return self._predict_relaxation(X)
elif test_mode == "oscillation":
return self._predict_oscillation(X)
elif test_mode == "creep":
return self._predict_creep(X)
elif test_mode in ("steady_shear", "flow_curve"):
return self._predict_steady_shear(X)
elif test_mode == "startup":
return self._predict_startup(X)
elif test_mode == "laos":
return self._predict_laos(X)
else:
raise ValueError(f"Unknown test_mode: {test_mode}")
@staticmethod
@jax.jit
def _predict_relaxation_jit(
t: jnp_typing.ndarray,
E_inf: float,
E_i: jnp_typing.ndarray,
tau_i: jnp_typing.ndarray,
) -> jnp_typing.ndarray:
"""JIT-compiled relaxation prediction.
Args:
t: Time array
E_inf: Equilibrium modulus
E_i: Prony coefficients (N,)
tau_i: Relaxation times (N,)
Returns:
Relaxation modulus E(t)
"""
# E(t) = E_∞ + Σᵢ Eᵢ exp(-t/τᵢ)
E_t = E_inf + jnp.sum(
E_i[:, None] * jnp.exp(-t[None, :] / tau_i[:, None]), axis=0
)
return E_t
def _predict_relaxation(self, t: np.ndarray | jnp_typing.ndarray) -> np.ndarray:
"""Predict relaxation modulus E(t).
Args:
t: Time array
Returns:
Relaxation modulus array
"""
symbol = "E" if self._modulus_type == "tensile" else "G"
# Extract parameters
E_inf = self.parameters.get_value(f"{symbol}_inf")
E_i = jnp.array(
[self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)]
)
tau_i = jnp.array(
[self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)]
)
# Convert input to JAX array
t_jax = jnp.asarray(t)
# Call JIT-compiled prediction
E_t = self._predict_relaxation_jit(t_jax, E_inf, E_i, tau_i)
return np.asarray(E_t)
@staticmethod
@jax.jit
def _predict_oscillation_jit(
omega: jnp_typing.ndarray,
E_inf: float,
E_i: jnp_typing.ndarray,
tau_i: jnp_typing.ndarray,
) -> jnp_typing.ndarray:
"""JIT-compiled oscillation prediction.
Args:
omega: Angular frequency array
E_inf: Equilibrium modulus
E_i: Prony coefficients (N,)
tau_i: Relaxation times (N,)
Returns:
Complex modulus [E', E"] (2, M)
"""
# Closed-form Fourier transform
omega_tau = omega[None, :] * tau_i[:, None]
omega_tau_sq = omega_tau**2
# E'(ω) = E_∞ + Σᵢ Eᵢ (ωτᵢ)²/(1+(ωτᵢ)²)
E_prime = E_inf + jnp.sum(
E_i[:, None] * omega_tau_sq / (1 + omega_tau_sq), axis=0
)
# E"(ω) = Σᵢ Eᵢ (ωτᵢ)/(1+(ωτᵢ)²)
E_double_prime = jnp.sum(E_i[:, None] * omega_tau / (1 + omega_tau_sq), axis=0)
# Return as (2, M) for standard complex modulus convention
return jnp.stack([E_prime, E_double_prime], axis=0)
def _predict_oscillation(
self, omega: np.ndarray | jnp_typing.ndarray
) -> np.ndarray:
"""Predict complex modulus in oscillation mode.
Args:
omega: Angular frequency array
Returns:
Complex modulus G* = G' + iG'' (or E* for tensile)
"""
symbol = "E" if self._modulus_type == "tensile" else "G"
# Extract parameters
E_inf = self.parameters.get_value(f"{symbol}_inf")
E_i = jnp.array(
[self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)]
)
tau_i = jnp.array(
[self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)]
)
# Convert input to JAX array
omega_jax = jnp.asarray(omega)
# Call JIT-compiled prediction (returns (2, M))
E_star = self._predict_oscillation_jit(omega_jax, E_inf, E_i, tau_i)
# Return as complex G* = G' + iG'' (consistent with all other models)
E_prime = np.asarray(E_star[0])
E_double_prime = np.asarray(E_star[1])
return E_prime + 1j * E_double_prime
@staticmethod
@jax.jit
def _predict_creep_jit(
t: jnp_typing.ndarray,
E_inf: float,
E_i: jnp_typing.ndarray,
tau_i: jnp_typing.ndarray,
sigma_0: float = 1.0,
) -> jnp_typing.ndarray:
"""JIT-compiled creep prediction via backward-Euler.
Args:
t: Time array
E_inf: Equilibrium modulus
E_i: Prony coefficients (N,)
tau_i: Relaxation times (N,)
sigma_0: Applied stress (default 1.0)
Returns:
Creep compliance J(t)
"""
# Backward-Euler scheme for unconditional stability
# GMM ODEs: dσᵢ/dt = -σᵢ/τᵢ + Eᵢ dε/dt
# Total stress: σ = E_∞ ε + Σᵢ σᵢ
# Apply step stress σ₀, solve for ε(t), compute J(t) = ε(t)/σ₀
n_steps = len(t)
n_modes = len(E_i)
# Initialize arrays
epsilon = jnp.zeros(n_steps)
# Time step (assume uniform spacing for now, handle variable later)
dt = jnp.diff(t, prepend=0.0)
# Backward-Euler update loop
def update_step(carry, inputs):
"""Update internal variables and strain."""
eps_prev, sig_i_prev = carry
t_curr, dt_curr = inputs
# Protect against zero dt at first step
dt_safe = jnp.maximum(dt_curr, 1e-12)
# Solve for new strain from total stress balance
# σ₀ = E_∞ εⁿ⁺¹ + Σᵢ σᵢⁿ⁺¹
# σᵢⁿ⁺¹ = (σᵢⁿ + Eᵢ Δε) / (1 + Δt/τᵢ)
# Substitute and solve for Δε
# Coefficients for backward-Euler
alpha_i = jnp.exp(-dt_safe / tau_i) # Exact exponential integration
beta_i = E_i * tau_i * (1 - alpha_i) / dt_safe
# Total effective modulus
E_eff = E_inf + jnp.sum(beta_i)
# Solve for strain increment
stress_from_prev = jnp.sum(alpha_i * sig_i_prev)
d_eps = (sigma_0 - stress_from_prev) / E_eff
eps_new = eps_prev + d_eps
# Update internal stresses
sig_i_new = alpha_i * sig_i_prev + beta_i * d_eps
return (eps_new, sig_i_new), eps_new
# Initialize
eps_init = 0.0
sig_i_init = jnp.zeros(n_modes)
# Scan over time steps
_, epsilon = jax.lax.scan(update_step, (eps_init, sig_i_init), (t, dt))
# Compute compliance
J_t = epsilon / sigma_0
return J_t
def _predict_creep(self, t: np.ndarray | jnp_typing.ndarray) -> np.ndarray:
"""Predict creep compliance J(t).
Args:
t: Time array
Returns:
Creep compliance array
"""
symbol = "E" if self._modulus_type == "tensile" else "G"
# Extract parameters
E_inf = self.parameters.get_value(f"{symbol}_inf")
E_i = jnp.array(
[self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)]
)
tau_i = jnp.array(
[self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)]
)
# Convert input to JAX array
t_jax = jnp.asarray(t)
# Call JIT-compiled prediction
J_t = self._predict_creep_jit(t_jax, E_inf, E_i, tau_i, sigma_0=1.0)
return np.asarray(J_t)
def _extract_nlsq_diagnostics(self, nlsq_result) -> dict:
"""Extract diagnostics from NLSQ OptimizationResult.
Args:
nlsq_result: OptimizationResult from nlsq_optimize()
Returns:
Dictionary with diagnostic metrics
"""
# Extract convergence flag
convergence_flag = nlsq_result.success
# Extract gradient norm (optimality metric)
gradient_norm = (
nlsq_result.optimality if nlsq_result.optimality is not None else np.inf
)
# Estimate Hessian condition number from Jacobian
# For least-squares: Hessian ≈ J^T J
if nlsq_result.jac is not None:
jac = np.asarray(nlsq_result.jac)
# Compute approximate Hessian
hessian_approx = jac.T @ jac
# Compute condition number (ratio of largest/smallest singular values)
try:
cond_number = np.linalg.cond(hessian_approx)
except np.linalg.LinAlgError:
cond_number = np.inf
else:
cond_number = np.inf
# Estimate parameter uncertainties from diagonal of covariance matrix
# Cov ≈ inv(J^T J) if well-conditioned
param_uncertainties = {}
symbol = "E" if self._modulus_type == "tensile" else "G"
if nlsq_result.jac is not None and cond_number < 1e10:
try:
# Compute covariance matrix
cov_matrix = np.linalg.inv(hessian_approx)
std_devs = np.sqrt(np.abs(np.diag(cov_matrix)))
# Map to parameter names
param_names = [f"{symbol}_inf"]
param_names += [f"{symbol}_{i+1}" for i in range(self._n_modes)]
param_names += [f"tau_{i+1}" for i in range(self._n_modes)]
for i, name in enumerate(param_names):
if i < len(std_devs):
param_uncertainties[name] = float(std_devs[i])
except (np.linalg.LinAlgError, ValueError):
# Covariance matrix computation failed
pass
# Check proximity to bounds
params_near_bounds = {}
for param_name in self.parameters.keys():
value = self.parameters.get_value(param_name)
assert value is not None
param = self.parameters.get(param_name)
assert param is not None
bounds = param.bounds
assert bounds is not None
lower, upper = bounds
# Check if within 10% of bounds
bound_range = upper - lower
if abs(value - lower) < 0.1 * bound_range:
params_near_bounds[param_name] = "lower"
elif abs(value - upper) < 0.1 * bound_range:
params_near_bounds[param_name] = "upper"
return {
"convergence_flag": convergence_flag,
"gradient_norm": gradient_norm,
"hessian_condition": cond_number,
"param_uncertainties": param_uncertainties,
"params_near_bounds": params_near_bounds,
}
def _classify_nlsq_convergence(self, diagnostics: dict) -> str:
"""Classify NLSQ convergence quality.
Args:
diagnostics: Dictionary from _extract_nlsq_diagnostics()
Returns:
Classification: 'hard_failure', 'suspicious', or 'good'
"""
# Hard failure conditions
if not diagnostics["convergence_flag"]:
return "hard_failure"
# GMM-specific: High Hessian condition and params near bounds are often acceptable
# Only classify as suspicious if BOTH conditions are true AND uncertainties are high
# Check if any uncertainties are > 100% of parameter value (very unreliable)
high_uncertainty_count = 0
for param_name, std_dev in diagnostics["param_uncertainties"].items():
value = self.parameters.get_value(param_name)
assert value is not None
if abs(value) > 1e-12 and std_dev / abs(value) > 1.0:
high_uncertainty_count += 1
# Suspicious if: (high condition OR many params near bounds) AND high uncertainties
if (
high_uncertainty_count > self._n_modes
): # More than half the parameters are highly uncertain
if (
diagnostics["hessian_condition"] > 1e10
or len(diagnostics["params_near_bounds"]) > self._n_modes
):
return "suspicious"
# Good convergence if optimizer says so
return "good"
def _construct_bayesian_priors(
self,
classification: str,
prior_mode: str = "warn",
allow_fallback_priors: bool = False,
) -> dict:
"""Construct Bayesian priors based on NLSQ convergence classification.
Args:
classification: 'hard_failure', 'suspicious', or 'good'
prior_mode: 'strict', 'warn', or 'auto_widen'
allow_fallback_priors: Enable generic priors on hard failure
Returns:
Dictionary of priors for NumPyro: {param_name: {'mean': float, 'std': float}}
Raises:
ValueError: If hard failure and prior_mode='strict' or allow_fallback_priors=False
"""
priors = {}
if classification == "hard_failure":
# Hard failure: raise error or use fallback priors
if prior_mode == "strict" or not allow_fallback_priors:
raise ValueError(
"NLSQ optimization failed or did not converge properly. "
"Cannot construct reliable priors from failed fit. "
"Please:\n"
" 1. Check model suitability for your data\n"
" 2. Adjust initial values or bounds\n"
" 3. Increase max_iter if optimization terminated early\n"
" 4. Provide manual priors via fit_bayesian(priors={...})\n"
" 5. Set allow_fallback_priors=True for generic weakly informative priors (not recommended)"
)
# Fallback: generic weakly informative priors
warnings.warn(
"WARNING: NLSQ optimization failed. Using generic weakly informative priors. "
"Results may not be reliable. Consider manual prior specification.",
UserWarning,
stacklevel=2,
)
# Use parameter bounds as guides for generic priors
for param_name in self.parameters.keys():
param = self.parameters.get(param_name)
assert param is not None
bounds = param.bounds
assert bounds is not None
lower, upper = bounds
mean = (lower + upper) / 2
std = (upper - lower) / 4 # Wide prior covering ~95% of bounds
priors[param_name] = {"mean": mean, "std": std}
elif classification == "suspicious":
# Suspicious: use safer priors, optionally widen
if prior_mode == "auto_widen":
warnings.warn(
"Suspicious NLSQ convergence detected (high Hessian condition, params near bounds, or high uncertainty). "
"Using inflated priors centered at NLSQ estimates.",
UserWarning,
stacklevel=2,
)
# Center at NLSQ, inflate std
for param_name in self.parameters.keys():
value = self.parameters.get_value(param_name)
assert value is not None
param = self.parameters.get(param_name)
assert param is not None
bounds = param.bounds
assert bounds is not None
lower, upper = bounds
# Inflate std to 50% of estimate or 10% of bounds, whichever is larger
std_from_estimate = 0.5 * abs(value)
std_from_bounds = 0.1 * (upper - lower)
std = max(std_from_estimate, std_from_bounds)
priors[param_name] = {"mean": value, "std": std}
else:
# Warn mode: decouple from Hessian, use wider priors
logger.warning(
"Suspicious NLSQ convergence. Using safer priors decoupled from Hessian."
)
for param_name in self.parameters.keys():
value = self.parameters.get_value(param_name)
assert value is not None
param = self.parameters.get(param_name)
assert param is not None
bounds = param.bounds
assert bounds is not None
lower, upper = bounds
# Use 20% of bounds range as std
std = 0.2 * (upper - lower)
priors[param_name] = {"mean": value, "std": std}
else: # Good convergence
# Use NLSQ estimates and covariance for prior construction
diagnostics = self._extract_nlsq_diagnostics(self._nlsq_result)
for param_name in self.parameters.keys():
value = self.parameters.get_value(param_name)
assert value is not None
# Get uncertainty from Hessian if available
if param_name in diagnostics["param_uncertainties"]:
std = diagnostics["param_uncertainties"][param_name]
# Cap minimum std to avoid delta-like distributions
min_std = 0.01 * abs(value) if abs(value) > 1e-12 else 1e-6
std = max(std, min_std)
else:
# Fallback: use 5% of parameter value or 5% of bounds
param = self.parameters.get(param_name)
assert param is not None
bounds = param.bounds
assert bounds is not None
lower, upper = bounds
std = max(0.05 * abs(value), 0.05 * (upper - lower))
priors[param_name] = {"mean": value, "std": std}
return priors
[docs]
def get_relaxation_spectrum(self) -> dict:
"""Get discrete relaxation spectrum (E_i, τ_i).
Returns:
Dictionary with 'E_inf', 'E_i', 'tau_i'
"""
symbol = "E" if self._modulus_type == "tensile" else "G"
E_inf = self.parameters.get_value(f"{symbol}_inf")
E_i = np.array(
[self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)]
)
tau_i = np.array(
[self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)]
)
return {f"{symbol}_inf": E_inf, f"{symbol}_i": E_i, "tau_i": tau_i}
[docs]
def get_element_minimization_diagnostics(self) -> dict | None:
"""Get element minimization diagnostics.
Returns:
Dictionary with .n_initial., .r2., .n_modes., .n_optimal., .optimization_factor. or None if not run
"""
return self._element_minimization_diagnostics
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference with NumPyro NUTS.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes GMM predictions given input X and a parameter array.
Args:
X: Independent variable (time or frequency)
params: Array of parameter values [E_inf, E_1, ..., E_N, tau_1, ..., tau_N]
Length: 1 + 2*n_modes
Returns:
Model predictions as JAX array
Note:
Uses self._test_mode (set during fit()) to route to appropriate prediction method.
For oscillation mode, returns complex modulus [G', G"] with shape (M, 2).
"""
# Extract parameters from array
E_inf = params[0]
E_i = params[1 : 1 + self._n_modes]
tau_i = params[1 + self._n_modes :]
# Use stored test mode from last fit
if test_mode is None:
test_mode = getattr(self, "_test_mode", "relaxation")
# Route to appropriate prediction method
if test_mode == "relaxation":
return self._predict_relaxation_jit(jnp.asarray(X), E_inf, E_i, tau_i)
elif test_mode == "oscillation":
# _predict_oscillation_jit returns (2, M); transpose to (M, 2)
E_star = self._predict_oscillation_jit(jnp.asarray(X), E_inf, E_i, tau_i)
return E_star.T
elif test_mode == "creep":
return self._predict_creep_jit(
jnp.asarray(X), E_inf, E_i, tau_i, sigma_0=1.0
)
elif test_mode == "steady_shear":
return self._predict_steady_shear_jit(E_inf, E_i, tau_i)
elif test_mode == "startup":
gamma_dot = kwargs.get(
"gamma_dot", getattr(self, "_startup_gamma_dot", 1.0)
)
return self._predict_startup_jit(
jnp.asarray(X), E_inf, E_i, tau_i, gamma_dot
)
elif test_mode == "laos":
omega = kwargs.get("omega", getattr(self, "_laos_omega", 1.0))
gamma_0 = kwargs.get("gamma_0", getattr(self, "_laos_gamma_0", 0.01))
return self._predict_laos_jit(
jnp.asarray(X), E_inf, E_i, tau_i, omega, gamma_0
)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
# =========================================================================
# Steady-State Flow Protocol
# =========================================================================
def _fit_steady_shear_mode(
self,
gamma_dot: np.ndarray,
eta: np.ndarray,
optimization_factor: float | None = None,
**kwargs,
) -> None:
"""Fit GMM to steady-shear viscosity data.
For a linear viscoelastic model, steady-state viscosity is constant:
η₀ = Σᵢ Gᵢτᵢ (zero-shear viscosity)
Since GMM is linear, it cannot capture shear-thinning. This fit finds
parameters that best match the given viscosity data by using the
zero-shear viscosity relationship.
Args:
gamma_dot: Shear rate array (1/s)
eta: Viscosity array (Pa.s)
optimization_factor: Not used (no element minimization for steady-shear)
**kwargs: NLSQ optimizer arguments
"""
# For linear viscoelastic model, η = η₀ = Σᵢ Gᵢτᵢ (constant)
# Fit by matching average viscosity
eta_avg = np.mean(eta)
symbol = "G" if self._modulus_type == "shear" else "E"
# Initialize with simple estimate: distribute η₀ across modes
eta_per_mode = eta_avg / self._n_modes
tau_i_guess = np.logspace(-2, 2, self._n_modes)
G_i_guess = eta_per_mode / tau_i_guess
# Set parameters
self.parameters.set_value(
f"{symbol}_inf", 0.0
) # No equilibrium modulus for flow
for i in range(self._n_modes):
self.parameters.set_value(f"{symbol}_{i+1}", float(G_i_guess[i]))
self.parameters.set_value(f"tau_{i+1}", float(tau_i_guess[i]))
logger.info(
"GMM fitted to steady-shear mode",
eta_0=eta_avg,
note="Linear model gives constant viscosity η₀=ΣGᵢτᵢ",
)
@staticmethod
@jax.jit
def _predict_steady_shear_jit(
E_inf: float,
E_i: jnp_typing.ndarray,
tau_i: jnp_typing.ndarray,
) -> jnp_typing.ndarray:
"""JIT-compiled zero-shear viscosity calculation.
η₀ = Σᵢ Gᵢτᵢ
"""
eta_0 = jnp.sum(E_i * tau_i)
return eta_0
def _predict_steady_shear(self, gamma_dot: np.ndarray) -> np.ndarray:
"""Predict steady-shear viscosity (constant for linear model).
Args:
gamma_dot: Shear rate array (ignored for linear model)
Returns:
Viscosity array (constant η₀ for all shear rates)
"""
symbol = "G" if self._modulus_type == "shear" else "E"
E_inf = self.parameters.get_value(f"{symbol}_inf")
E_i = jnp.array(
[self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)]
)
tau_i = jnp.array(
[self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)]
)
eta_0 = self._predict_steady_shear_jit(E_inf, E_i, tau_i)
# Return constant viscosity for all shear rates
# Use jnp.full_like to avoid explicit float() conversion (JIT blocker)
return jnp.full_like(jnp.asarray(gamma_dot), eta_0)
# =========================================================================
# Startup Flow Protocol
# =========================================================================
def _fit_startup_mode(
self,
t: np.ndarray,
eta_plus: np.ndarray,
optimization_factor: float | None = 1.5,
gamma_dot: float = 1.0,
**kwargs,
) -> None:
"""Fit GMM to startup flow (stress growth) data.
The stress growth coefficient η⁺(t) = σ(t)/γ̇ for step shear rate.
Args:
t: Time array (s)
eta_plus: Stress growth coefficient η⁺(t) = σ(t)/γ̇ (Pa.s)
optimization_factor: R² threshold for element minimization
gamma_dot: Applied shear rate (1/s) - stored for predictions
**kwargs: NLSQ optimizer arguments
"""
# Store gamma_dot for predictions
self._startup_gamma_dot = gamma_dot
# OPT-YDATA-001: stash y_data so _nlsq_fit attaches it for r_squared.
self._current_y_data = np.asarray(eta_plus)
# Extract kwargs
max_iter = kwargs.get("max_iter", 1000)
ftol = kwargs.get("ftol", 1e-6)
xtol = kwargs.get("xtol", 1e-6)
gtol = kwargs.get("gtol", 1e-6)
symbol = "G" if self._modulus_type == "shear" else "E"
# Define objective function
def objective(params):
"""Residual for startup flow."""
E_inf = params[0]
E_i = params[1 : 1 + self._n_modes]
tau_i = params[1 + self._n_modes :]
eta_plus_pred = self._predict_startup_jit(
jnp.asarray(t), E_inf, E_i, tau_i, gamma_dot
)
return eta_plus_pred - eta_plus
# Initial guess from relaxation behavior
# Use initial_params if provided (for warm-start in element minimization)
initial_params = kwargs.get("initial_params", None)
if initial_params is not None and len(initial_params) == 1 + 2 * self._n_modes:
x0 = jnp.asarray(initial_params)
else:
eta_inf = np.max(eta_plus) # Long-time viscosity
E_i_guess = jnp.full(self._n_modes, eta_inf / self._n_modes / 1.0)
tau_i_guess = jnp.logspace(-2, 2, self._n_modes)
x0 = jnp.concatenate([jnp.array([0.0]), E_i_guess, tau_i_guess])
# Bounds
bounds_lower = jnp.concatenate(
[
jnp.array([0.0]),
jnp.full(self._n_modes, 1e-12),
jnp.full(self._n_modes, 1e-6),
]
)
bounds_upper = jnp.concatenate(
[
jnp.array([np.max(eta_plus) * 10]),
jnp.full(self._n_modes, np.max(eta_plus) * 10),
jnp.full(self._n_modes, 1e6),
]
)
result = self._nlsq_fit(
objective,
x0,
bounds=(bounds_lower, bounds_upper),
max_nfev=max_iter,
ftol=ftol,
xtol=xtol,
gtol=gtol,
)
# Set parameters (batch update for 5-10% speedup)
params_opt = result.x
param_values = {f"{symbol}_inf": float(params_opt[0])}
param_values.update(
{f"{symbol}_{i+1}": float(params_opt[1 + i]) for i in range(self._n_modes)}
)
param_values.update(
{
f"tau_{i+1}": float(params_opt[1 + self._n_modes + i])
for i in range(self._n_modes)
}
)
self.parameters.set_values(param_values)
self._nlsq_result = result
# Element minimization
if optimization_factor is not None and self._n_modes > 1:
self._apply_element_minimization(t, eta_plus, optimization_factor, **kwargs)
@staticmethod
@jax.jit
def _predict_startup_jit(
t: jnp_typing.ndarray,
E_inf: float,
E_i: jnp_typing.ndarray,
tau_i: jnp_typing.ndarray,
gamma_dot: float,
) -> jnp_typing.ndarray:
"""JIT-compiled startup flow prediction.
Stress growth coefficient: η⁺(t) = σ(t)/γ̇
For Maxwell element: ηᵢ⁺(t) = Gᵢτᵢ(1 - exp(-t/τᵢ))
Total: η⁺(t) = Σᵢ Gᵢτᵢ(1 - exp(-t/τᵢ))
"""
# Each mode contribution: Gᵢτᵢ(1 - exp(-t/τᵢ))
eta_plus = jnp.sum(
E_i[:, None] * tau_i[:, None] * (1 - jnp.exp(-t[None, :] / tau_i[:, None])),
axis=0,
)
return eta_plus
def _predict_startup(self, t: np.ndarray) -> np.ndarray:
"""Predict stress growth coefficient η⁺(t).
Args:
t: Time array (s)
Returns:
Stress growth coefficient η⁺(t) (Pa.s)
"""
symbol = "G" if self._modulus_type == "shear" else "E"
gamma_dot = getattr(self, "_startup_gamma_dot", 1.0)
E_inf = self.parameters.get_value(f"{symbol}_inf")
E_i = jnp.array(
[self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)]
)
tau_i = jnp.array(
[self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)]
)
eta_plus = self._predict_startup_jit(
jnp.asarray(t), E_inf, E_i, tau_i, gamma_dot
)
return np.asarray(eta_plus)
# =========================================================================
# LAOS Protocol
# =========================================================================
def _fit_laos_mode(
self,
omega: np.ndarray,
G_star: np.ndarray,
optimization_factor: float | None = 1.5,
gamma_0: float = 0.01,
**kwargs,
) -> None:
"""Fit GMM to LAOS data.
For a linear viscoelastic model, LAOS = SAOS (no nonlinear harmonics).
This delegates to oscillation fitting.
Args:
omega: Angular frequency array (rad/s)
G_star: Complex modulus [G', G''] - same format as oscillation
optimization_factor: R² threshold for element minimization
gamma_0: Strain amplitude (stored for predictions)
**kwargs: NLSQ optimizer arguments
"""
# Store LAOS parameters
self._laos_omega = omega[0] if len(omega) > 0 else 1.0
self._laos_gamma_0 = gamma_0
# For linear model, LAOS = SAOS
logger.info(
"GMM LAOS mode: Linear model gives SAOS response (no nonlinear harmonics)"
)
self._fit_oscillation_mode(omega, G_star, optimization_factor, **kwargs)
@staticmethod
@jax.jit
def _predict_laos_jit(
t: jnp_typing.ndarray,
E_inf: float,
E_i: jnp_typing.ndarray,
tau_i: jnp_typing.ndarray,
omega: float,
gamma_0: float,
) -> jnp_typing.ndarray:
"""JIT-compiled LAOS stress prediction.
For linear viscoelastic model:
γ(t) = γ₀ sin(ωt)
σ(t) = G'γ₀ sin(ωt) + G''γ₀ cos(ωt)
Returns stress(t) array.
"""
# Compute G' and G'' at this frequency
omega_tau = omega * tau_i
omega_tau_sq = omega_tau**2
G_prime = E_inf + jnp.sum(E_i * omega_tau_sq / (1 + omega_tau_sq))
G_double_prime = jnp.sum(E_i * omega_tau / (1 + omega_tau_sq))
# Linear response: σ(t) = G'γ₀ sin(ωt) + G''γ₀ cos(ωt)
stress = G_prime * gamma_0 * jnp.sin(
omega * t
) + G_double_prime * gamma_0 * jnp.cos(omega * t)
return stress
def _predict_laos(self, t: np.ndarray) -> np.ndarray:
"""Predict LAOS stress response.
For linear model, returns sinusoidal stress (no higher harmonics).
Args:
t: Time array (s)
Returns:
Stress response σ(t) (Pa)
"""
symbol = "G" if self._modulus_type == "shear" else "E"
omega = getattr(self, "_laos_omega", 1.0)
gamma_0 = getattr(self, "_laos_gamma_0", 0.01)
E_inf = self.parameters.get_value(f"{symbol}_inf")
E_i = jnp.array(
[self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)]
)
tau_i = jnp.array(
[self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)]
)
stress = self._predict_laos_jit(
jnp.asarray(t), E_inf, E_i, tau_i, omega, gamma_0
)
return np.asarray(stress)
[docs]
def simulate_laos(
self,
omega: float,
gamma_0: float,
n_cycles: int = 5,
n_points_per_cycle: int = 64,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Simulate LAOS response.
Args:
omega: Angular frequency (rad/s)
gamma_0: Strain amplitude
n_cycles: Number of oscillation cycles
n_points_per_cycle: Points per cycle
Returns:
t: Time array
strain: Strain array
stress: Stress array
"""
# Store for predictions
self._laos_omega = omega
self._laos_gamma_0 = gamma_0
# Generate time array
period = 2 * np.pi / omega
t = np.linspace(0, n_cycles * period, n_cycles * n_points_per_cycle)
# Strain
strain = gamma_0 * np.sin(omega * t)
# Stress (linear response)
stress = self._predict_laos(t)
return t, strain, stress