"""ITT-MCT Schematic F₁₂ Model.
The F₁₂ schematic model is a simplified Mode-Coupling Theory that captures
the essential physics of the glass transition with minimal parameters:
- Glass transition at v₂ = 4 (for v₁ = 0)
- Yield stress in glass state (ε > 0)
- Shear thinning from cage breaking
- Two-step relaxation (β and α processes)
Parameters
----------
v1 : float
Linear vertex coefficient (typically 0)
v2 : float
Quadratic vertex coefficient (glass transition at v₂_c = 4)
Gamma : float
Bare relaxation rate (1/s)
gamma_c : float
Critical strain for cage breaking (dimensionless)
G_inf : float
High-frequency modulus (Pa)
epsilon : float
Separation parameter ε = (v₂ - v₂_c)/v₂_c
References
----------
Götze W. (2009) "Complex Dynamics of Glass-Forming Liquids", Chapter 4
Fuchs M. & Cates M.E. (2002) Phys. Rev. Lett. 89, 248304
"""
from __future__ import annotations
from typing import Any, Literal
import numpy as np
from scipy.integrate import solve_ivp
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger
from rheojax.models.itt_mct._base import ITTMCTBase
from rheojax.models.itt_mct._kernels import (
compute_complex_modulus_from_correlator,
extract_laos_harmonics,
f12_equilibrium_correlator_rhs,
f12_memory,
f12_volterra_creep_rhs,
f12_volterra_laos_rhs,
f12_volterra_relaxation_rhs,
f12_volterra_startup_rhs,
)
from rheojax.utils.mct_kernels import (
glass_transition_criterion,
)
# Try to import diffrax-based solvers for fast ODE integration
try:
from rheojax.models.itt_mct._kernels_diffrax import (
is_diffrax_available,
precompile_flow_curve_solver,
solve_flow_curve_batch,
)
_HAS_DIFFRAX = is_diffrax_available()
except ImportError:
_HAS_DIFFRAX = False
def precompile_flow_curve_solver(*args, **kwargs): # type: ignore[misc]
"""Stub when diffrax not available."""
return 0.0
jax, jnp = safe_import_jax()
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"itt_mct_schematic",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
Protocol.CREEP,
Protocol.RELAXATION,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class ITTMCTSchematic(ITTMCTBase):
"""ITT-MCT Schematic F₁₂ Model.
The F₁₂ model uses a quadratic memory kernel m(Φ) = v₁Φ + v₂Φ²
to describe the cage effect in dense colloidal suspensions.
The glass transition occurs when the non-ergodicity parameter f
(long-time limit of Φ) becomes non-zero, which happens at v₂ = v₂_c = 4
for v₁ = 0.
Parameters
----------
epsilon : float, optional
Separation parameter. If provided, v₂ is set to achieve this ε.
ε < 0: fluid state
ε = 0: critical point
ε > 0: glass state
v2 : float, optional
Quadratic vertex coefficient. Alternative to epsilon.
integration_method : {"volterra", "history"}, default "volterra"
Integration method for memory kernel
n_prony_modes : int, default 10
Number of Prony modes for Volterra integration
decorrelation_form : {"gaussian", "lorentzian"}, default "gaussian"
Strain decorrelation function form:
- "gaussian": h(γ) = exp(-(γ/γ_c)²) - faster exponential decay
- "lorentzian": h(γ) = 1/(1+(γ/γ_c)²) - slower algebraic decay
memory_form : {"simplified", "full"}, default "simplified"
Memory kernel form:
- "simplified": single decorrelation m(Φ) = h[γ_acc] × (v₁Φ + v₂Φ²)
- "full": two-time decorrelation m(t,s,t₀) = h[γ(t,t₀)] × h[γ(t,s)] × (v₁Φ + v₂Φ²)
stress_form : {"schematic", "microscopic"}, default "schematic"
Stress computation form:
- "schematic": σ = G_∞ × γ̇ × ∫ Φ² × h(γ) dt (standard schematic)
- "microscopic": σ = (k_BT/60π²) × ∫dk k⁴ [S'/S²]² Φ² (structure factor weighted)
phi_volume : float, optional
Volume fraction for Percus-Yevick S(k). Required if stress_form="microscopic".
k_BT : float, default 1.0
Thermal energy k_B × T in Joules. Default 1.0 gives dimensionless stress.
Attributes
----------
parameters : ParameterSet
Model parameters with the following:
- v1: Linear vertex (default 0)
- v2: Quadratic vertex (default 2.0, fluid state)
- Gamma: Bare relaxation rate (default 1.0 s⁻¹)
- gamma_c: Critical strain (default 0.1)
- G_inf: High-frequency modulus (default 1e6 Pa)
Examples
--------
>>> model = ITTMCTSchematic(epsilon=-0.1) # Fluid state
>>> model.get_glass_transition_info()
{'is_glass': False, 'epsilon': -0.1, ...}
>>> model = ITTMCTSchematic(epsilon=0.05) # Glass state
>>> sigma = model.predict(np.logspace(-3, 2, 50), test_mode='flow_curve')
>>> # Shows yield stress at low rates
>>> # Use Lorentzian decorrelation for materials with extended yielding
>>> model = ITTMCTSchematic(epsilon=0.05, decorrelation_form="lorentzian")
>>> # Use full two-time memory kernel (Fuchs & Cates 2002)
>>> model = ITTMCTSchematic(epsilon=0.05, memory_form="full")
>>> # Use microscopic stress with Percus-Yevick S(k)
>>> model = ITTMCTSchematic(
... epsilon=0.05,
... stress_form="microscopic",
... phi_volume=0.5,
... k_BT=4.11e-21, # Room temperature
... )
"""
[docs]
def __init__(
self,
epsilon: float | None = None,
v2: float | None = None,
integration_method: Literal["volterra", "history"] = "volterra",
n_prony_modes: int = 10,
decorrelation_form: Literal["gaussian", "lorentzian"] = "gaussian",
memory_form: Literal["simplified", "full"] = "simplified",
stress_form: Literal["schematic", "microscopic"] = "schematic",
phi_volume: float | None = None,
k_BT: float = 1.0,
):
"""Initialize F₁₂ Schematic Model.
Parameters
----------
epsilon : float, optional
Separation parameter ε = (v₂ - v₂_c)/v₂_c.
Mutually exclusive with v2.
v2 : float, optional
Direct vertex coefficient. Mutually exclusive with epsilon.
integration_method : str, default "volterra"
Integration method for memory kernel
n_prony_modes : int, default 10
Number of Prony modes
decorrelation_form : {"gaussian", "lorentzian"}, default "gaussian"
Form of the strain decorrelation function h(γ):
- "gaussian": h(γ) = exp(-(γ/γ_c)²) - faster decay (default, Fuchs & Cates 2002)
- "lorentzian": h(γ) = 1/(1+(γ/γ_c)²) - slower algebraic decay (Brader et al. 2008)
memory_form : {"simplified", "full"}, default "simplified"
Memory kernel form:
- "simplified": single decorrelation m(Φ) = h[γ_acc] × (v₁Φ + v₂Φ²)
- "full": two-time decorrelation m(t,s,t₀) = h[γ(t,t₀)] × h[γ(t,s)] × (v₁Φ + v₂Φ²)
stress_form : {"schematic", "microscopic"}, default "schematic"
Stress computation form:
- "schematic": σ = G_∞ × γ̇ × ∫ Φ² × h(γ) dt (standard schematic)
- "microscopic": σ = (k_BT/60π²) × ∫dk k⁴ [S'/S²]² Φ² (structure factor weighted)
phi_volume : float, optional
Volume fraction for Percus-Yevick S(k). Required if stress_form="microscopic".
k_BT : float, default 1.0
Thermal energy k_B × T in Joules. Default 1.0 gives dimensionless stress.
Use 4.11e-21 J for T=298K with real units.
"""
# Store initialization parameters before parent __init__
self._init_epsilon = epsilon
self._init_v2 = v2
# Validate decorrelation form
if decorrelation_form not in ("gaussian", "lorentzian"):
raise ValueError(
f"decorrelation_form must be 'gaussian' or 'lorentzian', got {decorrelation_form!r}"
)
self._use_lorentzian = decorrelation_form == "lorentzian"
self._decorrelation_form = decorrelation_form
# Validate memory form
if memory_form not in ("simplified", "full"):
raise ValueError(
f"memory_form must be 'simplified' or 'full', got {memory_form!r}"
)
self._memory_form = memory_form
# Validate stress form
if stress_form not in ("schematic", "microscopic"):
raise ValueError(
f"stress_form must be 'schematic' or 'microscopic', got {stress_form!r}"
)
if stress_form == "microscopic" and phi_volume is None:
raise ValueError("phi_volume is required when stress_form='microscopic'")
self._stress_form = stress_form
self._phi_volume = phi_volume
self._k_BT = k_BT
# Pre-compute microscopic stress prefactor if needed
self._microscopic_stress_prefactor = None
if stress_form == "microscopic":
from rheojax.utils.mct_kernels import get_microscopic_stress_prefactor
assert phi_volume is not None
self._microscopic_stress_prefactor = get_microscopic_stress_prefactor(
phi_volume, k_BT=k_BT
)
super().__init__(
integration_method=integration_method,
n_prony_modes=n_prony_modes,
)
# Track physics params for Prony cache invalidation
self._prony_param_hash: tuple[float, ...] | None = None
# Set v2 from epsilon or direct value
if epsilon is not None and v2 is not None:
raise ValueError("Specify either epsilon or v2, not both")
v1 = self.parameters.get_value("v1")
assert v1 is not None
v2_critical = self._get_v2_critical(v1)
if epsilon is not None:
v2_value = v2_critical * (1 + epsilon)
self.parameters.set_value("v2", v2_value)
elif v2 is not None:
self.parameters.set_value("v2", v2)
def _setup_parameters(self) -> None:
"""Initialize F₁₂ model parameters."""
self.parameters = ParameterSet()
# Vertex coefficients
self.parameters.add(
name="v1",
value=0.0,
bounds=(0.0, 5.0),
units="-",
description="Linear vertex coefficient (typically 0 for F₁₂)",
)
self.parameters.add(
name="v2",
value=2.0, # Default: fluid state
bounds=(0.5, 10.0),
units="-",
description="Quadratic vertex coefficient (glass at v₂ > 4)",
)
# Dynamics
self.parameters.add(
name="Gamma",
value=1.0,
bounds=(1e-6, 1e6),
units="1/s",
description="Bare relaxation rate",
)
# Strain decorrelation
self.parameters.add(
name="gamma_c",
value=0.1,
bounds=(0.01, 0.5),
units="-",
description="Critical strain for cage breaking",
)
# Modulus
self.parameters.add(
name="G_inf",
value=1e6,
bounds=(1.0, 1e12),
units="Pa",
description="High-frequency elastic modulus",
)
def _get_v2_critical(self, v1: float) -> float:
"""Get critical v₂ value for glass transition.
Parameters
----------
v1 : float
Linear vertex coefficient
Returns
-------
float
Critical v₂ value
"""
# For F₁₂ with v₁ = 0: v₂_c = 4
if abs(v1) < 1e-10:
return 4.0
else:
# Approximate for non-zero v₁
return (4.0 - 2.0 * v1) / (1.0 - v1 / 4.0) if v1 < 4.0 else 4.0
def _compute_equilibrium_correlator(
self,
t: jnp.ndarray,
) -> jnp.ndarray:
"""Compute equilibrium (quiescent) correlator Φ_eq(t).
Solves the MCT equation without shear:
∂Φ/∂t + Γ[Φ + ∫₀^t m(Φ) ∂Φ/∂s ds] = 0
Parameters
----------
t : jnp.ndarray
Time array
Returns
-------
jnp.ndarray
Equilibrium correlator Φ_eq(t)
"""
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
Gamma = self.parameters.get_value("Gamma")
t_np = np.array(t)
t_max = t_np.max()
# Get or initialize Prony modes (bootstrap with power-law seed)
if self._prony_amplitudes is None:
tau_modes = np.logspace(-3, np.log10(t_max), self.n_prony_modes)
# Power-law taper preserves weight at ALL timescales,
# critical for MCT glass states where Φ_eq has a non-ergodic
# plateau. Exponent a ≈ 0.3 matches MCT critical dynamics.
g_modes = tau_modes ** (-0.3)
g_modes /= g_modes.sum() # Normalize
self._prony_amplitudes = g_modes
self._prony_times = tau_modes
g = jnp.array(self._prony_amplitudes)
tau = jnp.array(self._prony_times)
# Initial state: [Φ, K₁, K₂, ..., K_n]
state0 = np.zeros(1 + self.n_prony_modes)
state0[0] = 1.0 # Φ(0) = 1
def rhs_numpy(t_val, state):
"""Numpy wrapper for ODE solver."""
state_jax = jnp.array(state)
deriv = f12_equilibrium_correlator_rhs(
state_jax, t_val, v1, v2, Gamma, g, tau, self.n_prony_modes
)
return np.array(deriv)
# Solve ODE
sol = solve_ivp(
rhs_numpy,
[0, t_max],
state0,
t_eval=t_np,
method="RK45",
rtol=1e-8,
atol=1e-10,
)
# Extract correlator
phi_eq = jnp.array(sol.y[0, :])
# Ensure physical bounds
phi_eq = jnp.clip(phi_eq, 0.0, 1.0)
return phi_eq
def _compute_memory_kernel(
self,
phi: jnp.ndarray,
) -> jnp.ndarray:
"""Compute memory kernel m(Φ) = v₁Φ + v₂Φ².
Parameters
----------
phi : jnp.ndarray
Correlator values
Returns
-------
jnp.ndarray
Memory kernel values
"""
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
return f12_memory(phi, v1, v2)
[docs]
def get_glass_transition_info(self) -> dict[str, Any]:
"""Get information about the glass transition state.
Returns
-------
dict
Glass transition properties:
- is_glass: bool
- epsilon: separation parameter
- v2_critical: critical v₂ value
- f_neq: non-ergodicity parameter
- lambda_exponent: MCT exponent parameter
"""
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
assert v1 is not None
assert v2 is not None
return glass_transition_criterion(v1, v2)
@property
def epsilon(self) -> float:
"""Get separation parameter ε = (v₂ - v₂_c)/v₂_c."""
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
assert v1 is not None
assert v2 is not None
v2_c = self._get_v2_critical(v1)
return (v2 - v2_c) / v2_c
@epsilon.setter
def epsilon(self, value: float) -> None:
"""Set separation parameter and update v₂ accordingly."""
v1 = self.parameters.get_value("v1")
assert v1 is not None
v2_c = self._get_v2_critical(v1)
v2_new = v2_c * (1 + value)
self.parameters.set_value("v2", v2_new)
# =========================================================================
# Protocol Implementations
# =========================================================================
def _predict_flow_curve(
self,
gamma_dot: np.ndarray,
use_diffrax: bool | None = None,
**kwargs,
) -> np.ndarray:
"""Predict steady-state flow curve σ(γ̇).
Parameters
----------
gamma_dot : np.ndarray
Shear rate array (1/s)
use_diffrax : bool, optional
Force use of diffrax (True) or scipy (False).
If None (default), uses diffrax when available.
Returns
-------
np.ndarray
Steady-state stress σ (Pa)
"""
gamma_dot = np.asarray(gamma_dot)
# Get parameters
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
Gamma = self.parameters.get_value("Gamma")
gamma_c = self.parameters.get_value("gamma_c")
G_inf = self.parameters.get_value("G_inf")
# Assert parameters are not None
assert v1 is not None
assert v2 is not None
assert Gamma is not None
assert gamma_c is not None
assert G_inf is not None
# Invalidate Prony cache if physics params changed, then init
self._check_prony_cache()
if self._prony_amplitudes is None:
self.initialize_prony_modes()
g = self._prony_amplitudes
tau = self._prony_times
assert g is not None
assert tau is not None
# Determine which solver to use
should_use_diffrax = use_diffrax if use_diffrax is not None else _HAS_DIFFRAX
if should_use_diffrax and _HAS_DIFFRAX:
return self._predict_flow_curve_diffrax(
gamma_dot, v1, v2, Gamma, gamma_c, G_inf, g, tau
)
else:
return self._predict_flow_curve_scipy(
gamma_dot, v1, v2, Gamma, gamma_c, G_inf, g, tau
)
def _predict_flow_curve_diffrax(
self,
gamma_dot: np.ndarray,
v1: float,
v2: float,
Gamma: float,
gamma_c: float,
G_inf: float,
g: np.ndarray,
tau: np.ndarray,
) -> np.ndarray:
"""Fast flow curve prediction using diffrax + vmap.
First call triggers JIT compilation (~5-10s), subsequent calls
are very fast (<0.5s for 50 points).
"""
# Handle zero shear rates separately (yield stress)
mask_zero = gamma_dot < 1e-15
mask_nonzero = ~mask_zero
# Use microscopic prefactor if stress_form is microscopic
G_eff = G_inf
if (
self._stress_form == "microscopic"
and self._microscopic_stress_prefactor is not None
):
G_eff = self._microscopic_stress_prefactor
sigma = np.zeros_like(gamma_dot)
# Zero shear rate: yield stress if glass
if np.any(mask_zero):
info = self.get_glass_transition_info()
if info["is_glass"]:
f_neq = info["f_neq"]
sigma[mask_zero] = G_eff * gamma_c * f_neq * f_neq
# else: sigma stays 0
# Non-zero shear rates: batched diffrax solve
if np.any(mask_nonzero):
gamma_dot_nonzero = gamma_dot[mask_nonzero]
# Call batched diffrax solver with memory_form
# The ODE naturally captures yield stress through the non-zero
# plateau of the advected correlator in the glass state — no
# post-ODE sigma_y addition needed (matches scipy path).
sigma_nonzero = solve_flow_curve_batch(
jnp.asarray(gamma_dot_nonzero),
v1,
v2,
Gamma,
gamma_c,
G_eff, # Use effective modulus (G_inf or microscopic)
jnp.asarray(g),
jnp.asarray(tau),
self.n_prony_modes,
self._use_lorentzian,
self._memory_form,
)
sigma_arr = np.array(
sigma_nonzero
) # writable copy (np.asarray → read-only JAX view)
# NaN fallback: diffrax may fail at low γ̇ in glass state
# (explicit solver exceeds max_steps due to stiff Prony modes).
# Fall back to scipy for these points (implicit RK45 handles stiffness).
nan_mask = np.isnan(sigma_arr)
if np.any(nan_mask):
nan_rates = gamma_dot_nonzero[nan_mask]
for j, gd in enumerate(nan_rates):
sigma_arr[np.where(nan_mask)[0][j]] = (
self._compute_steady_state_stress(float(gd))
)
sigma[mask_nonzero] = sigma_arr
return sigma
def _predict_flow_curve_scipy(
self,
gamma_dot: np.ndarray,
v1: float,
v2: float,
Gamma: float,
gamma_c: float,
G_inf: float,
g: np.ndarray,
tau: np.ndarray,
) -> np.ndarray:
"""Slow flow curve prediction using scipy (fallback).
Warning: This is ~100x slower than diffrax version.
"""
# Use microscopic prefactor if stress_form is microscopic
G_eff = G_inf
if (
self._stress_form == "microscopic"
and self._microscopic_stress_prefactor is not None
):
G_eff = self._microscopic_stress_prefactor
sigma = np.zeros_like(gamma_dot)
for i, gd in enumerate(gamma_dot):
if gd < 1e-15:
# Zero shear rate: yield stress if glass, zero if fluid
info = self.get_glass_transition_info()
if info["is_glass"]:
# Approximate yield stress
f_neq = info["f_neq"]
sigma[i] = G_eff * gamma_c * f_neq * f_neq
else:
sigma[i] = 0.0
else:
# Integrate to steady state
sigma[i] = self._compute_steady_state_stress(gd)
return sigma
def _compute_steady_state_stress(
self,
gamma_dot: float,
t_max: float | None = None,
) -> float:
"""Compute steady-state stress at a single shear rate.
The stress is computed as the time integral:
σ = G_eff * γ̇ * ∫₀^∞ Φ(t) * h(γ(t)) dt
Parameters
----------
gamma_dot : float
Shear rate
t_max : float, optional
Maximum integration time. If None, uses adaptive time.
Returns
-------
float
Steady-state stress
"""
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
Gamma = self.parameters.get_value("Gamma")
gamma_c = self.parameters.get_value("gamma_c")
G_inf = self.parameters.get_value("G_inf")
assert v1 is not None
assert v2 is not None
assert Gamma is not None
assert gamma_c is not None
assert G_inf is not None
# Use microscopic prefactor if stress_form is microscopic
G_eff = G_inf
if (
self._stress_form == "microscopic"
and self._microscopic_stress_prefactor is not None
):
G_eff = self._microscopic_stress_prefactor
self._check_prony_cache()
if self._prony_amplitudes is None:
self.initialize_prony_modes()
assert self._prony_amplitudes is not None
assert self._prony_times is not None
g = jnp.array(self._prony_amplitudes)
tau = jnp.array(self._prony_times)
# Adaptive integration time
if t_max is None:
tau_bare = 1.0 / Gamma
tau_shear = gamma_c / max(gamma_dot, 1e-10)
tau_eff = min(tau_bare, tau_shear)
t_max = 50.0 * tau_eff
t_max = max(10.0, min(t_max, 500.0))
# Initial state: [Φ, K₁..K_n, γ, σ_integral]
state0 = np.zeros(3 + self.n_prony_modes)
state0[0] = 1.0 # Φ(0) = 1
# γ(0) = 0, σ_integral(0) = 0
use_full_memory = self._memory_form == "full"
def rhs_numpy(t_val, state):
# Extract state
phi = state[0]
K = state[1 : 1 + self.n_prony_modes]
gamma_acc = state[1 + self.n_prony_modes]
# Strain decorrelation (use model's decorrelation form)
if self._use_lorentzian:
h_gamma = 1.0 / (1.0 + (gamma_acc / gamma_c) ** 2)
else:
h_gamma = np.exp(-((gamma_acc / gamma_c) ** 2))
phi_advected = phi * h_gamma
# Memory kernel
m_phi = v1 * phi_advected + v2 * phi_advected * phi_advected
# Memory integral from Prony modes
memory_integral = np.sum(K)
# MCT equation
dphi_dt = -Gamma * (phi + memory_integral)
# Prony mode evolution with memory form
if use_full_memory:
# Full two-time: mode-specific decorrelation
gamma_mode = gamma_dot * np.asarray(tau)
if self._use_lorentzian:
h_mode = 1.0 / (1.0 + (gamma_mode / gamma_c) ** 2)
else:
h_mode = np.exp(-((gamma_mode / gamma_c) ** 2))
dK_dt = -K / np.asarray(tau) + np.asarray(g) * m_phi * h_mode * dphi_dt
else:
dK_dt = -K / np.asarray(tau) + np.asarray(g) * m_phi * dphi_dt
# Strain accumulation
dgamma_dt = gamma_dot
# Stress integrand: d(σ_integral)/dt = G_eff × γ̇ × Φ_adv² (Fuchs & Cates 2002)
dsigma_dt = G_eff * gamma_dot * phi_advected * phi_advected
return np.concatenate([[dphi_dt], dK_dt, [dgamma_dt, dsigma_dt]])
# Integrate ODE
t_span = [0, t_max]
sol = solve_ivp(rhs_numpy, t_span, state0, method="RK45", rtol=1e-5, atol=1e-7)
# Extract stress integral from final state
sigma = sol.y[2 + self.n_prony_modes, -1]
return float(sigma)
def _predict_oscillation(
self,
omega: np.ndarray,
return_components: bool = False,
**kwargs,
) -> np.ndarray:
"""Predict linear viscoelastic moduli G*(ω).
Parameters
----------
omega : np.ndarray
Angular frequency (rad/s)
return_components : bool, default False
If True, return (G', G'') as shape (n, 2)
Returns
-------
np.ndarray
Complex modulus G* = G' + iG'' by default.
If return_components=True, returns (n, 2) array [G', G''].
"""
omega = np.asarray(omega)
G_inf = self.parameters.get_value("G_inf")
assert G_inf is not None
# Ensure Prony modes are refined (not just bootstrap seed).
# Without this, _compute_equilibrium_correlator uses crude bootstrap
# modes → misses glass plateau → G'(ω) shows ω² instead of plateau.
self._check_prony_cache()
if self._prony_amplitudes is None:
self.initialize_prony_modes()
# Need equilibrium correlator over sufficient time range
omega_min = omega.min()
t_max = 100.0 / omega_min # Cover several periods of slowest frequency
t = np.logspace(-4, np.log10(t_max), 2000)
# Compute equilibrium correlator (uses refined Prony modes)
phi_eq = np.array(self._compute_equilibrium_correlator(jnp.array(t)))
# Compute G*(ω) via Fourier transform
G_prime, G_double_prime = compute_complex_modulus_from_correlator(
jnp.array(omega),
jnp.array(t),
jnp.array(phi_eq),
G_inf,
)
G_prime = np.array(G_prime)
G_double_prime = np.array(G_double_prime)
if return_components:
return np.column_stack([G_prime, G_double_prime])
return G_prime + 1j * G_double_prime
def _predict_startup(
self,
t: np.ndarray,
gamma_dot: float = 1.0,
**kwargs,
) -> np.ndarray:
"""Predict stress growth in startup flow.
Parameters
----------
t : np.ndarray
Time array (s)
gamma_dot : float, default 1.0
Applied shear rate (1/s)
Returns
-------
np.ndarray
Stress response σ(t) (Pa)
"""
t = np.asarray(t)
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
Gamma = self.parameters.get_value("Gamma")
gamma_c = self.parameters.get_value("gamma_c")
G_inf = self.parameters.get_value("G_inf")
assert v1 is not None
assert v2 is not None
assert Gamma is not None
assert gamma_c is not None
assert G_inf is not None
self._check_prony_cache()
if self._prony_amplitudes is None:
self.initialize_prony_modes()
assert self._prony_amplitudes is not None
assert self._prony_times is not None
g = jnp.array(self._prony_amplitudes)
tau = jnp.array(self._prony_times)
# Initial state: [Φ, K₁..K_n, γ, σ]
state0 = np.zeros(3 + self.n_prony_modes)
state0[0] = 1.0 # Φ(0) = 1
state0[-2] = 0.0 # γ(0) = 0
state0[-1] = 0.0 # σ(0) = 0
def rhs_numpy(t_val, state):
state_jax = jnp.array(state)
deriv = f12_volterra_startup_rhs(
state_jax,
t_val,
gamma_dot,
v1,
v2,
Gamma,
gamma_c,
G_inf,
g,
tau,
self.n_prony_modes,
self._use_lorentzian,
memory_form=self._memory_form,
)
return np.array(deriv)
sol = solve_ivp(
rhs_numpy,
[0, t.max()],
state0,
t_eval=t,
method="RK45",
rtol=1e-6,
atol=1e-8,
)
# Extract stress
sigma = sol.y[-1, :]
return sigma
def _predict_creep(
self,
t: np.ndarray,
sigma_applied: float = 1.0,
**kwargs,
) -> np.ndarray:
"""Predict creep compliance J(t).
Parameters
----------
t : np.ndarray
Time array (s)
sigma_applied : float, default 1.0
Applied stress (Pa)
Returns
-------
np.ndarray
Creep compliance J(t) = γ(t)/σ₀ (1/Pa)
"""
t = np.asarray(t)
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
Gamma = self.parameters.get_value("Gamma")
gamma_c = self.parameters.get_value("gamma_c")
G_inf = self.parameters.get_value("G_inf")
assert v1 is not None
assert v2 is not None
assert Gamma is not None
assert gamma_c is not None
assert G_inf is not None
self._check_prony_cache()
if self._prony_amplitudes is None:
self.initialize_prony_modes()
assert self._prony_amplitudes is not None
assert self._prony_times is not None
g = jnp.array(self._prony_amplitudes)
tau = jnp.array(self._prony_times)
# Initial state: [Φ, K₁..K_n, γ, γ̇]
# R10-SCH-001: correct creep IC.
# At t=0⁺ there is an instantaneous elastic strain jump γ_e(0) = σ₀/G_inf
# (the "elastic jump" in a Maxwell-type creep experiment).
# The shear rate γ̇ starts at 0 — it evolves from the constraint equation.
# Setting state0[-1] = sigma_applied/G_inf was assigning this elastic strain
# to γ̇ instead of γ, giving a physically wrong initial condition.
state0 = np.zeros(3 + self.n_prony_modes)
state0[0] = 1.0 # Φ(0) = 1
state0[-2] = sigma_applied / G_inf # γ(0) = σ₀/G_inf (elastic jump)
state0[-1] = 0.0 # γ̇(0) = 0 (rate starts from rest)
def rhs_numpy(t_val, state):
state_jax = jnp.array(state)
deriv = f12_volterra_creep_rhs(
state_jax,
t_val,
sigma_applied,
v1,
v2,
Gamma,
gamma_c,
G_inf,
g,
tau,
self.n_prony_modes,
self._use_lorentzian,
memory_form=self._memory_form,
)
return np.array(deriv)
sol = solve_ivp(
rhs_numpy,
[0, t.max()],
state0,
t_eval=t,
method="RK45",
rtol=1e-6,
atol=1e-8,
)
# Extract strain and compute compliance
gamma = sol.y[-2, :]
J = gamma / sigma_applied
return J
def _predict_relaxation(
self,
t: np.ndarray,
gamma_pre: float = 0.01,
**kwargs,
) -> np.ndarray:
"""Predict stress relaxation after flow cessation.
Parameters
----------
t : np.ndarray
Time array (s) after stopping
gamma_pre : float, default 0.01
Pre-shear strain before relaxation
Returns
-------
np.ndarray
Relaxing stress σ(t) (Pa)
"""
t = np.asarray(t)
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
Gamma = self.parameters.get_value("Gamma")
gamma_c = self.parameters.get_value("gamma_c")
G_inf = self.parameters.get_value("G_inf")
assert v1 is not None
assert v2 is not None
assert Gamma is not None
assert gamma_c is not None
assert G_inf is not None
self._check_prony_cache()
if self._prony_amplitudes is None:
self.initialize_prony_modes()
assert self._prony_amplitudes is not None
assert self._prony_times is not None
g = jnp.array(self._prony_amplitudes)
tau = jnp.array(self._prony_times)
# Initial state after pre-shear: [Φ, K₁..K_n, σ]
# h(γ_pre) represents the decorrelated correlator after an
# instantaneous step strain γ_pre. The cage correlation is
# partially destroyed by the applied strain, so Φ(0) = h(γ_pre)
# rather than 1.0 — this is the standard MCT step-strain IC.
if self._use_lorentzian:
h_gamma = 1.0 / (1.0 + (gamma_pre / gamma_c) ** 2)
else:
h_gamma = np.exp(-((gamma_pre / gamma_c) ** 2))
h_pre = h_gamma # h(γ_pre): strain decorrelation at the pre-shear amplitude
state0 = np.zeros(2 + self.n_prony_modes)
state0[0] = h_pre # Φ(0) = h(γ_pre) — decorrelated by step strain
state0[1 + self.n_prony_modes] = (
G_inf * gamma_pre * h_pre * h_pre
) # σ(0) = G_∞ γ_pre h²
def rhs_numpy(t_val, state):
state_jax = jnp.array(state)
deriv = f12_volterra_relaxation_rhs(
state_jax,
t_val,
gamma_pre,
v1,
v2,
Gamma,
gamma_c,
G_inf,
g,
tau,
self.n_prony_modes,
self._use_lorentzian,
memory_form=self._memory_form,
)
return np.array(deriv)
sol = solve_ivp(
rhs_numpy,
[0, t.max()],
state0,
t_eval=t,
method="RK45",
rtol=1e-6,
atol=1e-8,
)
# Extract stress
sigma = sol.y[-1, :]
# R10-SCH-002: remove the np.maximum clamp for the glass state.
# The ODE already captures the non-ergodic plateau through the initial
# condition state0[0] = h(γ_pre) and the Volterra memory kernel. Applying
# sigma = np.maximum(sigma, sigma_residual) was double-applying the glass
# contribution AND suppressing the physical β-relaxation dip (the transient
# below the long-time plateau) which is a key MCT signature in glass states.
# The ODE naturally converges to the correct long-time plateau via the
# non-ergodic correlator.
return sigma
def _predict_laos(
self,
t: np.ndarray,
gamma_0: float = 0.1,
omega: float = 1.0,
**kwargs,
) -> np.ndarray:
"""Predict LAOS stress response.
Parameters
----------
t : np.ndarray
Time array (s)
gamma_0 : float, default 0.1
Strain amplitude
omega : float, default 1.0
Angular frequency (rad/s)
Returns
-------
np.ndarray
Stress response σ(t) (Pa)
"""
t = np.asarray(t)
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
Gamma = self.parameters.get_value("Gamma")
gamma_c = self.parameters.get_value("gamma_c")
G_inf = self.parameters.get_value("G_inf")
assert v1 is not None
assert v2 is not None
assert Gamma is not None
assert gamma_c is not None
assert G_inf is not None
self._check_prony_cache()
if self._prony_amplitudes is None:
self.initialize_prony_modes()
assert self._prony_amplitudes is not None
assert self._prony_times is not None
g = jnp.array(self._prony_amplitudes)
tau = jnp.array(self._prony_times)
# Initial state: [Φ, K₁..K_n, γ_acc, σ]
state0 = np.zeros(3 + self.n_prony_modes)
state0[0] = 1.0 # Φ(0) = 1
state0[-2] = 0.0 # γ_acc(0) = 0
state0[-1] = 0.0 # σ(0) = 0
def rhs_numpy(t_val, state):
state_jax = jnp.array(state)
deriv = f12_volterra_laos_rhs(
state_jax,
t_val,
gamma_0,
omega,
v1,
v2,
Gamma,
gamma_c,
G_inf,
g,
tau,
self.n_prony_modes,
self._use_lorentzian,
memory_form=self._memory_form,
)
return np.array(deriv)
sol = solve_ivp(
rhs_numpy,
[0, t.max()],
state0,
t_eval=t,
method="RK45",
rtol=1e-6,
atol=1e-8,
)
# Extract stress
sigma = sol.y[-1, :]
return sigma
[docs]
def get_laos_harmonics(
self,
t: np.ndarray,
gamma_0: float = 0.1,
omega: float = 1.0,
n_harmonics: int = 5,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract Fourier harmonics from LAOS response.
Parameters
----------
t : np.ndarray
Time array covering at least one full period
gamma_0 : float
Strain amplitude
omega : float
Angular frequency
n_harmonics : int, default 5
Number of odd harmonics to extract
Returns
-------
sigma_prime_n : np.ndarray
In-phase coefficients [σ'₁, σ'₃, σ'₅, ...]
sigma_double_prime_n : np.ndarray
Out-of-phase coefficients [σ''₁, σ''₃, σ''₅, ...]
"""
# Compute LAOS response
sigma = self._predict_laos(t, gamma_0=gamma_0, omega=omega)
# Extract harmonics
sigma_prime, sigma_double_prime = extract_laos_harmonics(
jnp.array(t),
jnp.array(sigma),
omega,
n_harmonics=n_harmonics,
)
return np.array(sigma_prime), np.array(sigma_double_prime)
[docs]
def model_function(
self,
X: np.ndarray,
params: np.ndarray,
test_mode: str = None,
**kwargs,
) -> np.ndarray:
"""Static model function for Bayesian inference.
NOTE: Bayesian inference is not yet supported for ITT-MCT models.
The Prony decomposition depends on parameters (v1, v2) and would need
to be recomputed for each MCMC sample, which is computationally prohibitive.
Use NLSQ fitting with bootstrap resampling for uncertainty quantification.
Parameters
----------
X : np.ndarray
Independent variable
params : np.ndarray
Parameter array [v1, v2, Gamma, gamma_c, G_inf] in parameter order
test_mode : str, optional
Protocol type
**kwargs
Additional protocol-specific parameters
Raises
------
NotImplementedError
Always raised - Bayesian inference not supported for ITT-MCT models
"""
raise NotImplementedError(
"Bayesian inference is not yet supported for ITT-MCT models. "
"The model requires Prony decomposition that depends on parameters "
"(v1, v2), making MCMC sampling computationally prohibitive. "
"Use NLSQ fitting with bootstrap resampling for uncertainty quantification."
)
[docs]
def precompile(
self,
test_mode: str = "relaxation",
X=None,
y=None,
**kwargs,
) -> float:
"""Pre-compile the diffrax ODE solver for fast subsequent calls.
Triggers JIT compilation with dummy data so the first real prediction
doesn't incur the compilation cost. Useful when predictable timing
is important (e.g., in interactive applications or benchmarks).
Returns
-------
float
Compilation time in seconds (0.0 if diffrax not available)
Examples
--------
>>> model = ITTMCTSchematic(epsilon=0.05)
>>> compile_time = model.precompile()
>>> print(f"Compilation took {compile_time:.1f}s")
>>> # Now flow curve predictions will be fast
>>> sigma = model.predict(gamma_dot, test_mode='flow_curve')
Notes
-----
First call to flow curve prediction triggers JIT compilation which
can take 30-90 seconds. This method triggers that compilation upfront.
Only affects diffrax-based flow curve solver. Other protocols
(oscillation, startup, etc.) use scipy and don't need precompilation.
"""
if not _HAS_DIFFRAX:
logger.warning("diffrax not available, precompilation skipped")
return 0.0
# Initialize Prony modes if needed
if self._prony_amplitudes is None:
self.initialize_prony_modes()
return precompile_flow_curve_solver(
n_modes=self.n_prony_modes,
use_lorentzian=self._use_lorentzian,
memory_form=self._memory_form,
)
def _check_prony_cache(self) -> None:
"""Invalidate Prony cache if physics parameters changed.
Tracks (v1, v2, Gamma) — when these change (e.g., during NLSQ),
the Prony decomposition must be recomputed since the equilibrium
correlator and memory kernel depend on these parameters.
"""
v1 = self.parameters.get_value("v1")
v2 = self.parameters.get_value("v2")
Gamma = self.parameters.get_value("Gamma")
current_hash = (v1, v2, Gamma)
if (
self._prony_param_hash is not None
and current_hash != self._prony_param_hash
):
self._prony_amplitudes = None
self._prony_times = None
self._prony_param_hash = current_hash
@property
def decorrelation_form(self) -> str:
"""Get the strain decorrelation function form."""
return self._decorrelation_form
@property
def memory_form(self) -> str:
"""Get the memory kernel form.
Returns
-------
str
"simplified" for single decorrelation m(Φ) = h[γ_acc] × (v₁Φ + v₂Φ²)
"full" for two-time decorrelation m(t,s,t₀) = h[γ(t,t₀)] × h[γ(t,s)] × (v₁Φ + v₂Φ²)
"""
return self._memory_form
@property
def stress_form(self) -> str:
"""Get the stress computation form.
Returns
-------
str
"schematic" for σ = G_∞ × γ̇ × ∫ Φ² × h(γ) dt
"microscopic" for σ = (k_BT/60π²) × ∫dk k⁴ [S'/S²]² Φ²
"""
return self._stress_form
[docs]
def __repr__(self) -> str:
"""Return string representation."""
info = self.get_glass_transition_info()
state = "glass" if info["is_glass"] else "fluid"
return (
f"ITTMCTSchematic("
f"ε={info['epsilon']:.3f} [{state}], "
f"v₂={self.parameters.get_value('v2'):.2f}, "
f"h(γ)={self._decorrelation_form}, "
f"m={self._memory_form}, "
f"σ={self._stress_form}, "
f"G_inf={self.parameters.get_value('G_inf'):.2e} Pa)"
)