Source code for rheojax.models.fluidity.nonlocal_model

"""Non-Local Fluidity Model Implementation.

This module implements the Non-Local (1D PDE, Coussot-Ovarlez) Fluidity model
for yield-stress fluids with spatial diffusion, supporting shear banding analysis.
"""

from __future__ import annotations

from typing import Any, cast

import numpy as np

from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax

diffrax = lazy_import("diffrax")
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger, log_fit
from rheojax.models.fluidity._base import FluidityBase
from rheojax.models.fluidity._kernels import (
    banding_ratio,
    fluidity_nonlocal_creep_pde_rhs,
    fluidity_nonlocal_pde_rhs,
    fluidity_nonlocal_steady_state,
    shear_banding_cv,
)

# Safe JAX import
jax, jnp = safe_import_jax()

# Logger
logger = get_logger(__name__)

# Sentinel for distinguishing "not provided" from falsy values (FL-009)
_MISSING = object()

# FL-006: kwargs to pop before forwarding to nlsq_optimize.
# Start from the central set and add model-specific extras so the two
# never drift apart (see _RHEOJAX_RESERVED_KWARGS in optimization.py).
from rheojax.utils.optimization import _RHEOJAX_RESERVED_KWARGS

_NLSQ_RESERVED = _RHEOJAX_RESERVED_KWARGS | {
    "use_log_residuals",
    "smart_init",
    "use_multi_start",
    "n_starts",
    "perturb_factor",
    "callback",
    "sigma_0",
}

# Filter for ODE protocols — keeps "method" so the caller or the default
# "scipy" routing reaches nlsq_optimize (diffrax's custom_vjp adjoint is
# incompatible with NLSQ's jacfwd-based Jacobian).
_NLSQ_RESERVED_ODE = _NLSQ_RESERVED - {"method"}


