Source code for rheojax.models.fractional.fractional_kelvin_voigt

"""Fractional Kelvin-Voigt Model (FKV).

This model consists of a spring and a SpringPot element in parallel. It describes
materials with solid-like behavior with power-law creep, typical of filled polymers
and soft solids.

Mathematical Description:
    Relaxation Modulus: G(t) = G_e + c_α t^(-α) / Γ(1-α)
    Complex Modulus: G*(ω) = G_e + c_α (iω)^α
    Creep Compliance: J(t) = (1/G_e) (1 - E_α(-(t/τ_ε)^α))

where τ_ε = (c_α/G_e)^(1/α) is a characteristic retardation time.

Parameters:
    Ge (float): Equilibrium modulus (Pa), bounds [1e-3, 1e9]
    c_alpha (float): SpringPot constant (Pa·s^α), bounds [1e-3, 1e9]
    alpha (float): Fractional order, bounds [0.0, 1.0]

Test Modes: Relaxation, Creep, Oscillation

References:
    - Bagley, R. L., & Torvik, P. J. (1983). A theoretical basis for the application
      of fractional calculus to viscoelasticity. Journal of Rheology, 27(3), 201-210.
    - Makris, N., & Constantinou, M. C. (1991). Fractional-derivative Maxwell model
      for viscous dampers. Journal of Structural Engineering, 117(9), 2708-2724.
"""

from __future__ import annotations

import numpy as np

from rheojax.core.base import BaseModel, ParameterSet
from rheojax.core.data import RheoData
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode
from rheojax.models.fractional.fractional_mixin import FRACTIONAL_ORDER_BOUNDS

