Source code for rheojax.models.fractional.fractional_zener_ss

"""Fractional Zener Solid-Solid (FZSS) Model.

This model combines two springs and one SpringPot, providing both
instantaneous and equilibrium elasticity with fractional relaxation.

Theory
------
The FZSS model consists of:
- Spring (G_e) in parallel with
- Series combination of spring (G_m) and SpringPot

Relaxation modulus:
    G(t) = G_e + G_m * E_α(-(t/τ_α)^α)

Complex modulus:
    G*(ω) = G_e + G_m / (1 + (iωτ_α)^(-α))

where E_α is the one-parameter Mittag-Leffler function.

Parameters
----------
Ge : float
    Equilibrium modulus (Pa), bounds [1e-3, 1e9]
Gm : float
    Maxwell arm modulus (Pa), bounds [1e-3, 1e9]
alpha : float
    Fractional order, bounds [0.0, 1.0]
tau_alpha : float
    Relaxation time (s^α), bounds [1e-6, 1e6]

Limit Cases
-----------
- alpha → 0: Two springs in parallel (G = G_e + G_m)
- alpha → 1: Classical Zener solid with G(t) = G_e + G_m*exp(-t/τ)

References
----------
- Mainardi, F. (2010). Fractional Calculus and Waves in Linear Viscoelasticity
- Schiessel, H., et al. (1995). J. Phys. A: Math. Gen. 28, 6567
"""

from __future__ import annotations

from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger, log_fit
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS

jax, jnp = safe_import_jax()


