"""Fractional Maxwell Model (FMM).
This is the most general fractional Maxwell model with two SpringPots in series,
each with independent fractional orders. It provides maximum flexibility in
describing viscoelastic materials with fractional dynamics.
Mathematical Description:
Relaxation Modulus: G(t) = (c_1/τ^β) t^(β-α) E_{β, β-α+1}(-(t/τ)^β)
Complex Modulus: G*(ω) = c_1 (iω)^α / (1 + (iωτ)^β)
Creep Compliance: J(t) = (1/c_1) [t^α/Γ(1+α) + τ^β t^(α-β)/Γ(1+α-β)]
where α and β are independent fractional orders.
Parameters:
c1 (float): Material constant (Pa·s^α), bounds [1e-3, 1e9]
alpha (float): First fractional order, bounds [0.0, 1.0]
beta (float): Second fractional order, bounds [0.0, 1.0]
tau (float): Relaxation time (s), bounds [1e-6, 1e6]
Test Modes: Relaxation, Creep, Oscillation
References:
- Schiessel, H., & Blumen, A. (1993). Hierarchical analogues to fractional relaxation
equations. Journal of Physics A: Mathematical and General, 26(19), 5057.
- Heymans, N., & Bauwens, J. C. (1994). Fractal rheological models and fractional
differential equations for viscoelastic behavior. Rheologica Acta, 33(3), 210-219.
"""
from __future__ import annotations
from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger, log_fit
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS
jax, jnp = safe_import_jax()
import numpy as np
from rheojax.core.base import BaseModel, ParameterSet
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.utils.mittag_leffler import mittag_leffler_e2
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"fractional_maxwell_model",
protocols=[
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.OSCILLATION,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class FractionalMaxwellModel(BaseModel):
"""Fractional Maxwell Model: Two SpringPots in series with independent orders.
This is the most general fractional Maxwell model, allowing for complex
viscoelastic behavior with two independent fractional orders.
Attributes:
parameters: ParameterSet with c1, alpha, beta, tau
Examples:
>>> from rheojax.models import FractionalMaxwellModel
>>> from rheojax.core.data import RheoData
>>> import numpy as np
>>>
>>> # Create model with parameters
>>> model = FractionalMaxwellModel()
>>> model.parameters.set_value('c1', 1e5)
>>> model.parameters.set_value('alpha', 0.5)
>>> model.parameters.set_value('beta', 0.7)
>>> model.parameters.set_value('tau', 1.0)
>>>
>>> # Predict relaxation modulus
>>> t = np.logspace(-3, 3, 50)
>>> data = RheoData(x=t, y=np.zeros_like(t), domain='time')
>>> data.metadata['test_mode'] = 'relaxation'
>>> G_t = model.predict(data)
"""
[docs]
def __init__(self):
"""Initialize Fractional Maxwell Model."""
super().__init__()
self.parameters = ParameterSet()
self.parameters.add(
name="c1",
value=1e5,
bounds=(1e-3, 1e9),
units="Pa·s^α",
description="Material constant",
)
self.parameters.add(
name="alpha",
value=0.5,
bounds=FRACTIONAL_ORDER_BOUNDS,
units="dimensionless",
description="First fractional order",
)
self.parameters.add(
name="beta",
value=0.5,
bounds=FRACTIONAL_ORDER_BOUNDS,
units="dimensionless",
description="Second fractional order",
)
self.parameters.add(
name="tau",
value=1.0,
bounds=(1e-6, 1e6),
units="s",
description="Relaxation time",
)
self.fitted_ = False
@staticmethod
@jax.jit
def _predict_relaxation_jax(
t: jnp.ndarray, c1: float, alpha: float, beta: float, tau: float
) -> jnp.ndarray:
"""Predict relaxation modulus G(t) using JAX.
Derived from the inverse Laplace transform of G̃(s) = G*(s)/s:
G(t) = (c₁/τ^β) · t^(β-α) · E_{β, β-α+1}(-(t/τ)^β)
where E_{β, β-α+1} is the two-parameter Mittag-Leffler function.
Args:
t: Time array
c1: Material constant (Pa·s^α)
alpha: First fractional order
beta: Second fractional order
tau: Relaxation time (s)
Returns:
Relaxation modulus array
"""
epsilon = 1e-12
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
beta_safe = jnp.clip(beta, epsilon, 1.0 - epsilon)
t_safe = jnp.maximum(t, epsilon)
tau_safe = jnp.maximum(tau, epsilon)
# Mittag-Leffler argument
z = -((t_safe / tau_safe) ** beta_safe)
# E_{β, β-α+1}(z) — two-parameter Mittag-Leffler
ml_beta_param = beta_safe - alpha_safe + 1.0
ml_value = mittag_leffler_e2(z, alpha=beta_safe, beta=ml_beta_param)
# G(t) = (c1 / tau^beta) * t^{beta - alpha} * E_{beta, beta-alpha+1}(z)
prefactor = c1 / (tau_safe**beta_safe)
G_t = prefactor * (t_safe ** (beta_safe - alpha_safe)) * ml_value
return G_t
@staticmethod
@jax.jit
def _predict_creep_jax(
t: jnp.ndarray, c1: float, alpha: float, beta: float, tau: float
) -> jnp.ndarray:
"""Predict creep compliance J(t) using JAX.
Derived from J̃(s) = 1/(s·G*(s)) = (1 + (sτ)^β) / (c₁ s^{α+1}):
J(t) = (1/c₁) [t^α/Γ(1+α) + τ^β · t^{α-β}/Γ(1+α-β)]
Two power-law terms: the first is the springpot creep, the second
captures the contribution from the second element.
Args:
t: Time array
c1: Material constant (Pa·s^α)
alpha: First fractional order
beta: Second fractional order
tau: Relaxation time (s)
Returns:
Creep compliance array
"""
epsilon = 1e-12
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
beta_safe = jnp.clip(beta, epsilon, 1.0 - epsilon)
t_safe = jnp.maximum(t, epsilon)
tau_safe = jnp.maximum(tau, epsilon)
c1_safe = jnp.maximum(c1, epsilon)
# J(t) = (1/c1) * [t^alpha / Gamma(1+alpha) + tau^beta * t^{alpha-beta} / Gamma(1+alpha-beta)]
inv_gamma_1 = jnp.exp(-jax.lax.lgamma(1.0 + alpha_safe))
inv_gamma_2 = jnp.exp(-jax.lax.lgamma(1.0 + alpha_safe - beta_safe))
term1 = (t_safe**alpha_safe) * inv_gamma_1
term2 = (
(tau_safe**beta_safe) * (t_safe ** (alpha_safe - beta_safe)) * inv_gamma_2
)
J_t = (term1 + term2) / c1_safe
# Clip to reasonable range
J_t_clipped = jnp.clip(J_t, epsilon, 1e10)
return J_t_clipped
def _predict_oscillation_jax(
self, omega: jnp.ndarray, c1: float, alpha: float, beta: float, tau: float
) -> jnp.ndarray:
"""Predict complex modulus G*(ω) using JAX.
G*(ω) = c_1 (iω)^α / (1 + (iωτ)^β)
Args:
omega: Angular frequency array
c1: Material constant
alpha: First fractional order
beta: Second fractional order
tau: Relaxation time
Returns:
Complex modulus array
"""
# Add small epsilon
epsilon = 1e-12
# Clip alpha and beta to safe range (now works with JAX tracers)
alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon)
beta_safe = jnp.clip(beta, epsilon, 1.0 - epsilon)
# Compute safe values
omega_safe = jnp.maximum(omega, epsilon)
tau_safe = jnp.maximum(tau, epsilon)
# (iω)^α = |ω|^α * exp(i α π/2)
i_omega_alpha = (omega_safe**alpha_safe) * jnp.exp(
1j * alpha_safe * jnp.pi / 2.0
)
# (iωτ)^β = |ωτ|^β * exp(i β π/2)
omega_tau = omega_safe * tau_safe
i_omega_tau_beta = (omega_tau**beta_safe) * jnp.exp(
1j * beta_safe * jnp.pi / 2.0
)
# Complex modulus
G_star = c1 * i_omega_alpha / (1.0 + i_omega_tau_beta)
return G_star
def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> FractionalMaxwellModel:
"""Fit model parameters to data.
Args:
X: Independent variable (time or frequency)
y: Dependent variable (modulus or compliance)
**kwargs: Additional fitting options
Returns:
self for method chaining
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Handle RheoData input
if isinstance(X, RheoData):
rheo_data = X
x_data = jnp.array(rheo_data.x)
y_data = jnp.array(rheo_data.y)
test_mode = rheo_data.test_mode
else:
x_data = jnp.array(X)
y_data = jnp.array(y)
test_mode = kwargs.get("test_mode", "relaxation")
# R13-FMM-001: Store test_mode so Bayesian inference picks up the correct
# protocol (otherwise _resolve_test_mode defaults to 'relaxation').
self._test_mode = test_mode
# Determine data shape for logging
data_shape = (len(X),) if hasattr(X, "__len__") else None
with log_fit(
logger,
model="FractionalMaxwellModel",
data_shape=data_shape,
test_mode=test_mode if isinstance(test_mode, str) else str(test_mode),
) as ctx:
logger.debug(
"Starting FMM fit",
n_points=len(X) if hasattr(X, "__len__") else 1,
test_mode=str(test_mode),
initial_params=self.parameters.to_dict(),
)
# Smart initialization for oscillation mode (Issue #9)
if test_mode == "oscillation":
try:
from rheojax.utils.initialization import (
initialize_fractional_maxwell_model,
)
success = initialize_fractional_maxwell_model(
np.array(X), np.array(y), self.parameters
)
if success:
logger.debug(
"Smart initialization applied from frequency-domain features",
initialized_params=self.parameters.to_dict(),
)
except Exception as e:
logger.debug(
"Smart initialization failed, using defaults",
error=str(e),
)
# Smart initialization for creep/relaxation mode
elif test_mode in ("creep", "relaxation"):
try:
x_np = np.asarray(X) if not isinstance(X, np.ndarray) else X
y_np = np.asarray(y) if not isinstance(y, np.ndarray) else y
y_real = np.abs(y_np) if np.iscomplexobj(y_np) else y_np
# Filter valid data points
valid = (x_np > 0) & (y_real > 0) & np.isfinite(y_real)
if np.sum(valid) >= 2:
x_valid = x_np[valid]
y_valid = y_real[valid]
# tau: geometric mean of time range (characteristic time)
tau_init = np.sqrt(x_valid.min() * x_valid.max())
if test_mode == "creep":
# For creep: J(t) ~ 1/c1, so c1 ~ 1/J_mid
y_mid = np.sqrt(y_valid.min() * y_valid.max())
c1_init = 1.0 / y_mid
else:
# For relaxation: G(t) ~ c1
c1_init = np.sqrt(y_valid.min() * y_valid.max())
# Clip to parameter bounds
c1_param = self.parameters.get("c1")
tau_param = self.parameters.get("tau")
assert c1_param is not None and c1_param.bounds is not None
assert tau_param is not None and tau_param.bounds is not None
c1_bounds = c1_param.bounds
tau_bounds = tau_param.bounds
c1_init = np.clip(c1_init, c1_bounds[0], c1_bounds[1])
tau_init = np.clip(tau_init, tau_bounds[0], tau_bounds[1])
self.parameters.set_value("c1", float(c1_init))
self.parameters.set_value("tau", float(tau_init))
logger.debug(
f"Smart initialization applied for {test_mode} mode",
c1=c1_init,
tau=tau_init,
)
except Exception as e:
logger.debug(
f"Smart initialization failed for {test_mode} mode, using defaults",
error=str(e),
)
# Create objective function with stateless predictions
def model_fn(x, params):
"""Model function for optimization (stateless)."""
c1, alpha, beta, tau = params[0], params[1], params[2], params[3]
# Direct prediction based on test mode (stateless, calls _jax methods)
if test_mode == "relaxation":
return self._predict_relaxation_jax(x, c1, alpha, beta, tau)
elif test_mode == "creep":
return self._predict_creep_jax(x, c1, alpha, beta, tau)
elif test_mode == "oscillation":
return self._predict_oscillation_jax(x, c1, alpha, beta, tau)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
logger.debug("Creating least squares objective", normalize=True)
objective = create_least_squares_objective(
model_fn, x_data, y_data, normalize=True
)
# Optimize using NLSQ (JAX enabled by default)
logger.debug(
"Starting NLSQ optimization",
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
try:
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
except Exception as e:
logger.error(
"NLSQ optimization raised exception",
error_type=type(e).__name__,
error=str(e),
exc_info=True,
)
raise
# Validate optimization succeeded
if not result.success:
logger.error(
"Optimization failed",
message=result.message,
final_params=self.parameters.to_dict(),
)
raise RuntimeError(
f"Optimization failed: {result.message}. "
f"Try adjusting initial values, bounds, or max_iter."
)
self.fitted_ = True
ctx["final_params"] = self.parameters.to_dict()
ctx["success"] = True
logger.debug(
"FMM fit completed successfully",
final_params=self.parameters.to_dict(),
)
return self
def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
"""Internal predict implementation.
Args:
X: RheoData object or array of x-values
**kwargs: Additional arguments (test_mode handled via self._test_mode)
Returns:
Predicted values
"""
# Handle RheoData input
if isinstance(X, RheoData):
return self.predict_rheodata(X)
# Handle raw array input
from rheojax.core.test_modes import TestMode
x = jnp.asarray(X)
c1 = self.parameters.get_value("c1")
alpha = self.parameters.get_value("alpha")
beta = self.parameters.get_value("beta")
tau = self.parameters.get_value("tau")
assert (
c1 is not None
and alpha is not None
and beta is not None
and tau is not None
)
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None)
)
if test_mode in ("oscillation", TestMode.OSCILLATION):
result = self._predict_oscillation_jax(x, c1, alpha, beta, tau)
elif test_mode in ("creep", TestMode.CREEP):
result = self._predict_creep_jax(x, c1, alpha, beta, tau)
else:
result = self._predict_relaxation_jax(x, c1, alpha, beta, tau)
return np.array(result)
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes predictions given input X and a parameter array.
CRITICAL: test_mode is now passed as parameter (NOT read from self._test_mode)
to ensure correct posteriors in Bayesian inference (v0.4.0 fix).
Args:
X: Independent variable (time or frequency)
params: Array of parameter values [c1, alpha, beta, tau]
test_mode: Explicit test mode for predictions. If None, defaults
to 'relaxation' for backward compatibility.
Returns:
Model predictions as JAX array
"""
# Extract parameters from array (in order they were added to ParameterSet)
c1 = params[0]
alpha = params[1]
beta = params[2]
tau = params[3]
# Use explicit test_mode parameter (closure-captured in fit_bayesian)
# Fall back to self._test_mode only for backward compatibility
if test_mode is None:
test_mode = getattr(self, "_test_mode", "relaxation")
# Normalize test_mode to string
if hasattr(test_mode, "value"):
test_mode = test_mode.value
# Dispatch to appropriate prediction method
if test_mode == "relaxation":
return self._predict_relaxation_jax(X, c1, alpha, beta, tau)
elif test_mode == "creep":
return self._predict_creep_jax(X, c1, alpha, beta, tau)
elif test_mode == "oscillation":
return self._predict_oscillation_jax(X, c1, alpha, beta, tau)
else:
# Default to relaxation for unknown modes
return self._predict_relaxation_jax(X, c1, alpha, beta, tau)
[docs]
def predict_rheodata(
self, rheo_data: RheoData, test_mode: str | None = None
) -> RheoData:
"""Predict response for RheoData.
Args:
rheo_data: Input RheoData with x values
test_mode: Test mode ('relaxation', 'creep', 'oscillation')
If None, auto-detect from rheo_data
Returns:
RheoData with predicted y values
"""
# Auto-detect test mode if not provided
if test_mode is None:
# Check for explicit test_mode in metadata first
if "test_mode" in rheo_data.metadata:
test_mode = rheo_data.metadata["test_mode"]
else:
test_mode = rheo_data.test_mode
# Get parameters
c1 = self.parameters.get_value("c1")
alpha = self.parameters.get_value("alpha")
beta = self.parameters.get_value("beta")
tau = self.parameters.get_value("tau")
assert (
c1 is not None
and alpha is not None
and beta is not None
and tau is not None
)
# Convert input to JAX
x = jnp.asarray(rheo_data.x)
# Route to appropriate prediction method
if test_mode == "relaxation":
y_pred = self._predict_relaxation_jax(x, c1, alpha, beta, tau)
elif test_mode == "creep":
y_pred = self._predict_creep_jax(x, c1, alpha, beta, tau)
elif test_mode == "oscillation":
y_pred = self._predict_oscillation_jax(x, c1, alpha, beta, tau)
else:
raise ValueError(
f"Unknown test mode: {test_mode}. "
f"Must be 'relaxation', 'creep', or 'oscillation'"
)
# Create output RheoData
result = RheoData(
x=np.array(x),
y=np.array(y_pred),
x_units=rheo_data.x_units,
y_units=rheo_data.y_units,
domain=rheo_data.domain,
metadata=rheo_data.metadata.copy(),
)
return result
[docs]
def predict(self, X, test_mode: str | None = None, **kwargs): # type: ignore[override]
"""Predict response.
R13-FMM-002: Delegates to BaseModel.predict() for plain-array inputs so
that deformation_mode (E*->G* conversion) and test_mode restoration are
handled correctly. Only RheoData inputs bypass super() because they
return a RheoData wrapper object.
Args:
X: RheoData object or array of x-values
test_mode: Test mode for prediction ('relaxation', 'creep', 'oscillation')
Required when X is a raw array. If None, defaults to 'relaxation'.
**kwargs: Additional arguments passed to BaseModel.predict()
Returns:
Predicted values (RheoData if input is RheoData, else array)
"""
if isinstance(X, RheoData):
return self.predict_rheodata(X, test_mode=test_mode)
else:
# Delegate to BaseModel.predict() which handles deformation_mode
# conversion (DMTA E*->G*->E*) and test_mode restoration in its
# finally block.
return super().predict(X, test_mode=test_mode, **kwargs)