Source code for rheojax.models.fractional.fractional_maxwell_gel

"""Fractional Maxwell Gel (FMG) model.

This model consists of a SpringPot element (power-law viscoelastic element) in series
with a dashpot (Newtonian viscous element). It captures the transition from power-law
viscoelastic behavior to terminal flow.

Mathematical Description:
    Relaxation Modulus: G(t) = c_α t^(-α) E_{1-α,1-α}(-t^(1-α)/τ)
    Complex Modulus: G*(ω) = c_α (iω)^α · (iωτ) / (1 + iωτ)
    Creep Compliance: J(t) = (1/c_α) t^α E_{1+α,1+α}(-(t/τ)^(1-α))

where τ = η / c_α^(1/(1-α)) is a characteristic relaxation time.

Parameters:
    c_alpha (float): Material constant (Pa·s^α), bounds [1e-3, 1e9]
    alpha (float): Power-law exponent, bounds [0.0, 1.0]
    eta (float): Viscosity (Pa·s), bounds [1e-6, 1e12]

Test Modes: Relaxation, Creep, Oscillation

References:
    - Blair, G. S., Veinoglou, B. C., & Caffyn, J. E. (1947). Limitations of the Newtonian
      time scale in relation to non-equilibrium rheological states and a theory of
      quasi-properties. Proc. R. Soc. Lond. A, 189(1016), 69-87.
    - Friedrich, C., & Braun, H. (1992). Generalized Cole-Cole behavior and its rheological
      relevance. Rheologica Acta, 31(4), 309-322.
"""

from __future__ import annotations

from rheojax.core.jax_config import safe_import_jax
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS

jax, jnp = safe_import_jax()


import numpy as np

