"""Fractional Zener Solid-Liquid (FZSL) Model.
This model combines a fractional Maxwell element (SpringPot + dashpot) in parallel
with a spring, providing both equilibrium elasticity and fractional relaxation behavior.
Theory
------
The FZSL model consists of:
- Spring (G_e) in parallel with
- Fractional Maxwell element (SpringPot c_alpha + dashpot eta in series)
Relaxation modulus:
G(t) = G_e + c_α * t^(-α) * E_{1-α,1}(-(t/τ)^(1-α))
Complex modulus:
G*(ω) = G_e + c_α * (iω)^α / (1 + iωτ)
where E_{α,β} is the two-parameter Mittag-Leffler function.
Parameters
----------
Ge : float
Equilibrium modulus (Pa), bounds [1e-3, 1e9]
c_alpha : float
SpringPot constant (Pa·s^α), bounds [1e-3, 1e9]
alpha : float
Fractional order, bounds [0.0, 1.0]
tau : float
Relaxation time (s), bounds [1e-6, 1e6]
Limit Cases
-----------
- alpha → 0: Purely elastic behavior (spring only)
- alpha → 1: Classical Zener solid (two springs and one dashpot)
References
----------
- Mainardi, F. (2010). Fractional Calculus and Waves in Linear Viscoelasticity
- Schiessel, H., et al. (1995). J. Phys. A: Math. Gen. 28, 6567
"""
from __future__ import annotations
from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger, log_fit
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS
jax, jnp = safe_import_jax()
from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.utils.mittag_leffler import mittag_leffler_e2
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"fractional_zener_sl",
protocols=[
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.OSCILLATION,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class FractionalZenerSolidLiquid(BaseModel):
"""Fractional Zener Solid-Liquid model.
A fractional viscoelastic model combining equilibrium elasticity
with fractional relaxation behavior.
Test Modes
----------
- Relaxation: Supported
- Creep: Supported (via numerical inversion)
- Oscillation: Supported
- Rotation: Not supported (no steady-state flow)
Examples
--------
>>> import jax.numpy as jnp
>>> from rheojax.models import FractionalZenerSolidLiquid
>>>
>>> # Create model
>>> model = FractionalZenerSolidLiquid()
>>>
>>> # Set parameters
>>> model.set_params(Ge=1000.0, c_alpha=500.0, alpha=0.5, tau=1.0)
>>>
>>> # Predict relaxation modulus
>>> t = jnp.logspace(-2, 2, 50)
>>> G_t = model.predict(t) # Relaxation mode
>>>
>>> # Predict complex modulus
>>> omega = jnp.logspace(-2, 2, 50)
>>> G_star = model.predict(omega) # Oscillation mode
"""
[docs]
def __init__(self):
"""Initialize Fractional Zener Solid-Liquid model."""
super().__init__()
# Define parameters with bounds and descriptions
self.parameters = ParameterSet()
self.parameters.add(
name="Ge",
value=1000.0,
bounds=(1e-3, 1e9),
units="Pa",
description="Equilibrium modulus",
)
self.parameters.add(
name="c_alpha",
value=500.0,
bounds=(1e-3, 1e9),
units="Pa·s^α",
description="SpringPot constant",
)
self.parameters.add(
name="alpha",
value=0.5,
bounds=FRACTIONAL_ORDER_BOUNDS,
units="",
description="Fractional order",
)
self.parameters.add(
name="tau",
value=1.0,
bounds=(1e-6, 1e6),
units="s",
description="Relaxation time",
)
@staticmethod
@jax.jit
def _predict_relaxation(
t: jnp.ndarray,
Ge: float,
c_alpha: float,
alpha: float,
tau: float,
) -> jnp.ndarray:
"""Predict relaxation modulus G(t).
G(t) = G_e + c_α * t^(-α) * E_{1-α,1}(-(t/τ)^(1-α))
Parameters
----------
t : jnp.ndarray
Time array (s)
Ge : float
Equilibrium modulus (Pa)
c_alpha : float
SpringPot constant (Pa·s^α)
alpha : float
Fractional order
tau : float
Relaxation time (s)
Returns
-------
jnp.ndarray
Relaxation modulus G(t) (Pa)
"""
# Add small epsilon to prevent issues
epsilon = 1e-12
# Clip alpha to safe range (works with JAX tracers)
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
# Parameters for two-parameter Mittag-Leffler: E_{1-α,1}
ml_alpha = 1.0 - alpha_safe
ml_beta = 1.0
tau_safe = tau + epsilon
# P2-FRAC-002: Guard t=0 — power(0, -alpha_safe) = +inf when alpha>0.
t_safe = jnp.maximum(t, 1e-30)
# Compute fractional relaxation term
# E_{1-α,1}(-(t/τ)^(1-α))
z = -jnp.power(t_safe / tau_safe, ml_alpha)
# Mittag-Leffler function with concrete alpha/beta
ml_term = mittag_leffler_e2(z, alpha=ml_alpha, beta=ml_beta)
# G(t) = G_e + c_α * t^(-α) * E_{1-α,1}(...)
fractional_term = c_alpha * jnp.power(t_safe, -alpha_safe) * ml_term
G_t = Ge + fractional_term
return G_t
@staticmethod
@jax.jit
def _predict_creep(
t: jnp.ndarray,
Ge: float,
c_alpha: float,
alpha: float,
tau: float,
) -> jnp.ndarray:
"""Predict creep compliance J(t).
Note: Analytical creep compliance for FZSL is complex.
This uses numerical approximation based on inverse relationship.
Parameters
----------
t : jnp.ndarray
Time array (s)
Ge : float
Equilibrium modulus (Pa)
c_alpha : float
SpringPot constant (Pa·s^α)
alpha : float
Fractional order
tau : float
Relaxation time (s)
Returns
-------
jnp.ndarray
Creep compliance J(t) (1/Pa)
"""
# Add small epsilon to prevent issues
epsilon = 1e-12
# Clip alpha to safe range (works with JAX tracers)
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
# For equilibrium: J(∞) = 1/G_e
# Approximate creep using inverse relaxation at long times
J_eq = 1.0 / (Ge + epsilon)
# Short time: dominated by SpringPot
# J(t) ≈ t^α / c_α for small t
J_short = jnp.power(t, alpha_safe) / (c_alpha + epsilon)
# Use smooth, monotonic interpolation
# Sigmoid-based transition to ensure monotonicity
# Map time to sigmoid argument with characteristic scale tau
x = jnp.log10(t / tau + epsilon) / 2.0 # Log-scale transition
sigmoid_weight = 1.0 / (1.0 + jnp.exp(-x))
# Ensure J_short <= J_eq at transition by scaling
J_short_scaled = jnp.minimum(J_short, J_eq * 0.9)
# Monotonic blend: start from J_short, approach J_eq
J_t = J_short_scaled * (1.0 - sigmoid_weight) + J_eq * sigmoid_weight
return J_t
@staticmethod
@jax.jit
def _predict_oscillation(
omega: jnp.ndarray,
Ge: float,
c_alpha: float,
alpha: float,
tau: float,
) -> jnp.ndarray:
"""Predict complex modulus G*(ω).
G*(ω) = G_e + c_α * (iω)^α / (1 + (iωτ)^(1-α))
This is the correct formula for FZSL (spring + FMG in parallel).
Parameters
----------
omega : jnp.ndarray
Angular frequency array (rad/s)
Ge : float
Equilibrium modulus (Pa)
c_alpha : float
SpringPot constant (Pa·s^α)
alpha : float
Fractional order
tau : float
Relaxation time (s)
Returns
-------
jnp.ndarray
Complex modulus array with shape (..., 2) where [:, 0] is G' and [:, 1] is G''
"""
# Add small epsilon to prevent issues
epsilon = 1e-12
# Clip alpha to safe range (works with JAX tracers)
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
# beta = 1 - alpha for this model
beta_safe = 1.0 - alpha_safe
tau_safe = tau + epsilon
omega_safe = jnp.maximum(omega, epsilon)
# Compute (iω)^α = ω^α * exp(i*π*α/2)
omega_alpha = jnp.power(omega_safe, alpha_safe)
phase_alpha = jnp.pi * alpha_safe / 2.0
i_omega_alpha = omega_alpha * (jnp.cos(phase_alpha) + 1j * jnp.sin(phase_alpha))
# Compute (iωτ)^(1-α) = |ωτ|^(1-α) * exp(i*(1-α)*π/2)
omega_tau = omega_safe * tau_safe
omega_tau_beta = jnp.power(omega_tau, beta_safe)
phase_beta = jnp.pi * beta_safe / 2.0
i_omega_tau_beta = omega_tau_beta * (
jnp.cos(phase_beta) + 1j * jnp.sin(phase_beta)
)
# Denominator: 1 + (iωτ)^(1-α)
denominator = 1.0 + i_omega_tau_beta
# Fractional term: c_α * (iω)^α / (1 + (iωτ)^(1-α))
fractional_term = c_alpha * i_omega_alpha / denominator
# Total complex modulus
G_star = Ge + fractional_term
# Extract storage and loss moduli
G_prime = jnp.real(G_star)
G_double_prime = jnp.imag(G_star)
return jnp.stack([G_prime, G_double_prime], axis=-1)
def _fit(
self, X: jnp.ndarray, y: jnp.ndarray, **kwargs
) -> FractionalZenerSolidLiquid:
"""Fit model to data using NLSQ TRF optimization.
Parameters
----------
X : jnp.ndarray
Independent variable (time or frequency)
y : jnp.ndarray
Dependent variable (modulus or compliance)
**kwargs : dict
Additional fitting options (test_mode, optimization settings)
Returns
-------
self
Fitted model instance
"""
from rheojax.core.test_modes import TestMode
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Detect test mode if not provided
test_mode_str = kwargs.get("test_mode", "relaxation")
# Convert string to TestMode enum
if isinstance(test_mode_str, str):
test_mode_map = {
"relaxation": TestMode.RELAXATION,
"creep": TestMode.CREEP,
"oscillation": TestMode.OSCILLATION,
}
test_mode = test_mode_map.get(test_mode_str, TestMode.RELAXATION)
else:
test_mode = test_mode_str
# Store test mode for model_function
self._test_mode = test_mode
# Determine data shape for logging
data_shape = (len(X),) if hasattr(X, "__len__") else None
with log_fit(
logger,
model="FractionalZenerSolidLiquid",
data_shape=data_shape,
test_mode=(
test_mode_str if isinstance(test_mode_str, str) else str(test_mode)
),
) as ctx:
logger.debug(
"Starting FZSL fit",
n_points=len(X) if hasattr(X, "__len__") else 1,
test_mode=str(test_mode),
initial_params=self.parameters.to_dict(),
)
# Smart initialization for oscillation mode (Issue #9)
if test_mode == TestMode.OSCILLATION:
try:
import numpy as np
from rheojax.utils.initialization import (
initialize_fractional_zener_sl,
)
success = initialize_fractional_zener_sl(
np.array(X), np.array(y), self.parameters
)
if success:
logger.debug(
"Smart initialization applied from frequency-domain features",
initialized_params=self.parameters.to_dict(),
)
except Exception as e:
logger.debug(
"Smart initialization failed, using defaults",
error=str(e),
)
# Create stateless model function for optimization
def model_fn(x, params):
"""Model function for optimization (stateless)."""
Ge, c_alpha, alpha, tau = params[0], params[1], params[2], params[3]
# Direct prediction based on test mode (stateless)
if test_mode == TestMode.RELAXATION:
return self._predict_relaxation(x, Ge, c_alpha, alpha, tau)
elif test_mode == TestMode.CREEP:
return self._predict_creep(x, Ge, c_alpha, alpha, tau)
elif test_mode == TestMode.OSCILLATION:
return self._predict_oscillation(x, Ge, c_alpha, alpha, tau)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
# Create objective function
logger.debug("Creating least squares objective", normalize=True)
objective = create_least_squares_objective(
model_fn, jnp.array(X), jnp.array(y), normalize=True
)
# Optimize using NLSQ TRF
logger.debug(
"Starting NLSQ optimization",
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
try:
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
except Exception as e:
logger.error(
"NLSQ optimization raised exception",
error_type=type(e).__name__,
error=str(e),
exc_info=True,
)
raise
# Validate optimization succeeded
if not result.success:
logger.error(
"Optimization failed",
message=result.message,
final_params=self.parameters.to_dict(),
)
raise RuntimeError(
f"Optimization failed: {result.message}. "
f"Try adjusting initial values, bounds, or max_iter."
)
self.fitted_ = True
ctx["final_params"] = self.parameters.to_dict()
ctx["success"] = True
logger.debug(
"FZSL fit completed successfully",
final_params=self.parameters.to_dict(),
)
return self
def _predict(self, X: jnp.ndarray, **kwargs) -> jnp.ndarray:
"""Predict response for given input.
Parameters
----------
X : jnp.ndarray
Independent variable (time or frequency)
Returns
-------
jnp.ndarray
Predicted values
"""
# Get parameter values
Ge = self.parameters.get_value("Ge")
c_alpha = self.parameters.get_value("c_alpha")
alpha = self.parameters.get_value("alpha")
tau = self.parameters.get_value("tau")
# Dispatch based on test_mode
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None)
)
if test_mode in ("oscillation",):
return self._predict_oscillation(X, Ge, c_alpha, alpha, tau)
elif test_mode in ("creep",):
return self._predict_creep(X, Ge, c_alpha, alpha, tau)
else:
# Default to relaxation
return self._predict_relaxation(X, Ge, c_alpha, alpha, tau)
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes predictions given input X and a parameter array.
Args:
X: Independent variable (time or frequency)
params: Array of parameter values [Ge, c_alpha, alpha, tau]
Returns:
Model predictions as JAX array
"""
# Extract parameters from array (in order they were added to ParameterSet)
Ge = params[0]
c_alpha = params[1]
alpha = params[2]
tau = params[3]
# Use test_mode from last fit if available, otherwise default to relaxation
if test_mode is None:
test_mode = getattr(self, "_test_mode", "relaxation")
# Normalize test_mode to handle both string and TestMode enum
if hasattr(test_mode, "value"):
test_mode = test_mode.value
# Call appropriate prediction function based on test mode
if test_mode == "relaxation":
return self._predict_relaxation(X, Ge, c_alpha, alpha, tau)
elif test_mode == "creep":
return self._predict_creep(X, Ge, c_alpha, alpha, tau)
elif test_mode == "oscillation":
stacked = self._predict_oscillation(X, Ge, c_alpha, alpha, tau)
return stacked[..., 0] + 1j * stacked[..., 1]
else:
# Default to relaxation mode for FZSL model
return self._predict_relaxation(X, Ge, c_alpha, alpha, tau)
# Convenience alias
FZSL = FractionalZenerSolidLiquid
__all__ = ["FractionalZenerSolidLiquid", "FZSL"]