"""SpringPot fractional viscoelastic element.
The SpringPot (also called fractional element or Scott-Blair element) is a
power-law viscoelastic element that interpolates between pure elastic (alpha=1)
and pure viscous (alpha=0) behavior.
Theory:
- Relaxation modulus: G(t) = c_alpha * t^(-alpha) / Gamma(1-alpha)
- Complex modulus: G*(omega) = c_alpha * (i*omega)^alpha
- Creep compliance: J(t) = (1/c_alpha) * t^alpha / Gamma(1+alpha)
- Uses Mittag-Leffler functions for accurate fractional calculus
References:
- Scott Blair, G. W. (1947). The role of psychophysics in rheology.
- Bagley, R. L., & Torvik, P. J. (1983). Fractional calculus model of viscoelastic behavior.
- Schiessel, H., et al. (1995). Generalized viscoelastic models.
"""
from __future__ import annotations
from rheojax.core.jax_config import safe_import_jax
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS
jax, jnp = safe_import_jax()
jax_gamma = jax.scipy.special.gamma
from rheojax.core.base import BaseModel
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode, detect_test_mode
from rheojax.logging import get_logger
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"springpot",
protocols=[
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.OSCILLATION,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class SpringPot(BaseModel):
"""SpringPot fractional viscoelastic element.
The SpringPot represents a power-law viscoelastic material that exhibits
fractional behavior between pure solid (alpha=1) and pure fluid (alpha=0).
Parameters:
c_alpha (float): Material constant in Pa·s^alpha, range [1e-3, 1e9], default 1e5
alpha (float): Power-law exponent (0=elastic/solid, 1=viscous/fluid), range [0.0, 1.0], default 0.5
Supported test modes:
- Relaxation: Stress relaxation under constant strain
- Creep: Strain development under constant stress
- Oscillation: Small amplitude oscillatory shear (SAOS)
- Rotation: NOT SUPPORTED (SpringPot is linear viscoelastic)
Example:
>>> from rheojax.models.springpot import SpringPot
>>> from rheojax.core.data import RheoData
>>> import jax.numpy as jnp
>>>
>>> # Create model
>>> model = SpringPot()
>>> model.parameters.set_value('c_alpha', 1e5)
>>> model.parameters.set_value('alpha', 0.5)
>>>
>>> # Predict relaxation
>>> t = jnp.linspace(0.01, 10, 100)
>>> data = RheoData(x=t, y=jnp.zeros_like(t), domain='time')
>>> G_t = model.predict(data)
"""
[docs]
def __init__(self):
"""Initialize SpringPot model with default parameters."""
super().__init__()
# Define parameters with physical bounds
self.parameters = ParameterSet()
self.parameters.add(
name="c_alpha",
value=1e5,
bounds=(1e-3, 1e9),
units="Pa·s^alpha",
description="Material constant",
)
self.parameters.add(
name="alpha",
value=0.5,
bounds=FRACTIONAL_ORDER_BOUNDS,
units="dimensionless",
description="Power-law exponent (0=elastic/solid, 1=viscous/fluid)",
)
self.fitted_ = False
self._test_mode = TestMode.RELAXATION # Store test mode for model_function
def _fit(self, X, y, **kwargs):
"""Fit SpringPot model to data.
Args:
X: RheoData object or independent variable array
y: Dependent variable array (if X is not RheoData)
**kwargs: Additional fitting options
Returns:
self for method chaining
"""
# Pre-validate: SpringPot does not support rotation
supplied_mode = kwargs.get("test_mode", TestMode.RELAXATION)
if isinstance(X, RheoData):
supplied_mode = X.test_mode
if supplied_mode == TestMode.ROTATION:
logger.error(
"Invalid test mode for SpringPot",
test_mode="ROTATION",
supported_modes=["RELAXATION", "CREEP", "OSCILLATION"],
)
raise ValueError(
"SpringPot model does not support steady shear (rotation) test mode"
)
def model_fn(x, params):
"""Model function for optimization (stateless)."""
c_alpha, alpha = params[0], params[1]
tm = self._test_mode
if tm == TestMode.RELAXATION:
return self._predict_relaxation(x, c_alpha, alpha)
elif tm == TestMode.CREEP:
return self._predict_creep(x, c_alpha, alpha)
elif tm == TestMode.OSCILLATION:
return self._predict_oscillation(x, c_alpha, alpha)
else:
raise ValueError(f"Unsupported test mode: {tm}")
return self._standard_nlsq_fit(
X, y, model_fn, default_test_mode=TestMode.RELAXATION, **kwargs
)
def _predict(self, X, **kwargs):
"""Predict response based on input data.
Args:
X: RheoData object or independent variable array
Returns:
Predicted values as JAX array
"""
# Handle RheoData input
if isinstance(X, RheoData):
rheo_data = X
test_mode = detect_test_mode(rheo_data)
x_data = jnp.array(rheo_data.x)
else:
x_data = jnp.array(X)
# Use test_mode from last fit if available, otherwise default to RELAXATION
test_mode = getattr(self, "_test_mode", TestMode.RELAXATION)
# Validate test mode
if test_mode == TestMode.ROTATION:
raise ValueError(
"SpringPot model does not support steady shear (rotation) test mode"
)
# Get parameter values
c_alpha = self.parameters.get_value("c_alpha")
alpha = self.parameters.get_value("alpha")
# Dispatch to appropriate prediction method
if test_mode == TestMode.RELAXATION:
return self._predict_relaxation(x_data, c_alpha, alpha)
elif test_mode == TestMode.CREEP:
return self._predict_creep(x_data, c_alpha, alpha)
elif test_mode == TestMode.OSCILLATION:
return self._predict_oscillation(x_data, c_alpha, alpha)
else:
supported = [TestMode.RELAXATION, TestMode.CREEP, TestMode.OSCILLATION]
raise ValueError(
f"Unsupported test mode: {test_mode}. "
f"SpringPot supports: {', '.join(supported)}"
)
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes predictions given input X and a parameter array.
Args:
X: Independent variable (time or frequency)
params: Array of parameter values [c_alpha, alpha]
Returns:
Model predictions as JAX array
"""
# Extract parameters from array (in order they were added to ParameterSet)
c_alpha = params[0]
alpha = params[1]
# Use stored test mode from last fit, or default to RELAXATION
if test_mode is None:
test_mode = getattr(self, "_test_mode", TestMode.RELAXATION)
# Dispatch to appropriate prediction method
if test_mode == TestMode.RELAXATION:
return self._predict_relaxation(X, c_alpha, alpha)
elif test_mode == TestMode.CREEP:
return self._predict_creep(X, c_alpha, alpha)
elif test_mode == TestMode.OSCILLATION:
return self._predict_oscillation(X, c_alpha, alpha)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
@staticmethod
@jax.jit
def _predict_relaxation(
t: jnp.ndarray, c_alpha: float, alpha: float
) -> jnp.ndarray:
"""Predict relaxation modulus G(t).
Theory: G(t) = c_alpha * t^(-alpha) / Gamma(1-alpha)
For alpha=0 (pure elastic): G(t) = c_alpha (constant)
For alpha=1 (pure viscous): G(t) → 0 for t > 0 (Gamma(0) = ∞)
Args:
t: Time array (s)
c_alpha: Material constant (Pa·s^alpha)
alpha: Power-law exponent (0=elastic/solid, 1=viscous/fluid)
Returns:
Relaxation modulus G(t) in Pa
"""
# Handle special cases
# alpha -> 0: pure elastic (G -> c_alpha, constant)
# alpha -> 1: pure viscous (G -> 0 for t > 0)
# General formula: G(t) = c_alpha * t^(-alpha) / Gamma(1-alpha)
# P2-FRAC-001: Guard t=0 — power(0, -alpha) = +inf when alpha>0, which
# propagates NaN through the NLSQ Jacobian. Use 1e-30 floor so that
# t=0 maps to a very large (but finite) modulus, consistent with the
# singular G(0) = +∞ predicted by the SpringPot model.
t_safe = jnp.maximum(t, 1e-30)
gamma_factor = jax_gamma(1.0 - alpha)
return c_alpha * jnp.power(t_safe, -alpha) / gamma_factor
@staticmethod
@jax.jit
def _predict_creep(t: jnp.ndarray, c_alpha: float, alpha: float) -> jnp.ndarray:
"""Predict creep compliance J(t).
Theory: J(t) = (1/c_alpha) * t^alpha / Gamma(1+alpha)
For alpha=0 (pure elastic): J(t) = 1/c_alpha (constant)
For alpha=1 (pure viscous): J(t) = t/c_alpha (linear in t)
Args:
t: Time array (s)
c_alpha: Material constant (Pa·s^alpha)
alpha: Power-law exponent (0=elastic/solid, 1=viscous/fluid)
Returns:
Creep compliance J(t) in 1/Pa
"""
# General formula: J(t) = (1/c_alpha) * t^alpha / Gamma(1+alpha)
gamma_factor = jax_gamma(1.0 + alpha)
return (1.0 / c_alpha) * jnp.power(t, alpha) / gamma_factor
@staticmethod
@jax.jit
def _predict_oscillation(
omega: jnp.ndarray, c_alpha: float, alpha: float
) -> jnp.ndarray:
"""Predict complex modulus G*(omega).
Theory: G*(omega) = c_alpha * (i*omega)^alpha
This can be written as:
G*(omega) = c_alpha * omega^alpha * (cos(pi*alpha/2) + i*sin(pi*alpha/2))
Therefore:
G'(omega) = c_alpha * omega^alpha * cos(pi*alpha/2)
G''(omega) = c_alpha * omega^alpha * sin(pi*alpha/2)
Args:
omega: Angular frequency array (rad/s)
c_alpha: Material constant (Pa·s^alpha)
alpha: Power-law exponent (0=elastic/solid, 1=viscous/fluid)
Returns:
Complex modulus G*(omega) in Pa
"""
# Compute (i*omega)^alpha = omega^alpha * exp(i*pi*alpha/2)
omega_alpha = jnp.power(omega, alpha)
phase = jnp.pi * alpha / 2.0
# Storage modulus G'
G_prime = c_alpha * omega_alpha * jnp.cos(phase)
# Loss modulus G''
G_double_prime = c_alpha * omega_alpha * jnp.sin(phase)
# Complex modulus
return G_prime + 1j * G_double_prime
[docs]
def get_characteristic_time(self, reference_value: float = 1.0) -> float:
"""Get characteristic time scale for the material.
For SpringPot, there's no single characteristic time, but we can define
a reference time where G(t) = reference_value.
From G(t) = c_alpha * t^(-alpha) / Gamma(1-alpha) = reference_value:
t = (c_alpha / (reference_value * Gamma(1-alpha)))^(1/alpha)
Args:
reference_value: Reference modulus value (Pa), default 1.0
Returns:
Characteristic time in seconds
"""
c_alpha = self.parameters.get_value("c_alpha")
alpha = self.parameters.get_value("alpha")
# Avoid division by zero for alpha=0
if alpha < 1e-10:
return float("inf")
gamma_factor = float(jax_gamma(1.0 - alpha))
return (c_alpha / (reference_value * gamma_factor)) ** (1.0 / alpha)
[docs]
def __repr__(self) -> str:
"""String representation of SpringPot model."""
c_alpha = self.parameters.get_value("c_alpha")
alpha = self.parameters.get_value("alpha")
return f"SpringPot(c_alpha={c_alpha:.2e} Pa·s^{alpha:.2f}, alpha={alpha:.2f})"
__all__ = ["SpringPot"]