Source code for rheojax.models.fikh.fikh

"""FIKH (Fractional Isotropic-Kinematic Hardening) Model.

This module implements the FIKH model, a thixotropic elasto-viscoplastic
model with Caputo fractional derivative for structure evolution and
optional thermokinematic coupling.

Key Features:
    - Power-law memory in structure evolution (Caputo derivative)
    - Temperature-dependent viscosity and yield stress (Arrhenius)
    - Viscous heating with convective cooling
    - Armstrong-Frederick kinematic hardening

Mathematical Framework:
    Stress: σ_total = σ + η_inf·γ̇
    Maxwell relaxation: dσ/dt = G(γ̇ - γ̇ᵖ) - σ/τ
    Yield: |σ - α| ≤ σ_y(λ, T)
    Backstress: dα = C·dγᵖ - γ_dyn·|α|^(m-1)·α·|dγᵖ|
    Structure: D^α_C λ = (1-λ)/τ_thix - Γ·λ·|γ̇ᵖ|
    Temperature: ρc_p·dT/dt = χ·σ·γ̇ᵖ - h·(T-T_env)

Example:
    >>> from rheojax.models.fikh import FIKH
    >>> model = FIKH(include_thermal=True, alpha_structure=0.5)
    >>> model.fit(t, stress, test_mode='startup', strain=strain)
    >>> sigma_pred = model.predict(t_new, strain=strain_new)
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, Protocol, TestMode
from rheojax.logging import get_logger
from rheojax.models.fikh._base import FIKHBase
from rheojax.utils.optimization import create_least_squares_objective, nlsq_optimize

if TYPE_CHECKING:
    from numpy.typing import ArrayLike

jax, jnp = safe_import_jax()

logger = get_logger(__name__)


[docs] @ModelRegistry.register( "fikh", protocols=[ Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FIKH(FIKHBase): r"""Fractional Isotropic-Kinematic Hardening (FIKH) Model. A thixotropic elasto-viscoplastic model extending MIKH with: 1. Caputo fractional derivative for structure evolution (power-law memory). 2. Full thermokinematic coupling (Arrhenius + viscous heating). The fractional derivative captures memory effects in thixotropic recovery, where the structure remembers its history via a power-law kernel rather than simple exponential decay. Governing Equations: σ_total = σ + η_inf·γ̇ Stress Evolution (ODE): dσ/dt = G(γ̇ - γ̇ᵖ) - (G/η)σ Yield Surface: |σ - α| ≤ σ_y(λ, T) σ_y = σ_y0 + Δσ_y·λ^m_y · exp(E_y/R·(1/T - 1/T_ref)) Fractional Structure Evolution (Caputo): D^α_C λ = (1-λ)/τ_thix - Γ·λ·|γ̇ᵖ| Backstress (Armstrong-Frederick): dα = C·dγᵖ - γ_dyn·|α|^(m-1)·α·|dγᵖ| Temperature: ρc_p·dT/dt = χ·σ·γ̇ᵖ - h·(T - T_env) Parameters (22 with thermal): G: Shear modulus [Pa] eta: Maxwell viscosity [Pa·s] C: Kinematic hardening modulus [Pa] gamma_dyn: Dynamic recovery parameter [-] m: AF recovery exponent [-] sigma_y0: Minimal yield stress [Pa] delta_sigma_y: Structural yield contribution [Pa] tau_thix: Thixotropic time scale [s] Gamma: Breakdown coefficient [-] alpha_structure: Fractional order (0 < α < 1) [-] eta_inf: High-shear viscosity [Pa·s] mu_p: Plastic viscosity [Pa·s] T_ref: Reference temperature [K] E_a: Viscosity activation energy [J/mol] E_y: Yield stress activation energy [J/mol] m_y: Structure exponent for yield [-] rho_cp: Volumetric heat capacity [J/(m³·K)] chi: Taylor-Quinney coefficient [-] h: Heat transfer coefficient [W/(m²·K)] T_env: Environmental temperature [K] Limiting Behavior: α → 1: Recovers classical IKH/MIKH (exponential structure relaxation) E_a = E_y = 0: Isothermal behavior (temperature-independent) Example: >>> # Isothermal FIKH >>> model = FIKH(include_thermal=False, alpha_structure=0.7) >>> model.fit(omega, G_star, test_mode='oscillation') >>> # Thermal FIKH with startup >>> model = FIKH(include_thermal=True) >>> result = model.fit(t, stress, test_mode='startup', strain=strain) >>> sigma_pred = model.predict_startup(t_new, gamma_dot=1.0) """
[docs] def __init__( self, include_thermal: bool = True, include_isotropic_hardening: bool = False, alpha_structure: float = 0.5, n_history: int = 100, stable_dt: float = 0.01, ): """Initialize FIKH model. Args: include_thermal: Enable thermokinematic coupling (Arrhenius + heating). include_isotropic_hardening: Enable isotropic hardening R. alpha_structure: Fractional order for structure (0 < α < 1). - α → 0: Strong memory (slow recovery) - α → 1: Weak memory (fast, exponential recovery) n_history: History buffer size for Caputo derivative. stable_dt: Internal integration substep (seconds) for startup / LAOS. See ``FIKHBase`` for the full explanation. Coarse user grids are densified to this step before the explicit return-mapping kernel runs. Set to 0 to disable. Default 0.02 s. """ super().__init__( include_thermal=include_thermal, include_isotropic_hardening=include_isotropic_hardening, alpha_structure=alpha_structure, n_history=n_history, stable_dt=stable_dt, ) logger.debug( "Initialized FIKH model", include_thermal=include_thermal, alpha_structure=alpha_structure, )
# ========================================================================= # Fitting Methods # ========================================================================= def _fit(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit model parameters using protocol-aware optimization. Args: X: Input data (depends on test_mode). y: Target data (stress or strain). **kwargs: Options including: - test_mode: Protocol type - gamma_dot: Shear rate (startup) - sigma_applied: Applied stress (creep) - sigma_0: Initial stress (relaxation) - strain: Strain array (if X is time only) Returns: Self with fitted parameters. """ test_mode = kwargs.get("test_mode", "startup") self._test_mode = test_mode mode = self._validate_test_mode(test_mode) if mode == TestMode.FLOW_CURVE: return self._fit_flow_curve(X, y, **kwargs) elif mode in (TestMode.CREEP, TestMode.RELAXATION): return self._fit_ode_formulation(X, y, **kwargs) elif mode == TestMode.STARTUP: # STARTUP and LAOS both use return mapping return self._fit_return_mapping(X, y, **kwargs) elif mode == TestMode.OSCILLATION: return self._fit_oscillation(X, y, **kwargs) else: return self._fit_return_mapping(X, y, **kwargs) def _fit_flow_curve(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit to steady-state flow curve data.""" from rheojax.models.fikh._kernels import fikh_flow_curve_steady_state gamma_dot = jnp.asarray(X) sigma_target = jnp.asarray(y, dtype=jnp.float64) def model_fn(x_data, params): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, params, strict=False)) return fikh_flow_curve_steady_state( x_data, include_thermal=self.include_thermal, **p_dict ) # Flow curves span decades — log residuals give equal weight to # low and high shear rate regions. objective = create_least_squares_objective( model_fn, gamma_dot, sigma_target, use_log_residuals=kwargs.pop("use_log_residuals", True), ) nlsq_optimize(objective, self.parameters, **kwargs) return self def _fit_ode_formulation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit using ODE formulation (creep/relaxation).""" t = jnp.asarray(X) y_target = jnp.asarray(y, dtype=jnp.float64) test_mode = kwargs.get("test_mode", "relaxation") gamma_dot = kwargs.get("gamma_dot", 0.0) sigma_applied = kwargs.get("sigma_applied", 100.0) sigma_0 = kwargs.get("sigma_0", 100.0) T_init = kwargs.get("T_init", None) # Cache protocol kwargs so model_function() can retrieve them during NUTS self._fit_gamma_dot = gamma_dot self._fit_sigma_applied = sigma_applied self._fit_sigma_0 = sigma_0 def model_fn(x_data, params): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, params, strict=False)) return self._simulate_transient( x_data, p_dict, test_mode, gamma_dot, sigma_applied, sigma_0, T_init ) # Transient data (creep/relaxation) often starts at zero — normalize=False # avoids division by ~0 at early time points (same rationale as fluidity). objective = create_least_squares_objective( model_fn, t, y_target, normalize=False, use_log_residuals=kwargs.pop("use_log_residuals", False), ) nlsq_optimize(objective, self.parameters, **kwargs) return self def _fit_return_mapping(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit using return mapping (startup/LAOS).""" times, strains = self._extract_time_strain(X, **kwargs) sigma_target = jnp.asarray(y, dtype=jnp.float64) # Pre-compute the stable-dt substep count from concrete times and cache # it so the subsequent NUTS trace reuses the same integration grid. # See FIKHBase._compute_n_sub / _densify_grid_for_return_mapping for # why this is necessary (explicit return mapping is only stable when # G·Δγ per step is small relative to the yield stress). self._n_sub_cached = self._compute_n_sub(times) def model_fn(x_data, params): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, params, strict=False)) return self._predict_from_params(x_data, strains, p_dict) # Startup/LAOS stress crosses zero — normalize=False avoids # division by ~0 at the zero crossings. objective = create_least_squares_objective( model_fn, times, sigma_target, normalize=False, use_log_residuals=kwargs.pop("use_log_residuals", False), ) nlsq_optimize(objective, self.parameters, **kwargs) return self def _fit_oscillation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit to oscillatory data (SAOS). This method fits to frequency-domain SAOS data by internally simulating time-domain oscillations at each frequency and extracting G* via Fourier. Args: X: Angular frequency array (omega) [rad/s]. y: Target modulus data - can be: - Complex G* = G' + i·G'' (uses both components) - Real |G*| magnitude (fits to magnitude) **kwargs: Options including: - gamma_0: Strain amplitude (default 0.01) - n_cycles: Number of cycles per frequency (default 5) Returns: Self with fitted parameters. """ omega = jnp.asarray(X) y_arr = jnp.asarray(y) gamma_0 = kwargs.get("gamma_0", 0.01) n_cycles = kwargs.get("n_cycles", 5) # Cache protocol kwargs so model_function() can retrieve them during NUTS self._fit_gamma_0 = gamma_0 self._fit_n_cycles = n_cycles # Determine if fitting to complex, (N, 2) [G', G''], or magnitude is_complex = jnp.iscomplexobj(y_arr) is_stacked = y_arr.ndim == 2 and y_arr.shape[1] == 2 # Pre-compute normalization denominators for consistent residual weighting. # FIKH oscillation uses time-domain FFT (not analytical), so we handle # the complex dispatch manually rather than through create_least_squares_objective. _norm_floor = jnp.float64(1e-10) if is_complex: _norm_Gp = jnp.maximum(jnp.abs(jnp.real(y_arr)), _norm_floor) _norm_Gpp = jnp.maximum(jnp.abs(jnp.imag(y_arr)), _norm_floor) elif is_stacked: _norm_Gp = jnp.maximum(jnp.abs(y_arr[:, 0]), _norm_floor) _norm_Gpp = jnp.maximum(jnp.abs(y_arr[:, 1]), _norm_floor) else: _norm_mag = jnp.maximum(jnp.abs(y_arr), _norm_floor) def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=False)) # Predict G* at each frequency using time-domain simulation G_star_pred = self._predict_oscillation_from_params( omega, p_dict, gamma_0, n_cycles ) if is_complex: # Fit both G' and G'' by stacking normalized residuals residuals = jnp.concatenate( [ (jnp.real(G_star_pred) - jnp.real(y_arr)) / _norm_Gp, (jnp.imag(G_star_pred) - jnp.imag(y_arr)) / _norm_Gpp, ] ) elif is_stacked: # (N, 2) array - [G', G''] format residuals = jnp.concatenate( [ (jnp.real(G_star_pred) - y_arr[:, 0]) / _norm_Gp, (jnp.imag(G_star_pred) - y_arr[:, 1]) / _norm_Gpp, ] ) else: # Fit to magnitude |G*| residuals = (jnp.abs(G_star_pred) - jnp.abs(y_arr)) / _norm_mag return residuals nlsq_optimize(objective, self.parameters, **kwargs) return self def _predict_oscillation_from_params( self, omega: jnp.ndarray, params: dict[str, Any], gamma_0: float = 0.01, n_cycles: int = 5, ) -> jnp.ndarray: """Predict complex modulus G* from parameter dictionary. Internal method used by both NLSQ fitting and Bayesian inference. F-004/F-024: Vectorized via jax.vmap over frequencies (replaces Python loop). Args: omega: Angular frequency array. params: Parameter dictionary. gamma_0: Strain amplitude. n_cycles: Number of cycles to simulate. Returns: Complex modulus G* = G' + i·G'' for each frequency. """ from rheojax.models.fikh._kernels import ( fikh_scan_kernel_isothermal, fikh_scan_kernel_thermal, ) alpha = params.get("alpha_structure", self.alpha_structure) # Use n_cycles * pts_per_cycle + 1 so that dt = period / pts_per_cycle # exactly, giving integer-period windows for Fourier extraction. pts_per_cycle = 100 n_pts = n_cycles * pts_per_cycle + 1 # Last cycle: pts_per_cycle + 1 points spanning exactly one period last_cycle_start = (n_cycles - 1) * pts_per_cycle n_last = n_pts - last_cycle_start # = pts_per_cycle + 1 # Close over params/options so only omega varies include_thermal = self.include_thermal n_history = self.n_history def predict_single_omega(w): """Compute G* at a single frequency (vmappable).""" period = 2 * jnp.pi / w t = jnp.linspace(0.0, n_cycles * period, n_pts) strain = gamma_0 * jnp.sin(w * t) if include_thermal: T_init = params.get("T_env", params.get("T_ref", 298.15)) stress, _ = fikh_scan_kernel_thermal( t, strain, n_history=n_history, alpha=alpha, use_viscosity=True, T_init=T_init, **params, ) else: stress = fikh_scan_kernel_isothermal( t, strain, n_history=n_history, alpha=alpha, use_viscosity=True, **params, ) # Extract last cycle via dynamic_slice (trace-safe) t_last = jax.lax.dynamic_slice(t, [last_cycle_start], [n_last]) stress_last = jax.lax.dynamic_slice(stress, [last_cycle_start], [n_last]) # Least-squares extraction of G' and G'' from last-cycle stress. # σ(t) = G'·γ₀·sin(ωt) + G''·γ₀·cos(ωt) # This is exact regardless of window span and avoids the G'→G'' # cross-talk that trapezoid Fourier integration suffers when the # window doesn't span an exact integer number of periods. sin_basis = gamma_0 * jnp.sin(w * t_last) cos_basis = gamma_0 * jnp.cos(w * t_last) A = jnp.column_stack([sin_basis, cos_basis]) # (n_last, 2) # Normal equations: (AᵀA)⁻¹ Aᵀ b — 2×2 system, always well-conditioned ATA = A.T @ A # (2, 2) ATb = A.T @ stress_last # (2,) coeffs = jnp.linalg.solve(ATA, ATb) # [G', G''] return coeffs # Vectorize over all frequencies at once results = jax.vmap(predict_single_omega)(omega) # (N_omega, 2) return results[:, 0] + 1j * results[:, 1] # ========================================================================= # Prediction Methods # ========================================================================= def _predict_from_params( self, times: jnp.ndarray, strains: jnp.ndarray, params: dict[str, Any], ) -> jnp.ndarray: """Predict stress using parameter dictionary. This is the core prediction method used by both NLSQ fitting and Bayesian inference. The user's (times, strains) grid is densified to the base-class ``stable_dt`` before the scan kernel runs so that the explicit return mapping stays well inside its linearization regime, then the result is subsampled back to the user's time points. Args: times: Time array. strains: Strain array. params: Parameter dictionary. Returns: Predicted stress array at the user's time points. """ from rheojax.models.fikh._kernels import ( fikh_scan_kernel_isothermal, fikh_scan_kernel_thermal, ) # Extract alpha (can now be a traced value since it's not in static_argnums) alpha = params.get("alpha_structure", self.alpha_structure) t_dense, strain_dense, n_sub = self._densify_grid_for_return_mapping( times, strains ) if self.include_thermal: T_init = params.get("T_env", params.get("T_ref", 298.15)) sigma_dense, _ = fikh_scan_kernel_thermal( t_dense, strain_dense, n_history=self.n_history, alpha=alpha, use_viscosity=True, T_init=T_init, **params, ) else: sigma_dense = fikh_scan_kernel_isothermal( t_dense, strain_dense, n_history=self.n_history, alpha=alpha, use_viscosity=True, **params, ) if n_sub > 1: return sigma_dense[::n_sub] return sigma_dense def _predict(self, X: ArrayLike, **kwargs) -> ArrayLike: """Predict based on test_mode. Args: X: Input data (shape depends on test_mode). **kwargs: Additional parameters. Returns: Predicted values. """ _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else ( getattr(self, "_test_mode", None) if getattr(self, "_test_mode", None) is not None else "startup" ) ) mode = self._validate_test_mode(test_mode) params = self._get_params_dict() if mode == TestMode.FLOW_CURVE: from rheojax.models.fikh._kernels import fikh_flow_curve_steady_state gamma_dot = jnp.asarray(X) return fikh_flow_curve_steady_state( gamma_dot, include_thermal=self.include_thermal, **params ) elif mode in (TestMode.CREEP, TestMode.RELAXATION): t = jnp.asarray(X) gamma_dot = kwargs.get("gamma_dot", 0.0) sigma_applied = kwargs.get("sigma_applied", 100.0) sigma_0 = kwargs.get("sigma_0", 60.0) T_init = kwargs.get("T_init", None) return self._simulate_transient( t, params, mode.value, gamma_dot, sigma_applied, sigma_0, T_init ) elif mode == TestMode.OSCILLATION: # Frequency-domain SAOS: X is omega, return G* omega = jnp.asarray(X) gamma_0 = kwargs.get("gamma_0", 0.01) n_cycles = kwargs.get("n_cycles", 5) return self._predict_oscillation_from_params( omega, params, gamma_0, n_cycles ) else: # Strain-driven protocols (startup, laos) times, strains = self._extract_time_strain(X, **kwargs) return self._predict_from_params(times, strains, params) # ========================================================================= # Protocol-Specific Prediction Methods # =========================================================================
[docs] def predict_flow_curve( self, gamma_dot: ArrayLike, T: float | None = None ) -> ArrayLike: """Predict steady-state flow curve. Args: gamma_dot: Shear rate array. T: Temperature (if thermal enabled). Returns: Stress array. """ return self._predict(gamma_dot, test_mode="flow_curve")
[docs] def predict_startup( self, t: ArrayLike, gamma_dot: float = 1.0, T_init: float | None = None, ) -> ArrayLike: """Predict startup shear response. Args: t: Time array. gamma_dot: Constant shear rate. T_init: Initial temperature. Returns: Stress vs time. """ params = self._get_params_dict() return self._simulate_transient( jnp.asarray(t), params, "startup", gamma_dot=gamma_dot, T_init=T_init )
[docs] def predict_relaxation( self, t: ArrayLike, sigma_0: float = 100.0, T_init: float | None = None, ) -> ArrayLike: """Predict stress relaxation. Args: t: Time array. sigma_0: Initial stress. T_init: Initial temperature. Returns: Stress vs time. """ params = self._get_params_dict() return self._simulate_transient( jnp.asarray(t), params, "relaxation", sigma_0=sigma_0, T_init=T_init )
[docs] def predict_creep( self, t: ArrayLike, sigma_applied: float = 50.0, T_init: float | None = None, ) -> ArrayLike: """Predict creep response. Args: t: Time array. sigma_applied: Applied constant stress. T_init: Initial temperature. Returns: Strain vs time. """ params = self._get_params_dict() return self._simulate_transient( jnp.asarray(t), params, "creep", sigma_applied=sigma_applied, T_init=T_init )
[docs] def predict_oscillation( self, omega: ArrayLike, gamma_0: float = 0.01, n_cycles: int = 5, ) -> jnp.ndarray: """Predict oscillatory response (SAOS G*, G', G''). For small amplitudes, this uses the linearized response. For accurate nonlinear response, use predict_laos(). Args: omega: Angular frequency array. gamma_0: Strain amplitude (should be small). n_cycles: Number of cycles to simulate. Returns: Complex modulus G* = G' + i·G'' for each frequency. """ omega_arr = jnp.asarray(omega) params = self._get_params_dict() # Reuse the vectorized implementation from _predict_oscillation_from_params return self._predict_oscillation_from_params( omega_arr, params, gamma_0, n_cycles )
[docs] def predict_laos( self, t: ArrayLike, gamma_0: float = 1.0, omega: float = 1.0, T_init: float | None = None, strain: ArrayLike | None = None, ) -> dict[str, jnp.ndarray]: """Predict LAOS (Large Amplitude Oscillatory Shear) response. Integrates on the densified ``stable_dt`` grid (same as the fit path) so fit and predict stay in the linearization regime of the explicit return mapping. If ``strain`` is omitted, a clean sinusoid ``gamma_0·sin(omega·t)`` is synthesized; pass the measured strain array for fit/predict consistency on experimental data. Args: t: Time array at which to report the response. gamma_0: Strain amplitude used if ``strain`` is not given. omega: Angular frequency used if ``strain`` is not given. T_init: Initial temperature (thermal models only). strain: Optional measured strain array aligned with ``t``. When supplied, ``gamma_0`` / ``omega`` are ignored for the simulation and the measured trace drives the return mapping. Returns: Dictionary with 'time', 'strain', 'stress', and optionally 'temperature'. """ t_arr = jnp.asarray(t) if strain is not None: strain_arr = jnp.asarray(strain) else: strain_arr = gamma_0 * jnp.sin(omega * t_arr) params = self._get_params_dict() if not self.include_thermal: # Reuse the fit-path predictor so the integration grid is # densified identically to what NLSQ / NUTS optimize against. stress = self._predict_from_params(t_arr, strain_arr, params) return { "time": t_arr, "strain": strain_arr, "stress": stress, } # Thermal path: mirror the densification that ``_predict_from_params`` # applies, but keep the (stress, temperature) tuple the thermal # kernel returns. from rheojax.models.fikh._kernels import fikh_scan_kernel_thermal alpha = params.get("alpha_structure", self.alpha_structure) T_0 = T_init if T_init is not None else params.get("T_env", 298.15) t_dense, strain_dense, n_sub = self._densify_grid_for_return_mapping( t_arr, strain_arr ) stress_dense, temperature_dense = fikh_scan_kernel_thermal( t_dense, strain_dense, n_history=self.n_history, alpha=alpha, use_viscosity=True, T_init=T_0, **params, ) if n_sub > 1: stress = stress_dense[::n_sub] temperature = temperature_dense[::n_sub] else: stress = stress_dense temperature = temperature_dense return { "time": t_arr, "strain": strain_arr, "stress": stress, "temperature": temperature, }
# ========================================================================= # Bayesian Interface # =========================================================================
[docs] def model_function( self, X: ArrayLike, params: ArrayLike | dict[str, Any], test_mode: str | None = None, **kwargs, ) -> jnp.ndarray: """Model function for NumPyro Bayesian inference. This method provides a pure function interface for Bayesian sampling, capturing the test_mode via closure for correct mode-aware inference. Args: X: Input data. params: Parameter array or dictionary. test_mode: Protocol (uses stored _test_mode if None). **kwargs: Protocol-specific arguments (e.g., strain, sigma_0). Returns: Predicted values. """ # Prefer explicit test_mode; fall back to _last_fit_kwargs # (set by fit()) over stale self._test_mode to avoid wrong NUTS likelihood if test_mode is not None: mode = test_mode elif getattr(self, "_last_fit_kwargs", {}).get("test_mode") is not None: mode = self._last_fit_kwargs["test_mode"] elif self._test_mode is not None: mode = self._test_mode else: mode = "startup" # Convert array to dict if needed if isinstance(params, (np.ndarray, jnp.ndarray)): param_names = list(self.parameters.keys()) param_dict = dict(zip(param_names, params, strict=False)) else: param_dict = dict(params) mode_enum = self._validate_test_mode(mode) if mode_enum == TestMode.FLOW_CURVE: from rheojax.models.fikh._kernels import fikh_flow_curve_steady_state gamma_dot = jnp.asarray(X) return fikh_flow_curve_steady_state( gamma_dot, include_thermal=self.include_thermal, **param_dict ) elif mode_enum in (TestMode.CREEP, TestMode.RELAXATION): t = jnp.asarray(X) gamma_dot = kwargs.get("gamma_dot", getattr(self, "_fit_gamma_dot", 0.0)) sigma_applied = kwargs.get( "sigma_applied", getattr(self, "_fit_sigma_applied", 100.0) ) sigma_0 = kwargs.get("sigma_0", getattr(self, "_fit_sigma_0", 60.0)) return self._simulate_transient( t, param_dict, mode_enum.value, gamma_dot, sigma_applied, sigma_0 ) elif mode_enum == TestMode.OSCILLATION: # Frequency-domain SAOS: X is omega, return |G*| for Bayesian fitting omega = jnp.asarray(X) gamma_0 = kwargs.get("gamma_0", getattr(self, "_fit_gamma_0", 0.01)) n_cycles = kwargs.get("n_cycles", getattr(self, "_fit_n_cycles", 5)) G_star = self._predict_oscillation_from_params( omega, param_dict, gamma_0, n_cycles ) return jnp.column_stack([jnp.real(G_star), jnp.imag(G_star)]) else: # Strain-driven protocols (startup, laos) times, strains = self._extract_time_strain(X, **kwargs) return self._predict_from_params(times, strains, param_dict)
# ========================================================================= # Utility Methods # =========================================================================
[docs] def get_limiting_behavior(self) -> dict[str, Any]: """Get limiting behavior diagnostics. Returns: Dictionary with limiting cases and expected behavior. """ alpha = self.parameters.get_value("alpha_structure") E_a = self.parameters.get_value("E_a") if self.include_thermal else 0.0 return { "fractional_order": alpha, "is_near_integer": alpha > 0.95, "memory_type": ( "weak (near exponential)" if alpha > 0.7 else "strong (power-law)" ), "thermal_coupling": self.include_thermal, "arrhenius_enabled": E_a > 0 if self.include_thermal else False, "limiting_case_alpha_1": "Classical MIKH behavior", "limiting_case_E_a_0": "Isothermal FIKH behavior", }
[docs] def precompile( self, test_mode: str = "relaxation", X=None, y=None, *, n_points: int = 100, verbose: bool = True, ) -> float: """Precompile JIT kernels for faster subsequent predictions. Triggers JAX JIT compilation of the core FIKH kernels by running a small dummy prediction. This is useful when you want to avoid the compilation overhead on first real prediction. Args: test_mode: Accepted for parent compatibility (unused). X: Accepted for parent compatibility (unused). y: Accepted for parent compatibility (unused). n_points: Number of time points for dummy data. verbose: Print compilation time if True. Returns: Compilation time in seconds. Example: >>> model = FIKH(include_thermal=True) >>> compile_time = model.precompile() # Triggers JIT >>> # Now predictions will be fast >>> sigma = model.predict_startup(t_real, gamma_dot=1.0) """ import time as time_module # Create dummy data t_dummy = jnp.linspace(0, 10, n_points) strain_dummy = 0.1 * t_dummy # Linear ramp params = self._get_params_dict() start = time_module.perf_counter() # Trigger isothermal kernel compilation from rheojax.models.fikh._kernels import ( fikh_scan_kernel_isothermal, fikh_scan_kernel_thermal, ) alpha = params.get("alpha_structure", self.alpha_structure) # Always compile isothermal kernel _ = fikh_scan_kernel_isothermal( t_dummy, strain_dummy, n_history=self.n_history, alpha=alpha, use_viscosity=True, **params, ) # Compile thermal kernel if enabled if self.include_thermal: T_init = params.get("T_env", params.get("T_ref", 298.15)) _ = fikh_scan_kernel_thermal( t_dummy, strain_dummy, n_history=self.n_history, alpha=alpha, use_viscosity=True, T_init=T_init, **params, ) elapsed = time_module.perf_counter() - start if verbose: logger.info( "FIKH kernels precompiled", compile_time_s=f"{elapsed:.2f}", include_thermal=self.include_thermal, ) return elapsed
[docs] def __repr__(self) -> str: """String representation.""" alpha = self.parameters.get_value("alpha_structure") return ( f"FIKH(include_thermal={self.include_thermal}, " f"alpha_structure={alpha:.3f}, n_history={self.n_history})" )