Source code for rheojax.models.fikh.fikh

"""FIKH (Fractional Isotropic-Kinematic Hardening) Model.

This module implements the FIKH model, a thixotropic elasto-viscoplastic
model with Caputo fractional derivative for structure evolution and
optional thermokinematic coupling.

Key Features:
    - Power-law memory in structure evolution (Caputo derivative)
    - Temperature-dependent viscosity and yield stress (Arrhenius)
    - Viscous heating with convective cooling
    - Armstrong-Frederick kinematic hardening

Mathematical Framework:
    Stress: σ_total = σ + η_inf·γ̇
    Maxwell relaxation: dσ/dt = G(γ̇ - γ̇ᵖ) - σ/τ
    Yield: |σ - α| ≤ σ_y(λ, T)
    Backstress: dα = C·dγᵖ - γ_dyn·|α|^(m-1)·α·|dγᵖ|
    Structure: D^α_C λ = (1-λ)/τ_thix - Γ·λ·|γ̇ᵖ|
    Temperature: ρc_p·dT/dt = χ·σ·γ̇ᵖ - h·(T-T_env)

Example:
    >>> from rheojax.models.fikh import FIKH
    >>> model = FIKH(include_thermal=True, alpha_structure=0.5)
    >>> model.fit(t, stress, test_mode='startup', strain=strain)
    >>> sigma_pred = model.predict(t_new, strain=strain_new)
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np

from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, Protocol, TestMode
from rheojax.logging import get_logger
from rheojax.models.fikh._base import FIKHBase
from rheojax.utils.optimization import nlsq_optimize

if TYPE_CHECKING:
    from numpy.typing import ArrayLike

jax, jnp = safe_import_jax()

logger = get_logger(__name__)


[docs] @ModelRegistry.register( "fikh", protocols=[ Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FIKH(FIKHBase): r"""Fractional Isotropic-Kinematic Hardening (FIKH) Model. A thixotropic elasto-viscoplastic model extending MIKH with: 1. Caputo fractional derivative for structure evolution (power-law memory). 2. Full thermokinematic coupling (Arrhenius + viscous heating). The fractional derivative captures memory effects in thixotropic recovery, where the structure remembers its history via a power-law kernel rather than simple exponential decay. Governing Equations: σ_total = σ + η_inf·γ̇ Stress Evolution (ODE): dσ/dt = G(γ̇ - γ̇ᵖ) - (G/η)σ Yield Surface: |σ - α| ≤ σ_y(λ, T) σ_y = σ_y0 + Δσ_y·λ^m_y · exp(E_y/R·(1/T - 1/T_ref)) Fractional Structure Evolution (Caputo): D^α_C λ = (1-λ)/τ_thix - Γ·λ·|γ̇ᵖ| Backstress (Armstrong-Frederick): dα = C·dγᵖ - γ_dyn·|α|^(m-1)·α·|dγᵖ| Temperature: ρc_p·dT/dt = χ·σ·γ̇ᵖ - h·(T - T_env) Parameters (22 with thermal): G: Shear modulus [Pa] eta: Maxwell viscosity [Pa·s] C: Kinematic hardening modulus [Pa] gamma_dyn: Dynamic recovery parameter [-] m: AF recovery exponent [-] sigma_y0: Minimal yield stress [Pa] delta_sigma_y: Structural yield contribution [Pa] tau_thix: Thixotropic time scale [s] Gamma: Breakdown coefficient [-] alpha_structure: Fractional order (0 < α < 1) [-] eta_inf: High-shear viscosity [Pa·s] mu_p: Plastic viscosity [Pa·s] T_ref: Reference temperature [K] E_a: Viscosity activation energy [J/mol] E_y: Yield stress activation energy [J/mol] m_y: Structure exponent for yield [-] rho_cp: Volumetric heat capacity [J/(m³·K)] chi: Taylor-Quinney coefficient [-] h: Heat transfer coefficient [W/(m²·K)] T_env: Environmental temperature [K] Limiting Behavior: α → 1: Recovers classical IKH/MIKH (exponential structure relaxation) E_a = E_y = 0: Isothermal behavior (temperature-independent) Example: >>> # Isothermal FIKH >>> model = FIKH(include_thermal=False, alpha_structure=0.7) >>> model.fit(omega, G_star, test_mode='oscillation') >>> # Thermal FIKH with startup >>> model = FIKH(include_thermal=True) >>> result = model.fit(t, stress, test_mode='startup', strain=strain) >>> sigma_pred = model.predict_startup(t_new, gamma_dot=1.0) """
[docs] def __init__( self, include_thermal: bool = True, include_isotropic_hardening: bool = False, alpha_structure: float = 0.5, n_history: int = 100, ): """Initialize FIKH model. Args: include_thermal: Enable thermokinematic coupling (Arrhenius + heating). include_isotropic_hardening: Enable isotropic hardening R. alpha_structure: Fractional order for structure (0 < α < 1). - α → 0: Strong memory (slow recovery) - α → 1: Weak memory (fast, exponential recovery) n_history: History buffer size for Caputo derivative. """ super().__init__( include_thermal=include_thermal, include_isotropic_hardening=include_isotropic_hardening, alpha_structure=alpha_structure, n_history=n_history, ) logger.debug( "Initialized FIKH model", include_thermal=include_thermal, alpha_structure=alpha_structure, )
# ========================================================================= # Fitting Methods # ========================================================================= def _fit(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit model parameters using protocol-aware optimization. Args: X: Input data (depends on test_mode). y: Target data (stress or strain). **kwargs: Options including: - test_mode: Protocol type - gamma_dot: Shear rate (startup) - sigma_applied: Applied stress (creep) - sigma_0: Initial stress (relaxation) - strain: Strain array (if X is time only) Returns: Self with fitted parameters. """ test_mode = kwargs.get("test_mode", "startup") self._test_mode = test_mode mode = self._validate_test_mode(test_mode) if mode == TestMode.FLOW_CURVE: return self._fit_flow_curve(X, y, **kwargs) elif mode in (TestMode.CREEP, TestMode.RELAXATION): return self._fit_ode_formulation(X, y, **kwargs) elif mode == TestMode.STARTUP: # STARTUP and LAOS both use return mapping return self._fit_return_mapping(X, y, **kwargs) elif mode == TestMode.OSCILLATION: return self._fit_oscillation(X, y, **kwargs) else: return self._fit_return_mapping(X, y, **kwargs) def _fit_flow_curve(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit to steady-state flow curve data.""" from rheojax.models.fikh._kernels import fikh_flow_curve_steady_state gamma_dot = jnp.asarray(X) sigma_target = jnp.asarray(y) def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=False)) sigma_pred = fikh_flow_curve_steady_state( gamma_dot, include_thermal=self.include_thermal, **p_dict ) return sigma_pred - sigma_target nlsq_optimize(objective, self.parameters, **kwargs) return self def _fit_ode_formulation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit using ODE formulation (creep/relaxation).""" t = jnp.asarray(X) y_target = jnp.asarray(y) test_mode = kwargs.get("test_mode", "relaxation") gamma_dot = kwargs.get("gamma_dot", 0.0) sigma_applied = kwargs.get("sigma_applied", 100.0) sigma_0 = kwargs.get("sigma_0", 100.0) T_init = kwargs.get("T_init", None) # Cache protocol kwargs so model_function() can retrieve them during NUTS self._fit_gamma_dot = gamma_dot self._fit_sigma_applied = sigma_applied self._fit_sigma_0 = sigma_0 def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=False)) y_pred = self._simulate_transient( t, p_dict, test_mode, gamma_dot, sigma_applied, sigma_0, T_init ) return y_pred - y_target nlsq_optimize(objective, self.parameters, **kwargs) return self def _fit_return_mapping(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit using return mapping (startup/LAOS).""" times, strains = self._extract_time_strain(X, **kwargs) sigma_target = jnp.asarray(y) def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=False)) sigma_pred = self._predict_from_params(times, strains, p_dict) return sigma_pred - sigma_target nlsq_optimize(objective, self.parameters, **kwargs) return self def _fit_oscillation(self, X: ArrayLike, y: ArrayLike, **kwargs) -> FIKH: """Fit to oscillatory data (SAOS). This method fits to frequency-domain SAOS data by internally simulating time-domain oscillations at each frequency and extracting G* via Fourier. Args: X: Angular frequency array (omega) [rad/s]. y: Target modulus data - can be: - Complex G* = G' + i·G'' (uses both components) - Real |G*| magnitude (fits to magnitude) **kwargs: Options including: - gamma_0: Strain amplitude (default 0.01) - n_cycles: Number of cycles per frequency (default 5) Returns: Self with fitted parameters. """ omega = jnp.asarray(X) y_arr = jnp.asarray(y) gamma_0 = kwargs.get("gamma_0", 0.01) n_cycles = kwargs.get("n_cycles", 5) # Cache protocol kwargs so model_function() can retrieve them during NUTS self._fit_gamma_0 = gamma_0 self._fit_n_cycles = n_cycles # Determine if fitting to complex, (N, 2) [G', G''], or magnitude is_complex = jnp.iscomplexobj(y_arr) is_stacked = y_arr.ndim == 2 and y_arr.shape[1] == 2 def objective(param_values): p_names = list(self.parameters.keys()) p_dict = dict(zip(p_names, param_values, strict=False)) # Predict G* at each frequency using time-domain simulation G_star_pred = self._predict_oscillation_from_params( omega, p_dict, gamma_0, n_cycles ) if is_complex: # Fit both G' and G'' by stacking residuals residuals = jnp.concatenate( [ jnp.real(G_star_pred) - jnp.real(y_arr), jnp.imag(G_star_pred) - jnp.imag(y_arr), ] ) elif is_stacked: # (N, 2) array - [G', G''] format residuals = jnp.concatenate( [ jnp.real(G_star_pred) - y_arr[:, 0], jnp.imag(G_star_pred) - y_arr[:, 1], ] ) else: # Fit to magnitude |G*| residuals = jnp.abs(G_star_pred) - jnp.abs(y_arr) return residuals nlsq_optimize(objective, self.parameters, **kwargs) return self def _predict_oscillation_from_params( self, omega: jnp.ndarray, params: dict[str, Any], gamma_0: float = 0.01, n_cycles: int = 5, ) -> jnp.ndarray: """Predict complex modulus G* from parameter dictionary. Internal method used by both NLSQ fitting and Bayesian inference. F-004/F-024: Vectorized via jax.vmap over frequencies (replaces Python loop). Args: omega: Angular frequency array. params: Parameter dictionary. gamma_0: Strain amplitude. n_cycles: Number of cycles to simulate. Returns: Complex modulus G* = G' + i·G'' for each frequency. """ from rheojax.models.fikh._kernels import ( fikh_scan_kernel_isothermal, fikh_scan_kernel_thermal, ) alpha = params.get("alpha_structure", self.alpha_structure) n_pts = 100 * n_cycles # Static slice index for last cycle extraction last_cycle_start = n_pts * (n_cycles - 1) // n_cycles n_last = n_pts - last_cycle_start # Close over params/options so only omega varies include_thermal = self.include_thermal n_history = self.n_history def predict_single_omega(w): """Compute G* at a single frequency (vmappable).""" period = 2 * jnp.pi / w t = jnp.linspace(0.0, n_cycles * period, n_pts) strain = gamma_0 * jnp.sin(w * t) if include_thermal: T_init = params.get("T_env", params.get("T_ref", 298.15)) stress, _ = fikh_scan_kernel_thermal( t, strain, n_history=n_history, alpha=alpha, use_viscosity=True, T_init=T_init, **params, ) else: stress = fikh_scan_kernel_isothermal( t, strain, n_history=n_history, alpha=alpha, use_viscosity=True, **params, ) # Extract last cycle via dynamic_slice (trace-safe) t_last = jax.lax.dynamic_slice(t, [last_cycle_start], [n_last]) stress_last = jax.lax.dynamic_slice(stress, [last_cycle_start], [n_last]) # Fourier decomposition (first harmonic) # F-034: use dt from actual time points (not T_cycle / n_last) dt = t_last[1] - t_last[0] T_cycle = t_last[-1] - t_last[0] + dt # exact integration span G_prime = (2 / (gamma_0 * T_cycle)) * jnp.trapezoid( stress_last * jnp.sin(w * t_last), dx=dt ) G_double_prime = (2 / (gamma_0 * T_cycle)) * jnp.trapezoid( stress_last * jnp.cos(w * t_last), dx=dt ) return jnp.array([G_prime, G_double_prime]) # Vectorize over all frequencies at once results = jax.vmap(predict_single_omega)(omega) # (N_omega, 2) return results[:, 0] + 1j * results[:, 1] # ========================================================================= # Prediction Methods # ========================================================================= def _predict_from_params( self, times: jnp.ndarray, strains: jnp.ndarray, params: dict[str, Any], ) -> jnp.ndarray: """Predict stress using parameter dictionary. This is the core prediction method used by both NLSQ fitting and Bayesian inference. Args: times: Time array. strains: Strain array. params: Parameter dictionary. Returns: Predicted stress array. """ from rheojax.models.fikh._kernels import ( fikh_scan_kernel_isothermal, fikh_scan_kernel_thermal, ) # Extract alpha (can now be a traced value since it's not in static_argnums) alpha = params.get("alpha_structure", self.alpha_structure) if self.include_thermal: T_init = params.get("T_env", params.get("T_ref", 298.15)) sigma_series, _ = fikh_scan_kernel_thermal( times, strains, n_history=self.n_history, alpha=alpha, use_viscosity=True, T_init=T_init, **params, ) else: sigma_series = fikh_scan_kernel_isothermal( times, strains, n_history=self.n_history, alpha=alpha, use_viscosity=True, **params, ) return sigma_series def _predict(self, X: ArrayLike, **kwargs) -> ArrayLike: """Predict based on test_mode. Args: X: Input data (shape depends on test_mode). **kwargs: Additional parameters. Returns: Predicted values. """ _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else ( getattr(self, "_test_mode", None) if getattr(self, "_test_mode", None) is not None else "startup" ) ) mode = self._validate_test_mode(test_mode) params = self._get_params_dict() if mode == TestMode.FLOW_CURVE: from rheojax.models.fikh._kernels import fikh_flow_curve_steady_state gamma_dot = jnp.asarray(X) return fikh_flow_curve_steady_state( gamma_dot, include_thermal=self.include_thermal, **params ) elif mode in (TestMode.CREEP, TestMode.RELAXATION): t = jnp.asarray(X) gamma_dot = kwargs.get("gamma_dot", 0.0) sigma_applied = kwargs.get("sigma_applied", 100.0) sigma_0 = kwargs.get("sigma_0", 60.0) T_init = kwargs.get("T_init", None) return self._simulate_transient( t, params, mode.value, gamma_dot, sigma_applied, sigma_0, T_init ) elif mode == TestMode.OSCILLATION: # Frequency-domain SAOS: X is omega, return G* omega = jnp.asarray(X) gamma_0 = kwargs.get("gamma_0", 0.01) n_cycles = kwargs.get("n_cycles", 5) return self._predict_oscillation_from_params( omega, params, gamma_0, n_cycles ) else: # Strain-driven protocols (startup, laos) times, strains = self._extract_time_strain(X, **kwargs) return self._predict_from_params(times, strains, params) # ========================================================================= # Protocol-Specific Prediction Methods # =========================================================================
[docs] def predict_flow_curve( self, gamma_dot: ArrayLike, T: float | None = None ) -> ArrayLike: """Predict steady-state flow curve. Args: gamma_dot: Shear rate array. T: Temperature (if thermal enabled). Returns: Stress array. """ return self._predict(gamma_dot, test_mode="flow_curve")
[docs] def predict_startup( self, t: ArrayLike, gamma_dot: float = 1.0, T_init: float | None = None, ) -> ArrayLike: """Predict startup shear response. Args: t: Time array. gamma_dot: Constant shear rate. T_init: Initial temperature. Returns: Stress vs time. """ params = self._get_params_dict() return self._simulate_transient( jnp.asarray(t), params, "startup", gamma_dot=gamma_dot, T_init=T_init )
[docs] def predict_relaxation( self, t: ArrayLike, sigma_0: float = 100.0, T_init: float | None = None, ) -> ArrayLike: """Predict stress relaxation. Args: t: Time array. sigma_0: Initial stress. T_init: Initial temperature. Returns: Stress vs time. """ params = self._get_params_dict() return self._simulate_transient( jnp.asarray(t), params, "relaxation", sigma_0=sigma_0, T_init=T_init )
[docs] def predict_creep( self, t: ArrayLike, sigma_applied: float = 50.0, T_init: float | None = None, ) -> ArrayLike: """Predict creep response. Args: t: Time array. sigma_applied: Applied constant stress. T_init: Initial temperature. Returns: Strain vs time. """ params = self._get_params_dict() return self._simulate_transient( jnp.asarray(t), params, "creep", sigma_applied=sigma_applied, T_init=T_init )
[docs] def predict_oscillation( self, omega: ArrayLike, gamma_0: float = 0.01, n_cycles: int = 5, ) -> jnp.ndarray: """Predict oscillatory response (SAOS G*, G', G''). For small amplitudes, this uses the linearized response. For accurate nonlinear response, use predict_laos(). Args: omega: Angular frequency array. gamma_0: Strain amplitude (should be small). n_cycles: Number of cycles to simulate. Returns: Complex modulus G* = G' + i·G'' for each frequency. """ omega_arr = jnp.asarray(omega) params = self._get_params_dict() # Reuse the vectorized implementation from _predict_oscillation_from_params return self._predict_oscillation_from_params( omega_arr, params, gamma_0, n_cycles )
[docs] def predict_laos( self, t: ArrayLike, gamma_0: float = 1.0, omega: float = 1.0, T_init: float | None = None, ) -> dict[str, jnp.ndarray]: """Predict LAOS (Large Amplitude Oscillatory Shear) response. Args: t: Time array. gamma_0: Strain amplitude. omega: Angular frequency. T_init: Initial temperature. Returns: Dictionary with 'time', 'strain', 'stress', and optionally 'temperature'. """ t_arr = jnp.asarray(t) strain = gamma_0 * jnp.sin(omega * t_arr) params = self._get_params_dict() from rheojax.models.fikh._kernels import ( fikh_scan_kernel_isothermal, fikh_scan_kernel_thermal, ) alpha = params.get("alpha_structure", self.alpha_structure) if self.include_thermal: T_0 = T_init if T_init is not None else params.get("T_env", 298.15) stress, temperature = fikh_scan_kernel_thermal( t_arr, strain, n_history=self.n_history, alpha=alpha, use_viscosity=True, T_init=T_0, **params, ) return { "time": t_arr, "strain": strain, "stress": stress, "temperature": temperature, } else: stress = fikh_scan_kernel_isothermal( t_arr, strain, n_history=self.n_history, alpha=alpha, use_viscosity=True, **params, ) return { "time": t_arr, "strain": strain, "stress": stress, }
# ========================================================================= # Bayesian Interface # =========================================================================
[docs] def model_function( self, X: ArrayLike, params: ArrayLike | dict[str, Any], test_mode: str | None = None, **kwargs, ) -> jnp.ndarray: """Model function for NumPyro Bayesian inference. This method provides a pure function interface for Bayesian sampling, capturing the test_mode via closure for correct mode-aware inference. Args: X: Input data. params: Parameter array or dictionary. test_mode: Protocol (uses stored _test_mode if None). **kwargs: Protocol-specific arguments (e.g., strain, sigma_0). Returns: Predicted values. """ # Prefer explicit test_mode; fall back to _last_fit_kwargs # (set by fit()) over stale self._test_mode to avoid wrong NUTS likelihood if test_mode is not None: mode = test_mode elif getattr(self, "_last_fit_kwargs", {}).get("test_mode") is not None: mode = self._last_fit_kwargs["test_mode"] elif self._test_mode is not None: mode = self._test_mode else: mode = "startup" # Convert array to dict if needed if isinstance(params, (np.ndarray, jnp.ndarray)): param_names = list(self.parameters.keys()) param_dict = dict(zip(param_names, params, strict=False)) else: param_dict = dict(params) mode_enum = self._validate_test_mode(mode) if mode_enum == TestMode.FLOW_CURVE: from rheojax.models.fikh._kernels import fikh_flow_curve_steady_state gamma_dot = jnp.asarray(X) return fikh_flow_curve_steady_state( gamma_dot, include_thermal=self.include_thermal, **param_dict ) elif mode_enum in (TestMode.CREEP, TestMode.RELAXATION): t = jnp.asarray(X) gamma_dot = kwargs.get("gamma_dot", getattr(self, "_fit_gamma_dot", 0.0)) sigma_applied = kwargs.get( "sigma_applied", getattr(self, "_fit_sigma_applied", 100.0) ) sigma_0 = kwargs.get("sigma_0", getattr(self, "_fit_sigma_0", 60.0)) return self._simulate_transient( t, param_dict, mode_enum.value, gamma_dot, sigma_applied, sigma_0 ) elif mode_enum == TestMode.OSCILLATION: # Frequency-domain SAOS: X is omega, return |G*| for Bayesian fitting omega = jnp.asarray(X) gamma_0 = kwargs.get("gamma_0", getattr(self, "_fit_gamma_0", 0.01)) n_cycles = kwargs.get("n_cycles", getattr(self, "_fit_n_cycles", 5)) G_star = self._predict_oscillation_from_params( omega, param_dict, gamma_0, n_cycles ) return jnp.column_stack([jnp.real(G_star), jnp.imag(G_star)]) else: # Strain-driven protocols (startup, laos) times, strains = self._extract_time_strain(X, **kwargs) return self._predict_from_params(times, strains, param_dict)
# ========================================================================= # Utility Methods # =========================================================================
[docs] def get_limiting_behavior(self) -> dict[str, Any]: """Get limiting behavior diagnostics. Returns: Dictionary with limiting cases and expected behavior. """ alpha = self.parameters.get_value("alpha_structure") E_a = self.parameters.get_value("E_a") if self.include_thermal else 0.0 return { "fractional_order": alpha, "is_near_integer": alpha > 0.95, "memory_type": ( "weak (near exponential)" if alpha > 0.7 else "strong (power-law)" ), "thermal_coupling": self.include_thermal, "arrhenius_enabled": E_a > 0 if self.include_thermal else False, "limiting_case_alpha_1": "Classical MIKH behavior", "limiting_case_E_a_0": "Isothermal FIKH behavior", }
[docs] def precompile( self, test_mode: str = "relaxation", X=None, y=None, *, n_points: int = 100, verbose: bool = True, ) -> float: """Precompile JIT kernels for faster subsequent predictions. Triggers JAX JIT compilation of the core FIKH kernels by running a small dummy prediction. This is useful when you want to avoid the compilation overhead on first real prediction. Args: test_mode: Accepted for parent compatibility (unused). X: Accepted for parent compatibility (unused). y: Accepted for parent compatibility (unused). n_points: Number of time points for dummy data. verbose: Print compilation time if True. Returns: Compilation time in seconds. Example: >>> model = FIKH(include_thermal=True) >>> compile_time = model.precompile() # Triggers JIT >>> # Now predictions will be fast >>> sigma = model.predict_startup(t_real, gamma_dot=1.0) """ import time as time_module # Create dummy data t_dummy = jnp.linspace(0, 10, n_points) strain_dummy = 0.1 * t_dummy # Linear ramp params = self._get_params_dict() start = time_module.perf_counter() # Trigger isothermal kernel compilation from rheojax.models.fikh._kernels import ( fikh_scan_kernel_isothermal, fikh_scan_kernel_thermal, ) alpha = params.get("alpha_structure", self.alpha_structure) # Always compile isothermal kernel _ = fikh_scan_kernel_isothermal( t_dummy, strain_dummy, n_history=self.n_history, alpha=alpha, use_viscosity=True, **params, ) # Compile thermal kernel if enabled if self.include_thermal: T_init = params.get("T_env", params.get("T_ref", 298.15)) _ = fikh_scan_kernel_thermal( t_dummy, strain_dummy, n_history=self.n_history, alpha=alpha, use_viscosity=True, T_init=T_init, **params, ) elapsed = time_module.perf_counter() - start if verbose: logger.info( "FIKH kernels precompiled", compile_time_s=f"{elapsed:.2f}", include_thermal=self.include_thermal, ) return elapsed
[docs] def __repr__(self) -> str: """String representation.""" alpha = self.parameters.get_value("alpha_structure") return ( f"FIKH(include_thermal={self.include_thermal}, " f"alpha_structure={alpha:.3f}, n_history={self.n_history})" )