Source code for rheojax.models.fractional.fractional_burgers

"""Fractional Burgers Model (FBM).

This model combines Maxwell and Kelvin-Voigt elements in series with
fractional derivatives, providing four relaxation mechanisms for
complex viscoelastic behavior.

Theory
------
The FBM model consists of:
- Maxwell element (spring + dashpot) in series with
- Fractional Kelvin-Voigt element (spring + SpringPot)

Creep compliance:
    J(t) = J_g + (t^α)/(η_1 * Γ(1+α)) + J_k * (1 - E_α(-(t/τ_k)^α))

where:
- J_g: Glassy compliance (instantaneous)
- η_1: Viscosity (Maxwell arm)
- J_k: Kelvin compliance
- τ_k: Retardation time

Parameters
----------
Jg : float
    Glassy compliance (1/Pa), bounds [1e-9, 1e3]
eta1 : float
    Viscosity (Pa·s), bounds [1e-6, 1e12]
Jk : float
    Kelvin compliance (1/Pa), bounds [1e-9, 1e3]
alpha : float
    Fractional order, bounds [0.0, 1.0]
tau_k : float
    Retardation time (s), bounds [1e-6, 1e6]

Limit Cases
-----------
- alpha → 0: Classical Burgers model with Newtonian flow
- alpha → 1: Burgers model with power-law flow

References
----------
- Mainardi, F. (2010). Fractional Calculus and Waves in Linear Viscoelasticity
- Bagley, R.L. & Torvik, P.J. (1986). J. Rheol. 30, 133-155
- Schiessel, H. & Blumen, A. (1993). J. Phys. A: Math. Gen. 26, 5057
"""

from __future__ import annotations

from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger, log_fit
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS

jax, jnp = safe_import_jax()

jax_gamma = jax.scipy.special.gamma

from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.utils.mittag_leffler import mittag_leffler_e

logger = get_logger(__name__)


