Source code for rheojax.models.multimode.generalized_maxwell

"""Generalized Maxwell Model (Prony series) for multi-mode viscoelastic relaxation.

The Generalized Maxwell Model (GMM) extends the single Maxwell element to N modes,
providing a flexible framework for capturing complex relaxation spectra:

    E(t) = E_∞ + Σᵢ₌₁ᴺ Eᵢ exp(-t/τᵢ)

Key features:
- Tri-mode equality: relaxation, oscillation, and creep predictions
- Two-step NLSQ fitting with softmax penalty for physical constraints
- Transparent element minimization (auto-optimize N)
- Bayesian inference via NumPyro NUTS with warm-start
- Tiered Bayesian prior safety mechanism (fail-fast on bad NLSQ convergence)
- JIT-compiled predictions for GPU acceleration

References:
    - Park, S. W., & Schapery, R. A. (1999). Methods of interconversion between
      linear viscoelastic material functions. Part I—A numerical method based on
      Prony series. International Journal of Solids and Structures, 36(11), 1653-1675.
    - pyvisco: https://github.com/saintsfan342000/pyvisco
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, cast

import nlsq
import numpy as np

from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol

# Lazy import diffrax for transient simulations (deferred to avoid ~250ms startup cost)
from rheojax.core.jax_config import lazy_import as _lazy_import
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode
from rheojax.logging import get_logger, log_fit
from rheojax.utils.optimization import OptimizationResult
from rheojax.utils.prony import (
    compute_r_squared,
    create_prony_parameter_set,
    select_optimal_n,
    softmax_penalty,
)

diffrax = _lazy_import("diffrax")

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()

if TYPE_CHECKING:  # pragma: no cover
    import jax.numpy as jnp_typing
else:
    jnp_typing = np

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "generalized_maxwell", protocols=[ Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class GeneralizedMaxwell(BaseModel): """Generalized Maxwell Model with N exponential relaxation modes. The GMM uses Prony series representation for tri-mode viscoelastic behavior: **Relaxation mode:** E(t) = E_∞ + Σᵢ₌₁ᴺ Eᵢ exp(-t/τᵢ) **Oscillation mode (closed-form Fourier transform):** E'(ω) = E_∞ + Σᵢ Eᵢ (ωτᵢ)²/(1+(ωτᵢ)²) E"(ω) = Σᵢ Eᵢ (ωτᵢ)/(1+(ωτᵢ)²) **Creep mode (numerical simulation):** J(t) = ε(t)/σ₀ via backward-Euler integration **Performance Optimization (v0.4.0+):** Element minimization workflows use warm-start optimization for 2-5x speedup: - Successive fits initialized from optimal N+1 parameters - Compilation reuse across n_modes iterations - Early termination when R² degrades below threshold - Transparent optimization (no API changes required) - Typical speedup: 20-50s → 4-25s for N=10 element search Parameters: n_modes: Number of relaxation modes (N) modulus_type: 'shear' (G) or 'tensile' (E) Attributes: parameters: ParameterSet containing E_inf, E_i, tau_i (or G equivalents) Example: >>> from rheojax.models.generalized_maxwell import GeneralizedMaxwell >>> import numpy as np >>> model = GeneralizedMaxwell(n_modes=3, modulus_type='shear') >>> t = np.logspace(-3, 2, 50) >>> G_data = ... # Relaxation modulus data >>> model.fit(t, G_data, test_mode='relaxation', optimization_factor=1.5) >>> G_pred = model.predict(t) >>> # Element minimization automatically uses warm-start for 2-5x speedup >>> print(f"Optimal modes: {model._n_modes}") # Auto-reduced from 3 """
[docs] def __init__(self, n_modes: int = 3, modulus_type: str = "shear"): """Initialize Generalized Maxwell Model. Args: n_modes: Number of exponential relaxation modes (N ≥ 1) modulus_type: 'shear' for G (default) or 'tensile' for E Raises: ValueError: If n_modes < 1 or modulus_type invalid """ super().__init__() if n_modes < 1: raise ValueError(f"n_modes must be ≥ 1, got {n_modes}") if modulus_type not in ["shear", "tensile"]: raise ValueError( f"modulus_type must be 'shear' or 'tensile', got '{modulus_type}'" ) self._n_modes = n_modes self._modulus_type = modulus_type self._test_mode: TestMode | str | None = None # Create Prony parameter set self.parameters = create_prony_parameter_set(n_modes, modulus_type) # Store NLSQ result for warm-start and diagnostics self._nlsq_result: OptimizationResult | None = None # Store element minimization diagnostics self._element_minimization_diagnostics: dict[str, object] | None = None
def _fit( self, X: np.ndarray, y: np.ndarray, test_mode: str | None = None, optimization_factor: float | None = 1.5, **kwargs, ) -> None: """Fit GMM to data using NLSQ optimization. Args: X: Independent variable (time or frequency) y: Dependent variable (modulus or compliance) test_mode: Test mode ('relaxation', 'oscillation', 'creep') optimization_factor: R² threshold multiplier for element minimization (None to disable) **kwargs: NLSQ optimizer arguments (max_iter, ftol, xtol, gtol) Raises: ValueError: If test_mode not provided or invalid """ # Detect test mode if test_mode is None: logger.error("test_mode must be specified for GMM fitting") raise ValueError("test_mode must be specified for GMM fitting") self._test_mode = test_mode with log_fit( logger, self.__class__.__name__, data_shape=X.shape, test_mode=test_mode, n_modes=self._n_modes, modulus_type=self._modulus_type, ) as ctx: logger.debug( "Processing GMM input data", x_range=(float(X.min()), float(X.max())), y_range=(float(np.real(y).min()), float(np.real(y).max())), optimization_factor=optimization_factor, ) # Route to appropriate fitting method try: if test_mode == "relaxation": self._fit_relaxation_mode( X, y, optimization_factor=optimization_factor, **kwargs ) elif test_mode == "oscillation": self._fit_oscillation_mode( X, y, optimization_factor=optimization_factor, **kwargs ) elif test_mode == "creep": self._fit_creep_mode( X, y, optimization_factor=optimization_factor, **kwargs ) elif test_mode == "steady_shear": self._fit_steady_shear_mode( X, y, optimization_factor=optimization_factor, **kwargs ) elif test_mode == "startup": self._fit_startup_mode( X, y, optimization_factor=optimization_factor, **kwargs ) elif test_mode == "laos": self._fit_laos_mode( X, y, optimization_factor=optimization_factor, **kwargs ) else: logger.error("Unknown test_mode", test_mode=test_mode) raise ValueError(f"Unknown test_mode: {test_mode}") except Exception as e: logger.error( "GMM fitting failed", error_type=type(e).__name__, error_message=str(e), exc_info=True, ) raise # Log fitted parameters symbol = "E" if self._modulus_type == "tensile" else "G" ctx["n_modes_final"] = self._n_modes ctx[f"{symbol}_inf"] = self.parameters.get_value(f"{symbol}_inf") logger.debug( "GMM fitting completed", n_modes_final=self._n_modes, modulus_inf=self.parameters.get_value(f"{symbol}_inf"), ) def _nlsq_fit( self, objective, x0, bounds, max_nfev=1000, ftol=1e-6, xtol=1e-6, gtol=1e-6, y_data=None, ) -> OptimizationResult: """NLSQ wrapper for consistent fitting across modes. Args: objective: Residual function x0: Initial parameter guess bounds: (lower, upper) parameter bounds max_nfev: Maximum function evaluations ftol: Function tolerance xtol: Parameter tolerance gtol: Gradient tolerance y_data: Optional raw dependent-variable array. When provided, gets attached to the result so ``r_squared`` computes correctly. If absent, also tries ``getattr(objective, "_y_data", None)``. Returns: OptimizationResult with fitted parameters and diagnostics """ logger.debug( "Starting NLSQ optimization", n_params=len(x0), max_nfev=max_nfev, ftol=ftol, xtol=xtol, gtol=gtol, ) ls = nlsq.LeastSquares() try: nlsq_result = ls.least_squares( objective, x0=np.asarray(x0), bounds=bounds, method="trf", ftol=ftol, xtol=xtol, gtol=gtol, max_nfev=max_nfev, verbose=0, ) except ValueError as e: # Handle infeasible initial guess logger.error( "NLSQ optimization failed with ValueError", error_message=str(e), exc_info=True, ) raise RuntimeError( f"NLSQ optimization failed with error: {e}\n" "This may indicate:\n" " 1. Data is unsuitable for GMM fitting (e.g., constant values)\n" " 2. Initial parameter guess is outside bounds\n" " 3. Too many modes for the available data" ) from e # OPT-YDATA-001: compute residuals at optimum and attach y_data so # ``r_squared`` works. Without this the GMM custom path leaves # residuals=None and y_data=None, masking fit success as r_squared=None. from rheojax.utils.optimization import attach_y_data_to_result try: _final_res = np.asarray(objective(nlsq_result.x)) if np.iscomplexobj(_final_res): _final_res = np.concatenate( [np.real(_final_res), np.imag(_final_res)] ) _final_res = _final_res.astype(np.float64) except Exception: # pragma: no cover - defensive _final_res = None # Convert to OptimizationResult result = OptimizationResult( x=np.asarray(nlsq_result.x), fun=nlsq_result.cost, jac=np.asarray(nlsq_result.jac) if nlsq_result.jac is not None else None, success=nlsq_result.success, message=nlsq_result.message, nit=nlsq_result.nfev, nfev=nlsq_result.nfev, njev=nlsq_result.njev if hasattr(nlsq_result, "njev") else 0, optimality=( nlsq_result.optimality if hasattr(nlsq_result, "optimality") else None ), active_mask=( nlsq_result.active_mask if hasattr(nlsq_result, "active_mask") else None ), cost=nlsq_result.cost, grad=( np.asarray(nlsq_result.grad) if hasattr(nlsq_result, "grad") and nlsq_result.grad is not None else None ), nlsq_result=nlsq_result, residuals=_final_res, ) # Prefer the explicit y_data argument; fall back to one stashed on # the objective itself by the caller, then on self (set by the # _fit_*_mode method that originated the call). _y = y_data if y_data is not None else getattr(objective, "_y_data", None) if _y is None: _y = getattr(self, "_current_y_data", None) attach_y_data_to_result(result, _y) logger.debug( "NLSQ optimization completed", success=result.success, cost=result.cost, nfev=result.nfev, message=result.message, ) return result def _fit_relaxation_mode( self, t: np.ndarray, E_t: np.ndarray, optimization_factor: float | None = 1.5, initial_params: np.ndarray | None = None, **kwargs, ) -> None: """Fit GMM to relaxation modulus data. Args: t: Time array E_t: Relaxation modulus array optimization_factor: R² threshold multiplier for element minimization initial_params: Optional initial parameter guess for warm-start Shape: (2*n_modes + 1,) [E_inf, E_1...E_N, tau_1...tau_N] If None, uses default heuristic initialization **kwargs: NLSQ optimizer arguments """ # OPT-YDATA-001: stash y_data so _nlsq_fit can attach it to the # OptimizationResult and ``r_squared`` works on _nlsq_result. self._current_y_data = np.asarray(E_t) # Extract kwargs max_iter = kwargs.get("max_iter", 1000) ftol = kwargs.get("ftol", 1e-6) xtol = kwargs.get("xtol", 1e-6) gtol = kwargs.get("gtol", 1e-6) use_log_residuals = kwargs.get("use_log_residuals", False) symbol = "E" if self._modulus_type == "tensile" else "G" # Precompute log-space observation once when using log residuals. _log_E_t = jnp.log10(jnp.maximum(jnp.asarray(E_t), 1e-30)) # Define objective function def objective(params): """Residual for relaxation modulus. Uses log-space residuals when ``use_log_residuals`` is set, so that master curves spanning many decades in E(t) weight every decade equally instead of being dominated by the glassy plateau. """ E_inf = params[0] E_i = params[1 : 1 + self._n_modes] tau_i = params[1 + self._n_modes :] # Predict relaxation modulus E_pred = self._predict_relaxation_jit(jnp.asarray(t), E_inf, E_i, tau_i) if use_log_residuals: return jnp.log10(jnp.maximum(E_pred, 1e-30)) - _log_E_t return E_pred - E_t # Derive tau range from the actual time data so master curves spanning # many decades are fittable. Previously hardcoded to logspace(-2, 2), # which silently truncated any t-range outside [0.01, 100] s. t_np = np.asarray(t) t_pos = t_np[t_np > 0] if t_pos.size > 0: log_t_lo = float(np.log10(t_pos.min())) log_t_hi = float(np.log10(t_pos.max())) else: log_t_lo, log_t_hi = -2.0, 2.0 # Pad by one decade on each side and clamp to a safe numerical floor. tau_lo_bound = max(10.0 ** (log_t_lo - 2.0), 1e-30) tau_hi_bound = 10.0 ** (log_t_hi + 2.0) if self._n_modes == 1: tau_guess_arr = jnp.array([10.0 ** (0.5 * (log_t_lo + log_t_hi))]) else: tau_guess_arr = jnp.logspace(log_t_lo, log_t_hi, self._n_modes) # Always compute derivative-based heuristic guesses so that the # multi-start retry block below can use them even when the caller # supplied ``initial_params``. E_inf_guess = jnp.min(E_t) E_sum_guess = jnp.max(E_t) - E_inf_guess # Derivative-based initial E_i: estimate the contribution from each # tau bin using the local drop in E(t) around that tau. This breaks # the uniform-guess Jacobian degeneracy at high n_modes. t_arr = np.asarray(t) E_arr = np.asarray(E_t) tau_arr = np.asarray(tau_guess_arr) if t_arr.size >= 2 and self._n_modes > 1: order = np.argsort(t_arr) t_sorted = t_arr[order] E_sorted = E_arr[order] log_t = np.log(np.maximum(t_sorted, 1e-30)) dEdlogt = np.gradient(E_sorted, log_t) contrib = np.interp( np.log(np.clip(tau_arr, t_sorted[0], t_sorted[-1])), log_t, -dEdlogt, ) contrib = np.clip(contrib, 1e-6, None) contrib_sum = float(contrib.sum()) total = float(E_sum_guess) if contrib_sum > 0 and total > 0: E_i_guess = jnp.asarray(contrib * (total / contrib_sum)) else: E_i_guess = jnp.full(self._n_modes, total / max(self._n_modes, 1)) else: E_i_guess = jnp.full( self._n_modes, E_sum_guess / max(self._n_modes, 1), ) tau_i_guess = tau_guess_arr if initial_params is not None: x0 = jnp.asarray(initial_params) else: x0 = jnp.concatenate([jnp.array([E_inf_guess]), E_i_guess, tau_i_guess]) # Parameter bounds — use data-derived tau range (with wide padding) so # the optimizer can actually reach the relaxation times in the data. bounds_lower = jnp.concatenate( [ jnp.array([0.0]), jnp.full(self._n_modes, 1e-12), jnp.full(self._n_modes, tau_lo_bound), ] ) bounds_upper = jnp.concatenate( [ jnp.array([jnp.max(E_t) * 10]), jnp.full(self._n_modes, jnp.max(E_t) * 10), jnp.full(self._n_modes, tau_hi_bound), ] ) # Step 1: Fit with softmax penalty def objective_step1(params): """Objective with softmax penalty.""" E_i = params[1 : 1 + self._n_modes] residual = objective(params) penalty = softmax_penalty(E_i, scale=1e-3) return jnp.concatenate([residual, jnp.array([penalty])]) def _run_fit_relax(x_init): return self._nlsq_fit( objective_step1, x_init, bounds=(bounds_lower, bounds_upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) result_step1 = _run_fit_relax(x0) # --- Multi-start: Prony fitting has many local minima because # adjacent modes overlap in their contributions to E(t). Always # perturb the initial guess a few times and keep the lowest-cost # result. This is ~4x the cost of a single fit but eliminates # seed-specific bad minima and Jacobian-ridge stalls at once. best_result = result_step1 if initial_params is None and self._n_modes >= 2: rng_retry = np.random.default_rng(0) n_p = self._n_modes total_E = float(jnp.max(E_t) - jnp.min(E_t)) base_E = np.asarray(E_i_guess) base_tau = np.asarray(tau_i_guess) for _attempt in range(4): pert_E = rng_retry.uniform(0.3, 3.0, size=n_p) pert_tau = 10.0 ** rng_retry.uniform(-0.5, 0.5, size=n_p) E_init = jnp.asarray( np.clip( base_E * pert_E, 1e-6 * max(total_E, 1.0), 10.0 * max(total_E, 1.0), ) ) tau_init = jnp.asarray( np.clip(base_tau * pert_tau, tau_lo_bound, tau_hi_bound) ) x_retry = jnp.concatenate([jnp.array([E_inf_guess]), E_init, tau_init]) try: result_retry = _run_fit_relax(x_retry) except Exception: continue if float(result_retry.cost) < float(best_result.cost): best_result = result_retry result_step1 = best_result # Check for negative Eᵢ params_opt = result_step1.x E_i_opt = params_opt[1 : 1 + self._n_modes] if jnp.any(E_i_opt < 0): logger.warning( "Negative Eᵢ detected in relaxation fit. Refitting with hard bounds." ) # Step 2: Refit with hard bounds result_step2 = self._nlsq_fit( objective, params_opt, bounds=(bounds_lower, bounds_upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) result_final = result_step2 params_opt = result_final.x else: result_final = result_step1 # Store NLSQ result self._nlsq_result = result_final # Set fitted parameters (batch update for 5-10% speedup) E_inf_opt = params_opt[0] E_i_opt = params_opt[1 : 1 + self._n_modes] tau_i_opt = params_opt[1 + self._n_modes :] param_values = {f"{symbol}_inf": float(E_inf_opt)} param_values.update( {f"{symbol}_{i+1}": float(E_i_opt[i]) for i in range(self._n_modes)} ) param_values.update( {f"tau_{i+1}": float(tau_i_opt[i]) for i in range(self._n_modes)} ) self.parameters.set_values(param_values) # Element minimization if optimization_factor is not None and self._n_modes > 1: self._apply_element_minimization(t, E_t, optimization_factor, **kwargs) def _apply_element_minimization( self, X: np.ndarray, y: np.ndarray, optimization_factor: float, **kwargs ) -> None: """Apply element minimization with padded arrays to avoid JIT recompilation. Performance optimization: eliminates JAX recompilation by keeping parameter arrays at fixed N_max shape throughout the N-reduction loop. Inactive modes are frozen via bounds (lower == upper) so they don't affect optimization. Key insight: Setting E_i=0 for inactive modes naturally zeroes their contribution in the additive Prony sum (0 * exp(-t/tau) = 0), so no explicit masking is needed. Args: X: Independent variable (time or frequency) y: Dependent variable (modulus or compliance) - For relaxation/creep: 1D array of shape (M,) - For oscillation: 1D concatenated array [G', G"] of shape (2*M,) optimization_factor: R² threshold multiplier (e.g., 1.5 means N_opt where R²_N >= 1.5 * R²_min) **kwargs: NLSQ optimizer arguments """ # OPT-YDATA-001: ensure y_data is stashed (may already be set by the # caller, but be robust if called directly). self._current_y_data = np.asarray(y) # Store initial n_modes for diagnostics n_max = self._n_modes n_initial = n_max # Extract NLSQ kwargs max_iter = kwargs.get("max_iter", 1000) ftol = kwargs.get("ftol", 1e-6) xtol = kwargs.get("xtol", 1e-6) gtol = kwargs.get("gtol", 1e-6) symbol = "E" if self._modulus_type == "tensile" else "G" # Convert data to JAX arrays (once) X_jax = jnp.asarray(X) y_jax = jnp.asarray(y) # Compute data-based upper bound for moduli E_max = float(jnp.max(jnp.abs(y_jax)) * 10) # Select JIT prediction function based on test mode # All prediction functions use E_i[:, None] broadcasting or jnp.sum(E_i * ...), # so E_i=0 for inactive modes naturally contributes zero. test_mode = self._test_mode # Define padded objective function (always uses N_max-shaped arrays) # This is JIT-compiled ONCE and reused for all n_active values. if test_mode in ("relaxation",): def objective(params): E_inf = params[0] E_i = params[1 : 1 + n_max] tau_i = params[1 + n_max :] pred = self._predict_relaxation_jit(X_jax, E_inf, E_i, tau_i) return pred - y_jax elif test_mode in ("oscillation", "laos"): def objective(params): E_inf = params[0] E_i = params[1 : 1 + n_max] tau_i = params[1 + n_max :] pred = self._predict_oscillation_jit(X_jax, E_inf, E_i, tau_i) return jnp.concatenate([pred[0], pred[1]]) - y_jax elif test_mode == "creep": def objective(params): E_inf = params[0] E_i = params[1 : 1 + n_max] tau_i = params[1 + n_max :] pred = self._predict_creep_jit(X_jax, E_inf, E_i, tau_i) return pred - y_jax elif test_mode == "startup": gamma_dot = getattr(self, "_startup_gamma_dot", 1.0) def objective(params): E_inf = params[0] E_i = params[1 : 1 + n_max] tau_i = params[1 + n_max :] pred = self._predict_startup_jit(X_jax, E_inf, E_i, tau_i, gamma_dot) return pred - y_jax else: raise ValueError( f"Element minimization not supported for test_mode: {test_mode}" ) # Softmax penalty wrapper (also fixed shape) def objective_step1(params): E_i = params[1 : 1 + n_max] residual = objective(params) penalty = softmax_penalty(E_i, scale=1e-3) return jnp.concatenate([residual, jnp.array([penalty])]) # Get current best params from the initial N_max fit if self._nlsq_result is not None: current_params = np.asarray(self._nlsq_result.x) else: E_inf = self.parameters.get_value(f"{symbol}_inf") E_i = [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(n_max)] tau_i = [self.parameters.get_value(f"tau_{i+1}") for i in range(n_max)] current_params = np.array([E_inf] + E_i + tau_i) # Iterative N reduction with padded arrays fit_results: dict = {} best_params = current_params.copy() r2_max = None r2_threshold = None # Pre-compute base bounds arrays (all-active case) and inactive values. # Only the active/inactive boundary changes per iteration, so we update # slices in-place instead of rebuilding from scratch each time. lower = np.zeros(2 * n_max + 1) upper = np.zeros(2 * n_max + 1) lower[0] = 0.0 upper[0] = E_max # Start with all modes active lower[1 : 1 + n_max] = 1e-12 upper[1 : 1 + n_max] = E_max lower[1 + n_max :] = 1e-6 upper[1 + n_max :] = 1e6 for n_active in range(n_max, 0, -1): try: # Freeze modes beyond n_active. # E_i bounds: inactive nearly frozen (NLSQ TRF requires lower < upper). # E_i < 1e-30 Pa is effectively zero. lower[1 + n_active : 1 + n_max] = 0.0 upper[1 + n_active : 1 + n_max] = 1e-30 # tau_i bounds: inactive nearly frozen around 1.0. lower[1 + n_max + n_active :] = 1.0 - 1e-12 upper[1 + n_max + n_active :] = 1.0 + 1e-12 # Warm-start: zero out inactive modes from previous best x0 = best_params.copy() x0[1 + n_active : 1 + n_max] = 0.0 # Inactive E_i x0[1 + n_max + n_active :] = 1.0 # Inactive tau_i # Clamp active params to bounds x0 = np.clip(x0, lower, upper) # Step 1: Fit with softmax penalty result = self._nlsq_fit( objective_step1, x0, bounds=(lower, upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) # Check for negative E_i in active modes and refit if needed params_opt = result.x E_i_active = params_opt[1 : 1 + n_active] if jnp.any(E_i_active < 0): result = self._nlsq_fit( objective, params_opt, bounds=(lower, upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) params_opt = result.x # Compute prediction for R² residual = np.asarray(objective(params_opt)) y_pred = np.asarray(y) + residual r2_n = compute_r_squared(y, y_pred) fit_results[n_active] = { "r2": r2_n, "params": params_opt.copy(), "result": result, } best_params = params_opt.copy() # Set R² threshold after first fit (highest N) if r2_max is None: r2_max = r2_n degradation_room = 1.0 - r2_max allowed_degradation = degradation_room * (optimization_factor - 1.0) r2_threshold = r2_max - allowed_degradation # Early termination: stop if R² falls below threshold if r2_threshold is not None and r2_n < r2_threshold: logger.info( f"Element minimization: early termination at n_modes={n_active} " f"(R²={r2_n:.6f} < threshold={r2_threshold:.6f})" ) break except (RuntimeError, ValueError) as e: logger.warning( f"Element minimization: fitting failed for n_modes={n_active}: {e}" ) break # Select optimal N r2_values = {n: cast(float, result["r2"]) for n, result in fit_results.items()} n_optimal = select_optimal_n(r2_values, optimization_factor=optimization_factor) # Store diagnostics with all required keys n_modes_list = sorted(r2_values.keys()) r2_list = [r2_values[n] for n in n_modes_list] self._element_minimization_diagnostics = { "n_initial": n_initial, "r2": r2_list, "n_modes": n_modes_list, "n_optimal": n_optimal, "optimization_factor": optimization_factor, } # Update model if optimal N is different if n_optimal < self._n_modes: logger.info( f"Element minimization: reducing from {self._n_modes} to {n_optimal} modes" ) # Extract active parameters from padded result optimal_params = fit_results[n_optimal]["params"] E_inf_opt = optimal_params[0] E_i_opt = optimal_params[1 : 1 + n_optimal] tau_i_opt = optimal_params[1 + n_max : 1 + n_max + n_optimal] # Rebuild ParameterSet with n_optimal modes self._n_modes = n_optimal self.parameters = create_prony_parameter_set( n_optimal, modulus_type=self._modulus_type ) # Set fitted parameter values param_values = {f"{symbol}_inf": float(E_inf_opt)} param_values.update( {f"{symbol}_{i+1}": float(E_i_opt[i]) for i in range(n_optimal)} ) param_values.update( {f"tau_{i+1}": float(tau_i_opt[i]) for i in range(n_optimal)} ) self.parameters.set_values(param_values) # Build slimmed-down NLSQ result for the optimal model slim_x = np.concatenate([[E_inf_opt], E_i_opt, tau_i_opt]) optimal_result = fit_results[n_optimal]["result"] self._nlsq_result = OptimizationResult( x=slim_x, fun=optimal_result.fun, jac=None, success=optimal_result.success, message=optimal_result.message, nit=optimal_result.nit, nfev=optimal_result.nfev, njev=optimal_result.njev, optimality=optimal_result.optimality, active_mask=None, cost=optimal_result.cost, grad=None, nlsq_result=optimal_result.nlsq_result, # OPT-YDATA-001: forward y_data so r_squared is computable on # the slimmed (post-element-minimization) result too. residuals=getattr(optimal_result, "residuals", None), y_data=getattr(optimal_result, "y_data", None), n_data=getattr(optimal_result, "n_data", None), ) def _fit_oscillation_mode( self, omega: np.ndarray, E_star: np.ndarray, optimization_factor: float | None = 1.5, initial_params: np.ndarray | None = None, **kwargs, ) -> None: """Fit GMM to complex modulus data. Args: omega: Angular frequency array E_star: Complex modulus [E', E"] - can be (2, M) or (M, 2) optimization_factor: R² threshold multiplier for element minimization initial_params: Optional initial parameter guess for warm-start Shape: (2*n_modes + 1,) [E_inf, E_1...E_N, tau_1...tau_N] If None, uses default heuristic initialization **kwargs: NLSQ optimizer arguments """ # OPT-YDATA-001: stash y_data so _nlsq_fit attaches it for r_squared. # E_star may be complex or (M,2); attach as-is, r_squared handles both. self._current_y_data = np.asarray(E_star) # Extract kwargs max_iter = kwargs.get("max_iter", 1000) ftol = kwargs.get("ftol", 1e-6) xtol = kwargs.get("xtol", 1e-6) gtol = kwargs.get("gtol", 1e-6) use_log_residuals = kwargs.get("use_log_residuals", False) symbol = "E" if self._modulus_type == "tensile" else "G" # Standardize input shape to (2, M) E_star = np.asarray(E_star) if E_star.ndim == 1: if np.iscomplexobj(E_star): # Handle complex 1D array [G*_1, G*_2, ..., G*_M] E_prime = np.real(E_star) E_double_prime = np.imag(E_star) else: # Handle 1D concatenated [G', G"] from element minimization M = len(E_star) // 2 E_prime = np.real(E_star[:M]) E_double_prime = np.real(E_star[M:]) elif E_star.shape[0] == 2: # Input is (2, M), extract directly E_prime = np.real(E_star[0]) E_double_prime = np.real(E_star[1]) # FIX: Added missing assignment elif E_star.shape[1] == 2: # Input is (M, 2), transpose to (2, M) E_prime = np.real(E_star[:, 0]) E_double_prime = np.real(E_star[:, 1]) else: raise ValueError( f"E_star must have shape (2, M), (M, 2), or be 1D concatenated [G', G\"], got {E_star.shape}" ) # Precompute log observations for log-residual mode so we avoid a # jnp.log10 call on every optimizer iteration. _log_Ep = jnp.log10(jnp.maximum(jnp.asarray(E_prime), 1e-30)) _log_Epp = jnp.log10(jnp.maximum(jnp.asarray(E_double_prime), 1e-30)) # Per-component scalar normalization for the linear-residual mode. # We divide residuals by the RMS of each component so that E' and E'' # contribute with balanced weight regardless of their absolute # magnitudes. RMS is preferred over max(|obs|) because it is robust # to outliers, and over per-point |obs| because per-point division # amplifies noise near the low-magnitude tails of E''. _Ep_rms = jnp.sqrt(jnp.mean(jnp.asarray(E_prime) ** 2)) _Epp_rms = jnp.sqrt(jnp.mean(jnp.asarray(E_double_prime) ** 2)) _Ep_scale = jnp.maximum(_Ep_rms, jnp.float64(1e-12)) _Epp_scale = jnp.maximum(_Epp_rms, jnp.float64(1e-12)) # Define objective function def objective(params): """Residual for complex modulus. Uses log-space residuals when ``use_log_residuals`` is set, which is essential when E'(ω) and E''(ω) span many decades. Otherwise uses *relative* residuals (pred−obs)/|obs| so that E' and E'' contribute with balanced weight regardless of their absolute magnitudes. Absolute residuals (the old default) let the glassy plateau of E' dominate the sum-of-squares and ignore E''. """ E_inf = params[0] E_i = params[1 : 1 + self._n_modes] tau_i = params[1 + self._n_modes :] # Predict complex modulus (returns (2, M)) E_star_pred = self._predict_oscillation_jit( jnp.asarray(omega), E_inf, E_i, tau_i ) E_prime_pred = E_star_pred[0] # Extract G' from (2, M) E_double_prime_pred = E_star_pred[1] # Extract G" from (2, M) if use_log_residuals: resid_p = jnp.log10(jnp.maximum(E_prime_pred, 1e-30)) - _log_Ep resid_pp = jnp.log10(jnp.maximum(E_double_prime_pred, 1e-30)) - _log_Epp else: resid_p = (E_prime_pred - E_prime) / _Ep_scale resid_pp = (E_double_prime_pred - E_double_prime) / _Epp_scale return jnp.concatenate([resid_p, resid_pp]) # Derive tau range from the observed frequency window (τ ≈ 1/ω). # Pad ±2 decades beyond the data so optimizer can reach boundary modes. omega_np = np.asarray(omega) omega_pos = omega_np[omega_np > 0] if omega_pos.size > 0: log_tau_lo_data = float(-np.log10(omega_pos.max())) log_tau_hi_data = float(-np.log10(omega_pos.min())) else: log_tau_lo_data, log_tau_hi_data = -2.0, 2.0 tau_lo_bound = max(10.0 ** (log_tau_lo_data - 2.0), 1e-30) tau_hi_bound = 10.0 ** (log_tau_hi_data + 2.0) if self._n_modes == 1: tau_i_guess = jnp.array( [10.0 ** (0.5 * (log_tau_lo_data + log_tau_hi_data))] ) else: tau_i_guess = jnp.logspace(log_tau_lo_data, log_tau_hi_data, self._n_modes) # Always compute derivative-based heuristic guesses so the multi-start # retry block below can use them even when the caller supplied # ``initial_params``. E_inf_guess = jnp.min(E_prime) # Low-frequency plateau E_sum_guess = jnp.max(E_prime) - E_inf_guess # Seed each E_i from the local storage-modulus derivative # −dE'/d(ln ω) evaluated at each τ_k (since 1/τ_k ≈ ω_k). omega_sorted_idx = np.argsort(omega_np) omega_sorted = omega_np[omega_sorted_idx] Ep_sorted = np.asarray(E_prime)[omega_sorted_idx] if omega_sorted.size >= 2 and self._n_modes > 1: log_omega = np.log(np.maximum(omega_sorted, 1e-30)) dEp_dlogw = np.gradient(Ep_sorted, log_omega) tau_np = np.asarray(tau_i_guess) omega_at_tau = 1.0 / np.clip(tau_np, 1e-30, None) contrib = np.interp( np.log(np.clip(omega_at_tau, omega_sorted[0], omega_sorted[-1])), log_omega, dEp_dlogw, ) contrib = np.clip(contrib, 1e-6, None) contrib_sum = float(contrib.sum()) total = float(E_sum_guess) if contrib_sum > 0 and total > 0: E_i_guess = jnp.asarray(contrib * (total / contrib_sum)) else: E_i_guess = jnp.full(self._n_modes, E_sum_guess / max(self._n_modes, 1)) else: E_i_guess = jnp.full(self._n_modes, E_sum_guess / max(self._n_modes, 1)) if initial_params is not None: x0 = jnp.asarray(initial_params) else: x0 = jnp.concatenate([jnp.array([E_inf_guess]), E_i_guess, tau_i_guess]) # Parameter bounds — data-derived tau range so master curves spanning # many decades stay inside the box. bounds_lower = jnp.concatenate( [ jnp.array([0.0]), jnp.full(self._n_modes, 1e-12), jnp.full(self._n_modes, tau_lo_bound), ] ) bounds_upper = jnp.concatenate( [ jnp.array([jnp.max(E_prime) * 10]), jnp.full(self._n_modes, jnp.max(E_prime) * 10), jnp.full(self._n_modes, tau_hi_bound), ] ) # Step 1: Fit with softmax penalty def objective_step1(params): """Objective with softmax penalty.""" E_i = params[1 : 1 + self._n_modes] residual = objective(params) penalty = softmax_penalty(E_i, scale=1e-3) return jnp.concatenate([residual, jnp.array([penalty])]) def _run_fit(x_init): return self._nlsq_fit( objective_step1, x_init, bounds=(bounds_lower, bounds_upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) result_step1 = _run_fit(x0) # --- Multi-start: Prony fitting has many local minima because # adjacent modes overlap in their contributions to E*(ω). Always # perturb the initial guess a few times and keep the lowest-cost # result. Eliminates both seed-specific bad minima and # Jacobian-ridge stalls at once. best_result = result_step1 if initial_params is None and self._n_modes >= 2: rng_retry = np.random.default_rng(0) n_p = self._n_modes total_E = float(jnp.max(E_prime) - jnp.min(E_prime)) base_E = np.asarray(E_i_guess) base_tau = np.asarray(tau_i_guess) for _attempt in range(4): pert_E = rng_retry.uniform(0.3, 3.0, size=n_p) pert_tau = 10.0 ** rng_retry.uniform(-0.5, 0.5, size=n_p) E_init = jnp.asarray( np.clip(base_E * pert_E, 1e-6 * total_E, 10.0 * total_E) ) tau_init = jnp.asarray( np.clip(base_tau * pert_tau, tau_lo_bound, tau_hi_bound) ) x_retry = jnp.concatenate([jnp.array([E_inf_guess]), E_init, tau_init]) try: result_retry = _run_fit(x_retry) except Exception: continue if float(result_retry.cost) < float(best_result.cost): best_result = result_retry result_step1 = best_result # Check for negative Eᵢ params_opt = result_step1.x E_i_opt = params_opt[1 : 1 + self._n_modes] if jnp.any(E_i_opt < 0): logger.warning( "Negative Eᵢ detected in oscillation fit. Refitting with hard bounds." ) # Step 2: Refit with hard bounds result_step2 = self._nlsq_fit( objective, params_opt, bounds=(bounds_lower, bounds_upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) result_final = result_step2 params_opt = result_final.x else: result_final = result_step1 # Store NLSQ result self._nlsq_result = result_final # Set fitted parameters (batch update for 5-10% speedup) E_inf_opt = params_opt[0] E_i_opt = params_opt[1 : 1 + self._n_modes] tau_i_opt = params_opt[1 + self._n_modes :] param_values = {f"{symbol}_inf": float(E_inf_opt)} param_values.update( {f"{symbol}_{i+1}": float(E_i_opt[i]) for i in range(self._n_modes)} ) param_values.update( {f"tau_{i+1}": float(tau_i_opt[i]) for i in range(self._n_modes)} ) self.parameters.set_values(param_values) # Element minimization if optimization_factor is not None and self._n_modes > 1: # Reconstruct combined data for minimization (flatten to 1D) combined_data = np.concatenate([E_prime, E_double_prime]) self._apply_element_minimization( omega, combined_data, optimization_factor, **kwargs ) def _fit_creep_mode( self, t: np.ndarray, J_t: np.ndarray, optimization_factor: float | None = 1.5, initial_params: np.ndarray | None = None, **kwargs, ) -> None: """Fit GMM to creep compliance data. Args: t: Time array J_t: Creep compliance array optimization_factor: R² threshold multiplier for element minimization initial_params: Optional initial parameter guess for warm-start Shape: (2*n_modes + 1,) [J_0, J_1...J_N, tau_1...tau_N] If None, uses default heuristic initialization **kwargs: NLSQ optimizer arguments """ # OPT-YDATA-001: stash y_data so _nlsq_fit attaches it for r_squared. self._current_y_data = np.asarray(J_t) # Extract kwargs max_iter = kwargs.get("max_iter", 1000) ftol = kwargs.get("ftol", 1e-6) xtol = kwargs.get("xtol", 1e-6) gtol = kwargs.get("gtol", 1e-6) symbol = "E" if self._modulus_type == "tensile" else "G" # Define objective function (predict creep from GMM simulation) def objective(params): """Residual for creep compliance.""" E_inf = params[0] E_i = params[1 : 1 + self._n_modes] tau_i = params[1 + self._n_modes :] # Predict creep compliance via GMM simulation # Apply step stress σ₀ = 1, solve for strain ε(t), compute J(t) = ε(t)/σ₀ J_pred = self._predict_creep_internal(t, E_inf, E_i, tau_i) return J_pred - J_t # Compute data-based bounds (needed regardless of warm-start) J_0 = jnp.min(J_t) # Initial compliance (instant response) J_inf = jnp.max(J_t) # Final compliance (long-time) # Initial parameter guess (warm-start if provided, else default heuristic) if initial_params is not None: x0 = jnp.asarray(initial_params) else: # For creep: J_0 = 1/(E_∞ + ΣEᵢ), J_∞ = 1/E_∞ # E_∞ corresponds to long-time equilibrium: J_∞ = 1/E_∞ E_inf_guess = 1.0 / J_inf # Total instant modulus: J_0 = 1/(E_∞ + ΣEᵢ) E_total_guess = 1.0 / J_0 E_sum_guess = max(E_total_guess - E_inf_guess, 1e-12) E_i_guess = jnp.full(self._n_modes, E_sum_guess / self._n_modes) tau_i_guess = jnp.logspace(-2, 2, self._n_modes) x0 = jnp.concatenate( [jnp.array([max(E_inf_guess, 1e-12)]), E_i_guess, tau_i_guess] ) # Parameter bounds bounds_lower = jnp.concatenate( [ jnp.array([0.0]), jnp.full(self._n_modes, 1e-12), jnp.full(self._n_modes, 1e-6), ] ) bounds_upper = jnp.concatenate( [ jnp.array([1.0 / J_0 * 10]), jnp.full(self._n_modes, 1.0 / J_0 * 10), jnp.full(self._n_modes, 1e6), ] ) # Step 1: Fit with softmax penalty def objective_step1(params): """Objective with softmax penalty.""" E_i = params[1 : 1 + self._n_modes] residual = objective(params) penalty = softmax_penalty(E_i, scale=1e-3) return jnp.concatenate([residual, jnp.array([penalty])]) result_step1 = self._nlsq_fit( objective_step1, x0, bounds=(bounds_lower, bounds_upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) # Check for negative Eᵢ params_opt = result_step1.x E_i_opt = params_opt[1 : 1 + self._n_modes] if jnp.any(E_i_opt < 0): logger.warning( "Negative Eᵢ detected in creep fit. Refitting with hard bounds." ) # Step 2: Refit with hard bounds result_step2 = self._nlsq_fit( objective, params_opt, bounds=(bounds_lower, bounds_upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) result_final = result_step2 params_opt = result_final.x else: result_final = result_step1 # Store NLSQ result self._nlsq_result = result_final # Set fitted parameters (batch update for 5-10% speedup) E_inf_opt = params_opt[0] E_i_opt = params_opt[1 : 1 + self._n_modes] tau_i_opt = params_opt[1 + self._n_modes :] param_values = {f"{symbol}_inf": float(E_inf_opt)} param_values.update( {f"{symbol}_{i+1}": float(E_i_opt[i]) for i in range(self._n_modes)} ) param_values.update( {f"tau_{i+1}": float(tau_i_opt[i]) for i in range(self._n_modes)} ) self.parameters.set_values(param_values) # Element minimization if optimization_factor is not None and self._n_modes > 1: self._apply_element_minimization(t, J_t, optimization_factor, **kwargs) def _predict_creep_internal( self, t: np.ndarray | jnp_typing.ndarray, E_inf: float, E_i: jnp_typing.ndarray, tau_i: jnp_typing.ndarray, sigma_0: float = 1.0, ) -> jnp_typing.ndarray: """Internal creep prediction for optimization. Args: t: Time array E_inf: Equilibrium modulus E_i: Prony coefficients (N,) tau_i: Relaxation times (N,) sigma_0: Applied stress (default 1.0) Returns: Creep compliance J(t) """ # Call JIT-compiled creep prediction J_t = self._predict_creep_jit(jnp.asarray(t), E_inf, E_i, tau_i, sigma_0) return J_t def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """Predict based on fitted test mode. Args: X: Independent variable (time or frequency) **kwargs: Additional arguments (test_mode handled via self._test_mode) Returns: Predicted values (modulus or compliance) Raises: ValueError: If test_mode not set (model not fitted) """ _kw_mode = kwargs.get("test_mode") test_mode = _kw_mode if _kw_mode is not None else self._test_mode if test_mode is None: raise ValueError("Model not fitted. Call fit() first.") # Normalize test_mode to string if hasattr(test_mode, "value"): test_mode = test_mode.value # Route to appropriate prediction method if test_mode == "relaxation": return self._predict_relaxation(X) elif test_mode == "oscillation": return self._predict_oscillation(X) elif test_mode == "creep": return self._predict_creep(X) elif test_mode in ("steady_shear", "flow_curve"): return self._predict_steady_shear(X) elif test_mode == "startup": return self._predict_startup(X) elif test_mode == "laos": return self._predict_laos(X) else: raise ValueError(f"Unknown test_mode: {test_mode}") @staticmethod @jax.jit def _predict_relaxation_jit( t: jnp_typing.ndarray, E_inf: float, E_i: jnp_typing.ndarray, tau_i: jnp_typing.ndarray, ) -> jnp_typing.ndarray: """JIT-compiled relaxation prediction. Args: t: Time array E_inf: Equilibrium modulus E_i: Prony coefficients (N,) tau_i: Relaxation times (N,) Returns: Relaxation modulus E(t) """ # E(t) = E_∞ + Σᵢ Eᵢ exp(-t/τᵢ) E_t = E_inf + jnp.sum( E_i[:, None] * jnp.exp(-t[None, :] / tau_i[:, None]), axis=0 ) return E_t def _predict_relaxation(self, t: np.ndarray | jnp_typing.ndarray) -> np.ndarray: """Predict relaxation modulus E(t). Args: t: Time array Returns: Relaxation modulus array """ symbol = "E" if self._modulus_type == "tensile" else "G" # Extract parameters E_inf = self.parameters.get_value(f"{symbol}_inf") E_i = jnp.array( [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)] ) tau_i = jnp.array( [self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)] ) # Convert input to JAX array t_jax = jnp.asarray(t) # Call JIT-compiled prediction E_t = self._predict_relaxation_jit(t_jax, E_inf, E_i, tau_i) return np.asarray(E_t) @staticmethod @jax.jit def _predict_oscillation_jit( omega: jnp_typing.ndarray, E_inf: float, E_i: jnp_typing.ndarray, tau_i: jnp_typing.ndarray, ) -> jnp_typing.ndarray: """JIT-compiled oscillation prediction. Args: omega: Angular frequency array E_inf: Equilibrium modulus E_i: Prony coefficients (N,) tau_i: Relaxation times (N,) Returns: Complex modulus [E', E"] (2, M) """ # Closed-form Fourier transform omega_tau = omega[None, :] * tau_i[:, None] omega_tau_sq = omega_tau**2 # E'(ω) = E_∞ + Σᵢ Eᵢ (ωτᵢ)²/(1+(ωτᵢ)²) E_prime = E_inf + jnp.sum( E_i[:, None] * omega_tau_sq / (1 + omega_tau_sq), axis=0 ) # E"(ω) = Σᵢ Eᵢ (ωτᵢ)/(1+(ωτᵢ)²) E_double_prime = jnp.sum(E_i[:, None] * omega_tau / (1 + omega_tau_sq), axis=0) # Return as (2, M) for standard complex modulus convention return jnp.stack([E_prime, E_double_prime], axis=0) def _predict_oscillation( self, omega: np.ndarray | jnp_typing.ndarray ) -> np.ndarray: """Predict complex modulus in oscillation mode. Args: omega: Angular frequency array Returns: Complex modulus G* = G' + iG'' (or E* for tensile) """ symbol = "E" if self._modulus_type == "tensile" else "G" # Extract parameters E_inf = self.parameters.get_value(f"{symbol}_inf") E_i = jnp.array( [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)] ) tau_i = jnp.array( [self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)] ) # Convert input to JAX array omega_jax = jnp.asarray(omega) # Call JIT-compiled prediction (returns (2, M)) E_star = self._predict_oscillation_jit(omega_jax, E_inf, E_i, tau_i) # Return as complex G* = G' + iG'' (consistent with all other models) E_prime = np.asarray(E_star[0]) E_double_prime = np.asarray(E_star[1]) return E_prime + 1j * E_double_prime @staticmethod @jax.jit def _predict_creep_jit( t: jnp_typing.ndarray, E_inf: float, E_i: jnp_typing.ndarray, tau_i: jnp_typing.ndarray, sigma_0: float = 1.0, ) -> jnp_typing.ndarray: """JIT-compiled creep prediction via backward-Euler. Args: t: Time array E_inf: Equilibrium modulus E_i: Prony coefficients (N,) tau_i: Relaxation times (N,) sigma_0: Applied stress (default 1.0) Returns: Creep compliance J(t) """ # Backward-Euler scheme for unconditional stability # GMM ODEs: dσᵢ/dt = -σᵢ/τᵢ + Eᵢ dε/dt # Total stress: σ = E_∞ ε + Σᵢ σᵢ # Apply step stress σ₀, solve for ε(t), compute J(t) = ε(t)/σ₀ n_steps = len(t) n_modes = len(E_i) # Initialize arrays epsilon = jnp.zeros(n_steps) # Time step (assume uniform spacing for now, handle variable later) dt = jnp.diff(t, prepend=0.0) # Backward-Euler update loop def update_step(carry, inputs): """Update internal variables and strain.""" eps_prev, sig_i_prev = carry t_curr, dt_curr = inputs # Protect against zero dt at first step dt_safe = jnp.maximum(dt_curr, 1e-12) # Solve for new strain from total stress balance # σ₀ = E_∞ εⁿ⁺¹ + Σᵢ σᵢⁿ⁺¹ # σᵢⁿ⁺¹ = (σᵢⁿ + Eᵢ Δε) / (1 + Δt/τᵢ) # Substitute and solve for Δε # Coefficients for backward-Euler alpha_i = jnp.exp(-dt_safe / tau_i) # Exact exponential integration beta_i = E_i * tau_i * (1 - alpha_i) / dt_safe # Total effective modulus E_eff = E_inf + jnp.sum(beta_i) # Solve for strain increment stress_from_prev = jnp.sum(alpha_i * sig_i_prev) d_eps = (sigma_0 - stress_from_prev) / E_eff eps_new = eps_prev + d_eps # Update internal stresses sig_i_new = alpha_i * sig_i_prev + beta_i * d_eps return (eps_new, sig_i_new), eps_new # Initialize eps_init = 0.0 sig_i_init = jnp.zeros(n_modes) # Scan over time steps _, epsilon = jax.lax.scan(update_step, (eps_init, sig_i_init), (t, dt)) # Compute compliance J_t = epsilon / sigma_0 return J_t def _predict_creep(self, t: np.ndarray | jnp_typing.ndarray) -> np.ndarray: """Predict creep compliance J(t). Args: t: Time array Returns: Creep compliance array """ symbol = "E" if self._modulus_type == "tensile" else "G" # Extract parameters E_inf = self.parameters.get_value(f"{symbol}_inf") E_i = jnp.array( [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)] ) tau_i = jnp.array( [self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)] ) # Convert input to JAX array t_jax = jnp.asarray(t) # Call JIT-compiled prediction J_t = self._predict_creep_jit(t_jax, E_inf, E_i, tau_i, sigma_0=1.0) return np.asarray(J_t) def _extract_nlsq_diagnostics(self, nlsq_result) -> dict: """Extract diagnostics from NLSQ OptimizationResult. Args: nlsq_result: OptimizationResult from nlsq_optimize() Returns: Dictionary with diagnostic metrics """ # Extract convergence flag convergence_flag = nlsq_result.success # Extract gradient norm (optimality metric) gradient_norm = ( nlsq_result.optimality if nlsq_result.optimality is not None else np.inf ) # Estimate Hessian condition number from Jacobian # For least-squares: Hessian ≈ J^T J if nlsq_result.jac is not None: jac = np.asarray(nlsq_result.jac) # Compute approximate Hessian hessian_approx = jac.T @ jac # Compute condition number (ratio of largest/smallest singular values) try: cond_number = np.linalg.cond(hessian_approx) except np.linalg.LinAlgError: cond_number = np.inf else: cond_number = np.inf # Estimate parameter uncertainties from diagonal of covariance matrix # Cov ≈ inv(J^T J) if well-conditioned param_uncertainties = {} symbol = "E" if self._modulus_type == "tensile" else "G" if nlsq_result.jac is not None and cond_number < 1e10: try: # Compute covariance matrix cov_matrix = np.linalg.inv(hessian_approx) std_devs = np.sqrt(np.abs(np.diag(cov_matrix))) # Map to parameter names param_names = [f"{symbol}_inf"] param_names += [f"{symbol}_{i+1}" for i in range(self._n_modes)] param_names += [f"tau_{i+1}" for i in range(self._n_modes)] for i, name in enumerate(param_names): if i < len(std_devs): param_uncertainties[name] = float(std_devs[i]) except (np.linalg.LinAlgError, ValueError): # Covariance matrix computation failed pass # Check proximity to bounds params_near_bounds = {} for param_name in self.parameters.keys(): value = self.parameters.get_value(param_name) assert value is not None param = self.parameters.get(param_name) assert param is not None bounds = param.bounds assert bounds is not None lower, upper = bounds # Check if within 10% of bounds bound_range = upper - lower if abs(value - lower) < 0.1 * bound_range: params_near_bounds[param_name] = "lower" elif abs(value - upper) < 0.1 * bound_range: params_near_bounds[param_name] = "upper" return { "convergence_flag": convergence_flag, "gradient_norm": gradient_norm, "hessian_condition": cond_number, "param_uncertainties": param_uncertainties, "params_near_bounds": params_near_bounds, } def _classify_nlsq_convergence(self, diagnostics: dict) -> str: """Classify NLSQ convergence quality. Args: diagnostics: Dictionary from _extract_nlsq_diagnostics() Returns: Classification: 'hard_failure', 'suspicious', or 'good' """ # Hard failure conditions if not diagnostics["convergence_flag"]: return "hard_failure" # GMM-specific: High Hessian condition and params near bounds are often acceptable # Only classify as suspicious if BOTH conditions are true AND uncertainties are high # Check if any uncertainties are > 100% of parameter value (very unreliable) high_uncertainty_count = 0 for param_name, std_dev in diagnostics["param_uncertainties"].items(): value = self.parameters.get_value(param_name) assert value is not None if abs(value) > 1e-12 and std_dev / abs(value) > 1.0: high_uncertainty_count += 1 # Suspicious if: (high condition OR many params near bounds) AND high uncertainties if ( high_uncertainty_count > self._n_modes ): # More than half the parameters are highly uncertain if ( diagnostics["hessian_condition"] > 1e10 or len(diagnostics["params_near_bounds"]) > self._n_modes ): return "suspicious" # Good convergence if optimizer says so return "good" def _construct_bayesian_priors( self, classification: str, prior_mode: str = "warn", allow_fallback_priors: bool = False, ) -> dict: """Construct Bayesian priors based on NLSQ convergence classification. Args: classification: 'hard_failure', 'suspicious', or 'good' prior_mode: 'strict', 'warn', or 'auto_widen' allow_fallback_priors: Enable generic priors on hard failure Returns: Dictionary of priors for NumPyro: {param_name: {'mean': float, 'std': float}} Raises: ValueError: If hard failure and prior_mode='strict' or allow_fallback_priors=False """ priors = {} if classification == "hard_failure": # Hard failure: raise error or use fallback priors if prior_mode == "strict" or not allow_fallback_priors: raise ValueError( "NLSQ optimization failed or did not converge properly. " "Cannot construct reliable priors from failed fit. " "Please:\n" " 1. Check model suitability for your data\n" " 2. Adjust initial values or bounds\n" " 3. Increase max_iter if optimization terminated early\n" " 4. Provide manual priors via fit_bayesian(priors={...})\n" " 5. Set allow_fallback_priors=True for generic weakly informative priors (not recommended)" ) # Fallback: generic weakly informative priors warnings.warn( "WARNING: NLSQ optimization failed. Using generic weakly informative priors. " "Results may not be reliable. Consider manual prior specification.", UserWarning, stacklevel=2, ) # Use parameter bounds as guides for generic priors for param_name in self.parameters.keys(): param = self.parameters.get(param_name) assert param is not None bounds = param.bounds assert bounds is not None lower, upper = bounds mean = (lower + upper) / 2 std = (upper - lower) / 4 # Wide prior covering ~95% of bounds priors[param_name] = {"mean": mean, "std": std} elif classification == "suspicious": # Suspicious: use safer priors, optionally widen if prior_mode == "auto_widen": warnings.warn( "Suspicious NLSQ convergence detected (high Hessian condition, params near bounds, or high uncertainty). " "Using inflated priors centered at NLSQ estimates.", UserWarning, stacklevel=2, ) # Center at NLSQ, inflate std for param_name in self.parameters.keys(): value = self.parameters.get_value(param_name) assert value is not None param = self.parameters.get(param_name) assert param is not None bounds = param.bounds assert bounds is not None lower, upper = bounds # Inflate std to 50% of estimate or 10% of bounds, whichever is larger std_from_estimate = 0.5 * abs(value) std_from_bounds = 0.1 * (upper - lower) std = max(std_from_estimate, std_from_bounds) priors[param_name] = {"mean": value, "std": std} else: # Warn mode: decouple from Hessian, use wider priors logger.warning( "Suspicious NLSQ convergence. Using safer priors decoupled from Hessian." ) for param_name in self.parameters.keys(): value = self.parameters.get_value(param_name) assert value is not None param = self.parameters.get(param_name) assert param is not None bounds = param.bounds assert bounds is not None lower, upper = bounds # Use 20% of bounds range as std std = 0.2 * (upper - lower) priors[param_name] = {"mean": value, "std": std} else: # Good convergence # Use NLSQ estimates and covariance for prior construction diagnostics = self._extract_nlsq_diagnostics(self._nlsq_result) for param_name in self.parameters.keys(): value = self.parameters.get_value(param_name) assert value is not None # Get uncertainty from Hessian if available if param_name in diagnostics["param_uncertainties"]: std = diagnostics["param_uncertainties"][param_name] # Cap minimum std to avoid delta-like distributions min_std = 0.01 * abs(value) if abs(value) > 1e-12 else 1e-6 std = max(std, min_std) else: # Fallback: use 5% of parameter value or 5% of bounds param = self.parameters.get(param_name) assert param is not None bounds = param.bounds assert bounds is not None lower, upper = bounds std = max(0.05 * abs(value), 0.05 * (upper - lower)) priors[param_name] = {"mean": value, "std": std} return priors
[docs] def get_relaxation_spectrum(self) -> dict: """Get discrete relaxation spectrum (E_i, τ_i). Returns: Dictionary with 'E_inf', 'E_i', 'tau_i' """ symbol = "E" if self._modulus_type == "tensile" else "G" E_inf = self.parameters.get_value(f"{symbol}_inf") E_i = np.array( [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)] ) tau_i = np.array( [self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)] ) return {f"{symbol}_inf": E_inf, f"{symbol}_i": E_i, "tau_i": tau_i}
[docs] def get_element_minimization_diagnostics(self) -> dict | None: """Get element minimization diagnostics. Returns: Dictionary with .n_initial., .r2., .n_modes., .n_optimal., .optimization_factor. or None if not run """ return self._element_minimization_diagnostics
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference with NumPyro NUTS. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes GMM predictions given input X and a parameter array. Args: X: Independent variable (time or frequency) params: Array of parameter values [E_inf, E_1, ..., E_N, tau_1, ..., tau_N] Length: 1 + 2*n_modes Returns: Model predictions as JAX array Note: Uses self._test_mode (set during fit()) to route to appropriate prediction method. For oscillation mode, returns complex modulus [G', G"] with shape (M, 2). """ # Extract parameters from array E_inf = params[0] E_i = params[1 : 1 + self._n_modes] tau_i = params[1 + self._n_modes :] # Use stored test mode from last fit if test_mode is None: test_mode = getattr(self, "_test_mode", "relaxation") # Route to appropriate prediction method if test_mode == "relaxation": return self._predict_relaxation_jit(jnp.asarray(X), E_inf, E_i, tau_i) elif test_mode == "oscillation": # _predict_oscillation_jit returns (2, M); transpose to (M, 2) E_star = self._predict_oscillation_jit(jnp.asarray(X), E_inf, E_i, tau_i) return E_star.T elif test_mode == "creep": return self._predict_creep_jit( jnp.asarray(X), E_inf, E_i, tau_i, sigma_0=1.0 ) elif test_mode == "steady_shear": return self._predict_steady_shear_jit(E_inf, E_i, tau_i) elif test_mode == "startup": gamma_dot = kwargs.get( "gamma_dot", getattr(self, "_startup_gamma_dot", 1.0) ) return self._predict_startup_jit( jnp.asarray(X), E_inf, E_i, tau_i, gamma_dot ) elif test_mode == "laos": omega = kwargs.get("omega", getattr(self, "_laos_omega", 1.0)) gamma_0 = kwargs.get("gamma_0", getattr(self, "_laos_gamma_0", 0.01)) return self._predict_laos_jit( jnp.asarray(X), E_inf, E_i, tau_i, omega, gamma_0 ) else: raise ValueError(f"Unsupported test mode: {test_mode}")
# ========================================================================= # Steady-State Flow Protocol # ========================================================================= def _fit_steady_shear_mode( self, gamma_dot: np.ndarray, eta: np.ndarray, optimization_factor: float | None = None, **kwargs, ) -> None: """Fit GMM to steady-shear viscosity data. For a linear viscoelastic model, steady-state viscosity is constant: η₀ = Σᵢ Gᵢτᵢ (zero-shear viscosity) Since GMM is linear, it cannot capture shear-thinning. This fit finds parameters that best match the given viscosity data by using the zero-shear viscosity relationship. Args: gamma_dot: Shear rate array (1/s) eta: Viscosity array (Pa.s) optimization_factor: Not used (no element minimization for steady-shear) **kwargs: NLSQ optimizer arguments """ # For linear viscoelastic model, η = η₀ = Σᵢ Gᵢτᵢ (constant) # Fit by matching average viscosity eta_avg = np.mean(eta) symbol = "G" if self._modulus_type == "shear" else "E" # Initialize with simple estimate: distribute η₀ across modes eta_per_mode = eta_avg / self._n_modes tau_i_guess = np.logspace(-2, 2, self._n_modes) G_i_guess = eta_per_mode / tau_i_guess # Set parameters self.parameters.set_value( f"{symbol}_inf", 0.0 ) # No equilibrium modulus for flow for i in range(self._n_modes): self.parameters.set_value(f"{symbol}_{i+1}", float(G_i_guess[i])) self.parameters.set_value(f"tau_{i+1}", float(tau_i_guess[i])) logger.info( "GMM fitted to steady-shear mode", eta_0=eta_avg, note="Linear model gives constant viscosity η₀=ΣGᵢτᵢ", ) @staticmethod @jax.jit def _predict_steady_shear_jit( E_inf: float, E_i: jnp_typing.ndarray, tau_i: jnp_typing.ndarray, ) -> jnp_typing.ndarray: """JIT-compiled zero-shear viscosity calculation. η₀ = Σᵢ Gᵢτᵢ """ eta_0 = jnp.sum(E_i * tau_i) return eta_0 def _predict_steady_shear(self, gamma_dot: np.ndarray) -> np.ndarray: """Predict steady-shear viscosity (constant for linear model). Args: gamma_dot: Shear rate array (ignored for linear model) Returns: Viscosity array (constant η₀ for all shear rates) """ symbol = "G" if self._modulus_type == "shear" else "E" E_inf = self.parameters.get_value(f"{symbol}_inf") E_i = jnp.array( [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)] ) tau_i = jnp.array( [self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)] ) eta_0 = self._predict_steady_shear_jit(E_inf, E_i, tau_i) # Return constant viscosity for all shear rates # Use jnp.full_like to avoid explicit float() conversion (JIT blocker) return jnp.full_like(jnp.asarray(gamma_dot), eta_0) # ========================================================================= # Startup Flow Protocol # ========================================================================= def _fit_startup_mode( self, t: np.ndarray, eta_plus: np.ndarray, optimization_factor: float | None = 1.5, gamma_dot: float = 1.0, **kwargs, ) -> None: """Fit GMM to startup flow (stress growth) data. The stress growth coefficient η⁺(t) = σ(t)/γ̇ for step shear rate. Args: t: Time array (s) eta_plus: Stress growth coefficient η⁺(t) = σ(t)/γ̇ (Pa.s) optimization_factor: R² threshold for element minimization gamma_dot: Applied shear rate (1/s) - stored for predictions **kwargs: NLSQ optimizer arguments """ # Store gamma_dot for predictions self._startup_gamma_dot = gamma_dot # OPT-YDATA-001: stash y_data so _nlsq_fit attaches it for r_squared. self._current_y_data = np.asarray(eta_plus) # Extract kwargs max_iter = kwargs.get("max_iter", 1000) ftol = kwargs.get("ftol", 1e-6) xtol = kwargs.get("xtol", 1e-6) gtol = kwargs.get("gtol", 1e-6) symbol = "G" if self._modulus_type == "shear" else "E" # Define objective function def objective(params): """Residual for startup flow.""" E_inf = params[0] E_i = params[1 : 1 + self._n_modes] tau_i = params[1 + self._n_modes :] eta_plus_pred = self._predict_startup_jit( jnp.asarray(t), E_inf, E_i, tau_i, gamma_dot ) return eta_plus_pred - eta_plus # Initial guess from relaxation behavior # Use initial_params if provided (for warm-start in element minimization) initial_params = kwargs.get("initial_params", None) if initial_params is not None and len(initial_params) == 1 + 2 * self._n_modes: x0 = jnp.asarray(initial_params) else: eta_inf = np.max(eta_plus) # Long-time viscosity E_i_guess = jnp.full(self._n_modes, eta_inf / self._n_modes / 1.0) tau_i_guess = jnp.logspace(-2, 2, self._n_modes) x0 = jnp.concatenate([jnp.array([0.0]), E_i_guess, tau_i_guess]) # Bounds bounds_lower = jnp.concatenate( [ jnp.array([0.0]), jnp.full(self._n_modes, 1e-12), jnp.full(self._n_modes, 1e-6), ] ) bounds_upper = jnp.concatenate( [ jnp.array([np.max(eta_plus) * 10]), jnp.full(self._n_modes, np.max(eta_plus) * 10), jnp.full(self._n_modes, 1e6), ] ) result = self._nlsq_fit( objective, x0, bounds=(bounds_lower, bounds_upper), max_nfev=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, ) # Set parameters (batch update for 5-10% speedup) params_opt = result.x param_values = {f"{symbol}_inf": float(params_opt[0])} param_values.update( {f"{symbol}_{i+1}": float(params_opt[1 + i]) for i in range(self._n_modes)} ) param_values.update( { f"tau_{i+1}": float(params_opt[1 + self._n_modes + i]) for i in range(self._n_modes) } ) self.parameters.set_values(param_values) self._nlsq_result = result # Element minimization if optimization_factor is not None and self._n_modes > 1: self._apply_element_minimization(t, eta_plus, optimization_factor, **kwargs) @staticmethod @jax.jit def _predict_startup_jit( t: jnp_typing.ndarray, E_inf: float, E_i: jnp_typing.ndarray, tau_i: jnp_typing.ndarray, gamma_dot: float, ) -> jnp_typing.ndarray: """JIT-compiled startup flow prediction. Stress growth coefficient: η⁺(t) = σ(t)/γ̇ For Maxwell element: ηᵢ⁺(t) = Gᵢτᵢ(1 - exp(-t/τᵢ)) Total: η⁺(t) = Σᵢ Gᵢτᵢ(1 - exp(-t/τᵢ)) """ # Each mode contribution: Gᵢτᵢ(1 - exp(-t/τᵢ)) eta_plus = jnp.sum( E_i[:, None] * tau_i[:, None] * (1 - jnp.exp(-t[None, :] / tau_i[:, None])), axis=0, ) return eta_plus def _predict_startup(self, t: np.ndarray) -> np.ndarray: """Predict stress growth coefficient η⁺(t). Args: t: Time array (s) Returns: Stress growth coefficient η⁺(t) (Pa.s) """ symbol = "G" if self._modulus_type == "shear" else "E" gamma_dot = getattr(self, "_startup_gamma_dot", 1.0) E_inf = self.parameters.get_value(f"{symbol}_inf") E_i = jnp.array( [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)] ) tau_i = jnp.array( [self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)] ) eta_plus = self._predict_startup_jit( jnp.asarray(t), E_inf, E_i, tau_i, gamma_dot ) return np.asarray(eta_plus) # ========================================================================= # LAOS Protocol # ========================================================================= def _fit_laos_mode( self, omega: np.ndarray, G_star: np.ndarray, optimization_factor: float | None = 1.5, gamma_0: float = 0.01, **kwargs, ) -> None: """Fit GMM to LAOS data. For a linear viscoelastic model, LAOS = SAOS (no nonlinear harmonics). This delegates to oscillation fitting. Args: omega: Angular frequency array (rad/s) G_star: Complex modulus [G', G''] - same format as oscillation optimization_factor: R² threshold for element minimization gamma_0: Strain amplitude (stored for predictions) **kwargs: NLSQ optimizer arguments """ # Store LAOS parameters self._laos_omega = omega[0] if len(omega) > 0 else 1.0 self._laos_gamma_0 = gamma_0 # For linear model, LAOS = SAOS logger.info( "GMM LAOS mode: Linear model gives SAOS response (no nonlinear harmonics)" ) self._fit_oscillation_mode(omega, G_star, optimization_factor, **kwargs) @staticmethod @jax.jit def _predict_laos_jit( t: jnp_typing.ndarray, E_inf: float, E_i: jnp_typing.ndarray, tau_i: jnp_typing.ndarray, omega: float, gamma_0: float, ) -> jnp_typing.ndarray: """JIT-compiled LAOS stress prediction. For linear viscoelastic model: γ(t) = γ₀ sin(ωt) σ(t) = G'γ₀ sin(ωt) + G''γ₀ cos(ωt) Returns stress(t) array. """ # Compute G' and G'' at this frequency omega_tau = omega * tau_i omega_tau_sq = omega_tau**2 G_prime = E_inf + jnp.sum(E_i * omega_tau_sq / (1 + omega_tau_sq)) G_double_prime = jnp.sum(E_i * omega_tau / (1 + omega_tau_sq)) # Linear response: σ(t) = G'γ₀ sin(ωt) + G''γ₀ cos(ωt) stress = G_prime * gamma_0 * jnp.sin( omega * t ) + G_double_prime * gamma_0 * jnp.cos(omega * t) return stress def _predict_laos(self, t: np.ndarray) -> np.ndarray: """Predict LAOS stress response. For linear model, returns sinusoidal stress (no higher harmonics). Args: t: Time array (s) Returns: Stress response σ(t) (Pa) """ symbol = "G" if self._modulus_type == "shear" else "E" omega = getattr(self, "_laos_omega", 1.0) gamma_0 = getattr(self, "_laos_gamma_0", 0.01) E_inf = self.parameters.get_value(f"{symbol}_inf") E_i = jnp.array( [self.parameters.get_value(f"{symbol}_{i+1}") for i in range(self._n_modes)] ) tau_i = jnp.array( [self.parameters.get_value(f"tau_{i+1}") for i in range(self._n_modes)] ) stress = self._predict_laos_jit( jnp.asarray(t), E_inf, E_i, tau_i, omega, gamma_0 ) return np.asarray(stress)
[docs] def simulate_laos( self, omega: float, gamma_0: float, n_cycles: int = 5, n_points_per_cycle: int = 64, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Simulate LAOS response. Args: omega: Angular frequency (rad/s) gamma_0: Strain amplitude n_cycles: Number of oscillation cycles n_points_per_cycle: Points per cycle Returns: t: Time array strain: Strain array stress: Stress array """ # Store for predictions self._laos_omega = omega self._laos_gamma_0 = gamma_0 # Generate time array period = 2 * np.pi / omega t = np.linspace(0, n_cycles * period, n_cycles * n_points_per_cycle) # Strain strain = gamma_0 * np.sin(omega * t) # Stress (linear response) stress = self._predict_laos(t) return t, strain, stress
[docs] def extract_harmonics( self, stress: np.ndarray, n_points_per_cycle: int ) -> dict[str, float]: """Extract Fourier harmonics from LAOS stress signal. For linear model, only I_1 is non-zero. Args: stress: Stress signal from last cycle n_points_per_cycle: Points per cycle Returns: Dictionary with I_1, I_3, I_3_I_1 (I_3/I_1 ratio) """ # Use last cycle for analysis last_cycle = stress[-n_points_per_cycle:] # FFT fft_result = jnp.fft.fft(last_cycle) magnitudes = jnp.abs(fft_result) / (n_points_per_cycle / 2) # Extract harmonics I_1 = float(magnitudes[1]) # Fundamental I_3 = float(magnitudes[3]) if len(magnitudes) > 3 else 0.0 # Third harmonic return { "I_1": I_1, "I_3": I_3, "I_3_I_1": I_3 / I_1 if I_1 > 1e-12 else 0.0, }