from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.utils.compatibility import format_compatibility_message
from rheojax.utils.mittag_leffler import mittag_leffler_e

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "fractional_zener_ss", protocols=[ Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FractionalZenerSolidSolid(BaseModel): """Fractional Zener Solid-Solid model. A fractional viscoelastic model with both instantaneous and equilibrium elasticity. Test Modes ---------- - Relaxation: Supported - Creep: Supported - Oscillation: Supported - Rotation: Not supported (no steady-state flow) Examples -------- >>> import jax.numpy as jnp >>> from rheojax.models import FractionalZenerSolidSolid >>> >>> # Create model >>> model = FractionalZenerSolidSolid() >>> >>> # Set parameters >>> model.set_params(Ge=1000.0, Gm=500.0, alpha=0.5, tau_alpha=1.0) >>> >>> # Predict relaxation modulus >>> t = jnp.logspace(-2, 2, 50) >>> G_t = model.predict(t) """
[docs] def __init__(self): """Initialize Fractional Zener Solid-Solid model.""" super().__init__() # Define parameters with bounds and descriptions self.parameters = ParameterSet() # Upper bounds widened to 1e11 Pa so glassy polymers (E_g ~ 1-10 GPa, # G_g ~ 0.4-4 GPa) and DMTA posterior samples do not violate the Ge/Gm # constraints during set_value(). tau_alpha span widened to ±10^10 s^α # to cover master curves spanning ~20 decades after TTS shifting. self.parameters.add( name="Ge", value=1000.0, bounds=(1e-3, 1e11), units="Pa", description="Equilibrium modulus", ) self.parameters.add( name="Gm", value=1000.0, bounds=(1e-3, 1e11), units="Pa", description="Maxwell arm modulus", ) self.parameters.add( name="alpha", value=0.5, bounds=FRACTIONAL_ORDER_BOUNDS, units="", description="Fractional order", ) self.parameters.add( name="tau_alpha", value=1.0, bounds=(1e-10, 1e10), units="s^α", description="Relaxation time", )
@staticmethod @jax.jit def _predict_relaxation_jax( t: jnp.ndarray, Ge: float, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict relaxation modulus G(t) using JAX. G(t) = G_e + G_m * E_α(-(t/τ_α)^α) """ epsilon = 1e-12 # Clip alpha using JAX operations (tracer-safe) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) tau_alpha_safe = tau_alpha + epsilon # Compute argument: z = -(t/τ_α)^α z = -jnp.power(t / tau_alpha_safe, alpha_safe) # Mittag-Leffler function E_α(z) ml_term = mittag_leffler_e(z, alpha_safe) # G(t) = G_e + G_m * E_α(-(t/τ_α)^α) return Ge + Gm * ml_term def _predict_relaxation( self, t: jnp.ndarray, Ge: float, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict relaxation modulus G(t). Wrapper for JIT-compiled implementation. """ return self._predict_relaxation_jax(t, Ge, Gm, alpha, tau_alpha) @staticmethod @jax.jit def _predict_creep_jax( t: jnp.ndarray, Ge: float, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict creep compliance J(t) using JAX. For FZSS, creep compliance is: J(t) = 1/(G_e + G_m) + (1/G_e - 1/(G_e + G_m)) * (1 - E_α(-(t/τ_α)^α)) """ epsilon = 1e-12 # Clip alpha using JAX operations (tracer-safe) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) tau_alpha_safe = tau_alpha + epsilon # Instantaneous and equilibrium compliances G_total = Ge + Gm + epsilon J_inst = 1.0 / G_total J_eq = 1.0 / (Ge + epsilon) # Compute argument: z = -(t/τ_α)^α z = -jnp.power(t / tau_alpha_safe, alpha_safe) # Mittag-Leffler function ml_term = mittag_leffler_e(z, alpha_safe) # J(t) = J_inst + (J_eq - J_inst) * (1 - E_α(-t^α/τ_α)) return J_inst + (J_eq - J_inst) * (1.0 - ml_term) def _predict_creep( self, t: jnp.ndarray, Ge: float, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict creep compliance J(t). Wrapper for JIT-compiled implementation. """ return self._predict_creep_jax(t, Ge, Gm, alpha, tau_alpha) @staticmethod @jax.jit def _predict_oscillation_jax( omega: jnp.ndarray, Ge: float, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict complex modulus G*(ω) using JAX. G*(ω) = G_e + G_m / (1 + (iωτ_α)^(-α)) """ epsilon = 1e-12 # Clip alpha using JAX operations (tracer-safe) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) tau_alpha_safe = tau_alpha + epsilon # Compute (iω)^(-α) = ω^(-α) * exp(-i*π*α/2) omega_neg_alpha = jnp.power(omega, -alpha_safe) phase = -jnp.pi * alpha_safe / 2.0 # (iω)^(-α) in complex form i_omega_neg_alpha = omega_neg_alpha * (jnp.cos(phase) + 1j * jnp.sin(phase)) # Denominator: 1 + (iωτ_α)^(-α) = 1 + τ_α^(-α) * (iω)^(-α) tau_neg_alpha = jnp.power(tau_alpha_safe, -alpha_safe) denominator = 1.0 + tau_neg_alpha * i_omega_neg_alpha # Maxwell arm contribution: G_m / (1 + (iωτ_α)^(-α)) maxwell_term = Gm / denominator # Total complex modulus G_star = Ge + maxwell_term # Extract storage and loss moduli G_prime = jnp.real(G_star) G_double_prime = jnp.imag(G_star) return jnp.stack([G_prime, G_double_prime], axis=-1) def _predict_oscillation( self, omega: jnp.ndarray, Ge: float, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict complex modulus G*(ω). Wrapper for JIT-compiled implementation. """ return self._predict_oscillation_jax(omega, Ge, Gm, alpha, tau_alpha) def _fit( self, X: jnp.ndarray, y: jnp.ndarray, **kwargs ) -> FractionalZenerSolidSolid: """Fit model to data using NLSQ TRF optimization. Parameters ---------- X : jnp.ndarray Independent variable (time or frequency) y : jnp.ndarray Dependent variable (modulus or compliance) **kwargs : dict Additional fitting options Returns ------- self Fitted model instance """ from rheojax.core.test_modes import TestMode from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Detect test mode test_mode_str = kwargs.get("test_mode", "relaxation") # Convert string to TestMode enum if isinstance(test_mode_str, str): test_mode_map = { "relaxation": TestMode.RELAXATION, "creep": TestMode.CREEP, "oscillation": TestMode.OSCILLATION, } test_mode = test_mode_map.get(test_mode_str, TestMode.RELAXATION) else: test_mode = test_mode_str # Store test mode for model_function self._test_mode = test_mode # Provide simple data-aware initialization for relaxation fits compatibility_guard = kwargs.pop("compatibility_guard", False) incompat_confidence = kwargs.pop("compatibility_confidence_threshold", 0.65) compatibility_report = None # Determine data shape for logging data_shape = (len(X),) if hasattr(X, "__len__") else None with log_fit( logger, model="FractionalZenerSolidSolid", data_shape=data_shape, test_mode=( test_mode_str if isinstance(test_mode_str, str) else str(test_mode) ), ) as ctx: logger.debug( "Starting FZSS fit", n_points=len(X) if hasattr(X, "__len__") else 1, test_mode=str(test_mode), initial_params=self.parameters.to_dict(), compatibility_guard=compatibility_guard, ) if test_mode == TestMode.RELAXATION: self._initialize_relaxation_parameters(X, y) logger.debug( "Relaxation parameters initialized", initialized_params=self.parameters.to_dict(), ) if compatibility_guard: compatibility_report = _compute_relaxation_compatibility(self, X, y) if compatibility_report and _should_block_relaxation_fit( compatibility_report, incompat_confidence ): message = format_compatibility_message(compatibility_report) logger.error( "Data incompatible with FZSS model", compatibility_report=compatibility_report, ) raise RuntimeError( "Optimization failed: data is incompatible with " "FractionalZenerSolidSolid.\n" f"Model-data compatibility:\n{message}" ) # Smart initialization for oscillation mode (Issue #9) if test_mode == TestMode.OSCILLATION: try: import numpy as np from rheojax.utils.initialization import ( initialize_fractional_zener_ss, ) success = initialize_fractional_zener_ss( np.array(X), np.array(y), self.parameters ) if success: logger.debug( "Smart initialization applied from frequency-domain features", initialized_params=self.parameters.to_dict(), ) except Exception as e: logger.debug( "Smart initialization failed, using defaults", error=str(e), ) # Create stateless model function for optimization def model_fn(x, params): """Model function for optimization (stateless).""" Ge, Gm, alpha, tau_alpha = params[0], params[1], params[2], params[3] # Direct prediction based on test mode (stateless) if test_mode == TestMode.RELAXATION: return self._predict_relaxation(x, Ge, Gm, alpha, tau_alpha) elif test_mode == TestMode.CREEP: return self._predict_creep(x, Ge, Gm, alpha, tau_alpha) elif test_mode == TestMode.OSCILLATION: return self._predict_oscillation(x, Ge, Gm, alpha, tau_alpha) else: raise ValueError(f"Unsupported test mode: {test_mode}") # Create objective function. Honor ``use_log_residuals`` from # kwargs so the auto-detection in ``BaseModel._detect_optimization_strategy`` # reaches the NLSQ residual builder for wide-range data # (>8 decades in ω). Matches the FractionalMaxwellLiquid pattern. use_log_residuals = kwargs.get("use_log_residuals", False) logger.debug( "Creating least squares objective", normalize=True, use_log_residuals=use_log_residuals, ) objective = create_least_squares_objective( model_fn, jnp.array(X), jnp.array(y), normalize=True, use_log_residuals=use_log_residuals, ) # Optimize using NLSQ TRF logger.debug( "Starting NLSQ optimization", method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) try: result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) except Exception as e: logger.error( "NLSQ optimization raised exception", error_type=type(e).__name__, error=str(e), exc_info=True, ) raise # Validate optimization succeeded if not result.success: logger.error( "Optimization failed", message=result.message, final_params=self.parameters.to_dict(), ) raise RuntimeError( f"Optimization failed: {result.message}. " f"Try adjusting initial values, bounds, or max_iter." ) # Detect incompatible relaxation data even when optimization converges if test_mode == TestMode.RELAXATION and compatibility_guard: if compatibility_report is None: compatibility_report = _compute_relaxation_compatibility(self, X, y) if compatibility_report and _should_block_relaxation_fit( compatibility_report, incompat_confidence ): message = format_compatibility_message(compatibility_report) logger.error( "Post-fit compatibility check failed", compatibility_report=compatibility_report, ) raise RuntimeError( "Optimization failed: data is incompatible with " "FractionalZenerSolidSolid.\n" f"Model-data compatibility:\n{message}" ) self.fitted_ = True ctx["final_params"] = self.parameters.to_dict() ctx["success"] = True logger.debug( "FZSS fit completed successfully", final_params=self.parameters.to_dict(), ) return self def _initialize_relaxation_parameters(self, X, y) -> bool: """Derive heuristic starting values from relaxation data.""" import logging import numpy as np try: t = np.asarray(X, dtype=float).ravel() g = np.asarray(y, dtype=float).ravel() if t.shape != g.shape or t.size < 4: return False order = np.argsort(t) t_sorted = t[order] g_sorted = g[order] tail = max(3, t_sorted.size // 6) head = max(3, t_sorted.size // 6) ge_guess = float(np.median(g_sorted[-tail:])) gm_guess = float(np.median(g_sorted[:head]) - ge_guess) gm_guess = max(gm_guess, 1e-3) ge_param = self.parameters.get("Ge") gm_param = self.parameters.get("Gm") tau_param = self.parameters.get("tau_alpha") alpha_param = self.parameters.get("alpha") assert ge_param is not None and ge_param.bounds is not None assert gm_param is not None and gm_param.bounds is not None assert tau_param is not None and tau_param.bounds is not None assert alpha_param is not None and alpha_param.bounds is not None ge_bounds = ge_param.bounds gm_bounds = gm_param.bounds tau_bounds = tau_param.bounds alpha_bounds = alpha_param.bounds ge_guess = float(np.clip(ge_guess, ge_bounds[0], ge_bounds[1])) gm_guess = float(np.clip(gm_guess, gm_bounds[0], gm_bounds[1])) normalized = np.clip((g_sorted - ge_guess) / (gm_guess + 1e-9), 0.0, 1.0) target = np.exp(-1.0) idx = int(np.argmin(np.abs(normalized - target))) tau_guess = float(np.clip(t_sorted[idx], tau_bounds[0], tau_bounds[1])) alpha_guess = float(np.clip(0.6, alpha_bounds[0], alpha_bounds[1])) self.parameters.set_value("Ge", ge_guess) self.parameters.set_value("Gm", gm_guess) self.parameters.set_value("tau_alpha", tau_guess) self.parameters.set_value("alpha", alpha_guess) logging.debug( "FZSS relaxation init | Ge=%.3g Gm=%.3g tau_alpha=%.3g alpha=%.2f", ge_guess, gm_guess, tau_guess, alpha_guess, ) return True except Exception as exc: # pragma: no cover - fallback only logging.debug(f"Relaxation initialization failed: {exc}") return False def _predict(self, X: jnp.ndarray, **kwargs) -> jnp.ndarray: """Predict response for given input. Parameters ---------- X : jnp.ndarray Independent variable Returns ------- jnp.ndarray Predicted values """ # Get parameter values Ge = self.parameters.get_value("Ge") Gm = self.parameters.get_value("Gm") alpha = self.parameters.get_value("alpha") tau_alpha = self.parameters.get_value("tau_alpha") assert ( Ge is not None and Gm is not None and alpha is not None and tau_alpha is not None ) # Dispatch based on test_mode _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None) ) if test_mode in ("oscillation",): return self._predict_oscillation(X, Ge, Gm, alpha, tau_alpha) elif test_mode in ("creep",): return self._predict_creep(X, Ge, Gm, alpha, tau_alpha) else: # Default to relaxation return self._predict_relaxation(X, Ge, Gm, alpha, tau_alpha)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes predictions given input X and a parameter array. Args: X: Independent variable (time or frequency) params: Array of parameter values [Ge, Gm, alpha, tau_alpha] Returns: Model predictions as JAX array """ # Extract parameters from array (in order they were added to ParameterSet) Ge = params[0] Gm = params[1] alpha = params[2] tau_alpha = params[3] # Use test_mode from last fit if available, otherwise default to relaxation if test_mode is None: test_mode = getattr(self, "_test_mode", "relaxation") # Normalize test_mode to handle both string and TestMode enum if hasattr(test_mode, "value"): test_mode = test_mode.value # Call appropriate prediction function based on test mode if test_mode == "relaxation": return self._predict_relaxation(X, Ge, Gm, alpha, tau_alpha) elif test_mode == "creep": return self._predict_creep(X, Ge, Gm, alpha, tau_alpha) elif test_mode == "oscillation": # _predict_oscillation returns stacked (n, 2) for NLSQ fitting, # but Bayesian inference needs complex array for gradient compatibility stacked = self._predict_oscillation(X, Ge, Gm, alpha, tau_alpha) # Convert [G', G''] stacked array to complex G* = G' + i*G'' return stacked[:, 0] + 1j * stacked[:, 1] else: # Default to relaxation mode for FZSS model return self._predict_relaxation(X, Ge, Gm, alpha, tau_alpha)
def _should_block_relaxation_fit(compat: dict, minimum_confidence: float) -> bool: """Return True when compatibility analysis flags obvious mismatches.""" if compat.get("compatible", True): return False if compat.get("confidence", 0.0) < minimum_confidence: return False try: from rheojax.utils.compatibility import DecayType, MaterialType except Exception: # pragma: no cover - defensive guard return False decay_type = compat.get("decay_type") material_type = compat.get("material_type") return ( decay_type == DecayType.EXPONENTIAL or material_type == MaterialType.VISCOELASTIC_LIQUID ) def _compute_relaxation_compatibility(model, X, y) -> dict | None: """Best-effort compatibility evaluation for relaxation data.""" try: import numpy as np from rheojax.utils.compatibility import check_model_compatibility return check_model_compatibility( model, t=np.asarray(X, dtype=float), G_t=np.asarray(y, dtype=float), test_mode="relaxation", ) except Exception: # pragma: no cover - compatibility is heuristic return None # Convenience alias FZSS = FractionalZenerSolidSolid __all__ = ["FractionalZenerSolidSolid", "FZSS"]