[docs] @ModelRegistry.register( "fractional_burgers", protocols=[ Protocol.CREEP, Protocol.RELAXATION, Protocol.OSCILLATION, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FractionalBurgersModel(BaseModel): """Fractional Burgers model. A four-parameter fractional viscoelastic model combining instantaneous compliance, viscous flow, and retardation. Test Modes ---------- - Relaxation: Supported (via inversion) - Creep: Supported (primary mode) - Oscillation: Supported - Rotation: Partial support (viscous flow at low frequencies) Examples -------- >>> import jax.numpy as jnp >>> from rheojax.models import FractionalBurgersModel >>> >>> # Create model >>> model = FractionalBurgersModel() >>> >>> # Set parameters >>> model.set_params(Jg=1e-6, eta1=1000.0, Jk=5e-6, alpha=0.5, tau_k=1.0) >>> >>> # Predict creep compliance >>> t = jnp.logspace(-2, 2, 50) >>> J_t = model.predict(t) """
[docs] def __init__(self): """Initialize Fractional Burgers model.""" super().__init__() # Define parameters with bounds and descriptions self.parameters = ParameterSet() self.parameters.add( name="Jg", value=1e-6, bounds=(1e-9, 1e3), units="1/Pa", description="Glassy compliance", ) self.parameters.add( name="eta1", value=1000.0, bounds=(1e-6, 1e12), units="Pa·s", description="Viscosity (Maxwell arm)", ) self.parameters.add( name="Jk", value=1e-5, bounds=(1e-9, 1e3), units="1/Pa", description="Kelvin compliance", ) self.parameters.add( name="alpha", value=0.5, bounds=FRACTIONAL_ORDER_BOUNDS, units="", description="Fractional order", ) self.parameters.add( name="tau_k", value=1.0, bounds=(1e-6, 1e6), units="s", description="Retardation time", )
@staticmethod @jax.jit def _predict_creep( t: jnp.ndarray, Jg: float, eta1: float, Jk: float, alpha: float, tau_k: float, ) -> jnp.ndarray: """Predict creep compliance J(t). J(t) = J_g + t^α/(η_1 * Γ(1+α)) + J_k * (1 - E_α(-(t/τ_k)^α)) Parameters ---------- t : jnp.ndarray Time array (s) Jg : float Glassy compliance (1/Pa) eta1 : float Viscosity (Pa·s) Jk : float Kelvin compliance (1/Pa) alpha : float Fractional order tau_k : float Retardation time (s) Returns ------- jnp.ndarray Creep compliance J(t) (1/Pa) """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha to safe range (works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) tau_k_safe = tau_k + epsilon eta1_safe = eta1 + epsilon # Instantaneous compliance (elastic response) J_instant = Jg # Fractional viscous flow term: t^α / (η_1 * Γ(1+α)) gamma_term = jax_gamma(1.0 + alpha_safe) J_flow = jnp.power(t, alpha_safe) / (eta1_safe * gamma_term) # Retardation term: J_k * (1 - E_α(-(t/τ_k)^α)) z = -jnp.power(t / tau_k_safe, alpha_safe) ml_term = mittag_leffler_e(z, alpha=alpha_safe) J_retard = Jk * (1.0 - ml_term) # Total creep compliance J_t = J_instant + J_flow + J_retard return J_t @staticmethod @jax.jit def _predict_relaxation( t: jnp.ndarray, Jg: float, eta1: float, Jk: float, alpha: float, tau_k: float, ) -> jnp.ndarray: """Predict relaxation modulus G(t). Note: Analytical relaxation modulus requires numerical inversion. This provides an approximation. Parameters ---------- t : jnp.ndarray Time array (s) Jg : float Glassy compliance (1/Pa) eta1 : float Viscosity (Pa·s) Jk : float Kelvin compliance (1/Pa) alpha : float Fractional order tau_k : float Retardation time (s) Returns ------- jnp.ndarray Relaxation modulus G(t) (Pa) """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha to safe range (works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) tau_k_safe = tau_k + epsilon # P2-FRAC-004: Guard t=0 — power(0/tau, -alpha_safe) = +inf when alpha>0. t_safe = jnp.maximum(t, 1e-30) # Approximate using inverse relationship # G(0) ≈ 1/J_g (instantaneous modulus) G_inst = 1.0 / (Jg + epsilon) # Long-time decay (fractional Maxwell-like) # G(t) ~ t^(-α) at intermediate times G_decay = G_inst * jnp.power(t_safe / tau_k_safe, -alpha_safe) # Smooth transition z = -jnp.power(t_safe / tau_k_safe, alpha_safe) ml_term = mittag_leffler_e(z, alpha=alpha_safe) # Combine terms G_t = G_inst * ml_term + G_decay * (1.0 - ml_term) return G_t @staticmethod @jax.jit def _predict_oscillation( omega: jnp.ndarray, Jg: float, eta1: float, Jk: float, alpha: float, tau_k: float, ) -> jnp.ndarray: """Predict complex modulus G*(ω). Computed from complex compliance: J*(ω) = J_g + (iω)^(-α)/(η_1*Γ(1-α)) + J_k/(1 + (iωτ_k)^α) G*(ω) = 1/J*(ω) Parameters ---------- omega : jnp.ndarray Angular frequency array (rad/s) Jg : float Glassy compliance (1/Pa) eta1 : float Viscosity (Pa·s) Jk : float Kelvin compliance (1/Pa) alpha : float Fractional order tau_k : float Retardation time (s) Returns ------- jnp.ndarray Complex modulus array with shape (..., 2) where [:, 0] is G' and [:, 1] is G'' """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha to safe range (works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) tau_k_safe = tau_k + epsilon eta1_safe = eta1 + epsilon # Instantaneous compliance J_inst = Jg # Fractional viscous term: (iω)^(-α) / (η_1 * Γ(1-α)) omega_neg_alpha = jnp.power(omega, -alpha_safe) phase = -jnp.pi * alpha_safe / 2.0 i_omega_neg_alpha = omega_neg_alpha * (jnp.cos(phase) + 1j * jnp.sin(phase)) gamma_term = jax_gamma(1.0 - alpha_safe) J_flow = i_omega_neg_alpha / (eta1_safe * gamma_term) # Retardation term: J_k / (1 + (iωτ_k)^α) omega_tau_alpha = jnp.power(omega * tau_k_safe, alpha_safe) phase_alpha = jnp.pi * alpha_safe / 2.0 i_omega_tau_alpha = omega_tau_alpha * ( jnp.cos(phase_alpha) + 1j * jnp.sin(phase_alpha) ) J_retard = Jk / (1.0 + i_omega_tau_alpha) # Total complex compliance J_star = J_inst + J_flow + J_retard # Complex modulus (inverse) G_star = 1.0 / (J_star + epsilon) # Extract storage and loss moduli G_prime = jnp.real(G_star) G_double_prime = jnp.imag(G_star) return jnp.stack([G_prime, G_double_prime], axis=-1) def _fit(self, X: jnp.ndarray, y: jnp.ndarray, **kwargs) -> FractionalBurgersModel: """Fit model to data using NLSQ TRF optimization. Parameters ---------- X : jnp.ndarray Independent variable (time or frequency) y : jnp.ndarray Dependent variable (modulus or compliance) **kwargs : dict Additional fitting options Returns ------- self Fitted model instance """ from rheojax.core.test_modes import TestMode from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Detect test mode test_mode_str = kwargs.get("test_mode", "creep") # Convert string to TestMode enum if isinstance(test_mode_str, str): test_mode_map = { "creep": TestMode.CREEP, "relaxation": TestMode.RELAXATION, "oscillation": TestMode.OSCILLATION, } test_mode = test_mode_map.get(test_mode_str, TestMode.CREEP) else: test_mode = test_mode_str # Store test mode for model_function self._test_mode = test_mode logger.info( "Starting FractionalBurgersModel fit", test_mode=( test_mode.value if hasattr(test_mode, "value") else str(test_mode) ), data_shape=X.shape, ) with log_fit(logger, model="FractionalBurgersModel", data_shape=X.shape) as ctx: # Smart initialization for oscillation mode (Issue #9) if test_mode == TestMode.OSCILLATION: try: import numpy as np from rheojax.utils.initialization import ( initialize_fractional_burgers, ) logger.debug( "Attempting smart initialization for oscillation mode", data_points=len(X), ) success = initialize_fractional_burgers( np.array(X), np.array(y), self.parameters ) if success: logger.debug( "Smart initialization applied from frequency-domain features", Jg=self.parameters.get_value("Jg"), eta1=self.parameters.get_value("eta1"), Jk=self.parameters.get_value("Jk"), alpha=self.parameters.get_value("alpha"), tau_k=self.parameters.get_value("tau_k"), ) except Exception as e: # Silent fallback to defaults - don't break if initialization fails logger.debug( "Smart initialization failed, using defaults", error=str(e), exc_info=True, ) # Create stateless model function for optimization def model_fn(x, params): """Model function for optimization (stateless).""" Jg, eta1, Jk, alpha, tau_k = ( params[0], params[1], params[2], params[3], params[4], ) # Direct prediction based on test mode (stateless) if test_mode == TestMode.RELAXATION: return self._predict_relaxation(x, Jg, eta1, Jk, alpha, tau_k) elif test_mode == TestMode.CREEP: return self._predict_creep(x, Jg, eta1, Jk, alpha, tau_k) elif test_mode == TestMode.OSCILLATION: return self._predict_oscillation(x, Jg, eta1, Jk, alpha, tau_k) else: raise ValueError(f"Unsupported test mode: {test_mode}") # Create objective function logger.debug("Creating least squares objective function") objective = create_least_squares_objective( model_fn, jnp.array(X), jnp.array(y), normalize=True ) # Optimize using NLSQ TRF logger.debug( "Starting NLSQ optimization", method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) # Validate optimization succeeded if not result.success: logger.error( "NLSQ optimization failed", message=result.message, exc_info=True, ) raise RuntimeError( f"Optimization failed: {result.message}. " f"Try adjusting initial values, bounds, or max_iter." ) self.fitted_ = True ctx["success"] = True ctx["fitted_params"] = { "Jg": self.parameters.get_value("Jg"), "eta1": self.parameters.get_value("eta1"), "Jk": self.parameters.get_value("Jk"), "alpha": self.parameters.get_value("alpha"), "tau_k": self.parameters.get_value("tau_k"), } logger.info( "FractionalBurgersModel fit completed", Jg=self.parameters.get_value("Jg"), eta1=self.parameters.get_value("eta1"), Jk=self.parameters.get_value("Jk"), alpha=self.parameters.get_value("alpha"), tau_k=self.parameters.get_value("tau_k"), ) return self def _predict(self, X: jnp.ndarray, **kwargs) -> jnp.ndarray: """Predict response for given input. Parameters ---------- X : jnp.ndarray Independent variable **kwargs Additional arguments (test_mode handled via self._test_mode) Returns ------- jnp.ndarray Predicted values """ # Get parameters Jg = self.parameters.get_value("Jg") eta1 = self.parameters.get_value("eta1") Jk = self.parameters.get_value("Jk") alpha = self.parameters.get_value("alpha") tau_k = self.parameters.get_value("tau_k") # Dispatch based on test_mode if set, otherwise auto-detect from rheojax.core.test_modes import TestMode _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None) ) if test_mode in ("oscillation", TestMode.OSCILLATION): return self._predict_oscillation(X, Jg, eta1, Jk, alpha, tau_k) elif test_mode in ("relaxation", TestMode.RELAXATION): return self._predict_relaxation(X, Jg, eta1, Jk, alpha, tau_k) elif test_mode in ("creep", TestMode.CREEP): return self._predict_creep(X, Jg, eta1, Jk, alpha, tau_k) # Auto-detect test mode (legacy fallback) if jnp.all(X > 0) and len(X) > 1: log_range = jnp.log10(jnp.max(X)) - jnp.log10(jnp.min(X) + 1e-12) if log_range > 3: return self._predict_oscillation(X, Jg, eta1, Jk, alpha, tau_k) # Default to creep (primary mode for Burgers) return self._predict_creep(X, Jg, eta1, Jk, alpha, tau_k)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes predictions given input X and a parameter array. Args: X: Independent variable (time or frequency) params: Array of parameter values [Jg, eta1, Jk, alpha, tau_k] Returns: Model predictions as JAX array """ from rheojax.core.test_modes import TestMode # Extract parameters from array (in order they were added to ParameterSet) Jg = params[0] eta1 = params[1] Jk = params[2] alpha = params[3] tau_k = params[4] # Use test_mode from last fit if available, otherwise default to CREEP # Use explicit test_mode parameter (closure-captured in fit_bayesian) # Fall back to self._test_mode only for backward compatibility if test_mode is None: test_mode = getattr(self, "_test_mode", TestMode.CREEP) # Normalize test_mode to handle both string and TestMode enum if hasattr(test_mode, "value"): test_mode = test_mode.value logger.debug( "model_function evaluation", test_mode=str(test_mode), alpha=alpha, # Don't cast tracer to float input_shape=X.shape if hasattr(X, "shape") else len(X), ) # Call appropriate prediction function based on test mode if test_mode == TestMode.RELAXATION: logger.debug("Computing relaxation modulus with Mittag-Leffler evaluation") return self._predict_relaxation(X, Jg, eta1, Jk, alpha, tau_k) elif test_mode == TestMode.CREEP: logger.debug("Computing creep compliance with Mittag-Leffler evaluation") return self._predict_creep(X, Jg, eta1, Jk, alpha, tau_k) elif test_mode == TestMode.OSCILLATION: logger.debug("Computing complex modulus for oscillation mode") stacked = self._predict_oscillation(X, Jg, eta1, Jk, alpha, tau_k) return stacked[..., 0] + 1j * stacked[..., 1] else: # Default to creep mode for Burgers model logger.debug("Default to creep mode prediction") return self._predict_creep(X, Jg, eta1, Jk, alpha, tau_k)
# Convenience alias FBM = FractionalBurgersModel __all__ = ["FractionalBurgersModel", "FBM"]