Source code for rheojax.models.fractional.fractional_maxwell_liquid

"""Fractional Maxwell Liquid (FML) model.

This model consists of a spring in series with a SpringPot element. It captures
the behavior of materials with elastic response at short times and power-law
relaxation at long times, typical of polymer melts and concentrated solutions.

Mathematical Description:
    Relaxation Modulus: G(t) = G_m E_{α,1}(-(t/τ_α)^α) = G_m E_α(-(t/τ_α)^α)
    Complex Modulus: G*(ω) = G_m (iωτ_α)^α / (1 + (iωτ_α)^α)
    Creep Compliance: J(t) = (1/G_m) + (t^α)/(G_m τ_α^α) E_{α,1+α}(-(t/τ_α)^α)

Parameters:
    Gm (float): Maxwell modulus (Pa), bounds [1e-3, 1e9]
    alpha (float): Power-law exponent, bounds [0.0, 1.0]
    tau_alpha (float): Relaxation time (s^α), bounds [1e-6, 1e6]

Test Modes: Relaxation, Creep, Oscillation

References:
    - Friedrich, C. (1991). Relaxation and retardation functions of the Maxwell model
      with fractional derivatives. Rheologica Acta, 30(2), 151-158.
    - Schiessel, H., Metzler, R., Blumen, A., & Nonnenmacher, T. F. (1995). Generalized
      viscoelastic models: their fractional equations with solutions. Journal of Physics
      A: Mathematical and General, 28(23), 6567.
"""

from __future__ import annotations

from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger, log_fit
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS

jax, jnp = safe_import_jax()


JAX_ARRAY_TYPES = tuple(
    t for t in (getattr(jax, "Array", None), jax.core.Tracer) if t is not None
)


import numpy as np