from rheojax.core.base import BaseModel, ParameterSet
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger, log_fit
from rheojax.utils.mittag_leffler import mittag_leffler_e2

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "fractional_maxwell_gel", protocols=[ Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FractionalMaxwellGel(BaseModel): """Fractional Maxwell Gel model: SpringPot in series with dashpot. This model describes the rheology of materials transitioning from power-law viscoelastic behavior to terminal flow, such as polymer solutions and gels. Attributes: parameters: ParameterSet with c_alpha, alpha, eta Examples: >>> from rheojax.models import FractionalMaxwellGel >>> from rheojax.core.data import RheoData >>> import numpy as np >>> >>> # Create model with parameters >>> model = FractionalMaxwellGel() >>> model.parameters.set_value('c_alpha', 1e5) >>> model.parameters.set_value('alpha', 0.5) >>> model.parameters.set_value('eta', 1e3) >>> >>> # Predict relaxation modulus >>> t = np.logspace(-3, 3, 50) >>> data = RheoData(x=t, y=np.zeros_like(t), domain='time') >>> data.metadata['test_mode'] = 'relaxation' >>> G_t = model.predict(data) >>> >>> # Predict complex modulus >>> omega = np.logspace(-2, 2, 50) >>> data = RheoData(x=omega, y=np.zeros_like(omega), domain='frequency') >>> G_star = model.predict(data) """
[docs] def __init__(self): """Initialize Fractional Maxwell Gel model.""" super().__init__() self.parameters = ParameterSet() self.parameters.add( name="c_alpha", value=10.0, # Chosen to keep tau numerically stable across alpha ∈ [0,1] bounds=(1e-3, 1e9), units="Pa·s^α", description="SpringPot material constant", ) self.parameters.add( name="alpha", value=0.5, bounds=FRACTIONAL_ORDER_BOUNDS, units="dimensionless", description="Power-law exponent", ) self.parameters.add( name="eta", value=1e4, # Chosen to keep tau~O(1) for alpha=0.5 with c_alpha=100 bounds=(1e-6, 1e12), units="Pa·s", description="Dashpot viscosity", ) self.fitted_ = False
[docs] def bayesian_nuts_kwargs(self) -> dict: """Prefer conservative NUTS settings for the stiff Mittag-Leffler kernel.""" return {"target_accept_prob": 0.99}
def _compute_tau(self, c_alpha: float, alpha: float) -> float: """Compute characteristic relaxation time. Args: c_alpha: SpringPot constant alpha: Power-law exponent Returns: Characteristic time τ = (η / c_α)^(1/(1-α)) """ eta = self.parameters.get_value("eta") # Add small epsilon to prevent division by zero epsilon = 1e-12 try: # Check for alpha close to 1 if alpha > 1.0 - 1e-6: return float("inf") # Use algebraic simplification to avoid overflow # tau^(1-alpha) = eta / c_alpha # tau = (eta / c_alpha)^(1/(1-alpha)) assert eta is not None and c_alpha is not None exponent = 1.0 / (1.0 - alpha + epsilon) base = eta / c_alpha # Check for potential overflow before computing if base > 1.0 and exponent > 700: # approx limit for exp(709) return float("inf") return base**exponent except (OverflowError, ZeroDivisionError): return float("inf") @staticmethod @jax.jit def _predict_relaxation_jax( t: jnp.ndarray, c_alpha: float, alpha: float, eta: float ) -> jnp.ndarray: """Predict relaxation modulus G(t) using JAX. G(t) = c_α t^(-α) E_{1-α,1-α}(-t^(1-α)/τ) """ # Add small epsilon to prevent issues at t=0 and with alpha=1 epsilon = 1e-12 # Clip alpha to safe range (now works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) # Compute safe values t_safe = jnp.maximum(t, epsilon) # Compute argument for Mittag-Leffler function # z = - (t/τ)^(1-α) # Using algebraic simplification to avoid overflow in tau calculation: # tau = (eta/c_alpha)^(1/(1-alpha)) # z = - (t * (c_alpha/eta)^(1/(1-alpha)))^(1-alpha) # z = - t^(1-alpha) * (c_alpha/eta) beta_safe = 1.0 - alpha_safe z = -jnp.power(t_safe, beta_safe) * (c_alpha / eta) # Compute E_{1-α,1-α}(z) ml_value = mittag_leffler_e2(z, alpha=beta_safe, beta=beta_safe) # Compute G(t) = c_α * t^(-α) * E(...) G_t = c_alpha * jnp.power(t_safe, -alpha_safe) * ml_value return G_t @staticmethod @jax.jit def _predict_creep_jax( t: jnp.ndarray, c_alpha: float, alpha: float, eta: float ) -> jnp.ndarray: """Predict creep compliance J(t) using JAX. J(t) = (1/c_α) t^α E_{1+α,1+α}(-(t/τ)^(1-α)) """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha to safe range (now works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) # Compute safe values t_safe = jnp.maximum(t, epsilon) # Compute argument for Mittag-Leffler function # z = - (t/τ)^(1-α) = - t^(1-alpha) * (c_alpha/eta) beta_exp = 1.0 - alpha_safe z = -jnp.power(t_safe, beta_exp) * (c_alpha / eta) # Compute E_{1+α,1+α}(z) ml_alpha = 1.0 + alpha_safe ml_beta = 1.0 + alpha_safe ml_value = mittag_leffler_e2(z, alpha=ml_alpha, beta=ml_beta) # Compute J(t) # J(t) = (1/c_alpha) * t^alpha * E(...) J_t = (1.0 / c_alpha) * jnp.power(t_safe, alpha_safe) * ml_value # Monotonicity enforced by physical parameter bounds, not in NUTS path return J_t @staticmethod @jax.jit def _predict_oscillation_jax( omega: jnp.ndarray, c_alpha: float, alpha: float, eta: float ) -> jnp.ndarray: """Predict complex modulus G*(ω) using JAX. G*(ω) = c_α (iω)^α / (1 + (iωτ)^(1-α)) """ # Add small epsilon epsilon = 1e-12 # Clip alpha to safe range (now works with JAX tracers) alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) # Compute beta for the denominator beta_safe = 1.0 - alpha_safe # Compute safe values omega_safe = jnp.maximum(omega, epsilon) # (iω)^α = |ω|^α * exp(i α π/2) omega_alpha = jnp.power(omega_safe, alpha_safe) phase_alpha = jnp.pi * alpha_safe / 2.0 i_omega_alpha = omega_alpha * (jnp.cos(phase_alpha) + 1j * jnp.sin(phase_alpha)) # (iωτ)^(1-α) = (iω)^(1-α) * τ^(1-α) # τ^(1-α) = [(eta/c_alpha)^(1/(1-alpha))]^(1-alpha) = eta/c_alpha # So term is (iω)^(1-α) * (eta/c_alpha) omega_beta = jnp.power(omega_safe, beta_safe) phase_beta = jnp.pi * beta_safe / 2.0 i_omega_beta = omega_beta * (jnp.cos(phase_beta) + 1j * jnp.sin(phase_beta)) denominator_term = i_omega_beta * (eta / c_alpha) # Complex modulus: G*(ω) = c_α (iω)^α / (1 + (iωτ)^(1-α)) # G* = c_alpha * i_omega_alpha / (1 + i_omega_beta * eta/c_alpha) # = i_omega_alpha / (1/c_alpha + i_omega_beta * eta/c_alpha^2) ? No. G_star = c_alpha * i_omega_alpha / (1.0 + denominator_term) # Extract storage and loss moduli G_prime = jnp.real(G_star) G_double_prime = jnp.imag(G_star) return jnp.stack([G_prime, G_double_prime], axis=-1) def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> FractionalMaxwellGel: """Fit model parameters to data. Args: X: Independent variable (time or frequency) y: Dependent variable (modulus or compliance) **kwargs: Additional fitting options Returns: self for method chaining """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Handle RheoData input if isinstance(X, RheoData): rheo_data = X x_data = jnp.array(rheo_data.x) y_data = jnp.array(rheo_data.y) test_mode = rheo_data.test_mode else: x_data = jnp.array(X) y_data = jnp.array(y) test_mode = kwargs.get("test_mode", "relaxation") # R13-FMG-001: Store test_mode so Bayesian inference picks up the correct # protocol (otherwise _resolve_test_mode defaults to 'relaxation'). self._test_mode = test_mode with log_fit(logger, model="FractionalMaxwellGel", data_shape=X.shape) as ctx: try: logger.info( "Starting Fractional Maxwell Gel model fit", test_mode=test_mode, n_points=len(X), ) logger.debug( "Input data statistics", x_range=(float(np.min(np.abs(X))), float(np.max(np.abs(X)))), y_range=(float(np.min(np.abs(y))), float(np.max(np.abs(y)))), ) ctx["test_mode"] = test_mode # Smart initialization for oscillation mode (Issue #9) if test_mode == "oscillation": try: from rheojax.utils.initialization import ( initialize_fractional_maxwell_gel, ) success = initialize_fractional_maxwell_gel( np.array(X), np.array(y), self.parameters ) if success: logger.debug( "Smart initialization applied from frequency-domain features", c_alpha=self.parameters.get_value("c_alpha"), alpha=self.parameters.get_value("alpha"), eta=self.parameters.get_value("eta"), ) except Exception as e: # Silent fallback to defaults - don't break if initialization fails logger.debug( "Smart initialization failed, using defaults", error=str(e), ) # Create objective function with stateless predictions def model_fn(x, params): """Model function for optimization (stateless).""" c_alpha, alpha, eta = params[0], params[1], params[2] # Direct prediction based on test mode (stateless, calls _jax methods) if test_mode == "relaxation": return self._predict_relaxation_jax(x, c_alpha, alpha, eta) elif test_mode == "creep": return self._predict_creep_jax(x, c_alpha, alpha, eta) elif test_mode == "oscillation": return self._predict_oscillation_jax(x, c_alpha, alpha, eta) else: raise ValueError(f"Unsupported test mode: {test_mode}") objective = create_least_squares_objective( model_fn, x_data, y_data, normalize=True ) logger.debug( "Running NLSQ optimization", use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) # Optimize using NLSQ (JAX enabled by default) result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) # Validate optimization succeeded if not result.success: logger.error( "Optimization failed", message=result.message, n_iterations=getattr(result, "nfev", None), ) raise RuntimeError( f"Optimization failed: {result.message}. " f"Try adjusting initial values, bounds, or max_iter." ) # Log final parameters c_alpha_val = self.parameters.get_value("c_alpha") alpha_val = self.parameters.get_value("alpha") eta_val = self.parameters.get_value("eta") assert c_alpha_val is not None and alpha_val is not None tau_val = self._compute_tau(c_alpha_val, alpha_val) ctx["c_alpha"] = c_alpha_val ctx["alpha"] = alpha_val ctx["eta"] = eta_val ctx["tau"] = tau_val ctx["cost"] = float(result.fun) if hasattr(result, "fun") else None logger.info( "Fractional Maxwell Gel model fit completed", c_alpha=c_alpha_val, alpha=alpha_val, eta=eta_val, tau=tau_val, cost=ctx["cost"], ) self.fitted_ = True return self except Exception as e: logger.error( "Fractional Maxwell Gel model fit failed", test_mode=test_mode, error=str(e), exc_info=True, ) raise def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """Internal predict implementation. Args: X: RheoData object or array of x-values **kwargs: Additional arguments (test_mode handled via self._test_mode) Returns: Predicted values """ # Handle RheoData input if isinstance(X, RheoData): return self.predict_rheodata(X) # Handle raw array input from rheojax.core.test_modes import TestMode x = jnp.asarray(X) c_alpha = self.parameters.get_value("c_alpha") alpha = self.parameters.get_value("alpha") eta = self.parameters.get_value("eta") _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None) ) if test_mode in ("oscillation", TestMode.OSCILLATION): result = self._predict_oscillation_jax(x, c_alpha, alpha, eta) elif test_mode in ("creep", TestMode.CREEP): result = self._predict_creep_jax(x, c_alpha, alpha, eta) else: result = self._predict_relaxation_jax(x, c_alpha, alpha, eta) return np.array(result)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes predictions given input X and a parameter array. Args: X: Independent variable (time or frequency) params: Array of parameter values [c_alpha, alpha, eta] Returns: Model predictions as JAX array """ # Use explicit test_mode parameter (closure-captured in fit_bayesian) # Fall back to self._test_mode only for backward compatibility if test_mode is None: test_mode = getattr(self, "_test_mode", "relaxation") # Normalize test_mode to string if hasattr(test_mode, "value"): test_mode = test_mode.value # Extract parameter names from function signature params_dict = {name: params[i] for i, name in enumerate(self.parameters.keys())} # Dispatch to appropriate prediction method if test_mode == "relaxation": return self._predict_relaxation_jax(X, **params_dict) elif test_mode == "creep": return self._predict_creep_jax(X, **params_dict) elif test_mode == "oscillation": stacked = self._predict_oscillation_jax(X, **params_dict) return stacked[..., 0] + 1j * stacked[..., 1] else: # Default to relaxation for unknown modes return self._predict_relaxation_jax(X, **params_dict)
[docs] def predict_rheodata( self, rheo_data: RheoData, test_mode: str | None = None ) -> RheoData: """Predict response for RheoData. Args: rheo_data: Input RheoData with x values test_mode: Test mode ('relaxation', 'creep', 'oscillation') If None, auto-detect from rheo_data Returns: RheoData with predicted y values """ # Auto-detect test mode if not provided if not isinstance(test_mode, str) or not test_mode: test_mode = rheo_data.test_mode # Get parameters c_alpha = self.parameters.get_value("c_alpha") alpha = self.parameters.get_value("alpha") eta = self.parameters.get_value("eta") # Convert input to JAX x = jnp.asarray(rheo_data.x) # Route to appropriate prediction method if test_mode == "relaxation": y_pred = self._predict_relaxation_jax(x, c_alpha, alpha, eta) elif test_mode == "creep": y_pred = self._predict_creep_jax(x, c_alpha, alpha, eta) elif test_mode == "oscillation": y_pred_stacked = self._predict_oscillation_jax(x, c_alpha, alpha, eta) y_pred = y_pred_stacked[..., 0] + 1j * y_pred_stacked[..., 1] else: raise ValueError( f"Unknown test mode: {test_mode}. " f"Must be 'relaxation', 'creep', or 'oscillation'" ) # Create output RheoData result = RheoData( x=np.array(x), y=np.array(y_pred), x_units=rheo_data.x_units, y_units=rheo_data.y_units, domain=rheo_data.domain, metadata=rheo_data.metadata.copy(), ) return result
[docs] def predict(self, X, test_mode: str | None = None, **kwargs): # type: ignore[override] """Predict response. R13-FMG-002: Delegates to BaseModel.predict() for plain-array inputs so that deformation_mode (E*->G* conversion) and test_mode restoration are handled correctly. Only RheoData inputs bypass super() because they return a RheoData wrapper object. Args: X: RheoData object or array of x-values test_mode: Test mode for prediction ('relaxation', 'creep', 'oscillation') Required when X is a raw array. If None, defaults to 'relaxation'. **kwargs: Additional arguments passed to BaseModel.predict() Returns: Predicted values (RheoData if input is RheoData, else array) """ if isinstance(X, RheoData): return self.predict_rheodata(X, test_mode=test_mode) else: return super().predict(X, test_mode=test_mode, **kwargs)