Source code for rheojax.models.fractional.fractional_jeffreys

"""Fractional Jeffreys Model (FJM).

This model combines two dashpots and one SpringPot in a parallel-series
arrangement, providing viscous flow with fractional relaxation behavior.

Theory
------
The FJM model consists of:
- Dashpot (η_1) in parallel with
- Series combination of dashpot (η_2) and SpringPot

Relaxation modulus:
    G(t) = (η_1/τ_1) * t^(-α) * E_{1-α,1-α}(-(t/τ_1)^(1-α))

where:
- τ_1 = η_2 / characteristic_modulus (relaxation time)
- E_{α,β} is the two-parameter Mittag-Leffler function

Complex modulus:
    G*(ω) = η_1(iω) * [1 + (iωτ_2)^α] / [1 + (iωτ_1)^α]

Parameters
----------
eta1 : float
    First viscosity (Pa·s), bounds [1e-6, 1e12]
eta2 : float
    Second viscosity (Pa·s), bounds [1e-6, 1e12]
alpha : float
    Fractional order, bounds [0.0, 1.0]
tau1 : float
    Relaxation time (s), bounds [1e-6, 1e6]

Limit Cases
-----------
- alpha → 0: Two dashpots in parallel (Newtonian fluid)
- alpha → 1: Classical Jeffreys model (viscoelastic liquid)

References
----------
- Mainardi, F. (2010). Fractional Calculus and Waves in Linear Viscoelasticity
- Jeffreys, H. (1929). The Earth
- Friedrich, C. (1991). Rheol. Acta 30, 151-158
"""

from __future__ import annotations

from rheojax.core.jax_config import safe_import_jax
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS

jax, jnp = safe_import_jax()

jax_gamma = jax.scipy.special.gamma