[docs] @ModelRegistry.register( "fluidity_nonlocal", protocols=[ Protocol.FLOW_CURVE, Protocol.CREEP, Protocol.RELAXATION, Protocol.STARTUP, Protocol.OSCILLATION, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class FluidityNonlocal(FluidityBase): """Non-Local (1D PDE) Fluidity Model for yield-stress fluids. Implements the Coussot-Ovarlez non-local fluidity model where the fluidity field f(y,t) evolves across the gap (y-direction) via: ∂f/∂t = (f_loc(σ) - f)/θ + ξ²∂²f/∂y² where: - f_loc(σ) is the local equilibrium fluidity from HB flow curve - θ is the relaxation time - ξ is the cooperativity length (non-local diffusion) This captures shear banding: localized flow in yield-stress fluids where the cooperativity length ξ determines band width. Key features: - 1D Couette gap discretization (N_y points) - Neumann (zero-flux) boundary conditions at walls - Diffrax Dopri5 solver (explicit, robust) for PDE - Shear banding metrics: CV and max/min ratio Attributes: N_y: Number of grid points across gap gap_width: Physical gap width (m) """
[docs] def __init__(self, N_y: int = 64, gap_width: float = 1e-3): """Initialize Non-Local Fluidity Model. Args: N_y: Number of spatial grid points (default 64) gap_width: Physical gap width in meters (default 1 mm) """ super().__init__() # FL-011: Guard against N_y < 2 which causes ZeroDivisionError in dy if N_y < 2: raise ValueError(f"N_y must be >= 2 for spatial discretization, got {N_y}") self.N_y = N_y self.gap_width = gap_width self.dy = gap_width / (N_y - 1) # Add non-local specific parameter self._add_nonlocal_parameters() # Storage for fluidity field trajectory self._f_field_trajectory: np.ndarray | None = None
def _add_nonlocal_parameters(self): """Add non-local specific parameters.""" # xi: Cooperativity length (m) self.parameters.add( name="xi", value=1e-5, bounds=(1e-9, 1e-3), units="m", description="Cooperativity length (non-local diffusion scale)", ) # The nonlocal PDE RHS (fluidity_nonlocal_creep_pde_rhs / # fluidity_nonlocal_pde_rhs) uses HB aging via f_loc(σ; tau_y, K, n_flow) # and does NOT include the rejuvenation term a·|γ̇|^n·(f_inf - f) that # the local ODE carries. Consequently a, n_rejuv, and f_inf never enter # the nonlocal residual. ξ only matters when the f-field develops # spatial variation; with a uniform initial condition and Neumann BCs # the field stays uniform and ∇²f ≡ 0 for all the single-protocol # benchmarks in examples/fluidity. The identifiability map below # reflects this — it OVERRIDES FluidityBase._IDENTIFIABILITY. _IDENTIFIABILITY = { "flow_curve": { # HB steady state σ = τ_y + K·γ̇^n — transient parameters inert. "identifiable": ("tau_y", "K", "n_flow"), "product_degenerate": (), "inactive": ("G", "f_eq", "f_inf", "theta", "a", "n_rejuv", "xi"), }, "startup": { # Rate-controlled: elastic backbone G enters via dΣ/dt = G(γ̇ - Σf); # f_loc uses HB params; θ sets relaxation. f_eq sets only initial # f-field (decays in ~θ); rejuvenation terms absent; field stays # uniform so ξ inert. "identifiable": ("G", "tau_y", "K", "n_flow", "theta"), "product_degenerate": (), "inactive": ("f_eq", "f_inf", "a", "n_rejuv", "xi"), }, "relaxation": { # Stress decays via dΣ/dt = -G·Σ·f_avg with f relaxing toward # f_loc(σ). All three HB params shape the late-time plateau; θ # sets the decay rate. "identifiable": ("G", "tau_y", "K", "n_flow", "theta"), "product_degenerate": (), "inactive": ("f_eq", "f_inf", "a", "n_rejuv", "xi"), }, "creep": { # dγ/dt = σ·f_avg. Constant σ ⇒ G drops out entirely (no Maxwell # equation being integrated). f_eq only sets f(t0) which decays # to f_loc in ~θ; rejuvenation absent; field uniform ⇒ ξ inert. "identifiable": ("tau_y", "K", "n_flow", "theta"), "product_degenerate": (), "inactive": ("G", "f_eq", "f_inf", "a", "n_rejuv", "xi"), }, "oscillation": { # SAOS linearisation around f_eq gives the Maxwell-like response # G'(ω), G''(ω) with τ_eff = 1/(G·f_eq). HB params enter through # f_loc at the working point (σ ≈ 0 for SAOS ⇒ f_loc ≈ 0 ⇒ # tau_y/K/n_flow gated by softplus, weakly identifiable). "identifiable": ("G", "f_eq", "theta"), "product_degenerate": (), "inactive": ("tau_y", "K", "n_flow", "f_inf", "a", "n_rejuv", "xi"), }, "laos": { # Large-amplitude oscillation excites the full nonlinear response: # HB params and θ all active alongside G. "identifiable": ("G", "tau_y", "K", "n_flow", "theta"), "product_degenerate": (), "inactive": ("f_eq", "f_inf", "a", "n_rejuv", "xi"), }, } def _fit( self, X: np.ndarray, y: np.ndarray, **kwargs, ) -> FluidityNonlocal: """Fit Non-Local Fluidity model to data. Args: X: Independent variable (time, frequency, or shear rate) y: Dependent variable (stress, modulus, viscosity) **kwargs: Optimizer options. Must include 'test_mode'. """ test_mode = kwargs.get("test_mode") if test_mode is None: if hasattr(self, "_test_mode") and self._test_mode is not None: test_mode = self._test_mode else: raise ValueError("test_mode must be specified for Fluidity fitting") # FL-001: Normalize aliases early so self._test_mode is canonical if test_mode == "saos": test_mode = "oscillation" with log_fit(logger, model="FluidityNonlocal", data_shape=X.shape) as ctx: self._test_mode = cast(str, test_mode) ctx["test_mode"] = test_mode ctx["N_y"] = self.N_y if test_mode in ["steady_shear", "rotation", "flow_curve"]: self._fit_flow_curve(X, y, **kwargs) elif test_mode == "startup": self._fit_transient(X, y, mode="startup", **kwargs) elif test_mode == "relaxation": self._fit_transient(X, y, mode="relaxation", **kwargs) elif test_mode == "creep": self._fit_transient(X, y, mode="creep", **kwargs) elif test_mode == "oscillation": self._fit_oscillation(X, y, **kwargs) elif test_mode == "laos": self._fit_laos(X, y, **kwargs) else: raise ValueError(f"Unsupported test_mode: {test_mode}") self.fitted_ = True return self # ========================================================================= # Grid and Initial Conditions # ========================================================================= def _get_grid_args(self, params: dict | None = None) -> dict: """Get grid-related arguments for PDE solver. Args: params: Optional parameter dictionary Returns: Dictionary with grid parameters """ if params is None: params = self.get_parameter_dict() return { "N_y": self.N_y, "dy": self.dy, "xi": params.get("xi", 1e-5), } def _get_initial_f_field( self, f_init: float | None = None, N_y: int | None = None ) -> jnp.ndarray: """Get initial fluidity field (uniform across gap). Args: f_init: Initial fluidity value. If None, uses f_eq. N_y: Number of grid points override. If None, uses self.N_y. Returns: Fluidity field array, shape (N_y,) """ if f_init is None: f_init = self.get_initial_fluidity() n = N_y if N_y is not None else self.N_y return jnp.ones(n) * f_init def _get_initial_state( self, mode: str, params: dict, sigma_0: float | None = None, N_y: int | None = None, ) -> jnp.ndarray: """Get initial state vector for PDE solver. State vector: [Σ (or γ for creep), f[0], f[1], ..., f[N_y-1]] Args: mode: 'startup', 'relaxation', 'creep', or 'laos' params: Parameter dictionary sigma_0: Initial stress for relaxation Returns: Initial state vector """ f_eq = params["f_eq"] f_inf = params["f_inf"] if mode == "creep": # State: [γ, f_field] - strain starts at 0 f_field = self._get_initial_f_field(f_eq, N_y=N_y) return jnp.concatenate([jnp.array([0.0]), f_field]) elif mode == "relaxation": # State: [Σ, f_field] - stress starts at sigma_0, f at f_inf sigma_init = sigma_0 if sigma_0 is not None else params["tau_y"] f_field = self._get_initial_f_field(f_inf, N_y=N_y) # Just flowed return jnp.concatenate([jnp.array([sigma_init]), f_field]) else: # startup or laos # State: [Σ, f_field] - stress at 0, f at f_eq f_field = self._get_initial_f_field(f_eq, N_y=N_y) return jnp.concatenate([jnp.array([0.0]), f_field]) # ========================================================================= # Flow Curve (Steady State) # ========================================================================= def _fit_flow_curve( self, gamma_dot: np.ndarray, stress: np.ndarray, **kwargs ) -> None: """Fit steady-state flow curve. For homogeneous (non-banding) steady state, uses HB: σ = τ_y + K*|γ̇|^n Args: gamma_dot: Shear rate array (1/s) stress: Shear stress array (Pa) **kwargs: Optimizer options """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64) stress_jax = jnp.asarray(stress, dtype=jnp.float64) def model_fn(x_data, params): p_map = dict(zip(self.parameters.keys(), params, strict=True)) return fluidity_nonlocal_steady_state( x_data, p_map["G"], p_map["tau_y"], p_map["K"], p_map["n_flow"], p_map["f_eq"], p_map["f_inf"], p_map["theta"], ) objective = create_least_squares_objective( model_fn, gamma_dot_jax, stress_jax, use_log_residuals=kwargs.get("use_log_residuals", True), ) # FL-006: Pop protocol/meta kwargs before forwarding to nlsq_optimize nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _NLSQ_RESERVED} result = nlsq_optimize(objective, self.parameters, **nlsq_kwargs) if not result.success: logger.warning(f"Fluidity flow curve fit warning: {result.message}") # FL-013: _predict_flow_curve is not used by _predict() or model_function() # (flow curve routing goes through fluidity_nonlocal_steady_state directly). # Kept as a thin compatibility wrapper for external callers. def _predict_flow_curve(self, gamma_dot: np.ndarray) -> np.ndarray: """Predict steady-state flow curve (compatibility wrapper).""" return np.array(self._predict(gamma_dot, test_mode="flow_curve")) # ========================================================================= # Transient Protocols (Startup, Relaxation, Creep) # ========================================================================= def _fit_transient(self, t: np.ndarray, y: np.ndarray, mode: str, **kwargs) -> None: """Fit transient response using PDE solver. Args: t: Time array (s) y: Response data (stress for startup/relaxation, strain for creep) mode: 'startup', 'relaxation', or 'creep' **kwargs: Protocol-specific inputs and optimizer options """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) t_jax = jnp.asarray(t, dtype=jnp.float64) # Preserve complex dtype for oscillation data (G* = G' + iG'') y_arr = np.asarray(y) if np.iscomplexobj(y_arr): y_jax = jnp.asarray(y_arr, dtype=jnp.complex128) else: y_jax = jnp.asarray(y_arr, dtype=jnp.float64) # Extract protocol-specific inputs gamma_dot = kwargs.pop("gamma_dot", None) sigma_applied = kwargs.pop("sigma_applied", None) sigma_0 = kwargs.pop("sigma_0", None) # FL-003: Use local variables for coarser grid instead of mutating self # This avoids thread-safety issues where concurrent access could corrupt # self.N_y and self.dy during fitting fit_N_y = kwargs.pop("fit_N_y", min(self.N_y, 32)) fit_dy = self.gap_width / (fit_N_y - 1) if mode == "startup" and gamma_dot is None: raise ValueError("startup mode requires gamma_dot in kwargs") if mode == "creep" and sigma_applied is None: raise ValueError("creep mode requires sigma_applied in kwargs") # Store for prediction self._gamma_dot_applied = gamma_dot self._sigma_applied = sigma_applied self._sigma_0 = sigma_0 def model_fn(x_data, params): p_map = dict(zip(self.parameters.keys(), params, strict=True)) return self._simulate_pde( x_data, p_map, mode, gamma_dot, sigma_applied, sigma_0, N_y=fit_N_y, dy=fit_dy, ) # See FluidityLocal._fit_transient for the rationale: relative # residuals (normalize=True) blow up at the zero crossings / # zero starting points of creep and startup data. objective = create_least_squares_objective( model_fn, t_jax, y_jax, normalize=False, use_log_residuals=kwargs.get("use_log_residuals", False), ) # Keep "method" so it reaches nlsq_optimize. Transient protocols use # a diffrax ODE (custom_vjp) — default to scipy if caller didn't pick. nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _NLSQ_RESERVED_ODE} nlsq_kwargs.setdefault("method", "scipy") result = nlsq_optimize(objective, self.parameters, **nlsq_kwargs) if not result.success: logger.warning(f"Fluidity transient fit warning: {result.message}") def _simulate_pde( self, t: jnp.ndarray, params: dict, mode: str, gamma_dot: float | None, sigma_applied: float | None, sigma_0: float | None, N_y: int | None = None, dy: float | None = None, ) -> jnp.ndarray: """Simulate PDE response using Diffrax. Args: t: Time array params: Parameter dictionary mode: 'startup', 'relaxation', or 'creep' gamma_dot: Applied shear rate (for startup) sigma_applied: Applied stress (for creep) sigma_0: Initial stress (for relaxation) N_y: Grid points override (FL-003 thread safety). If None, uses self.N_y. dy: Grid spacing override (FL-003 thread safety). If None, uses self.dy. Returns: Primary output (stress for startup/relaxation, strain for creep) """ # FL-003: Use local variables instead of self.N_y/self.dy for thread safety N_y_local = N_y if N_y is not None else self.N_y dy_local = dy if dy is not None else self.dy # Build args for PDE RHS # FL-012: Removed dead "N_y" key — PDE kernels infer N_y from state vector shape args = { "G": params["G"], "tau_y": params["tau_y"], "K": params["K"], "n_flow": params["n_flow"], "theta": params["theta"], "xi": params.get("xi", 1e-5), "dy": dy_local, } # Mode-specific setup if mode == "creep": pde_fn = fluidity_nonlocal_creep_pde_rhs args["sigma_applied"] = sigma_applied if sigma_applied is not None else 0.0 else: pde_fn = fluidity_nonlocal_pde_rhs if mode == "startup": args["mode"] = 0 # rate_controlled args["gamma_dot"] = gamma_dot if gamma_dot is not None else 0.0 else: # relaxation args["mode"] = 0 # rate_controlled args["gamma_dot"] = 0.0 # Initial state (uses N_y_local for grid size) y0 = self._get_initial_state(mode, params, sigma_0, N_y=N_y_local) # Diffrax setup - use Dopri5 for stiff PDEs (explicit, avoids tracer issues) term = diffrax.ODETerm( jax.checkpoint(lambda ti, yi, args_i: pde_fn(cast(float, ti), yi, args_i)) ) solver = diffrax.Dopri5() stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-7) t0 = t[0] t1 = t[-1] dt0 = (t1 - t0) / max(len(t), 1000) saveat = diffrax.SaveAt(ts=t) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, stepsize_controller=stepsize_controller, max_steps=10_000_000, throw=False, # Return partial result on failure (for optimization) ) # Store trajectory for analysis (skip during JAX tracing, e.g. NUTS) # FL-007: Log exceptions instead of silently swallowing try: self._f_field_trajectory = np.array(sol.ys[:, 1:]) except Exception as e: logger.warning("Could not store fluidity field trajectory: %s", e) # Extract primary variable (index 0) # For creep: strain; for startup/relaxation: stress result = sol.ys[:, 0] # Handle solver failure by returning NaN (optimization will avoid this) result = jnp.where(sol.result == diffrax.RESULTS.successful, result, jnp.nan) return result def _predict_transient( self, t: np.ndarray, mode: str | None = None, sigma_0: float | None = None, gamma_dot: Any = _MISSING, sigma_applied: Any = _MISSING, ) -> np.ndarray: """Predict transient response. Protocol inputs (``gamma_dot`` for startup, ``sigma_applied`` for creep, ``sigma_0`` for relaxation) are read from keyword arguments when supplied so ``predict()`` works without a prior ``fit()``. Any argument left as ``_MISSING`` falls back to the instance attribute populated by ``_fit_*`` (legacy path). """ t_jax = jnp.asarray(t, dtype=jnp.float64) p = self.get_parameter_dict() mode = mode if mode is not None else self._test_mode if mode is None: raise ValueError("Test mode not specified for prediction") if gamma_dot is _MISSING: gamma_dot = getattr(self, "_gamma_dot_applied", None) if sigma_applied is _MISSING: sigma_applied = getattr(self, "_sigma_applied", None) if sigma_0 is None: sigma_0 = getattr(self, "_sigma_0", None) result = self._simulate_pde( t_jax, p, mode, gamma_dot, sigma_applied, sigma_0, ) return np.array(result) # ========================================================================= # Shear Banding Analysis # =========================================================================
[docs] def get_fluidity_profile(self, time_idx: int = -1) -> np.ndarray: """Get fluidity profile at specified time index. Args: time_idx: Time index (default -1 for final time) Returns: Fluidity field across gap, shape (N_y,) """ if self._f_field_trajectory is None: raise ValueError("No trajectory available. Run simulation first.") return self._f_field_trajectory[time_idx]
[docs] def get_shear_banding_metric(self, f_field: np.ndarray | None = None) -> float: """Compute coefficient of variation as shear banding metric. CV = std(f) / mean(f) CV > 0.3 typically indicates significant shear banding. Args: f_field: Fluidity field. If None, uses final simulation state. Returns: Coefficient of variation (dimensionless) """ if f_field is None: f_field = self.get_fluidity_profile(-1) f_jax = jnp.asarray(f_field, dtype=jnp.float64) return float(shear_banding_cv(f_jax))
[docs] def get_banding_ratio(self, f_field: np.ndarray | None = None) -> float: """Compute max/min fluidity ratio as banding metric. ratio > 10 indicates strong localization. Args: f_field: Fluidity field. If None, uses final simulation state. Returns: Banding ratio (dimensionless) """ if f_field is None: f_field = self.get_fluidity_profile(-1) f_jax = jnp.asarray(f_field, dtype=jnp.float64) return float(banding_ratio(f_jax))
[docs] def is_banding( self, f_field: np.ndarray | None = None, cv_threshold: float = 0.3 ) -> bool: """Check if shear banding is occurring. Args: f_field: Fluidity field. If None, uses final simulation state. cv_threshold: CV threshold for banding (default 0.3) Returns: True if CV > threshold """ return self.get_shear_banding_metric(f_field) > cv_threshold
# ========================================================================= # Oscillatory Protocols # ========================================================================= def _fit_oscillation(self, X: np.ndarray, y: np.ndarray, **kwargs) -> None: """Fit SAOS data using linear viscoelastic approximation. For small amplitude, bulk response approximates Local model. Only G and f_eq affect the Maxwell-limit residual; optimizing the full parameter set produces a rank-2 Jacobian and causes NLSQ to terminate prematurely on xtol. We fit the reduced (G, f_eq) set. """ from rheojax.core.parameters import ParameterSet from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) omega_jax = jnp.asarray(X, dtype=jnp.float64) # Handle G_star format G_star_np = np.asarray(y) if np.iscomplexobj(G_star_np): G_prime_np = np.real(G_star_np) G_dp_np = np.imag(G_star_np) G_star_2d = np.column_stack([G_prime_np, G_dp_np]) elif G_star_np.ndim == 2 and G_star_np.shape[1] == 2: G_prime_np = G_star_np[:, 0] G_dp_np = G_star_np[:, 1] G_star_2d = G_star_np else: raise ValueError(f"G_star must be complex or (M, 2), got {G_star_np.shape}") G_star_jax = jnp.asarray(G_star_2d, dtype=jnp.float64) # Data-driven warm-start for G and f_eq. self._seed_saos_from_data(np.asarray(X, dtype=float), G_prime_np, G_dp_np) reduced = ParameterSet() for name in ("G", "f_eq"): src = self.parameters[name] reduced.add( name=name, value=src.value, bounds=src.bounds, units=src.units, description=src.description, ) def model_fn(x_data, params): p_map = dict(zip(reduced.keys(), params, strict=True)) return self._predict_saos_jit( x_data, p_map["G"], p_map["f_eq"], ) objective = create_least_squares_objective( model_fn, omega_jax, G_star_jax, normalize=True, ) # FL-006: Pop protocol/meta kwargs before forwarding to nlsq_optimize nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _NLSQ_RESERVED} result = nlsq_optimize(objective, reduced, **nlsq_kwargs) if not result.success: logger.warning(f"Fluidity SAOS fit warning: {result.message}") G_fit = reduced["G"].value f_eq_fit = reduced["f_eq"].value if G_fit is None or f_eq_fit is None: raise RuntimeError("NLSQ returned no value for SAOS parameters") self.parameters.set_value("G", float(G_fit)) self.parameters.set_value("f_eq", float(f_eq_fit)) def _seed_saos_from_data( self, omega: np.ndarray, G_prime: np.ndarray, G_double_prime: np.ndarray, ) -> None: """Seed G and f_eq from SAOS data so NLSQ starts near the minimum. G ← high-ω G' plateau; tau_eff ← crossover ω* (else location of G'' peak); f_eq ← 1/(G·tau_eff). All seeds clipped to bounds. """ omega = np.asarray(omega, dtype=float) Gp = np.asarray(G_prime, dtype=float) Gpp = np.asarray(G_double_prime, dtype=float) if omega.size < 2: return order = np.argsort(omega) Gp_sorted = Gp[order] n_top = max(1, len(Gp_sorted) // 5) G_seed = float(np.max(Gp_sorted[-n_top:])) with np.errstate(divide="ignore", invalid="ignore"): diff = np.log(np.maximum(Gp, 1e-300)) - np.log(np.maximum(Gpp, 1e-300)) sign_change = np.where(np.diff(np.sign(diff[order])) != 0)[0] if sign_change.size > 0: i = int(sign_change[0]) d0, d1 = diff[order][i], diff[order][i + 1] w0, w1 = np.log(omega[order][i]), np.log(omega[order][i + 1]) if d1 != d0: w_cross = w0 + (0.0 - d0) * (w1 - w0) / (d1 - d0) else: w_cross = 0.5 * (w0 + w1) tau_seed = float(np.exp(-w_cross)) else: i_peak = int(np.argmax(Gpp[order])) tau_seed = 1.0 / float(omega[order][i_peak]) f_eq_seed = 1.0 / max(G_seed * tau_seed, 1e-30) def _clipped(name: str, value: float) -> float: param = self.parameters[name] lo, hi = param.bounds if param.bounds else (-np.inf, np.inf) lo_v = lo if lo is not None else -np.inf hi_v = hi if hi is not None else np.inf return float(np.clip(value, lo_v, hi_v)) self.parameters.set_value("G", _clipped("G", G_seed)) self.parameters.set_value("f_eq", _clipped("f_eq", f_eq_seed)) # TODO (FL-010): _predict_saos_jit is duplicated in FluidityLocal. # Consider extracting to a shared module-level function or into _base.py. @staticmethod @jax.jit def _predict_saos_jit( omega: jnp.ndarray, G: float, f_eq: float, theta: float = 0.0, # FL-005: dead parameter, kept for backward compatibility ) -> jnp.ndarray: """SAOS prediction using linear viscoelastic approximation. Note: theta parameter is unused (FL-005) but kept for backward compatibility with external callers. """ del theta # FL-005: explicitly unused tau_eff = 1.0 / (G * f_eq + 1e-30) omega_tau = omega * tau_eff denom = 1.0 + omega_tau**2 G_prime = G * omega_tau**2 / denom G_double_prime = G * omega_tau / denom return jnp.stack([G_prime, G_double_prime], axis=1) def _fit_laos(self, t: np.ndarray, sigma: np.ndarray, **kwargs) -> None: """Fit LAOS data using full PDE integration.""" from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) gamma_0 = kwargs.pop("gamma_0", None) omega = kwargs.pop("omega", None) f_init = kwargs.pop("f_init", None) if gamma_0 is None or omega is None: raise ValueError("LAOS fitting requires gamma_0 and omega") self._gamma_0 = gamma_0 self._omega_laos = omega self._laos_f_init = f_init # FL-003: Use local variables for coarser grid instead of mutating self fit_N_y = kwargs.pop("fit_N_y", min(self.N_y, 32)) fit_dy = self.gap_width / (fit_N_y - 1) t_jax = jnp.asarray(t, dtype=jnp.float64) sigma_jax = jnp.asarray(sigma, dtype=jnp.float64) def model_fn(x_data, params): p_map = dict(zip(self.parameters.keys(), params, strict=True)) _, stress = self._simulate_laos_internal( x_data, p_map, gamma_0, omega, N_y=fit_N_y, dy=fit_dy, f_init=f_init, ) return stress # See FluidityLocal._fit_laos for the rationale: LAOS stress # crosses zero, so relative residuals (normalize=True) blow up # at the zero crossings and pull the optimizer away from the # true parameters. objective = create_least_squares_objective( model_fn, t_jax, sigma_jax, normalize=False, ) # Keep "method" so it reaches nlsq_optimize. LAOS uses a diffrax ODE # (custom_vjp) — default to scipy if caller didn't pick a method. nlsq_kwargs = {k: v for k, v in kwargs.items() if k not in _NLSQ_RESERVED_ODE} nlsq_kwargs.setdefault("method", "scipy") result = nlsq_optimize(objective, self.parameters, **nlsq_kwargs) if not result.success: logger.warning(f"Fluidity LAOS fit warning: {result.message}") def _simulate_laos_internal( self, t: jnp.ndarray, params: dict, gamma_0: float, omega: float, N_y: int | None = None, dy: float | None = None, f_init: np.ndarray | None = None, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Simulate LAOS response using PDE solver. Args: t: Time array params: Parameter dictionary gamma_0: Strain amplitude omega: Angular frequency N_y: Grid points override (FL-003 thread safety). If None, uses self.N_y. dy: Grid spacing override (FL-003 thread safety). If None, uses self.dy. """ # FL-003: Use local variables instead of self.N_y/self.dy for thread safety N_y_local = N_y if N_y is not None else self.N_y dy_local = dy if dy is not None else self.dy # Base args # FL-012: Removed dead "N_y" key — PDE kernels infer N_y from state vector shape base_args = { "G": params["G"], "tau_y": params["tau_y"], "K": params["K"], "n_flow": params["n_flow"], "theta": params["theta"], "xi": params.get("xi", 1e-5), "dy": dy_local, "mode": 0, # rate_controlled } # Initial state: use provided profile to seed spatial gradients, # or fall back to uniform f_eq (gives ∇²f≡0, identical to local model). if f_init is not None: f_field_init = jnp.asarray(f_init, dtype=jnp.float64) if len(f_field_init) != N_y_local: x_src = np.linspace(0, 1, len(f_field_init)) x_dst = np.linspace(0, 1, N_y_local) f_field_init = jnp.asarray( np.interp(x_dst, x_src, np.asarray(f_field_init)) ) y0 = jnp.concatenate( [jnp.array([0.0]), jnp.maximum(f_field_init, 1e-20)] ) else: y0 = self._get_initial_state("laos", params, N_y=N_y_local) # PDE with time-varying gamma_dot def laos_pde(ti, yi, args_i): gamma_dot_t = gamma_0 * omega * jnp.cos(omega * ti) args_with_rate = {**args_i, "gamma_dot": gamma_dot_t} return fluidity_nonlocal_pde_rhs(ti, yi, args_with_rate) term = diffrax.ODETerm(jax.checkpoint(laos_pde)) solver = diffrax.Dopri5() stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-7) t0 = t[0] t1 = t[-1] dt0 = (t1 - t0) / max(len(t), 1000) saveat = diffrax.SaveAt(ts=t) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, args=base_args, saveat=saveat, stepsize_controller=stepsize_controller, max_steps=16_000_000, throw=False, # Return partial result on failure (for optimization) ) stress = sol.ys[:, 0] strain = gamma_0 * jnp.sin(omega * t) # Handle solver failure by returning NaN stress = jnp.where(sol.result == diffrax.RESULTS.successful, stress, jnp.nan) # Store trajectory only when not in JIT context (concrete arrays) # FL-008: Use ConcretizationTypeError (modern) instead of deprecated # TracerArrayConversionError try: # This will fail during JIT tracing self._f_field_trajectory = np.asarray(sol.ys[:, 1:]) except (TypeError, jax.errors.ConcretizationTypeError): # During JIT tracing, skip storage pass return strain, stress
[docs] def simulate_laos( self, gamma_0: float, omega: float, n_cycles: int = 2, n_points_per_cycle: int = 256, f_init: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Simulate LAOS response. Args: gamma_0: Strain amplitude omega: Angular frequency (rad/s) n_cycles: Number of oscillation cycles n_points_per_cycle: Points per cycle Returns: (strain, stress) arrays """ self._gamma_0 = gamma_0 self._omega_laos = omega period = 2.0 * np.pi / omega t_max = n_cycles * period n_points = n_cycles * n_points_per_cycle t = np.linspace(0, t_max, n_points, endpoint=False) t_jax = jnp.asarray(t, dtype=jnp.float64) p = self.get_parameter_dict() strain, stress = self._simulate_laos_internal( t_jax, p, gamma_0, omega, f_init=f_init ) return np.array(strain), np.array(stress)
# ========================================================================= # Bayesian / Model Function Interface # =========================================================================
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """NumPyro/BayesianMixin model function. Accepts protocol-specific kwargs (gamma_dot, sigma_applied, etc.). """ p_values = dict(zip(self.parameters.keys(), params, strict=True)) mode = test_mode if test_mode is not None else self._test_mode if mode is None: mode = "oscillation" # FL-001: Normalize aliases if mode == "saos": mode = "oscillation" X_jax = jnp.asarray(X, dtype=jnp.float64) # FL-009: Use sentinel pattern to avoid swallowing falsy values (e.g. 0.0) gamma_dot = kwargs.get("gamma_dot", _MISSING) if gamma_dot is _MISSING: gamma_dot = getattr(self, "_gamma_dot_applied", None) sigma_applied = kwargs.get("sigma_applied", _MISSING) if sigma_applied is _MISSING: sigma_applied = getattr(self, "_sigma_applied", None) gamma_0 = kwargs.get("gamma_0", _MISSING) if gamma_0 is _MISSING: gamma_0 = getattr(self, "_gamma_0", None) omega = kwargs.get("omega", _MISSING) if omega is _MISSING: omega = getattr(self, "_omega_laos", None) if mode in ["steady_shear", "rotation", "flow_curve"]: return fluidity_nonlocal_steady_state( X_jax, p_values["G"], p_values["tau_y"], p_values["K"], p_values["n_flow"], p_values["f_eq"], p_values["f_inf"], p_values["theta"], ) elif mode == "oscillation": return self._predict_saos_jit( X_jax, p_values["G"], p_values["f_eq"], ) elif mode in ["startup", "relaxation", "creep"]: return self._simulate_pde( X_jax, p_values, mode, gamma_dot, sigma_applied, None, ) elif mode == "laos": if gamma_0 is None or omega is None: raise ValueError("LAOS mode requires gamma_0 and omega") f_init = kwargs.get("f_init", getattr(self, "_laos_f_init", None)) _, stress = self._simulate_laos_internal( X_jax, p_values, gamma_0, omega, f_init=f_init ) return stress return jnp.zeros_like(X_jax)
# ========================================================================= # Prediction Interface # ========================================================================= def _predict(self, X: np.ndarray, **kwargs: Any) -> np.ndarray: """Predict based on fitted state.""" X_jax = jnp.asarray(X, dtype=jnp.float64) p = self.get_parameter_dict() # Get test_mode from kwargs or instance attribute _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None) ) if test_mode is None: raise ValueError("test_mode must be specified for prediction") # FL-001: Normalize aliases if test_mode == "saos": test_mode = "oscillation" if test_mode in ["steady_shear", "rotation", "flow_curve"]: result = fluidity_nonlocal_steady_state( X_jax, p["G"], p["tau_y"], p["K"], p["n_flow"], p["f_eq"], p["f_inf"], p["theta"], ) return np.array(result) elif test_mode == "oscillation": result = self._predict_saos_jit( X_jax, p["G"], p["f_eq"], ) # Convert (N,2) [G', G''] to complex G* for consistent API result = np.array(result) return result[:, 0] + 1j * result[:, 1] elif test_mode in ["startup", "relaxation", "creep"]: return self._predict_transient( X, mode=test_mode, sigma_0=kwargs.get("sigma_0"), gamma_dot=kwargs.get("gamma_dot", _MISSING), sigma_applied=kwargs.get("sigma_applied", _MISSING), ) elif test_mode == "laos": # Get gamma_0 and omega from kwargs or instance attributes gamma_0 = kwargs.get("gamma_0", self._gamma_0) omega = kwargs.get("omega", self._omega_laos) f_init = kwargs.get("f_init", getattr(self, "_laos_f_init", None)) if gamma_0 is None or omega is None: raise ValueError("LAOS prediction requires gamma_0 and omega") _, stress = self._simulate_laos_internal( X_jax, p, gamma_0, omega, f_init=f_init ) return np.array(stress) return np.zeros_like(X)