"""Fractional Zener Liquid-Liquid (FZLL) Model.
This is the most general fractional Zener model with two SpringPots and one dashpot,
providing maximum flexibility in describing fractional viscoelastic behavior.
Theory
------
The FZLL model consists of:
- Two SpringPots with different fractional orders
- One dashpot
- Complex arrangement providing liquid-like behavior at long times
Complex modulus:
G*(ω) = c_1 * (iω)^α / (1 + (iωτ)^β) + c_2 * (iω)^γ
where all three fractional orders (α, β, γ) can be different.
Parameters
----------
c1 : float
First SpringPot constant (Pa·s^α), bounds [1e-3, 1e9]
c2 : float
Second SpringPot constant (Pa·s^γ), bounds [1e-3, 1e9]
alpha : float
First fractional order, bounds [0.0, 1.0]
beta : float
Second fractional order, bounds [0.0, 1.0]
gamma : float
Third fractional order, bounds [0.0, 1.0]
tau : float
Relaxation time (s), bounds [1e-6, 1e6]
Limit Cases
-----------
- alpha, beta, gamma → 1: Classical viscoelastic liquid
- beta → 0: Simplifies to parallel SpringPots
References
----------
- Mainardi, F. (2010). Fractional Calculus and Waves in Linear Viscoelasticity
- Schiessel, H., et al. (1995). J. Phys. A: Math. Gen. 28, 6567
"""
from __future__ import annotations
from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger, log_fit
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS
jax, jnp = safe_import_jax()
from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"fractional_zener_ll",
protocols=[
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.OSCILLATION,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class FractionalZenerLiquidLiquid(BaseModel):
"""Fractional Zener Liquid-Liquid model.
The most general fractional Zener model with three independent
fractional orders.
Test Modes
----------
- Relaxation: Supported (numerical)
- Creep: Supported (numerical)
- Oscillation: Supported (analytical)
- Rotation: Partial support (power-law at high shear rates)
Examples
--------
>>> import jax.numpy as jnp
>>> from rheojax.models import FractionalZenerLiquidLiquid
>>>
>>> # Create model
>>> model = FractionalZenerLiquidLiquid()
>>>
>>> # Set parameters
>>> model.set_params(c1=500.0, c2=100.0, alpha=0.5, beta=0.3, gamma=0.7, tau=1.0)
>>>
>>> # Predict complex modulus
>>> omega = jnp.logspace(-2, 2, 50)
>>> G_star = model.predict(omega)
"""
[docs]
def __init__(self):
"""Initialize Fractional Zener Liquid-Liquid model."""
super().__init__()
# Define parameters with bounds and descriptions
self.parameters = ParameterSet()
self.parameters.add(
name="c1",
value=500.0,
bounds=(1e-3, 1e9),
units="Pa·s^α",
description="First SpringPot constant",
)
self.parameters.add(
name="c2",
value=500.0,
bounds=(1e-3, 1e9),
units="Pa·s^γ",
description="Second SpringPot constant",
)
self.parameters.add(
name="alpha",
value=0.5,
bounds=FRACTIONAL_ORDER_BOUNDS,
units="",
description="First fractional order",
)
self.parameters.add(
name="beta",
value=0.5,
bounds=FRACTIONAL_ORDER_BOUNDS,
units="",
description="Second fractional order",
)
self.parameters.add(
name="gamma",
value=0.5,
bounds=FRACTIONAL_ORDER_BOUNDS,
units="",
description="Third fractional order",
)
self.parameters.add(
name="tau",
value=1.0,
bounds=(1e-6, 1e6),
units="s",
description="Relaxation time",
)
@staticmethod
@jax.jit
def _predict_oscillation(
omega: jnp.ndarray,
c1: float,
c2: float,
alpha: float,
beta: float,
gamma: float,
tau: float,
) -> jnp.ndarray:
"""Predict complex modulus G*(ω).
G*(ω) = c_1 * (iω)^α / (1 + (iωτ)^β) + c_2 * (iω)^γ
"""
epsilon = 1e-12
# Clip fractional orders using JAX operations (tracer-safe)
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
beta_safe = jnp.clip(beta, epsilon, 1.0 - epsilon)
gamma_safe = jnp.clip(gamma, epsilon, 1.0 - epsilon)
tau_safe = tau + epsilon
# First term: c_1 * (iω)^α / (1 + (iωτ)^β)
# Compute (iω)^α
omega_alpha = jnp.power(omega, alpha_safe)
phase_alpha = jnp.pi * alpha_safe / 2.0
i_omega_alpha = omega_alpha * (jnp.cos(phase_alpha) + 1j * jnp.sin(phase_alpha))
# Compute (iωτ)^β
omega_tau_beta = jnp.power(omega * tau_safe, beta_safe)
phase_beta = jnp.pi * beta_safe / 2.0
i_omega_tau_beta = omega_tau_beta * (
jnp.cos(phase_beta) + 1j * jnp.sin(phase_beta)
)
# First term
term1 = c1 * i_omega_alpha / (1.0 + i_omega_tau_beta)
# Second term: c_2 * (iω)^γ
omega_gamma = jnp.power(omega, gamma_safe)
phase_gamma = jnp.pi * gamma_safe / 2.0
i_omega_gamma = omega_gamma * (jnp.cos(phase_gamma) + 1j * jnp.sin(phase_gamma))
term2 = c2 * i_omega_gamma
# Total complex modulus
G_star = term1 + term2
# Extract storage and loss moduli
G_prime = jnp.real(G_star)
G_double_prime = jnp.imag(G_star)
return jnp.stack([G_prime, G_double_prime], axis=-1)
@staticmethod
@jax.jit
def _predict_relaxation(
t: jnp.ndarray,
c1: float,
c2: float,
alpha: float,
beta: float,
gamma: float,
tau: float,
) -> jnp.ndarray:
"""Predict relaxation modulus G(t).
Approximate form (no exact closed-form in Mittag-Leffler for FZLL):
G(t) ≈ c1 / (1 + (t/τ)^α) + c2 · exp(-t/τ)
The first term captures the power-law decay from the springpot branch,
and the second term captures the exponential decay from the dashpot
branch. This approximation preserves the correct limiting behaviors:
G(0+) = c1 + c2, and G(∞) = 0 (liquid).
"""
epsilon = 1e-12
# Clip fractional orders using JAX operations (tracer-safe)
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
tau_safe = tau + epsilon
# Use c1 as the main modulus contribution and alpha as the decay exponent
# This breaks the c1/c2 degeneracy by using them differently
# G(t) = c1 / (1 + (t/tau)^alpha) + c2 * exp(-t/tau)
t_tau_ratio = t / tau_safe
# Primary power-law decay term from c1
term1 = c1 / (1.0 + jnp.power(t_tau_ratio, alpha_safe))
# Secondary exponential decay term from c2 (simpler decay)
term2 = c2 * jnp.exp(-t_tau_ratio)
G_t = term1 + term2
return G_t
@staticmethod
@jax.jit
def _predict_creep(
t: jnp.ndarray,
c1: float,
c2: float,
alpha: float,
beta: float,
gamma: float,
tau: float,
) -> jnp.ndarray:
"""Predict creep compliance J(t).
Note: Analytical creep compliance is complex for FZLL.
This provides a numerical approximation.
"""
epsilon = 1e-12
# Clip fractional orders using JAX operations (tracer-safe)
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
gamma_safe = jnp.clip(gamma, epsilon, 1.0 - epsilon)
# Compute average order
avg_order = (alpha_safe + gamma_safe) / 2.0
# Short time behavior
J_short = jnp.power(t, alpha_safe) / (c1 + epsilon)
# Long time behavior (unbounded growth for liquid)
J_long = jnp.power(t, avg_order) / (c2 + epsilon)
# Crossover
weight = jnp.tanh(t / (tau + epsilon))
J_t = J_short * (1.0 - weight) + J_long * weight
return J_t
def _initialize_relaxation_parameters(self, X, y) -> bool:
"""Derive heuristic starting values from relaxation data.
For FZLL, estimates c1, c2, and tau from the data characteristics.
"""
import numpy as np
logger.debug(
"Attempting relaxation parameter initialization",
data_size=len(X) if hasattr(X, "__len__") else "unknown",
)
try:
t = np.asarray(X, dtype=float).ravel()
g = np.asarray(y, dtype=float).ravel()
if t.shape != g.shape or t.size < 4:
logger.debug(
"Insufficient data for relaxation initialization",
t_shape=t.shape,
g_shape=g.shape,
)
return False
order = np.argsort(t)
t_sorted = t[order]
g_sorted = g[order]
# For liquid model, G(t→∞) → 0
# Formula: G(t) = c1 / (1 + (t/tau)^alpha) + c2 * exp(-t/tau)
# At t=0: G(0) ≈ c1 + c2
# Find where G is at half of its maximum - that gives us tau estimate
g_max = float(np.max(g_sorted))
# At early times, G ≈ c1 / (1 + (t_min/tau)^alpha) + c2 * exp(-t_min/tau)
# If t_min = tau: G = c1/2 + c2*exp(-1) ≈ c1/2 + 0.37*c2
# So g_max ≈ c1/2 for the case where t_min = tau
# Estimate: c1 ≈ 2*g_max, c2 ≈ 0 (small contribution)
c1_guess = g_max * 2.0
c2_guess = 1.0 # Small contribution from exponential term
# Find tau from where G decays to half (c1/(1+1) = c1/2 at t=tau)
# Since g_max ≈ c1/2, we have t_min ≈ tau
tau_guess = float(t_sorted[0])
# Get bounds
c1_param = self.parameters.get("c1")
c2_param = self.parameters.get("c2")
tau_param = self.parameters.get("tau")
assert c1_param is not None and c1_param.bounds is not None
assert c2_param is not None and c2_param.bounds is not None
assert tau_param is not None and tau_param.bounds is not None
c1_bounds = c1_param.bounds
c2_bounds = c2_param.bounds
tau_bounds = tau_param.bounds
c1_guess = float(np.clip(c1_guess, c1_bounds[0], c1_bounds[1]))
c2_guess = float(np.clip(c2_guess, c2_bounds[0], c2_bounds[1]))
tau_guess = float(np.clip(tau_guess, tau_bounds[0], tau_bounds[1]))
self.parameters.set_value("c1", c1_guess)
self.parameters.set_value("c2", c2_guess)
self.parameters.set_value("tau", tau_guess)
# Keep fractional orders at defaults (0.5)
logger.debug(
"FZLL relaxation initialization completed",
c1=c1_guess,
c2=c2_guess,
tau=tau_guess,
g_max=g_max,
)
return True
except Exception as exc:
logger.debug(
"Relaxation initialization failed",
error=str(exc),
exc_info=True,
)
return False
def _fit(
self, X: jnp.ndarray, y: jnp.ndarray, **kwargs
) -> FractionalZenerLiquidLiquid:
"""Fit model to data using NLSQ TRF optimization.
Parameters
----------
X : jnp.ndarray
Independent variable (time or frequency)
y : jnp.ndarray
Dependent variable (modulus or compliance)
**kwargs : dict
Additional fitting options
Returns
-------
self
Fitted model instance
"""
from rheojax.core.test_modes import TestMode
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Detect test mode
test_mode_str = kwargs.get("test_mode", "oscillation")
# Convert string to TestMode enum
if isinstance(test_mode_str, str):
test_mode_map = {
"relaxation": TestMode.RELAXATION,
"creep": TestMode.CREEP,
"oscillation": TestMode.OSCILLATION,
}
test_mode = test_mode_map.get(test_mode_str, TestMode.OSCILLATION)
else:
test_mode = test_mode_str
# Store test mode for model_function
self._test_mode = test_mode
logger.info(
"Starting FractionalZenerLiquidLiquid fit",
test_mode=(
test_mode.value if hasattr(test_mode, "value") else str(test_mode)
),
data_shape=X.shape,
)
with log_fit(
logger, model="FractionalZenerLiquidLiquid", data_shape=X.shape
) as ctx:
# Data-aware initialization for relaxation mode
if test_mode == TestMode.RELAXATION:
logger.debug("Applying relaxation-mode parameter initialization")
self._initialize_relaxation_parameters(X, y)
# Smart initialization for oscillation mode (Issue #9)
if test_mode == TestMode.OSCILLATION:
try:
import numpy as np
from rheojax.utils.initialization import (
initialize_fractional_zener_ll,
)
logger.debug(
"Attempting smart initialization for oscillation mode",
data_points=len(X),
)
success = initialize_fractional_zener_ll(
np.array(X), np.array(y), self.parameters
)
if success:
logger.debug(
"Smart initialization applied from frequency-domain features",
c1=self.parameters.get_value("c1"),
c2=self.parameters.get_value("c2"),
alpha=self.parameters.get_value("alpha"),
beta=self.parameters.get_value("beta"),
gamma=self.parameters.get_value("gamma"),
tau=self.parameters.get_value("tau"),
)
except Exception as e:
# Silent fallback to defaults - don't break if initialization fails
logger.debug(
"Smart initialization failed, using defaults",
error=str(e),
exc_info=True,
)
# Create stateless model function for optimization
def model_fn(x, params):
"""Model function for optimization (stateless)."""
c1, c2, alpha, beta, gamma, tau = (
params[0],
params[1],
params[2],
params[3],
params[4],
params[5],
)
# Direct prediction based on test mode (stateless)
if test_mode == TestMode.RELAXATION:
return self._predict_relaxation(x, c1, c2, alpha, beta, gamma, tau)
elif test_mode == TestMode.CREEP:
return self._predict_creep(x, c1, c2, alpha, beta, gamma, tau)
elif test_mode == TestMode.OSCILLATION:
return self._predict_oscillation(x, c1, c2, alpha, beta, gamma, tau)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
# Create objective function
logger.debug("Creating least squares objective function")
objective = create_least_squares_objective(
model_fn, jnp.array(X), jnp.array(y), normalize=True
)
# Optimize using NLSQ TRF
logger.debug(
"Starting NLSQ optimization",
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
# Validate optimization succeeded
if not result.success:
logger.error(
"NLSQ optimization failed",
message=result.message,
exc_info=True,
)
raise RuntimeError(
f"Optimization failed: {result.message}. "
f"Try adjusting initial values, bounds, or max_iter."
)
self.fitted_ = True
ctx["success"] = True
ctx["fitted_params"] = {
"c1": self.parameters.get_value("c1"),
"c2": self.parameters.get_value("c2"),
"alpha": self.parameters.get_value("alpha"),
"beta": self.parameters.get_value("beta"),
"gamma": self.parameters.get_value("gamma"),
"tau": self.parameters.get_value("tau"),
}
logger.info(
"FractionalZenerLiquidLiquid fit completed",
c1=self.parameters.get_value("c1"),
c2=self.parameters.get_value("c2"),
alpha=self.parameters.get_value("alpha"),
beta=self.parameters.get_value("beta"),
gamma=self.parameters.get_value("gamma"),
tau=self.parameters.get_value("tau"),
)
return self
def _predict(self, X: jnp.ndarray, **kwargs) -> jnp.ndarray:
"""Predict response for given input.
Parameters
----------
X : jnp.ndarray
Independent variable
Returns
-------
jnp.ndarray
Predicted values
"""
# Get parameter values
c1 = self.parameters.get_value("c1")
c2 = self.parameters.get_value("c2")
alpha = self.parameters.get_value("alpha")
beta = self.parameters.get_value("beta")
gamma = self.parameters.get_value("gamma")
tau = self.parameters.get_value("tau")
# Dispatch based on test_mode
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None)
)
if test_mode in ("oscillation",):
return self._predict_oscillation(X, c1, c2, alpha, beta, gamma, tau)
elif test_mode in ("creep",):
return self._predict_creep(X, c1, c2, alpha, beta, gamma, tau)
else:
# Default to relaxation
return self._predict_relaxation(X, c1, c2, alpha, beta, gamma, tau)
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes predictions given input X and a parameter array.
Args:
X: Independent variable (time or frequency)
params: Array of parameter values [c1, c2, alpha, beta, gamma, tau]
Returns:
Model predictions as JAX array
"""
# Extract parameters from array (in order they were added to ParameterSet)
c1 = params[0]
c2 = params[1]
alpha = params[2]
beta = params[3]
gamma = params[4]
tau = params[5]
# Use test_mode from last fit if available, otherwise default to oscillation
# Get test_mode value BEFORE entering JIT region to avoid tracing issues
if test_mode is None:
test_mode = getattr(self, "_test_mode", "oscillation")
# Convert to string representation for comparison (JAX-safe)
if hasattr(test_mode, "value"):
test_mode_str = test_mode.value
elif isinstance(test_mode, str):
test_mode_str = test_mode
else:
test_mode_str = "oscillation"
logger.debug(
"model_function evaluation",
test_mode=test_mode_str,
alpha=alpha,
beta=beta,
gamma=gamma,
input_shape=X.shape if hasattr(X, "shape") else len(X),
)
# Use string comparison (JAX-safe) instead of enum comparison
if test_mode_str == "relaxation":
logger.debug("Computing relaxation modulus for FZLL")
return self._predict_relaxation(X, c1, c2, alpha, beta, gamma, tau)
elif test_mode_str == "creep":
logger.debug("Computing creep compliance for FZLL")
return self._predict_creep(X, c1, c2, alpha, beta, gamma, tau)
else: # Default to oscillation
logger.debug("Computing complex modulus for oscillation mode")
stacked = self._predict_oscillation(X, c1, c2, alpha, beta, gamma, tau)
return stacked[..., 0] + 1j * stacked[..., 1]
# Convenience alias
FZLL = FractionalZenerLiquidLiquid
__all__ = ["FractionalZenerLiquidLiquid", "FZLL"]