from rheojax.core.base import BaseModel, ParameterSet
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.utils.mittag_leffler import mittag_leffler_e2

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "fractional_maxwell_liquid", protocols=[ Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.FLOW_CURVE, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FractionalMaxwellLiquid(BaseModel): """Fractional Maxwell Liquid model: Spring in series with SpringPot. This model describes materials with elastic response at short times and power-law relaxation at long times, such as polymer melts. Attributes: parameters: ParameterSet with Gm, alpha, tau_alpha Examples: >>> from rheojax.models import FractionalMaxwellLiquid >>> from rheojax.core.data import RheoData >>> import numpy as np >>> >>> # Create model with parameters >>> model = FractionalMaxwellLiquid() >>> model.parameters.set_value('Gm', 1e6) >>> model.parameters.set_value('alpha', 0.7) >>> model.parameters.set_value('tau_alpha', 1.0) >>> >>> # Predict relaxation modulus >>> t = np.logspace(-3, 3, 50) >>> data = RheoData(x=t, y=np.zeros_like(t), domain='time') >>> data.metadata['test_mode'] = 'relaxation' >>> G_t = model.predict(data) """
[docs] def __init__(self): """Initialize Fractional Maxwell Liquid model.""" super().__init__() self.parameters = ParameterSet() self.parameters.add( name="Gm", value=1e6, bounds=(1e-3, 1e9), units="Pa", description="Maxwell modulus", ) self.parameters.add( name="alpha", value=0.5, bounds=FRACTIONAL_ORDER_BOUNDS, units="dimensionless", description="Power-law exponent", ) self.parameters.add( name="tau_alpha", value=1.0, bounds=(1e-6, 1e6), units="s^α", description="Relaxation time", ) self.fitted_ = False
[docs] def bayesian_prior_factory( self, param_name: str, lower: float | None, upper: float | None ): """Provide custom priors that stay near realistic data-informed scales.""" return None # Disable custom priors for stability
# stats = getattr(self, "_fml_bayes_stats", {}) # if lower is None or upper is None: # return None # # def _log_normal(target: float, scale: float = 0.6): # if not np.isfinite(target) or target <= 0: # return None # low = float(max(lower, 1e-12)) # high = float(max(upper, low * 1.01)) # log_low = np.log(low) # log_high = np.log(high) # loc = float(np.clip(np.log(target), log_low + 1e-6, log_high - 1e-6)) # base = dist.TruncatedNormal( # loc=loc, scale=scale, low=log_low, high=log_high # ) # return dist.TransformedDistribution(base, dist_transforms.ExpTransform()) # # if param_name == "Gm" and "gm_target" in stats: # return _log_normal(stats["gm_target"], scale=0.7) # if param_name == "tau_alpha" and "tau_target" in stats: # return _log_normal(stats["tau_target"], scale=0.7) # return None
[docs] def bayesian_parameter_bounds( self, bounds: dict[str, tuple[float | None, float | None]], X: np.ndarray, y: np.ndarray, test_mode, ) -> dict[str, tuple[float | None, float | None]]: """Tighten tau bounds based on data scale to avoid pathological samples.""" return bounds # Disable bounds tightening for stability
# stats: dict[str, float] = {} # # if "tau_alpha" in bounds and X.size > 0: # positive_times = np.asarray(X, dtype=float) # positive_times = positive_times[positive_times > 0] # if positive_times.size: # t_min = float(np.min(positive_times)) # t_max = float(np.max(positive_times)) # lower, upper = bounds["tau_alpha"] # new_lower = max(lower or 0.0, t_min * 0.2, 1e-5) # new_upper = min(upper or np.inf, t_max * 5.0) # if new_upper <= new_lower: # new_upper = new_lower * 10.0 # bounds["tau_alpha"] = (new_lower, new_upper) # # tau_geo = float(np.exp(0.5 * (np.log(new_lower) + np.log(new_upper)))) # stats["tau_target"] = tau_geo # # y_abs = np.asarray(y, dtype=float) # if y_abs.size: # gm_lower, gm_upper = bounds.get("Gm", (None, None)) # gm_lower = gm_lower if gm_lower is not None else 1e-3 # gm_upper = gm_upper if gm_upper is not None else y_abs.max() * 10.0 # median_scale = float(np.median(np.abs(y_abs))) # gm_target = float(np.clip(median_scale, gm_lower * 1.1, gm_upper * 0.9)) # stats["gm_target"] = gm_target # # if stats: # self._fml_bayes_stats = stats # # return bounds
[docs] def bayesian_nuts_kwargs(self) -> dict: """Prefer conservative NUTS settings for the stiff Mittag-Leffler kernel.""" return {"target_accept_prob": 0.999, "max_tree_depth": 12}
@staticmethod @jax.jit def _predict_relaxation_jax( t: jnp.ndarray, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict relaxation modulus G(t) using JAX. G(t) = G_m E_{α,1}(-(t/τ)^α) Args: t: Time array Gm: Maxwell modulus alpha: Power-law exponent tau_alpha: Relaxation time Returns: Relaxation modulus array """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha but allow traced values when running inside JAX alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) # Compute relaxation modulus t_safe = jnp.maximum(t, epsilon) tau_alpha_safe = jnp.maximum(tau_alpha, epsilon) # Compute argument for Mittag-Leffler function # z = - (t/τ)^α z = -jnp.power(t_safe / tau_alpha_safe, alpha_safe) # Compute E_{α,1}(z) ml_value = mittag_leffler_e2(z, alpha=alpha_safe, beta=1.0) # Compute G(t) # G(t) = Gm * E(...) G_t = Gm * ml_value return G_t @staticmethod @jax.jit def _predict_creep_jax( t: jnp.ndarray, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict creep compliance J(t) using JAX. J(t) = (1/G_m) + (t^α)/(G_m τ^α Γ(1+α)) Args: t: Time array Gm: Maxwell modulus alpha: Power-law exponent tau_alpha: Relaxation time Returns: Creep compliance array """ # Add small epsilon epsilon = 1e-12 alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) # Compute creep compliance t_safe = jnp.maximum(t, epsilon) tau_alpha_safe = jnp.maximum(tau_alpha, epsilon) # Instantaneous compliance (elastic part) J_instant = 1.0 / Gm # Viscous/Fractional part # J_frac = t^α / (G_m * τ^α * Γ(1+α)) num = jnp.power(t_safe, alpha_safe) denom = ( Gm * jnp.power(tau_alpha_safe, alpha_safe) * jax.scipy.special.gamma(1.0 + alpha_safe) ) J_t = J_instant + num / denom return J_t @staticmethod @jax.jit def _predict_oscillation_jax( omega: jnp.ndarray, Gm: float, alpha: float, tau_alpha: float ) -> jnp.ndarray: """Predict complex modulus G*(ω) using JAX. G*(ω) = G_m (iωτ_α)^α / (1 + (iωτ_α)^α) Args: omega: Angular frequency array Gm: Maxwell modulus alpha: Power-law exponent tau_alpha: Relaxation time Returns: Complex modulus array [G', G''] """ # Add small epsilon epsilon = 1e-12 alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) # Compute oscillation response omega_safe = jnp.maximum(omega, epsilon) tau_alpha_safe = jnp.maximum(tau_alpha, epsilon) # (iωτ_α)^α = |ωτ_α|^α * exp(i α π/2) omega_tau = omega_safe * tau_alpha_safe omega_tau_alpha = jnp.power(omega_tau, alpha_safe) phase_alpha = jnp.pi * alpha_safe / 2.0 cos_phase = jnp.cos(phase_alpha) sin_phase = jnp.sin(phase_alpha) i_omega_tau_alpha = omega_tau_alpha * (cos_phase + 1j * sin_phase) # Complex modulus # G* = Gm * X / (1 + X) where X = (iωτ)^α G_star = Gm * i_omega_tau_alpha / (1.0 + i_omega_tau_alpha) return jnp.stack([jnp.real(G_star), jnp.imag(G_star)], axis=-1) def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> FractionalMaxwellLiquid: """Fit model parameters to data. Args: X: Independent variable (time or frequency) y: Dependent variable (modulus or compliance) **kwargs: Additional fitting options Returns: self for method chaining """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Handle RheoData input if isinstance(X, RheoData): rheo_data = X x_data = jnp.array(rheo_data.x) y_data = jnp.array(rheo_data.y) test_mode = rheo_data.test_mode else: x_data = jnp.array(X) y_data = jnp.array(y) test_mode = kwargs.get("test_mode", "relaxation") # R13-FML-001: Store test_mode so Bayesian inference picks up the correct # protocol (otherwise _resolve_test_mode defaults to 'relaxation'). self._test_mode = test_mode # Determine data shape for logging data_shape = (len(X),) if hasattr(X, "__len__") else None with log_fit( logger, model="FractionalMaxwellLiquid", data_shape=data_shape, test_mode=test_mode if isinstance(test_mode, str) else str(test_mode), ) as ctx: logger.debug( "Starting FML fit", n_points=len(X) if hasattr(X, "__len__") else 1, test_mode=str(test_mode), initial_params=self.parameters.to_dict(), ) # Smart initialization for oscillation mode (Issue #9) if test_mode == "oscillation": try: from rheojax.utils.initialization import ( initialize_fractional_maxwell_liquid, ) success = initialize_fractional_maxwell_liquid( np.array(X), np.array(y), self.parameters ) if success: logger.debug( "Smart initialization applied from frequency-domain features", initialized_params=self.parameters.to_dict(), ) except Exception as e: logger.debug( "Smart initialization failed, using defaults", error=str(e), ) # Create objective function with stateless predictions def model_fn(x, params): """Model function for optimization (stateless).""" Gm, alpha, tau_alpha = params[0], params[1], params[2] # Direct prediction based on test mode (stateless, calls _jax methods) if test_mode == "relaxation": return self._predict_relaxation_jax(x, Gm, alpha, tau_alpha) elif test_mode == "creep": return self._predict_creep_jax(x, Gm, alpha, tau_alpha) elif test_mode == "oscillation": return self._predict_oscillation_jax(x, Gm, alpha, tau_alpha) else: raise ValueError(f"Unsupported test mode: {test_mode}") # Extract optimization strategy from kwargs (set by BaseModel.fit) use_log_residuals = kwargs.get("use_log_residuals", False) use_multi_start = kwargs.get("use_multi_start", False) n_starts = kwargs.get("n_starts", 5) perturb_factor = kwargs.get("perturb_factor", 0.3) logger.debug( "Creating least squares objective", normalize=True, use_log_residuals=use_log_residuals, ) objective = create_least_squares_objective( model_fn, x_data, y_data, normalize=True, use_log_residuals=use_log_residuals, ) # Choose optimization strategy try: if use_multi_start: from rheojax.utils.optimization import nlsq_multistart_optimize logger.debug( "Starting multi-start NLSQ optimization", n_starts=n_starts, perturb_factor=perturb_factor, method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) result = nlsq_multistart_optimize( objective, self.parameters, n_starts=n_starts, perturb_factor=perturb_factor, use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), verbose=kwargs.get("verbose", False), ) else: logger.debug( "Starting NLSQ optimization", method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) except Exception as e: logger.error( "NLSQ optimization raised exception", error_type=type(e).__name__, error=str(e), exc_info=True, ) raise # Validate optimization succeeded if not result.success: logger.error( "Optimization failed", message=result.message, final_params=self.parameters.to_dict(), ) raise RuntimeError( f"Optimization failed: {result.message}. " f"Try adjusting initial values, bounds, or max_iter." ) self.fitted_ = True ctx["final_params"] = self.parameters.to_dict() ctx["success"] = True logger.debug( "FML fit completed successfully", final_params=self.parameters.to_dict(), ) return self def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """Internal predict implementation. Args: X: RheoData object or array of x-values **kwargs: Additional arguments (test_mode handled via self._test_mode) Returns: Predicted values """ # Handle RheoData input if isinstance(X, RheoData): return self.predict_rheodata(X) # Handle raw array input from rheojax.core.test_modes import TestMode x = jnp.asarray(X) Gm = self.parameters.get_value("Gm") alpha = self.parameters.get_value("alpha") tau_alpha = self.parameters.get_value("tau_alpha") _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None) ) if test_mode in ("oscillation", TestMode.OSCILLATION): result = self._predict_oscillation_jax(x, Gm, alpha, tau_alpha) elif test_mode in ("creep", TestMode.CREEP): result = self._predict_creep_jax(x, Gm, alpha, tau_alpha) elif test_mode in ( "flow_curve", "rotation", TestMode.FLOW_CURVE, TestMode.ROTATION, ): # Flow curve: use relaxation-based prediction result = self._predict_relaxation_jax(x, Gm, alpha, tau_alpha) else: result = self._predict_relaxation_jax(x, Gm, alpha, tau_alpha) return np.array(result)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes predictions given input X and a parameter array. Args: X: Independent variable (time or frequency) params: Array of parameter values [Gm, alpha, tau_alpha] Returns: Model predictions as JAX array """ # Use explicit test_mode parameter (closure-captured in fit_bayesian) # Fall back to self._test_mode only for backward compatibility if test_mode is None: test_mode = getattr(self, "_test_mode", "relaxation") # Normalize test_mode to string if hasattr(test_mode, "value"): test_mode = test_mode.value # Extract parameter names from function signature params_dict = {name: params[i] for i, name in enumerate(self.parameters.keys())} # Dispatch to appropriate prediction method if test_mode == "relaxation": return self._predict_relaxation_jax(X, **params_dict) elif test_mode == "creep": return self._predict_creep_jax(X, **params_dict) elif test_mode == "oscillation": # Return complex array for oscillation mode complex_result = self._predict_oscillation_jax(X, **params_dict) return complex_result[..., 0] + 1j * complex_result[..., 1] elif test_mode in ("rotation", "flow_curve"): # FML: no yield stress, so flow curve falls back to relaxation-based prediction return self._predict_relaxation_jax(X, **params_dict) else: # Default to relaxation for unknown modes return self._predict_relaxation_jax(X, **params_dict)
[docs] def predict_rheodata( self, rheo_data: RheoData, test_mode: str | None = None ) -> RheoData: """Predict response for RheoData. Args: rheo_data: Input RheoData with x values test_mode: Test mode ('relaxation', 'creep', 'oscillation') If None, auto-detect from rheo_data Returns: RheoData with predicted y values """ # Auto-detect test mode if not provided if not isinstance(test_mode, str) or not test_mode: test_mode = rheo_data.test_mode # Get parameters Gm = self.parameters.get_value("Gm") alpha = self.parameters.get_value("alpha") tau_alpha = self.parameters.get_value("tau_alpha") # Convert input to JAX x = jnp.asarray(rheo_data.x) # Route to appropriate prediction method if test_mode == "relaxation": y_pred = self._predict_relaxation_jax(x, Gm, alpha, tau_alpha) elif test_mode == "creep": y_pred = self._predict_creep_jax(x, Gm, alpha, tau_alpha) elif test_mode == "oscillation": # Return complex array for RheoData [G' + iG''] y_pred_stacked = self._predict_oscillation_jax(x, Gm, alpha, tau_alpha) y_pred = y_pred_stacked[..., 0] + 1j * y_pred_stacked[..., 1] else: raise ValueError( f"Unknown test mode: {test_mode}. " f"Must be 'relaxation', 'creep', or 'oscillation'" ) # Create output RheoData result = RheoData( x=np.array(x), y=np.array(y_pred), x_units=rheo_data.x_units, y_units=rheo_data.y_units, domain=rheo_data.domain, metadata=rheo_data.metadata.copy(), ) return result
[docs] def predict(self, X, test_mode: str | None = None, **kwargs): # type: ignore[override] """Predict response. R13-FML-002: Delegates to BaseModel.predict() for plain-array inputs so that deformation_mode (E*->G* conversion) and test_mode restoration are handled correctly. Only RheoData inputs bypass super() because they return a RheoData wrapper object. Args: X: RheoData object or array of x-values test_mode: Test mode for prediction ('relaxation', 'creep', 'oscillation', 'flow_curve'). If None, defaults to 'relaxation'. **kwargs: Additional arguments passed to BaseModel.predict() Returns: Predicted values (RheoData if input is RheoData, else array) """ if isinstance(X, RheoData): return self.predict_rheodata(X, test_mode=test_mode) else: return super().predict(X, test_mode=test_mode, **kwargs)