"""Non-Local Fluidity Model Implementation.
This module implements the Non-Local (1D PDE, Coussot-Ovarlez) Fluidity model
for yield-stress fluids with spatial diffusion, supporting shear banding analysis.
"""
from __future__ import annotations
from typing import Any, cast
import numpy as np
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger, log_fit
from rheojax.models.fluidity._base import FluidityBase
from rheojax.models.fluidity._kernels import (
banding_ratio,
fluidity_nonlocal_creep_pde_rhs,
fluidity_nonlocal_pde_rhs,
fluidity_nonlocal_steady_state,
shear_banding_cv,
)
# Safe JAX import
jax, jnp = safe_import_jax()
# Logger
logger = get_logger(__name__)
# Sentinel for distinguishing "not provided" from falsy values (FL-009)
_MISSING = object()
# FL-006: kwargs to pop before forwarding to nlsq_optimize.
# Start from the central set and add model-specific extras so the two
# never drift apart (see _RHEOJAX_RESERVED_KWARGS in optimization.py).
from rheojax.utils.optimization import _RHEOJAX_RESERVED_KWARGS
_NLSQ_RESERVED = _RHEOJAX_RESERVED_KWARGS | {
"use_log_residuals",
"smart_init",
"use_multi_start",
"n_starts",
"perturb_factor",
"callback",
"sigma_0",
}
# Filter for ODE protocols — keeps "method" so the caller or the default
# "scipy" routing reaches nlsq_optimize (diffrax's custom_vjp adjoint is
# incompatible with NLSQ's jacfwd-based Jacobian).
_NLSQ_RESERVED_ODE = _NLSQ_RESERVED - {"method"}
[docs]
@ModelRegistry.register(
"fluidity_nonlocal",
protocols=[
Protocol.FLOW_CURVE,
Protocol.CREEP,
Protocol.RELAXATION,
Protocol.STARTUP,
Protocol.OSCILLATION,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class FluidityNonlocal(FluidityBase):
"""Non-Local (1D PDE) Fluidity Model for yield-stress fluids.
Implements the Coussot-Ovarlez non-local fluidity model where the
fluidity field f(y,t) evolves across the gap (y-direction) via:
∂f/∂t = (f_loc(σ) - f)/θ + ξ²∂²f/∂y²
where:
- f_loc(σ) is the local equilibrium fluidity from HB flow curve
- θ is the relaxation time
- ξ is the cooperativity length (non-local diffusion)
This captures shear banding: localized flow in yield-stress fluids
where the cooperativity length ξ determines band width.
Key features:
- 1D Couette gap discretization (N_y points)
- Neumann (zero-flux) boundary conditions at walls
- Diffrax Dopri5 solver (explicit, robust) for PDE
- Shear banding metrics: CV and max/min ratio
Attributes:
N_y: Number of grid points across gap
gap_width: Physical gap width (m)
"""
[docs]
def __init__(self, N_y: int = 64, gap_width: float = 1e-3):
"""Initialize Non-Local Fluidity Model.
Args:
N_y: Number of spatial grid points (default 64)
gap_width: Physical gap width in meters (default 1 mm)
"""
super().__init__()
# FL-011: Guard against N_y < 2 which causes ZeroDivisionError in dy
if N_y < 2:
raise ValueError(f"N_y must be >= 2 for spatial discretization, got {N_y}")
self.N_y = N_y
self.gap_width = gap_width
self.dy = gap_width / (N_y - 1)
# Add non-local specific parameter
self._add_nonlocal_parameters()
# Storage for fluidity field trajectory
self._f_field_trajectory: np.ndarray | None = None
def _add_nonlocal_parameters(self):
"""Add non-local specific parameters."""
# xi: Cooperativity length (m)
self.parameters.add(
name="xi",
value=1e-5,
bounds=(1e-9, 1e-3),
units="m",
description="Cooperativity length (non-local diffusion scale)",
)
# The nonlocal PDE RHS (fluidity_nonlocal_creep_pde_rhs /
# fluidity_nonlocal_pde_rhs) uses HB aging via f_loc(σ; tau_y, K, n_flow)
# and does NOT include the rejuvenation term a·|γ̇|^n·(f_inf - f) that
# the local ODE carries. Consequently a, n_rejuv, and f_inf never enter
# the nonlocal residual. ξ only matters when the f-field develops
# spatial variation; with a uniform initial condition and Neumann BCs
# the field stays uniform and ∇²f ≡ 0 for all the single-protocol
# benchmarks in examples/fluidity. The identifiability map below
# reflects this — it OVERRIDES FluidityBase._IDENTIFIABILITY.
_IDENTIFIABILITY = {
"flow_curve": {
# HB steady state σ = τ_y + K·γ̇^n — transient parameters inert.
"identifiable": ("tau_y", "K", "n_flow"),
"product_degenerate": (),
"inactive": ("G", "f_eq", "f_inf", "theta", "a", "n_rejuv", "xi"),
},
"startup": {
# Rate-controlled: elastic backbone G enters via dΣ/dt = G(γ̇ - Σf);
# f_loc uses HB params; θ sets relaxation. f_eq sets only initial
# f-field (decays in ~θ); rejuvenation terms absent; field stays
# uniform so ξ inert.
"identifiable": ("G", "tau_y", "K", "n_flow", "theta"),
"product_degenerate": (),
"inactive": ("f_eq", "f_inf", "a", "n_rejuv", "xi"),
},
"relaxation": {
# Stress decays via dΣ/dt = -G·Σ·f_avg with f relaxing toward
# f_loc(σ). All three HB params shape the late-time plateau; θ
# sets the decay rate.
"identifiable": ("G", "tau_y", "K", "n_flow", "theta"),
"product_degenerate": (),
"inactive": ("f_eq", "f_inf", "a", "n_rejuv", "xi"),
},
"creep": {
# dγ/dt = σ·f_avg. Constant σ ⇒ G drops out entirely (no Maxwell
# equation being integrated). f_eq only sets f(t0) which decays
# to f_loc in ~θ; rejuvenation absent; field uniform ⇒ ξ inert.
"identifiable": ("tau_y", "K", "n_flow", "theta"),
"product_degenerate": (),
"inactive": ("G", "f_eq", "f_inf", "a", "n_rejuv", "xi"),
},
"oscillation": {
# SAOS linearisation around f_eq gives the Maxwell-like response
# G'(ω), G''(ω) with τ_eff = 1/(G·f_eq). HB params enter through
# f_loc at the working point (σ ≈ 0 for SAOS ⇒ f_loc ≈ 0 ⇒
# tau_y/K/n_flow gated by softplus, weakly identifiable).
"identifiable": ("G", "f_eq", "theta"),
"product_degenerate": (),
"inactive": ("tau_y", "K", "n_flow", "f_inf", "a", "n_rejuv", "xi"),
},
"laos": {
# Large-amplitude oscillation excites the full nonlinear response:
# HB params and θ all active alongside G.
"identifiable": ("G", "tau_y", "K", "n_flow", "theta"),
"product_degenerate": (),
"inactive": ("f_eq", "f_inf", "a", "n_rejuv", "xi"),
},
}
def _fit(
self,
X: np.ndarray,
y: np.ndarray,
**kwargs,
) -> FluidityNonlocal:
"""Fit Non-Local Fluidity model to data.
Args:
X: Independent variable (time, frequency, or shear rate)
y: Dependent variable (stress, modulus, viscosity)
**kwargs: Optimizer options. Must include 'test_mode'.
"""
test_mode = kwargs.get("test_mode")
if test_mode is None:
if hasattr(self, "_test_mode") and self._test_mode is not None:
test_mode = self._test_mode
else:
raise ValueError("test_mode must be specified for Fluidity fitting")
# FL-001: Normalize aliases early so self._test_mode is canonical
if test_mode == "saos":
test_mode = "oscillation"
with log_fit(logger, model="FluidityNonlocal", data_shape=X.shape) as ctx:
self._test_mode = cast(str, test_mode)
ctx["test_mode"] = test_mode
ctx["N_y"] = self.N_y
if test_mode in ["steady_shear", "rotation", "flow_curve"]:
self._fit_flow_curve(X, y, **kwargs)
elif test_mode == "startup":
self._fit_transient(X, y, mode="startup", **kwargs)
elif test_mode == "relaxation":
self._fit_transient(X, y, mode="relaxation", **kwargs)
elif test_mode == "creep":
self._fit_transient(X, y, mode="creep", **kwargs)
elif test_mode == "oscillation":
self._fit_oscillation(X, y, **kwargs)
elif test_mode == "laos":
self._fit_laos(X, y, **kwargs)
else:
raise ValueError(f"Unsupported test_mode: {test_mode}")
self.fitted_ = True
return self
# =========================================================================
# Grid and Initial Conditions
# =========================================================================
def _get_grid_args(self, params: dict | None = None) -> dict:
"""Get grid-related arguments for PDE solver.
Args:
params: Optional parameter dictionary
Returns:
Dictionary with grid parameters
"""
if params is None:
params = self.get_parameter_dict()
return {
"N_y": self.N_y,
"dy": self.dy,
"xi": params.get("xi", 1e-5),
}
def _get_initial_f_field(
self, f_init: float | None = None, N_y: int | None = None
) -> jnp.ndarray:
"""Get initial fluidity field (uniform across gap).
Args:
f_init: Initial fluidity value. If None, uses f_eq.
N_y: Number of grid points override. If None, uses self.N_y.
Returns:
Fluidity field array, shape (N_y,)
"""
if f_init is None:
f_init = self.get_initial_fluidity()
n = N_y if N_y is not None else self.N_y
return jnp.ones(n) * f_init
def _get_initial_state(
self,
mode: str,
params: dict,
sigma_0: float | None = None,
N_y: int | None = None,
) -> jnp.ndarray:
"""Get initial state vector for PDE solver.
State vector: [Σ (or γ for creep), f[0], f[1], ..., f[N_y-1]]
Args:
mode: 'startup', 'relaxation', 'creep', or 'laos'
params: Parameter dictionary
sigma_0: Initial stress for relaxation
Returns:
Initial state vector
"""
f_eq = params["f_eq"]
f_inf = params["f_inf"]
if mode == "creep":
# State: [γ, f_field] - strain starts at 0
f_field = self._get_initial_f_field(f_eq, N_y=N_y)
return jnp.concatenate([jnp.array([0.0]), f_field])
elif mode == "relaxation":
# State: [Σ, f_field] - stress starts at sigma_0, f at f_inf
sigma_init = sigma_0 if sigma_0 is not None else params["tau_y"]
f_field = self._get_initial_f_field(f_inf, N_y=N_y) # Just flowed
return jnp.concatenate([jnp.array([sigma_init]), f_field])
else: # startup or laos
# State: [Σ, f_field] - stress at 0, f at f_eq
f_field = self._get_initial_f_field(f_eq, N_y=N_y)
return jnp.concatenate([jnp.array([0.0]), f_field])
# =========================================================================
# Flow Curve (Steady State)
# =========================================================================
def _fit_flow_curve(
self, gamma_dot: np.ndarray, stress: np.ndarray, **kwargs
) -> None:
"""Fit steady-state flow curve.
For homogeneous (non-banding) steady state, uses HB:
σ = τ_y + K*|γ̇|^n
Args:
gamma_dot: Shear rate array (1/s)
stress: Shear stress array (Pa)
**kwargs: Optimizer options
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64)
stress_jax = jnp.asarray(stress, dtype=jnp.float64)
def model_fn(x_data, params):
p_map = dict(zip(self.parameters.keys(), params, strict=True))
return fluidity_nonlocal_steady_state(
x_data,
p_map["G"],
p_map["tau_y"],
p_map["K"],
p_map["n_flow"],
p_map["f_eq"],
p_map["f_inf"],
p_map["theta"],
)
objective = create_least_squares_objective(
model_fn,
gamma_dot_jax,
stress_jax,
use_log_residuals=kwargs.get("use_log_residuals", True),
)
# FL-006: Pop protocol/meta kwargs before forwarding to nlsq_optimize
nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _NLSQ_RESERVED}
result = nlsq_optimize(objective, self.parameters, **nlsq_kwargs)
if not result.success:
logger.warning(f"Fluidity flow curve fit warning: {result.message}")
# FL-013: _predict_flow_curve is not used by _predict() or model_function()
# (flow curve routing goes through fluidity_nonlocal_steady_state directly).
# Kept as a thin compatibility wrapper for external callers.
def _predict_flow_curve(self, gamma_dot: np.ndarray) -> np.ndarray:
"""Predict steady-state flow curve (compatibility wrapper)."""
return np.array(self._predict(gamma_dot, test_mode="flow_curve"))
# =========================================================================
# Transient Protocols (Startup, Relaxation, Creep)
# =========================================================================
def _fit_transient(self, t: np.ndarray, y: np.ndarray, mode: str, **kwargs) -> None:
"""Fit transient response using PDE solver.
Args:
t: Time array (s)
y: Response data (stress for startup/relaxation, strain for creep)
mode: 'startup', 'relaxation', or 'creep'
**kwargs: Protocol-specific inputs and optimizer options
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
t_jax = jnp.asarray(t, dtype=jnp.float64)
# Preserve complex dtype for oscillation data (G* = G' + iG'')
y_arr = np.asarray(y)
if np.iscomplexobj(y_arr):
y_jax = jnp.asarray(y_arr, dtype=jnp.complex128)
else:
y_jax = jnp.asarray(y_arr, dtype=jnp.float64)
# Extract protocol-specific inputs
gamma_dot = kwargs.pop("gamma_dot", None)
sigma_applied = kwargs.pop("sigma_applied", None)
sigma_0 = kwargs.pop("sigma_0", None)
# FL-003: Use local variables for coarser grid instead of mutating self
# This avoids thread-safety issues where concurrent access could corrupt
# self.N_y and self.dy during fitting
fit_N_y = kwargs.pop("fit_N_y", min(self.N_y, 32))
fit_dy = self.gap_width / (fit_N_y - 1)
if mode == "startup" and gamma_dot is None:
raise ValueError("startup mode requires gamma_dot in kwargs")
if mode == "creep" and sigma_applied is None:
raise ValueError("creep mode requires sigma_applied in kwargs")
# Store for prediction
self._gamma_dot_applied = gamma_dot
self._sigma_applied = sigma_applied
self._sigma_0 = sigma_0
def model_fn(x_data, params):
p_map = dict(zip(self.parameters.keys(), params, strict=True))
return self._simulate_pde(
x_data,
p_map,
mode,
gamma_dot,
sigma_applied,
sigma_0,
N_y=fit_N_y,
dy=fit_dy,
)
# See FluidityLocal._fit_transient for the rationale: relative
# residuals (normalize=True) blow up at the zero crossings /
# zero starting points of creep and startup data.
objective = create_least_squares_objective(
model_fn,
t_jax,
y_jax,
normalize=False,
use_log_residuals=kwargs.get("use_log_residuals", False),
)
# Keep "method" so it reaches nlsq_optimize. Transient protocols use
# a diffrax ODE (custom_vjp) — default to scipy if caller didn't pick.
nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _NLSQ_RESERVED_ODE}
nlsq_kwargs.setdefault("method", "scipy")
result = nlsq_optimize(objective, self.parameters, **nlsq_kwargs)
if not result.success:
logger.warning(f"Fluidity transient fit warning: {result.message}")
def _simulate_pde(
self,
t: jnp.ndarray,
params: dict,
mode: str,
gamma_dot: float | None,
sigma_applied: float | None,
sigma_0: float | None,
N_y: int | None = None,
dy: float | None = None,
) -> jnp.ndarray:
"""Simulate PDE response using Diffrax.
Args:
t: Time array
params: Parameter dictionary
mode: 'startup', 'relaxation', or 'creep'
gamma_dot: Applied shear rate (for startup)
sigma_applied: Applied stress (for creep)
sigma_0: Initial stress (for relaxation)
N_y: Grid points override (FL-003 thread safety). If None, uses self.N_y.
dy: Grid spacing override (FL-003 thread safety). If None, uses self.dy.
Returns:
Primary output (stress for startup/relaxation, strain for creep)
"""
# FL-003: Use local variables instead of self.N_y/self.dy for thread safety
N_y_local = N_y if N_y is not None else self.N_y
dy_local = dy if dy is not None else self.dy
# Build args for PDE RHS
# FL-012: Removed dead "N_y" key — PDE kernels infer N_y from state vector shape
args = {
"G": params["G"],
"tau_y": params["tau_y"],
"K": params["K"],
"n_flow": params["n_flow"],
"theta": params["theta"],
"xi": params.get("xi", 1e-5),
"dy": dy_local,
}
# Mode-specific setup
if mode == "creep":
pde_fn = fluidity_nonlocal_creep_pde_rhs
args["sigma_applied"] = sigma_applied if sigma_applied is not None else 0.0
else:
pde_fn = fluidity_nonlocal_pde_rhs
if mode == "startup":
args["mode"] = 0 # rate_controlled
args["gamma_dot"] = gamma_dot if gamma_dot is not None else 0.0
else: # relaxation
args["mode"] = 0 # rate_controlled
args["gamma_dot"] = 0.0
# Initial state (uses N_y_local for grid size)
y0 = self._get_initial_state(mode, params, sigma_0, N_y=N_y_local)
# Diffrax setup - use Dopri5 for stiff PDEs (explicit, avoids tracer issues)
term = diffrax.ODETerm(
jax.checkpoint(lambda ti, yi, args_i: pde_fn(cast(float, ti), yi, args_i))
)
solver = diffrax.Dopri5()
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-7)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=10_000_000,
throw=False, # Return partial result on failure (for optimization)
)
# Store trajectory for analysis (skip during JAX tracing, e.g. NUTS)
# FL-007: Log exceptions instead of silently swallowing
try:
self._f_field_trajectory = np.array(sol.ys[:, 1:])
except Exception as e:
logger.warning("Could not store fluidity field trajectory: %s", e)
# Extract primary variable (index 0)
# For creep: strain; for startup/relaxation: stress
result = sol.ys[:, 0]
# Handle solver failure by returning NaN (optimization will avoid this)
result = jnp.where(sol.result == diffrax.RESULTS.successful, result, jnp.nan)
return result
def _predict_transient(
self,
t: np.ndarray,
mode: str | None = None,
sigma_0: float | None = None,
gamma_dot: Any = _MISSING,
sigma_applied: Any = _MISSING,
) -> np.ndarray:
"""Predict transient response.
Protocol inputs (``gamma_dot`` for startup, ``sigma_applied`` for
creep, ``sigma_0`` for relaxation) are read from keyword arguments
when supplied so ``predict()`` works without a prior ``fit()``.
Any argument left as ``_MISSING`` falls back to the instance
attribute populated by ``_fit_*`` (legacy path).
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
p = self.get_parameter_dict()
mode = mode if mode is not None else self._test_mode
if mode is None:
raise ValueError("Test mode not specified for prediction")
if gamma_dot is _MISSING:
gamma_dot = getattr(self, "_gamma_dot_applied", None)
if sigma_applied is _MISSING:
sigma_applied = getattr(self, "_sigma_applied", None)
if sigma_0 is None:
sigma_0 = getattr(self, "_sigma_0", None)
result = self._simulate_pde(
t_jax,
p,
mode,
gamma_dot,
sigma_applied,
sigma_0,
)
return np.array(result)
# =========================================================================
# Shear Banding Analysis
# =========================================================================
[docs]
def get_fluidity_profile(self, time_idx: int = -1) -> np.ndarray:
"""Get fluidity profile at specified time index.
Args:
time_idx: Time index (default -1 for final time)
Returns:
Fluidity field across gap, shape (N_y,)
"""
if self._f_field_trajectory is None:
raise ValueError("No trajectory available. Run simulation first.")
return self._f_field_trajectory[time_idx]
[docs]
def get_shear_banding_metric(self, f_field: np.ndarray | None = None) -> float:
"""Compute coefficient of variation as shear banding metric.
CV = std(f) / mean(f)
CV > 0.3 typically indicates significant shear banding.
Args:
f_field: Fluidity field. If None, uses final simulation state.
Returns:
Coefficient of variation (dimensionless)
"""
if f_field is None:
f_field = self.get_fluidity_profile(-1)
f_jax = jnp.asarray(f_field, dtype=jnp.float64)
return float(shear_banding_cv(f_jax))
[docs]
def get_banding_ratio(self, f_field: np.ndarray | None = None) -> float:
"""Compute max/min fluidity ratio as banding metric.
ratio > 10 indicates strong localization.
Args:
f_field: Fluidity field. If None, uses final simulation state.
Returns:
Banding ratio (dimensionless)
"""
if f_field is None:
f_field = self.get_fluidity_profile(-1)
f_jax = jnp.asarray(f_field, dtype=jnp.float64)
return float(banding_ratio(f_jax))
[docs]
def is_banding(
self, f_field: np.ndarray | None = None, cv_threshold: float = 0.3
) -> bool:
"""Check if shear banding is occurring.
Args:
f_field: Fluidity field. If None, uses final simulation state.
cv_threshold: CV threshold for banding (default 0.3)
Returns:
True if CV > threshold
"""
return self.get_shear_banding_metric(f_field) > cv_threshold
# =========================================================================
# Oscillatory Protocols
# =========================================================================
def _fit_oscillation(self, X: np.ndarray, y: np.ndarray, **kwargs) -> None:
"""Fit SAOS data using linear viscoelastic approximation.
For small amplitude, bulk response approximates Local model. Only
G and f_eq affect the Maxwell-limit residual; optimizing the full
parameter set produces a rank-2 Jacobian and causes NLSQ to
terminate prematurely on xtol. We fit the reduced (G, f_eq) set.
"""
from rheojax.core.parameters import ParameterSet
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
omega_jax = jnp.asarray(X, dtype=jnp.float64)
# Handle G_star format
G_star_np = np.asarray(y)
if np.iscomplexobj(G_star_np):
G_prime_np = np.real(G_star_np)
G_dp_np = np.imag(G_star_np)
G_star_2d = np.column_stack([G_prime_np, G_dp_np])
elif G_star_np.ndim == 2 and G_star_np.shape[1] == 2:
G_prime_np = G_star_np[:, 0]
G_dp_np = G_star_np[:, 1]
G_star_2d = G_star_np
else:
raise ValueError(f"G_star must be complex or (M, 2), got {G_star_np.shape}")
G_star_jax = jnp.asarray(G_star_2d, dtype=jnp.float64)
# Data-driven warm-start for G and f_eq.
self._seed_saos_from_data(np.asarray(X, dtype=float), G_prime_np, G_dp_np)
reduced = ParameterSet()
for name in ("G", "f_eq"):
src = self.parameters[name]
reduced.add(
name=name,
value=src.value,
bounds=src.bounds,
units=src.units,
description=src.description,
)
def model_fn(x_data, params):
p_map = dict(zip(reduced.keys(), params, strict=True))
return self._predict_saos_jit(
x_data,
p_map["G"],
p_map["f_eq"],
)
objective = create_least_squares_objective(
model_fn,
omega_jax,
G_star_jax,
normalize=True,
)
# FL-006: Pop protocol/meta kwargs before forwarding to nlsq_optimize
nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _NLSQ_RESERVED}
result = nlsq_optimize(objective, reduced, **nlsq_kwargs)
if not result.success:
logger.warning(f"Fluidity SAOS fit warning: {result.message}")
G_fit = reduced["G"].value
f_eq_fit = reduced["f_eq"].value
if G_fit is None or f_eq_fit is None:
raise RuntimeError("NLSQ returned no value for SAOS parameters")
self.parameters.set_value("G", float(G_fit))
self.parameters.set_value("f_eq", float(f_eq_fit))
def _seed_saos_from_data(
self,
omega: np.ndarray,
G_prime: np.ndarray,
G_double_prime: np.ndarray,
) -> None:
"""Seed G and f_eq from SAOS data so NLSQ starts near the minimum.
G ← high-ω G' plateau; tau_eff ← crossover ω* (else location of
G'' peak); f_eq ← 1/(G·tau_eff). All seeds clipped to bounds.
"""
omega = np.asarray(omega, dtype=float)
Gp = np.asarray(G_prime, dtype=float)
Gpp = np.asarray(G_double_prime, dtype=float)
if omega.size < 2:
return
order = np.argsort(omega)
Gp_sorted = Gp[order]
n_top = max(1, len(Gp_sorted) // 5)
G_seed = float(np.max(Gp_sorted[-n_top:]))
with np.errstate(divide="ignore", invalid="ignore"):
diff = np.log(np.maximum(Gp, 1e-300)) - np.log(np.maximum(Gpp, 1e-300))
sign_change = np.where(np.diff(np.sign(diff[order])) != 0)[0]
if sign_change.size > 0:
i = int(sign_change[0])
d0, d1 = diff[order][i], diff[order][i + 1]
w0, w1 = np.log(omega[order][i]), np.log(omega[order][i + 1])
if d1 != d0:
w_cross = w0 + (0.0 - d0) * (w1 - w0) / (d1 - d0)
else:
w_cross = 0.5 * (w0 + w1)
tau_seed = float(np.exp(-w_cross))
else:
i_peak = int(np.argmax(Gpp[order]))
tau_seed = 1.0 / float(omega[order][i_peak])
f_eq_seed = 1.0 / max(G_seed * tau_seed, 1e-30)
def _clipped(name: str, value: float) -> float:
param = self.parameters[name]
lo, hi = param.bounds if param.bounds else (-np.inf, np.inf)
lo_v = lo if lo is not None else -np.inf
hi_v = hi if hi is not None else np.inf
return float(np.clip(value, lo_v, hi_v))
self.parameters.set_value("G", _clipped("G", G_seed))
self.parameters.set_value("f_eq", _clipped("f_eq", f_eq_seed))
# TODO (FL-010): _predict_saos_jit is duplicated in FluidityLocal.
# Consider extracting to a shared module-level function or into _base.py.
@staticmethod
@jax.jit
def _predict_saos_jit(
omega: jnp.ndarray,
G: float,
f_eq: float,
theta: float = 0.0, # FL-005: dead parameter, kept for backward compatibility
) -> jnp.ndarray:
"""SAOS prediction using linear viscoelastic approximation.
Note:
theta parameter is unused (FL-005) but kept for backward
compatibility with external callers.
"""
del theta # FL-005: explicitly unused
tau_eff = 1.0 / (G * f_eq + 1e-30)
omega_tau = omega * tau_eff
denom = 1.0 + omega_tau**2
G_prime = G * omega_tau**2 / denom
G_double_prime = G * omega_tau / denom
return jnp.stack([G_prime, G_double_prime], axis=1)
def _fit_laos(self, t: np.ndarray, sigma: np.ndarray, **kwargs) -> None:
"""Fit LAOS data using full PDE integration."""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
gamma_0 = kwargs.pop("gamma_0", None)
omega = kwargs.pop("omega", None)
f_init = kwargs.pop("f_init", None)
if gamma_0 is None or omega is None:
raise ValueError("LAOS fitting requires gamma_0 and omega")
self._gamma_0 = gamma_0
self._omega_laos = omega
self._laos_f_init = f_init
# FL-003: Use local variables for coarser grid instead of mutating self
fit_N_y = kwargs.pop("fit_N_y", min(self.N_y, 32))
fit_dy = self.gap_width / (fit_N_y - 1)
t_jax = jnp.asarray(t, dtype=jnp.float64)
sigma_jax = jnp.asarray(sigma, dtype=jnp.float64)
def model_fn(x_data, params):
p_map = dict(zip(self.parameters.keys(), params, strict=True))
_, stress = self._simulate_laos_internal(
x_data,
p_map,
gamma_0,
omega,
N_y=fit_N_y,
dy=fit_dy,
f_init=f_init,
)
return stress
# See FluidityLocal._fit_laos for the rationale: LAOS stress
# crosses zero, so relative residuals (normalize=True) blow up
# at the zero crossings and pull the optimizer away from the
# true parameters.
objective = create_least_squares_objective(
model_fn,
t_jax,
sigma_jax,
normalize=False,
)
# Keep "method" so it reaches nlsq_optimize. LAOS uses a diffrax ODE
# (custom_vjp) — default to scipy if caller didn't pick a method.
nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _NLSQ_RESERVED_ODE}
nlsq_kwargs.setdefault("method", "scipy")
result = nlsq_optimize(objective, self.parameters, **nlsq_kwargs)
if not result.success:
logger.warning(f"Fluidity LAOS fit warning: {result.message}")
def _simulate_laos_internal(
self,
t: jnp.ndarray,
params: dict,
gamma_0: float,
omega: float,
N_y: int | None = None,
dy: float | None = None,
f_init: np.ndarray | None = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Simulate LAOS response using PDE solver.
Args:
t: Time array
params: Parameter dictionary
gamma_0: Strain amplitude
omega: Angular frequency
N_y: Grid points override (FL-003 thread safety). If None, uses self.N_y.
dy: Grid spacing override (FL-003 thread safety). If None, uses self.dy.
"""
# FL-003: Use local variables instead of self.N_y/self.dy for thread safety
N_y_local = N_y if N_y is not None else self.N_y
dy_local = dy if dy is not None else self.dy
# Base args
# FL-012: Removed dead "N_y" key — PDE kernels infer N_y from state vector shape
base_args = {
"G": params["G"],
"tau_y": params["tau_y"],
"K": params["K"],
"n_flow": params["n_flow"],
"theta": params["theta"],
"xi": params.get("xi", 1e-5),
"dy": dy_local,
"mode": 0, # rate_controlled
}
# Initial state: use provided profile to seed spatial gradients,
# or fall back to uniform f_eq (gives ∇²f≡0, identical to local model).
if f_init is not None:
f_field_init = jnp.asarray(f_init, dtype=jnp.float64)
if len(f_field_init) != N_y_local:
x_src = np.linspace(0, 1, len(f_field_init))
x_dst = np.linspace(0, 1, N_y_local)
f_field_init = jnp.asarray(
np.interp(x_dst, x_src, np.asarray(f_field_init))
)
y0 = jnp.concatenate(
[jnp.array([0.0]), jnp.maximum(f_field_init, 1e-20)]
)
else:
y0 = self._get_initial_state("laos", params, N_y=N_y_local)
# PDE with time-varying gamma_dot
def laos_pde(ti, yi, args_i):
gamma_dot_t = gamma_0 * omega * jnp.cos(omega * ti)
args_with_rate = {**args_i, "gamma_dot": gamma_dot_t}
return fluidity_nonlocal_pde_rhs(ti, yi, args_with_rate)
term = diffrax.ODETerm(jax.checkpoint(laos_pde))
solver = diffrax.Dopri5()
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-7)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=base_args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=16_000_000,
throw=False, # Return partial result on failure (for optimization)
)
stress = sol.ys[:, 0]
strain = gamma_0 * jnp.sin(omega * t)
# Handle solver failure by returning NaN
stress = jnp.where(sol.result == diffrax.RESULTS.successful, stress, jnp.nan)
# Store trajectory only when not in JIT context (concrete arrays)
# FL-008: Use ConcretizationTypeError (modern) instead of deprecated
# TracerArrayConversionError
try:
# This will fail during JIT tracing
self._f_field_trajectory = np.asarray(sol.ys[:, 1:])
except (TypeError, jax.errors.ConcretizationTypeError):
# During JIT tracing, skip storage
pass
return strain, stress
[docs]
def simulate_laos(
self,
gamma_0: float,
omega: float,
n_cycles: int = 2,
n_points_per_cycle: int = 256,
f_init: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Simulate LAOS response.
Args:
gamma_0: Strain amplitude
omega: Angular frequency (rad/s)
n_cycles: Number of oscillation cycles
n_points_per_cycle: Points per cycle
Returns:
(strain, stress) arrays
"""
self._gamma_0 = gamma_0
self._omega_laos = omega
period = 2.0 * np.pi / omega
t_max = n_cycles * period
n_points = n_cycles * n_points_per_cycle
t = np.linspace(0, t_max, n_points, endpoint=False)
t_jax = jnp.asarray(t, dtype=jnp.float64)
p = self.get_parameter_dict()
strain, stress = self._simulate_laos_internal(
t_jax, p, gamma_0, omega, f_init=f_init
)
return np.array(strain), np.array(stress)
# =========================================================================
# Bayesian / Model Function Interface
# =========================================================================
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""NumPyro/BayesianMixin model function.
Accepts protocol-specific kwargs (gamma_dot, sigma_applied, etc.).
"""
p_values = dict(zip(self.parameters.keys(), params, strict=True))
mode = test_mode if test_mode is not None else self._test_mode
if mode is None:
mode = "oscillation"
# FL-001: Normalize aliases
if mode == "saos":
mode = "oscillation"
X_jax = jnp.asarray(X, dtype=jnp.float64)
# FL-009: Use sentinel pattern to avoid swallowing falsy values (e.g. 0.0)
gamma_dot = kwargs.get("gamma_dot", _MISSING)
if gamma_dot is _MISSING:
gamma_dot = getattr(self, "_gamma_dot_applied", None)
sigma_applied = kwargs.get("sigma_applied", _MISSING)
if sigma_applied is _MISSING:
sigma_applied = getattr(self, "_sigma_applied", None)
gamma_0 = kwargs.get("gamma_0", _MISSING)
if gamma_0 is _MISSING:
gamma_0 = getattr(self, "_gamma_0", None)
omega = kwargs.get("omega", _MISSING)
if omega is _MISSING:
omega = getattr(self, "_omega_laos", None)
if mode in ["steady_shear", "rotation", "flow_curve"]:
return fluidity_nonlocal_steady_state(
X_jax,
p_values["G"],
p_values["tau_y"],
p_values["K"],
p_values["n_flow"],
p_values["f_eq"],
p_values["f_inf"],
p_values["theta"],
)
elif mode == "oscillation":
return self._predict_saos_jit(
X_jax,
p_values["G"],
p_values["f_eq"],
)
elif mode in ["startup", "relaxation", "creep"]:
return self._simulate_pde(
X_jax,
p_values,
mode,
gamma_dot,
sigma_applied,
None,
)
elif mode == "laos":
if gamma_0 is None or omega is None:
raise ValueError("LAOS mode requires gamma_0 and omega")
f_init = kwargs.get("f_init", getattr(self, "_laos_f_init", None))
_, stress = self._simulate_laos_internal(
X_jax, p_values, gamma_0, omega, f_init=f_init
)
return stress
return jnp.zeros_like(X_jax)
# =========================================================================
# Prediction Interface
# =========================================================================
def _predict(self, X: np.ndarray, **kwargs: Any) -> np.ndarray:
"""Predict based on fitted state."""
X_jax = jnp.asarray(X, dtype=jnp.float64)
p = self.get_parameter_dict()
# Get test_mode from kwargs or instance attribute
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None)
)
if test_mode is None:
raise ValueError("test_mode must be specified for prediction")
# FL-001: Normalize aliases
if test_mode == "saos":
test_mode = "oscillation"
if test_mode in ["steady_shear", "rotation", "flow_curve"]:
result = fluidity_nonlocal_steady_state(
X_jax,
p["G"],
p["tau_y"],
p["K"],
p["n_flow"],
p["f_eq"],
p["f_inf"],
p["theta"],
)
return np.array(result)
elif test_mode == "oscillation":
result = self._predict_saos_jit(
X_jax,
p["G"],
p["f_eq"],
)
# Convert (N,2) [G', G''] to complex G* for consistent API
result = np.array(result)
return result[:, 0] + 1j * result[:, 1]
elif test_mode in ["startup", "relaxation", "creep"]:
return self._predict_transient(
X,
mode=test_mode,
sigma_0=kwargs.get("sigma_0"),
gamma_dot=kwargs.get("gamma_dot", _MISSING),
sigma_applied=kwargs.get("sigma_applied", _MISSING),
)
elif test_mode == "laos":
# Get gamma_0 and omega from kwargs or instance attributes
gamma_0 = kwargs.get("gamma_0", self._gamma_0)
omega = kwargs.get("omega", self._omega_laos)
f_init = kwargs.get("f_init", getattr(self, "_laos_f_init", None))
if gamma_0 is None or omega is None:
raise ValueError("LAOS prediction requires gamma_0 and omega")
_, stress = self._simulate_laos_internal(
X_jax, p, gamma_0, omega, f_init=f_init
)
return np.array(stress)
return np.zeros_like(X)