from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger, log_fit
from rheojax.utils.mittag_leffler import mittag_leffler_e2

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "fractional_jeffreys", protocols=[ Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.FLOW_CURVE, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FractionalJeffreysModel(BaseModel): """Fractional Jeffreys model. A fractional viscoelastic liquid model combining viscous flow with fractional relaxation behavior. Test Modes ---------- - Relaxation: Supported - Creep: Supported - Oscillation: Supported - Rotation: Supported (viscous flow at low frequencies) Examples -------- >>> import jax.numpy as jnp >>> from rheojax.models import FractionalJeffreysModel >>> >>> # Create model >>> model = FractionalJeffreysModel() >>> >>> # Set parameters >>> model.set_params(eta1=1000.0, eta2=500.0, alpha=0.5, tau1=1.0) >>> >>> # Predict relaxation modulus >>> t = jnp.logspace(-2, 2, 50) >>> G_t = model.predict(t) """
[docs] def __init__(self): """Initialize Fractional Jeffreys model.""" super().__init__() # Define parameters with bounds and descriptions self.parameters = ParameterSet() self.parameters.add( name="eta1", value=1000.0, bounds=(1e-6, 1e12), units="Pa·s", description="First viscosity", ) self.parameters.add( name="eta2", value=500.0, bounds=(1e-6, 1e12), units="Pa·s", description="Second viscosity", ) self.parameters.add( name="alpha", value=0.5, bounds=FRACTIONAL_ORDER_BOUNDS, units="", description="Fractional order", ) self.parameters.add( name="tau1", value=1.0, bounds=(1e-6, 1e6), units="s", description="Relaxation time", )
@staticmethod @jax.jit def _predict_relaxation( t: jnp.ndarray, eta1: float, eta2: float, alpha: float, tau1: float, ) -> jnp.ndarray: """Predict relaxation modulus G(t). G(t) = (η_1/τ_1) * t^(-α) * E_{1-α,1-α}(-(t/τ_1)^(1-α)) Parameters ---------- t : jnp.ndarray Time array (s) eta1 : float First viscosity (Pa·s) eta2 : float Second viscosity (Pa·s) alpha : float Fractional order tau1 : float Relaxation time (s) Returns ------- jnp.ndarray Relaxation modulus G(t) (Pa) """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha to safe range (works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) # Parameters for two-parameter Mittag-Leffler: E_{1-α,1-α} ml_alpha = 1.0 - alpha_safe ml_beta = 1.0 - alpha_safe tau1_safe = tau1 + epsilon eta1_safe = eta1 + epsilon # P2-FRAC-003: Guard t=0 — power(0, -alpha_safe) = +inf when alpha>0. t_safe = jnp.maximum(t, 1e-30) # Compute fractional relaxation term # E_{1-α,1-α}(-(t/τ_1)^(1-α)) z = -jnp.power(t_safe / tau1_safe, ml_alpha) # Two-parameter Mittag-Leffler function with concrete α and β ml_term = mittag_leffler_e2(z, alpha=ml_alpha, beta=ml_beta) # G(t) = (η_1/τ_1) * t^(-α) * E_{1-α,1-α}(-(t/τ_1)^(1-α)) prefactor = eta1_safe / tau1_safe G_t = prefactor * jnp.power(t_safe, -alpha_safe) * ml_term return G_t @staticmethod @jax.jit def _predict_creep( t: jnp.ndarray, eta1: float, eta2: float, alpha: float, tau1: float, ) -> jnp.ndarray: """Predict creep compliance J(t). For Jeffreys model, creep shows unbounded flow behavior. Parameters ---------- t : jnp.ndarray Time array (s) eta1 : float First viscosity (Pa·s) eta2 : float Second viscosity (Pa·s) alpha : float Fractional order tau1 : float Relaxation time (s) Returns ------- jnp.ndarray Creep compliance J(t) (1/Pa) """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha to safe range (works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) tau1_safe = tau1 + epsilon eta1_safe = eta1 + epsilon eta2_safe = eta2 + epsilon # For liquid-like behavior: J(t) ~ t/η_eff at long times # Effective viscosity combines both dashpots eta_eff = (eta1_safe * eta2_safe) / (eta1_safe + eta2_safe) # Short time: elastic-like response # Approximate using SpringPot behavior J_short = ( jnp.power(t, alpha_safe) * jax_gamma(1.0 + alpha_safe) / (eta1_safe * tau1_safe**alpha_safe) ) # Long time: Newtonian flow J_long = t / eta_eff # Crossover around tau1 weight = 1.0 - jnp.exp(-t / tau1_safe) J_t = J_short * (1.0 - weight) + J_long * weight return J_t @staticmethod @jax.jit def _predict_oscillation( omega: jnp.ndarray, eta1: float, eta2: float, alpha: float, tau1: float, ) -> jnp.ndarray: """Predict complex modulus G*(ω). G*(ω) = η_1(iω) * [1 + (iωτ_2)^α] / [1 + (iωτ_1)^α] where τ_2 = η_2/η_1 * τ_1 Parameters ---------- omega : jnp.ndarray Angular frequency array (rad/s) eta1 : float First viscosity (Pa·s) eta2 : float Second viscosity (Pa·s) alpha : float Fractional order tau1 : float Relaxation time (s) Returns ------- jnp.ndarray Complex modulus array with shape (..., 2) where [:, 0] is G' and [:, 1] is G'' """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha to safe range (works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) tau1_safe = tau1 + epsilon eta1_safe = eta1 + epsilon eta2_safe = eta2 + epsilon # Second time constant tau2 = (eta2_safe / eta1_safe) * tau1_safe # Compute (iω) i_omega = 1j * omega # Compute (iωτ_1)^α omega_tau1_alpha = jnp.power(omega * tau1_safe, alpha_safe) phase1 = jnp.pi * alpha_safe / 2.0 i_omega_tau1_alpha = omega_tau1_alpha * (jnp.cos(phase1) + 1j * jnp.sin(phase1)) # Compute (iωτ_2)^α omega_tau2_alpha = jnp.power(omega * tau2, alpha_safe) phase2 = jnp.pi * alpha_safe / 2.0 i_omega_tau2_alpha = omega_tau2_alpha * (jnp.cos(phase2) + 1j * jnp.sin(phase2)) # Complex modulus: G*(ω) = η_1(iω) * [1 + (iωτ_2)^α] / [1 + (iωτ_1)^α] numerator = 1.0 + i_omega_tau2_alpha denominator = 1.0 + i_omega_tau1_alpha G_star = eta1_safe * i_omega * (numerator / denominator) # Extract storage and loss moduli G_prime = jnp.real(G_star) G_double_prime = jnp.imag(G_star) return jnp.stack([G_prime, G_double_prime], axis=-1) def _predict_rotation( self, gamma_dot: jnp.ndarray, eta1: float, eta2: float, alpha: float, tau1: float, ) -> jnp.ndarray: """Predict steady shear viscosity η(γ̇). For Jeffreys model at steady state: η = η_1 (approximately, since it's a liquid) Parameters ---------- gamma_dot : jnp.ndarray Shear rate array (1/s) eta1 : float First viscosity (Pa·s) eta2 : float Second viscosity (Pa·s) alpha : float Fractional order tau1 : float Relaxation time (s) Returns ------- jnp.ndarray Viscosity η (Pa·s) """ # At steady state, Jeffreys model behaves as Newtonian # with effective viscosity dominated by parallel dashpot eta1 = eta1 + 1e-12 # Constant viscosity (Newtonian behavior) eta = jnp.full_like(gamma_dot, eta1) return eta def _fit(self, X: jnp.ndarray, y: jnp.ndarray, **kwargs) -> FractionalJeffreysModel: """Fit model to data using NLSQ TRF optimization. Parameters ---------- X : jnp.ndarray Independent variable (time or frequency) y : jnp.ndarray Dependent variable (modulus or compliance) **kwargs : dict Additional fitting options Returns ------- self Fitted model instance """ from rheojax.core.test_modes import TestMode from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Detect test mode test_mode_str = kwargs.get("test_mode", "relaxation") # Convert string to TestMode enum if isinstance(test_mode_str, str): test_mode_map = { "relaxation": TestMode.RELAXATION, "creep": TestMode.CREEP, "oscillation": TestMode.OSCILLATION, "rotation": TestMode.ROTATION, } test_mode = test_mode_map.get(test_mode_str, TestMode.RELAXATION) else: test_mode = test_mode_str # Store test mode for model_function self._test_mode = test_mode # Get test mode string for logging test_mode_log = ( test_mode.value if hasattr(test_mode, "value") else str(test_mode) ) x_arr = jnp.array(X) data_shape = (int(x_arr.shape[0]),) if hasattr(x_arr, "shape") else None with log_fit( logger, model="FractionalJeffreys", data_shape=data_shape, test_mode=test_mode_log, ) as ctx: logger.debug( "Starting Fractional Jeffreys model fit", test_mode=test_mode_log, n_points=data_shape[0] if data_shape else None, initial_eta1=self.parameters.get_value("eta1"), initial_eta2=self.parameters.get_value("eta2"), initial_alpha=self.parameters.get_value("alpha"), initial_tau1=self.parameters.get_value("tau1"), ) # Smart initialization for oscillation mode (Issue #9) if test_mode == TestMode.OSCILLATION: try: import numpy as np from rheojax.utils.initialization import ( initialize_fractional_jeffreys, ) success = initialize_fractional_jeffreys( np.array(X), np.array(y), self.parameters ) if success: logger.debug( "Smart initialization applied from frequency-domain features", eta1=self.parameters.get_value("eta1"), eta2=self.parameters.get_value("eta2"), alpha=self.parameters.get_value("alpha"), tau1=self.parameters.get_value("tau1"), ) except Exception as e: # Silent fallback to defaults - don't break if initialization fails logger.debug( "Smart initialization failed, using defaults", error_type=type(e).__name__, error_message=str(e), ) # Create stateless model function for optimization def model_fn(x, params): """Model function for optimization (stateless).""" eta1, eta2, alpha, tau1 = params[0], params[1], params[2], params[3] # Direct prediction based on test mode (stateless) if test_mode == TestMode.RELAXATION: return self._predict_relaxation(x, eta1, eta2, alpha, tau1) elif test_mode == TestMode.CREEP: return self._predict_creep(x, eta1, eta2, alpha, tau1) elif test_mode == TestMode.OSCILLATION: return self._predict_oscillation(x, eta1, eta2, alpha, tau1) elif test_mode == TestMode.ROTATION: return self._predict_rotation(x, eta1, eta2, alpha, tau1) else: raise ValueError(f"Unsupported test mode: {test_mode}") # Create objective function logger.debug("Creating least squares objective", normalize=True) objective = create_least_squares_objective( model_fn, jnp.array(X), jnp.array(y), normalize=True ) # Optimize using NLSQ TRF logger.debug( "Starting NLSQ optimization", use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) try: result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) except Exception as e: logger.error( "NLSQ optimization raised exception", error_type=type(e).__name__, error_message=str(e), exc_info=True, ) raise # Validate optimization succeeded if not result.success: logger.error( "Optimization failed", message=result.message, test_mode=test_mode_log, ) raise RuntimeError( f"Optimization failed: {result.message}. " f"Try adjusting initial values, bounds, or max_iter." ) self.fitted_ = True # Log fitted parameters fitted_eta1 = self.parameters.get_value("eta1") fitted_eta2 = self.parameters.get_value("eta2") fitted_alpha = self.parameters.get_value("alpha") fitted_tau1 = self.parameters.get_value("tau1") logger.debug( "Fractional Jeffreys fit completed successfully", fitted_eta1=fitted_eta1, fitted_eta2=fitted_eta2, fitted_alpha=fitted_alpha, fitted_tau1=fitted_tau1, ) ctx["eta1"] = fitted_eta1 ctx["eta2"] = fitted_eta2 ctx["alpha"] = fitted_alpha ctx["tau1"] = fitted_tau1 return self def _predict(self, X: jnp.ndarray, **kwargs) -> jnp.ndarray: """Predict response for given input. Parameters ---------- X : jnp.ndarray Independent variable **kwargs Additional arguments (test_mode handled via self._test_mode) Returns ------- jnp.ndarray Predicted values """ from rheojax.core.test_modes import TestMode # Get parameters eta1 = self.parameters.get_value("eta1") eta2 = self.parameters.get_value("eta2") alpha = self.parameters.get_value("alpha") tau1 = self.parameters.get_value("tau1") # Dispatch based on test_mode if set, otherwise auto-detect assert ( eta1 is not None and eta2 is not None and alpha is not None and tau1 is not None ) _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None) ) if test_mode in ("oscillation", TestMode.OSCILLATION): return self._predict_oscillation(X, eta1, eta2, alpha, tau1) elif test_mode in ("relaxation", TestMode.RELAXATION): return self._predict_relaxation(X, eta1, eta2, alpha, tau1) elif test_mode in ("creep", TestMode.CREEP): return self._predict_creep(X, eta1, eta2, alpha, tau1) elif test_mode in ( "flow_curve", "rotation", TestMode.FLOW_CURVE, TestMode.ROTATION, ): return self._predict_rotation(X, eta1, eta2, alpha, tau1) # Auto-detect test mode (legacy fallback) if jnp.all(X > 0) and len(X) > 1: log_range = jnp.log10(jnp.max(X)) - jnp.log10(jnp.min(X) + 1e-12) if log_range > 3: return self._predict_oscillation(X, eta1, eta2, alpha, tau1) # Default to relaxation return self._predict_relaxation(X, eta1, eta2, alpha, tau1)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes predictions given input X and a parameter array. Args: X: Independent variable (time or frequency) params: Array of parameter values [eta1, eta2, alpha, tau1] Returns: Model predictions as JAX array """ from rheojax.core.test_modes import TestMode # Extract parameters from array (in order they were added to ParameterSet) eta1 = params[0] eta2 = params[1] alpha = params[2] tau1 = params[3] # Use test_mode from last fit if available, otherwise default to RELAXATION # Use explicit test_mode parameter (closure-captured in fit_bayesian) # Fall back to self._test_mode only for backward compatibility if test_mode is None: test_mode = getattr(self, "_test_mode", TestMode.RELAXATION) # Normalize test_mode to handle both string and TestMode enum if hasattr(test_mode, "value"): test_mode = test_mode.value # Call appropriate prediction function based on test mode if test_mode == TestMode.RELAXATION: return self._predict_relaxation(X, eta1, eta2, alpha, tau1) elif test_mode == TestMode.CREEP: return self._predict_creep(X, eta1, eta2, alpha, tau1) elif test_mode == TestMode.OSCILLATION: stacked = self._predict_oscillation(X, eta1, eta2, alpha, tau1) return stacked[..., 0] + 1j * stacked[..., 1] elif test_mode in ( TestMode.ROTATION, TestMode.FLOW_CURVE, "rotation", "flow_curve", ): return self._predict_rotation(X, eta1, eta2, alpha, tau1) else: # Default to relaxation mode for Jeffreys model return self._predict_relaxation(X, eta1, eta2, alpha, tau1)
# Convenience alias FJM = FractionalJeffreysModel __all__ = ["FractionalJeffreysModel", "FJM"]