jax, jnp = safe_import_jax()
jax_gamma = jax.scipy.special.gamma
from rheojax.logging import get_logger, log_fit
from rheojax.utils.mittag_leffler import mittag_leffler_e

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "fractional_kelvin_voigt", protocols=[ Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FractionalKelvinVoigt(BaseModel): """Fractional Kelvin-Voigt model: Spring and SpringPot in parallel. This model describes solid-like materials with power-law creep behavior, typical of filled polymers and soft solids. Attributes: parameters: ParameterSet with Ge, c_alpha, alpha Examples: >>> from rheojax.models import FractionalKelvinVoigt >>> from rheojax.core.data import RheoData >>> import numpy as np >>> >>> # Create model with parameters >>> model = FractionalKelvinVoigt() >>> model.parameters.set_value('Ge', 1e6) >>> model.parameters.set_value('c_alpha', 1e4) >>> model.parameters.set_value('alpha', 0.5) >>> >>> # Predict relaxation modulus >>> t = np.logspace(-3, 3, 50) >>> data = RheoData(x=t, y=np.zeros_like(t), domain='time') >>> data.metadata['test_mode'] = 'relaxation' >>> G_t = model.predict(data) >>> >>> # Predict complex modulus >>> omega = np.logspace(-2, 2, 50) >>> data = RheoData(x=omega, y=np.zeros_like(omega), domain='frequency') >>> G_star = model.predict(data) """
[docs] def __init__(self): """Initialize Fractional Kelvin-Voigt model.""" super().__init__() self.parameters = ParameterSet() self.parameters.add( name="Ge", value=1e6, bounds=(1e-3, 1e9), units="Pa", description="Equilibrium modulus", ) self.parameters.add( name="c_alpha", value=1e4, bounds=(1e-3, 1e9), units="Pa·s^α", description="SpringPot constant", ) self.parameters.add( name="alpha", value=0.5, bounds=FRACTIONAL_ORDER_BOUNDS, units="dimensionless", description="Fractional order", ) self.fitted_ = False
def _compute_tau_epsilon(self, Ge: float, c_alpha: float, alpha: float) -> float: """Compute characteristic retardation time. Args: Ge: Equilibrium modulus c_alpha: SpringPot constant alpha: Fractional order Returns: Characteristic time τ_ε = (c_α/G_e)^(1/α) """ epsilon = 1e-12 alpha_safe = max(alpha, epsilon) return (c_alpha / Ge) ** (1.0 / alpha_safe) @staticmethod @jax.jit def _predict_relaxation_jax( t: jnp.ndarray, Ge: float, c_alpha: float, alpha: float ) -> jnp.ndarray: """Predict relaxation modulus G(t) using JAX. G(t) = G_e + c_α t^(-α) / Γ(1-α) Args: t: Time array Ge: Equilibrium modulus c_alpha: SpringPot constant alpha: Fractional order Returns: Relaxation modulus array """ # Add small epsilon to prevent issues epsilon = 1e-12 # Clip alpha to safe range alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) t_safe = jnp.maximum(t, epsilon) # Elastic part G_elastic = Ge # Viscous part: c_α t^(-α) / Γ(1-α) gamma_term = jax_gamma(1.0 - alpha_safe) G_viscous = c_alpha * (t_safe ** (-alpha_safe)) / gamma_term # Total relaxation modulus G_t = G_elastic + G_viscous return G_t @staticmethod @jax.jit def _predict_creep_jax( t: jnp.ndarray, Ge: float, c_alpha: float, alpha: float ) -> jnp.ndarray: """Predict creep compliance J(t) using JAX. J(t) = (1/G_e) (1 - E_α(-(t/τ_ε)^α)) where τ_ε = (c_α/G_e)^(1/α) Args: t: Time array Ge: Equilibrium modulus c_alpha: SpringPot constant alpha: Fractional order Returns: Creep compliance array """ # Add small epsilon epsilon = 1e-12 # Clip alpha to safe range alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) t_safe = jnp.maximum(t, epsilon) # Characteristic retardation time tau_epsilon = (c_alpha / Ge) ** (1.0 / alpha_safe) # Argument for Mittag-Leffler function z = -((t_safe / tau_epsilon) ** alpha_safe) # Compute E_α(z) with concrete alpha ml_value = mittag_leffler_e(z, alpha=alpha_safe) # Creep compliance J_t = (1.0 / Ge) * (1.0 - ml_value) return J_t @staticmethod @jax.jit def _predict_oscillation_jax( omega: jnp.ndarray, Ge: float, c_alpha: float, alpha: float ) -> jnp.ndarray: """Predict complex modulus G*(ω) using JAX. G*(ω) = G_e + c_α (iω)^α Args: omega: Angular frequency array Ge: Equilibrium modulus c_alpha: SpringPot constant alpha: Fractional order Returns: Complex modulus array """ # Add small epsilon epsilon = 1e-12 # Clip alpha to safe range alpha_safe = jnp.clip(alpha, epsilon, 1.0 - epsilon) omega_safe = jnp.maximum(omega, epsilon) # Elastic part G_elastic = Ge # SpringPot part: c_α (iω)^α = c_α |ω|^α exp(i α π/2) G_springpot = ( c_alpha * (omega_safe**alpha_safe) * jnp.exp(1j * alpha_safe * jnp.pi / 2.0) ) # Complex modulus G_star = G_elastic + G_springpot # Extract storage and loss moduli G_prime = jnp.real(G_star) G_double_prime = jnp.imag(G_star) return jnp.stack([G_prime, G_double_prime], axis=-1) def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> FractionalKelvinVoigt: """Fit model parameters to data. Args: X: Independent variable (time or frequency) y: Dependent variable (modulus or compliance) **kwargs: Additional fitting options Returns: self for method chaining """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) # Handle RheoData input if isinstance(X, RheoData): rheo_data = X x_data = jnp.array(rheo_data.x) y_data = jnp.array(rheo_data.y) test_mode = rheo_data.test_mode else: x_data = jnp.array(X) y_data = jnp.array(y) test_mode = kwargs.get("test_mode", "relaxation") # R13-FKV-001: Store test_mode so Bayesian inference picks up the correct # protocol (otherwise _resolve_test_mode defaults to 'relaxation'). self._test_mode = test_mode # Get test mode string for logging test_mode_str = ( test_mode.value if hasattr(test_mode, "value") else str(test_mode) ) data_shape = (int(x_data.shape[0]),) if hasattr(x_data, "shape") else None with log_fit( logger, model="FractionalKelvinVoigt", data_shape=data_shape, test_mode=test_mode_str, ) as ctx: logger.debug( "Starting Fractional Kelvin-Voigt model fit", test_mode=test_mode_str, n_points=data_shape[0] if data_shape else None, initial_Ge=self.parameters.get_value("Ge"), initial_c_alpha=self.parameters.get_value("c_alpha"), initial_alpha=self.parameters.get_value("alpha"), ) # Smart initialization for oscillation mode (Issue #9) if test_mode == "oscillation": try: from rheojax.utils.initialization import ( initialize_fractional_kelvin_voigt, ) success = initialize_fractional_kelvin_voigt( np.array(X), np.array(y), self.parameters ) if success: logger.debug( "Smart initialization applied from frequency-domain features", Ge=self.parameters.get_value("Ge"), c_alpha=self.parameters.get_value("c_alpha"), alpha=self.parameters.get_value("alpha"), ) except Exception as e: # Silent fallback to defaults - don't break if initialization fails logger.debug( "Smart initialization failed, using defaults", error_type=type(e).__name__, error_message=str(e), ) # Create objective function with stateless predictions def model_fn(x, params): """Model function for optimization (stateless).""" Ge, c_alpha, alpha = params[0], params[1], params[2] # Direct prediction based on test mode (stateless, calls _jax methods) if test_mode == "relaxation": return self._predict_relaxation_jax(x, Ge, c_alpha, alpha) elif test_mode == "creep": return self._predict_creep_jax(x, Ge, c_alpha, alpha) elif test_mode == "oscillation": return self._predict_oscillation_jax(x, Ge, c_alpha, alpha) else: raise ValueError(f"Unsupported test mode: {test_mode}") logger.debug("Creating least squares objective", normalize=True) objective = create_least_squares_objective( model_fn, x_data, y_data, normalize=True ) # Optimize using NLSQ logger.debug( "Starting NLSQ optimization", use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) try: result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", "auto"), max_iter=kwargs.get("max_iter", 1000), ) except Exception as e: logger.error( "NLSQ optimization raised exception", error_type=type(e).__name__, error_message=str(e), exc_info=True, ) raise # Validate optimization succeeded if not result.success: logger.error( "Optimization failed", message=result.message, test_mode=test_mode_str, ) raise RuntimeError( f"Optimization failed: {result.message}. " f"Try adjusting initial values, bounds, or max_iter." ) self.fitted_ = True # Log fitted parameters fitted_Ge = self.parameters.get_value("Ge") fitted_c_alpha = self.parameters.get_value("c_alpha") fitted_alpha = self.parameters.get_value("alpha") # Compute characteristic time assert ( fitted_Ge is not None and fitted_c_alpha is not None and fitted_alpha is not None ) tau_epsilon = self._compute_tau_epsilon( fitted_Ge, fitted_c_alpha, fitted_alpha ) logger.debug( "Fractional Kelvin-Voigt fit completed successfully", fitted_Ge=fitted_Ge, fitted_c_alpha=fitted_c_alpha, fitted_alpha=fitted_alpha, characteristic_time=tau_epsilon, ) ctx["Ge"] = fitted_Ge ctx["c_alpha"] = fitted_c_alpha ctx["alpha"] = fitted_alpha return self def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """Internal predict implementation. Args: X: RheoData object or array of x-values Returns: Predicted values """ # Handle RheoData input if isinstance(X, RheoData): return self.predict_rheodata(X) # Handle raw array input with test_mode dispatch _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None) ) x = jnp.asarray(X) Ge = self.parameters.get_value("Ge") c_alpha = self.parameters.get_value("c_alpha") alpha = self.parameters.get_value("alpha") assert Ge is not None and c_alpha is not None and alpha is not None if test_mode in ("oscillation", TestMode.OSCILLATION): result = self._predict_oscillation_jax(x, Ge, c_alpha, alpha) return np.array(result) elif test_mode in ("creep", TestMode.CREEP): result = self._predict_creep_jax(x, Ge, c_alpha, alpha) return np.array(result) else: # Default to relaxation result = self._predict_relaxation_jax(x, Ge, c_alpha, alpha) return np.array(result)
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """Model function for Bayesian inference. This method is required by BayesianMixin for NumPyro NUTS sampling. It computes predictions given input X and a parameter array. Args: X: Independent variable (time or frequency) params: Array of parameter values [Ge, c_alpha, alpha] Returns: Model predictions as JAX array """ # Use explicit test_mode parameter (closure-captured in fit_bayesian) # Fall back to self._test_mode only for backward compatibility if test_mode is None: test_mode = getattr(self, "_test_mode", "relaxation") # Normalize test_mode to string if hasattr(test_mode, "value"): test_mode = test_mode.value # Extract parameter names from function signature params_dict = {name: params[i] for i, name in enumerate(self.parameters.keys())} # Dispatch to appropriate prediction method if test_mode == "relaxation": return self._predict_relaxation_jax(X, **params_dict) elif test_mode == "creep": return self._predict_creep_jax(X, **params_dict) elif test_mode == "oscillation": # Return complex array for oscillation mode complex_result = self._predict_oscillation_jax(X, **params_dict) return complex_result[..., 0] + 1j * complex_result[..., 1] else: # Default to relaxation for unknown modes return self._predict_relaxation_jax(X, **params_dict)
[docs] def predict_rheodata( self, rheo_data: RheoData, test_mode: str | None = None ) -> RheoData: """Predict response for RheoData. Args: rheo_data: Input RheoData with x values test_mode: Test mode ('relaxation', 'creep', 'oscillation') If None, auto-detect from rheo_data Returns: RheoData with predicted y values """ # Auto-detect test mode if not provided if test_mode is None: # Check for explicit test_mode in metadata first if "test_mode" in rheo_data.metadata: test_mode = rheo_data.metadata["test_mode"] else: test_mode = rheo_data.test_mode # Get parameters Ge = self.parameters.get_value("Ge") c_alpha = self.parameters.get_value("c_alpha") alpha = self.parameters.get_value("alpha") # Convert input to JAX x = jnp.asarray(rheo_data.x) # Route to appropriate prediction method if test_mode == "relaxation": y_pred = self._predict_relaxation_jax(x, Ge, c_alpha, alpha) elif test_mode == "creep": y_pred = self._predict_creep_jax(x, Ge, c_alpha, alpha) elif test_mode == "oscillation": y_pred_stacked = self._predict_oscillation_jax(x, Ge, c_alpha, alpha) y_pred = y_pred_stacked[..., 0] + 1j * y_pred_stacked[..., 1] else: raise ValueError( f"Unknown test mode: {test_mode}. " f"Must be 'relaxation', 'creep', or 'oscillation'" ) # Create output RheoData result = RheoData( x=np.array(x), y=np.array(y_pred), x_units=rheo_data.x_units, y_units=rheo_data.y_units, domain=rheo_data.domain, metadata=rheo_data.metadata.copy(), ) return result
[docs] def predict(self, X, test_mode: str | None = None, **kwargs): # type: ignore[override] """Predict response. R13-FKV-002: Delegates to BaseModel.predict() for plain-array inputs so that deformation_mode (E*->G* conversion) and test_mode restoration are handled correctly. Only RheoData inputs bypass super() because they return a RheoData wrapper object. Args: X: RheoData object or array of x-values test_mode: Test mode for prediction ('relaxation', 'creep', 'oscillation') Required when X is a raw array. If None, defaults to 'relaxation'. **kwargs: Additional arguments passed to BaseModel.predict() Returns: Predicted values (RheoData if input is RheoData, else array) """ if isinstance(X, RheoData): return self.predict_rheodata(X, test_mode=test_mode) else: return super().predict(X, test_mode=test_mode, **kwargs)