"""Maxwell viscoelastic model.
The Maxwell model consists of a spring (G0) and dashpot (eta) in series,
representing the simplest linear viscoelastic behavior with stress relaxation.
Theory:
- Relaxation modulus: G(t) = G0 * exp(-t/tau) where tau = eta/G0
- Complex modulus: G*(omega) = G0*(omega*tau)^2/(1+(omega*tau)^2) + i*G0*omega*tau/(1+(omega*tau)^2)
- Creep compliance: J(t) = 1/G0 + t/eta
- Steady shear viscosity: eta(gamma_dot) = eta (constant)
References:
- Ferry, J. D. (1980). Viscoelastic properties of polymers.
- Tschoegl, N. W. (1989). The phenomenological theory of linear viscoelastic behavior.
"""
from __future__ import annotations
from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
import numpy as np
from rheojax.core.base import BaseModel
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode, detect_test_mode
from rheojax.logging import get_logger, log_fit
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"maxwell",
protocols=[
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.OSCILLATION,
Protocol.FLOW_CURVE,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class Maxwell(BaseModel):
"""Maxwell viscoelastic model (spring and dashpot in series).
The Maxwell model is the simplest viscoelastic model, consisting of a linear
spring (elastic modulus G0) in series with a linear dashpot (viscosity eta).
Parameters:
G0 (float): Elastic modulus in Pa, range [1e-3, 1e9], default 1e5
eta (float): Viscosity in Pa·s, range [1e-6, 1e12], default 1e3
Supported test modes:
- Relaxation: Stress relaxation under constant strain
- Creep: Strain development under constant stress
- Oscillation: Small amplitude oscillatory shear (SAOS)
- Rotation: Steady shear flow
Example:
>>> from rheojax.models.maxwell import Maxwell
>>> from rheojax.core.data import RheoData
>>> import jax.numpy as jnp
>>>
>>> # Create model
>>> model = Maxwell()
>>> model.parameters.set_value('G0', 1e5)
>>> model.parameters.set_value('eta', 1e3)
>>>
>>> # Predict relaxation
>>> t = jnp.linspace(0.01, 10, 100)
>>> data = RheoData(x=t, y=jnp.zeros_like(t), domain='time')
>>> G_t = model.predict(data)
"""
[docs]
def __init__(self):
"""Initialize Maxwell model with default parameters."""
super().__init__()
# Define parameters with physical bounds
self.parameters = ParameterSet()
self.parameters.add(
name="G0",
value=1e5,
bounds=(1e-3, 1e9),
units="Pa",
description="Elastic modulus",
)
self.parameters.add(
name="eta",
value=1e3,
bounds=(1e-6, 1e12),
units="Pa·s",
description="Viscosity",
)
self.fitted_ = False
self._relaxation_offset = 0.0
self._test_mode = TestMode.RELAXATION # Store test mode for model_function
def _fit(self, X, y, **kwargs):
"""Fit Maxwell model to data.
Args:
X: RheoData object or independent variable array
y: Dependent variable array (if X is not RheoData)
**kwargs: Additional fitting options
Returns:
self for method chaining
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Handle RheoData input
def _to_array(values):
arr = np.asarray(values)
if np.iscomplexobj(arr):
return arr.astype(np.complex128)
return arr.astype(float)
if isinstance(X, RheoData):
rheo_data = X
x_np = np.asarray(rheo_data.x, dtype=float)
y_np = _to_array(rheo_data.y)
test_mode = rheo_data.test_mode
else:
x_np = np.asarray(X, dtype=float)
y_np = _to_array(y)
supplied_mode = kwargs.get("test_mode")
if supplied_mode is None and np.iscomplexobj(y_np):
test_mode = TestMode.OSCILLATION
else:
test_mode = (
supplied_mode if supplied_mode is not None else TestMode.RELAXATION
)
if isinstance(test_mode, str):
try:
test_mode = TestMode[test_mode.upper()]
except KeyError:
test_mode = TestMode.RELAXATION
# Determine test_mode string for logging
test_mode_str = test_mode.name if hasattr(test_mode, "name") else str(test_mode)
with log_fit(
logger,
self.__class__.__name__,
data_shape=x_np.shape,
test_mode=test_mode_str,
) as ctx:
logger.debug(
"Processing input data",
x_range=(float(x_np.min()), float(x_np.max())),
y_range=(float(np.real(y_np).min()), float(np.real(y_np).max())),
is_complex=np.iscomplexobj(y_np),
)
# Store test mode for model_function
self._test_mode = test_mode
self._relaxation_offset = 0.0
if test_mode == TestMode.RELAXATION:
tail = max(3, y_np.size // 6)
offset = float(np.median(y_np[-tail:]))
y_np = y_np - offset
self._relaxation_offset = offset
# Store in _last_fit_kwargs for Bayesian pipeline forwarding
self._last_fit_kwargs["_relaxation_offset"] = offset
logger.debug(
"Applied relaxation offset correction",
offset=offset,
tail_points=tail,
)
x_data = jnp.array(x_np)
y_data = jnp.array(y_np)
# Provide simple heuristics for relaxation data to improve deterministic fits
if test_mode == TestMode.RELAXATION:
init_success = self._initialize_relaxation_parameters(x_data, y_data)
logger.debug(
"Relaxation parameter initialization",
success=init_success,
G0_init=self.parameters.get_value("G0"),
eta_init=self.parameters.get_value("eta"),
)
# Create objective function with stateless predictions
def model_fn(x, params):
"""Model function for optimization (stateless)."""
G0, eta = params[0], params[1]
# Direct prediction based on test mode (stateless)
if test_mode == TestMode.RELAXATION:
return self._predict_relaxation(x, G0, eta)
elif test_mode == TestMode.CREEP:
return self._predict_creep(x, G0, eta)
elif test_mode == TestMode.OSCILLATION:
return self._predict_oscillation(x, G0, eta)
elif test_mode in (TestMode.ROTATION, TestMode.FLOW_CURVE):
return self._predict_rotation(x, G0, eta)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
# Honor use_log_residuals from kwargs (set by BaseModel auto-detect
# or passed explicitly) so wide-range relaxation/master-curve data
# is fit with equal weight per decade.
use_log_residuals = kwargs.get("use_log_residuals", False)
objective = create_least_squares_objective(
model_fn,
x_data,
y_data,
normalize=True,
use_log_residuals=use_log_residuals,
)
logger.debug(
"Starting NLSQ optimization",
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
use_jax=kwargs.get("use_jax", True),
)
# Optimize
try:
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
except Exception as e:
logger.error(
"NLSQ optimization raised exception",
error_type=type(e).__name__,
error_message=str(e),
exc_info=True,
)
raise
# Validate optimization succeeded
if not result.success:
if not np.isfinite(result.fun) or result.fun > 1e6 * len(x_np):
logger.error(
"Optimization failed",
message=result.message,
iterations=getattr(result, "nit", None),
)
raise RuntimeError(
f"Optimization failed: {result.message}. "
f"Try adjusting initial values, bounds, or max_iter."
)
else:
logger.warning(
"Optimization did not fully converge",
message=result.message,
model=self.__class__.__name__,
)
self._nlsq_result = result
self.fitted_ = True
# Log fitted parameters and result metrics
G0_fitted = self.parameters.get_value("G0")
eta_fitted = self.parameters.get_value("eta")
tau_fitted = eta_fitted / G0_fitted
ctx["G0"] = G0_fitted
ctx["eta"] = eta_fitted
ctx["tau"] = tau_fitted
ctx["iterations"] = getattr(result, "nit", None)
ctx["cost"] = getattr(result, "fun", None)
logger.debug(
"Optimization completed successfully",
G0=G0_fitted,
eta=eta_fitted,
tau=tau_fitted,
iterations=getattr(result, "nit", None),
final_cost=getattr(result, "fun", None),
)
return self
def _initialize_relaxation_parameters(self, X, y) -> bool:
"""Estimate G0 and eta from relaxation data for faster convergence."""
try:
t = np.asarray(X, dtype=float).ravel()
g = np.asarray(y, dtype=float).ravel()
if t.shape != g.shape or t.size < 3:
logger.debug(
"Initialization skipped: insufficient data",
t_shape=t.shape,
g_shape=g.shape,
)
return False
order = np.argsort(t)
t_sorted = t[order]
g_sorted = g[order]
tail = max(3, t_sorted.size // 6)
baseline = float(np.median(g_sorted[-tail:]))
transient = g_sorted - baseline
g0_bounds = self.parameters.get("G0").bounds or (1e-3, 1e9)
eta_bounds = self.parameters.get("eta").bounds or (1e-6, 1e12)
# Attempt to estimate parameters from the first two signal-dominant points
positive_mask = transient > 0
signal_floor = max(float(np.max(transient)), 1e-12) * 1e-3
idx_candidates = np.where(positive_mask & (transient > signal_floor))[0]
if idx_candidates.size < 2:
idx_candidates = np.where(positive_mask)[0]
if idx_candidates.size < 2:
logger.debug(
"Initialization skipped: insufficient positive transient points",
n_candidates=idx_candidates.size,
)
return False
i0, i1 = idx_candidates[0], idx_candidates[1]
t0, t1 = t_sorted[i0], t_sorted[i1]
y0, y1 = transient[i0], transient[i1]
if not (y0 > 0 and y1 > 0 and t1 > t0 and y1 != y0):
logger.debug(
"Initialization skipped: invalid transient values",
y0=y0,
y1=y1,
t0=t0,
t1=t1,
)
return False
ratio = y1 / y0
if ratio <= 0 or ratio < 1e-3:
logger.debug(
"Initialization skipped: invalid decay ratio",
ratio=ratio,
)
return False
with np.errstate(divide="ignore"):
tau_estimate = -(t1 - t0) / np.log(ratio)
if not (np.isfinite(tau_estimate) and tau_estimate > 0):
logger.debug(
"Initialization skipped: invalid tau estimate",
tau_estimate=tau_estimate,
)
return False
g0_estimate = float(y0 * np.exp(t0 / tau_estimate))
g0_guess = float(np.clip(g0_estimate, g0_bounds[0], g0_bounds[1]))
eta_guess = float(
np.clip(g0_guess * tau_estimate, eta_bounds[0], eta_bounds[1])
)
self.parameters.set_value("G0", g0_guess)
self.parameters.set_value("eta", eta_guess)
logger.debug(
"Maxwell relaxation initialization successful",
G0=g0_guess,
eta=eta_guess,
tau_estimate=tau_estimate,
baseline=baseline,
)
return True
except Exception as exc: # pragma: no cover - heuristic best effort
logger.debug(
"Maxwell relaxation initialization failed",
error_type=type(exc).__name__,
error_message=str(exc),
exc_info=True,
)
return False
def _predict(self, X, **kwargs):
"""Predict response based on input data.
Args:
X: RheoData object or independent variable array
**kwargs: Additional arguments (ignored for Maxwell)
Returns:
Predicted values as JAX array
"""
# Handle RheoData input
if isinstance(X, RheoData):
rheo_data = X
test_mode = detect_test_mode(rheo_data)
x_data = jnp.array(rheo_data.x)
else:
x_data = jnp.array(X)
# Use test_mode from last fit if available, otherwise default to RELAXATION
test_mode = getattr(self, "_test_mode", TestMode.RELAXATION)
# Get parameter values
G0 = self.parameters.get_value("G0")
eta = self.parameters.get_value("eta")
# Dispatch to appropriate prediction method
if test_mode == TestMode.RELAXATION:
return self._predict_relaxation(x_data, G0, eta) + getattr(
self, "_relaxation_offset", 0.0
)
elif test_mode == TestMode.CREEP:
return self._predict_creep(x_data, G0, eta)
elif test_mode == TestMode.OSCILLATION:
return self._predict_oscillation(x_data, G0, eta)
elif test_mode in (TestMode.ROTATION, TestMode.FLOW_CURVE):
return self._predict_rotation(x_data, G0, eta)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes predictions given input X and a parameter array.
CRITICAL: test_mode is now passed as parameter (NOT read from self._test_mode)
to ensure correct posteriors in Bayesian inference (v0.4.0 fix).
Args:
X: Independent variable (time, frequency, or shear rate)
params: Array of parameter values [G0, eta]
test_mode: Explicit test mode for predictions. If None, falls back
to self._test_mode for backward compatibility.
Returns:
Model predictions as JAX array
"""
# Extract parameters from array
G0 = params[0]
eta = params[1]
# Use explicit test_mode parameter (closure-captured in fit_bayesian)
# Fall back to self._test_mode only for backward compatibility.
# R5-JAX-009: This Python-level TestMode dispatch is safe ONLY because
# `test_mode` is a concrete Python enum value captured at closure-build
# time by BayesianMixin._build_numpyro_model(). It must NEVER be a
# JAX tracer — passing a traced test_mode into model_function would
# freeze one branch at trace time and silently produce wrong predictions.
# All 20+ models share this pattern; enforce by keeping model_function
# signatures as `test_mode: TestMode | str | None` (never jnp.ndarray).
if test_mode is None:
test_mode = getattr(self, "_test_mode", TestMode.RELAXATION)
# Dispatch to appropriate prediction method.
# Read offset from kwargs (forwarded by NUTS closure) if available,
# otherwise fall back to self._ (correct when fit_bayesian follows
# fit on the same data — the recommended workflow).
_relaxation_offset = kwargs.get(
"_relaxation_offset", getattr(self, "_relaxation_offset", 0.0)
)
if test_mode == TestMode.RELAXATION:
return self._predict_relaxation(X, G0, eta) + _relaxation_offset
elif test_mode == TestMode.CREEP:
return self._predict_creep(X, G0, eta)
elif test_mode == TestMode.OSCILLATION:
return self._predict_oscillation(X, G0, eta)
elif test_mode in (TestMode.ROTATION, TestMode.FLOW_CURVE):
return self._predict_rotation(X, G0, eta)
else:
raise ValueError(f"Unsupported test mode: {test_mode}")
@staticmethod
@jax.jit
def _predict_relaxation(t: jnp.ndarray, G0: float, eta: float) -> jnp.ndarray:
"""Predict relaxation modulus G(t).
Theory: G(t) = G0 * exp(-t/tau) where tau = eta/G0
Args:
t: Time array (s)
G0: Elastic modulus (Pa)
eta: Viscosity (Pa·s)
Returns:
Relaxation modulus G(t) in Pa
"""
tau = eta / G0 # Relaxation time
return G0 * jnp.exp(-t / tau)
@staticmethod
@jax.jit
def _predict_creep(t: jnp.ndarray, G0: float, eta: float) -> jnp.ndarray:
"""Predict creep compliance J(t).
Theory: J(t) = 1/G0 + t/eta
Args:
t: Time array (s)
G0: Elastic modulus (Pa)
eta: Viscosity (Pa·s)
Returns:
Creep compliance J(t) in 1/Pa
"""
return (1.0 / G0) + (t / eta)
@staticmethod
@jax.jit
def _predict_oscillation(omega: jnp.ndarray, G0: float, eta: float) -> jnp.ndarray:
"""Predict complex modulus G*(omega).
Theory:
G'(omega) = G0 * (omega*tau)^2 / (1 + (omega*tau)^2)
G''(omega) = G0 * omega*tau / (1 + (omega*tau)^2)
G*(omega) = G'(omega) + i*G''(omega)
Args:
omega: Angular frequency array (rad/s)
G0: Elastic modulus (Pa)
eta: Viscosity (Pa·s)
Returns:
Complex modulus G*(omega) in Pa
"""
tau = eta / G0 # Relaxation time
omega_tau = omega * tau
omega_tau_sq = omega_tau**2
# Storage modulus G'
G_prime = G0 * omega_tau_sq / (1.0 + omega_tau_sq)
# Loss modulus G''
G_double_prime = G0 * omega_tau / (1.0 + omega_tau_sq)
# Complex modulus
return G_prime + 1j * G_double_prime
@staticmethod
@jax.jit
def _predict_rotation(gamma_dot: jnp.ndarray, G0: float, eta: float) -> jnp.ndarray:
"""Predict steady shear stress sigma(gamma_dot).
Theory: sigma = eta * gamma_dot (Newtonian flow)
Args:
gamma_dot: Shear rate array (1/s)
G0: Elastic modulus (Pa) - not used but kept for interface consistency
eta: Viscosity (Pa·s)
Returns:
Shear stress in Pa
"""
return eta * gamma_dot
[docs]
def get_relaxation_time(self) -> float:
"""Get characteristic relaxation time tau = eta/G0.
Returns:
Relaxation time in seconds
"""
G0 = self.parameters.get_value("G0")
eta = self.parameters.get_value("eta")
return eta / G0
[docs]
def __repr__(self) -> str:
"""String representation of Maxwell model."""
G0 = self.parameters.get_value("G0")
eta = self.parameters.get_value("eta")
tau = self.get_relaxation_time()
return f"Maxwell(G0={G0:.2e} Pa, eta={eta:.2e} Pa·s, tau={tau:.2e} s)"
__all__ = ["Maxwell"]