Source code for rheojax.models.dmt.nonlocal_model

"""Nonlocal (1D) de Souza Mendes-Thompson (DMT) model.

Implements the spatially-resolved DMT model for shear banding analysis.
Adds fluidity diffusion to regularize the local model and allow for
heterogeneous flow profiles.

This model is appropriate for:
- Shear banding detection and characterization
- Gap-dependent rheology
- Cooperativity length scale estimation
"""

from __future__ import annotations

from typing import Literal

import numpy as np

from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger
from rheojax.models.dmt._base import DMTBase
from rheojax.models.dmt._kernels import (
    structure_evolution,
    viscosity_exponential,
    viscosity_herschel_bulkley_regularized,
)

# Safe JAX import
jax, jnp = safe_import_jax()

# Module logger
logger = get_logger(__name__)


[docs] @ModelRegistry.register( "dmt_nonlocal", # Only FLOW_CURVE is fully wired through both _fit and _predict. # _fit_startup raises NotImplementedError and _predict has no # startup/creep branches; advertising those protocols in the # registry caused downstream code (e.g. predict-without-fit canary) # to call into paths that raise "Unknown test_mode for prediction". protocols=[Protocol.FLOW_CURVE], deformation_modes=[DeformationMode.SHEAR], ) class DMTNonlocal(DMTBase): r"""Nonlocal (1D) DMT model for shear banding analysis. Extends the local DMT model with spatial structure diffusion: ∂λ/∂t = (1-λ)/t_eq - a·λ·\|γ̇\|^c/t_eq + D_λ·∂²λ/∂y² The diffusion term introduces a cooperativity length scale: ξ ~ √(D_λ · t_eq) which regularizes the problem and sets the minimum width of shear bands. This model solves for: - λ(y, t): Structure field across the gap - v(y, t): Velocity profile (from momentum balance) - γ̇(y, t): Local shear rate Parameters ---------- closure : {"exponential", "herschel_bulkley"}, default "exponential" Viscosity closure type. include_elasticity : bool, default True Include Maxwell viscoelastic backbone. n_points : int, default 51 Number of spatial grid points across the gap. gap_width : float, default 1e-3 Gap width H [m] (e.g., for Couette cell). Attributes ---------- n_points : int Spatial grid resolution gap_width : float Physical gap width [m] y : array Spatial coordinate array [m] Examples -------- >>> from rheojax.models.dmt import DMTNonlocal >>> >>> # Create nonlocal model for banding analysis >>> model = DMTNonlocal( ... closure="herschel_bulkley", ... n_points=101, ... gap_width=1e-3 ... ) >>> >>> # Simulate steady shear with banding >>> result = model.simulate_steady_shear( ... gamma_dot_avg=10.0, t_end=1000.0 ... ) >>> >>> # Check for banding >>> banding_info = model.detect_banding(result) See Also -------- DMTLocal : Local (0D) variant for homogeneous flow FluidityNonlocal : Simpler nonlocal fluidity model References ---------- Coussot, P. et al. (2002). "Viscosity bifurcation in thixotropic, yielding fluids." J. Rheol. 46, 573-589. """
[docs] def __init__( self, closure: Literal["exponential", "herschel_bulkley"] = "exponential", include_elasticity: bool = True, n_points: int = 51, gap_width: float = 1e-3, ): """Initialize DMTNonlocal model.""" self.n_points = n_points self.gap_width = gap_width super().__init__(closure=closure, include_elasticity=include_elasticity) # Add nonlocal-specific parameters self._add_nonlocal_parameters() # Spatial grid self.y = np.linspace(0, gap_width, n_points) self.dy = gap_width / (n_points - 1) logger.info( "DMTNonlocal initialized", closure=closure, include_elasticity=include_elasticity, n_points=n_points, gap_width=gap_width, )
def _add_nonlocal_parameters(self): """Add parameters specific to nonlocal model.""" # D_λ: Structure diffusion coefficient self.parameters.add( name="D_lambda", value=1e-6, bounds=(1e-10, 1e-2), units="m²/s", description="Structure diffusion coefficient (cooperativity)", )
[docs] def get_cooperativity_length(self) -> float: """Compute cooperativity length scale. ξ = √(D_λ · t_eq) Returns ------- float Cooperativity length [m] """ D_lambda = self.parameters.get_value("D_lambda") t_eq = self.parameters.get_value("t_eq") return np.sqrt(D_lambda * t_eq)
# ========================================================================= # Required Abstract Methods # ========================================================================= def _fit(self, X: np.ndarray, y: np.ndarray, **kwargs) -> DMTNonlocal: """Fit model to data.""" test_mode = kwargs.get("test_mode", "flow_curve") # P1-DMTNL-001: Cache test_mode for Bayesian _resolve_test_mode() fallback self._test_mode = test_mode if test_mode in ("flow_curve", "rotation"): return self._fit_flow_curve(X, y, **kwargs) elif test_mode == "startup": return self._fit_startup(X, y, **kwargs) else: raise ValueError(f"Unknown test_mode: {test_mode}") def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """Predict model response.""" test_mode = kwargs.get("test_mode", "flow_curve") if test_mode in ("flow_curve", "rotation"): return self._predict_flow_curve(X, **kwargs) else: raise ValueError(f"Unknown test_mode for prediction: {test_mode}") # ========================================================================= # Spatial Operators # ========================================================================= def _laplacian(self, field: jnp.ndarray) -> jnp.ndarray: """Compute 1D Laplacian with Neumann BCs (∂field/∂y = 0 at walls). Uses second-order central differences. Parameters ---------- field : array Field values at grid points (shape: n_points) Returns ------- array Laplacian ∂²field/∂y² at each point """ dy_sq = self.dy**2 # Interior points: central difference lap = jnp.zeros_like(field) lap = lap.at[1:-1].set((field[:-2] - 2 * field[1:-1] + field[2:]) / dy_sq) # Neumann BCs: ∂field/∂y = 0 at y=0 and y=H # Ghost point approach: field[-1] = field[1], field[N] = field[N-2] lap = lap.at[0].set(2 * (field[1] - field[0]) / dy_sq) lap = lap.at[-1].set(2 * (field[-2] - field[-1]) / dy_sq) return lap def _compute_shear_rate_from_velocity(self, v: jnp.ndarray) -> jnp.ndarray: """Compute shear rate from velocity profile. γ̇(y) = dv/dy Parameters ---------- v : array Velocity profile v(y) Returns ------- array Shear rate profile γ̇(y) """ gamma_dot = jnp.zeros_like(v) # Central difference for interior gamma_dot = gamma_dot.at[1:-1].set((v[2:] - v[:-2]) / (2 * self.dy)) # One-sided for boundaries gamma_dot = gamma_dot.at[0].set((v[1] - v[0]) / self.dy) gamma_dot = gamma_dot.at[-1].set((v[-1] - v[-2]) / self.dy) return gamma_dot # ========================================================================= # Steady Shear Simulation # =========================================================================
[docs] def simulate_steady_shear( self, gamma_dot_avg: float, t_end: float, dt: float = 0.1, lam_init: float | np.ndarray = 1.0, ) -> dict[str, np.ndarray]: """Simulate approach to steady state under imposed average shear rate. For planar Couette: v(0) = 0, v(H) = V_wall = γ̇_avg · H The stress is uniform (σ(y) = Σ) at low Reynolds number. The local shear rate γ̇(y) varies to satisfy the local constitutive relation with the uniform stress. Parameters ---------- gamma_dot_avg : float Average (imposed) shear rate [1/s] t_end : float Simulation end time [s] dt : float Time step [s] lam_init : float or array Initial structure (scalar for uniform, array for profile) Returns ------- dict 't': time array 'lam': structure profiles λ(y, t_i) 'gamma_dot': shear rate profiles γ̇(y, t_i) 'velocity': velocity profiles v(y, t_i) 'stress': wall stress Σ(t) """ n_steps = int(t_end / dt) params = self.get_parameter_dict() # Wall velocity for imposed average shear rate V_wall = gamma_dot_avg * self.gap_width # Initial conditions if isinstance(lam_init, float): lam = jnp.ones(self.n_points) * lam_init else: lam = jnp.array(lam_init) # Initialize velocity (linear profile for homogeneous flow) v = jnp.linspace(0, V_wall, self.n_points) # Close over constants for lax.scan (avoid self references) dy = self.dy dy_sq = dy**2 D_lambda = params["D_lambda"] t_eq = params["t_eq"] a_param = params["a"] c_param = params["c"] def _laplacian_pure(field): """Laplacian with Neumann BCs (no self reference).""" lap = jnp.zeros_like(field) lap = lap.at[1:-1].set((field[:-2] - 2 * field[1:-1] + field[2:]) / dy_sq) lap = lap.at[0].set(2 * (field[1] - field[0]) / dy_sq) lap = lap.at[-1].set(2 * (field[-2] - field[-1]) / dy_sq) return lap def _shear_rate_pure(v_prof): """Shear rate from velocity (no self reference).""" gd = jnp.zeros_like(v_prof) gd = gd.at[1:-1].set((v_prof[2:] - v_prof[:-2]) / (2 * dy)) gd = gd.at[0].set((v_prof[1] - v_prof[0]) / dy) gd = gd.at[-1].set((v_prof[-1] - v_prof[-2]) / dy) return gd # Build viscosity function based on closure type (branch outside scan) if self.closure == "exponential": eta_0 = params["eta_0"] eta_inf = params["eta_inf"] def _viscosity_fn(lam_arr, gd_arr): return jax.vmap( lambda li, gi: viscosity_exponential(li, eta_0, eta_inf) )(lam_arr, gd_arr) else: tau_y0 = params["tau_y0"] K0 = params["K0"] n_flow = params["n_flow"] eta_inf_hb = params["eta_inf"] m1 = params["m1"] m2 = params["m2"] def _viscosity_fn(lam_arr, gd_arr): return jax.vmap( lambda li, gi: viscosity_herschel_bulkley_regularized( li, gi, tau_y0, K0, n_flow, eta_inf_hb, m1, m2 ) )(lam_arr, gd_arr) def scan_step(carry, _): """Single time step (pure JAX, no host transfers).""" lam_c, v_c = carry # Shear rate from velocity gamma_dot_c = _shear_rate_pure(v_c) # Viscosity eta = _viscosity_fn(lam_c, gamma_dot_c) # Stress (uniform for low Re) stress = jnp.mean(eta * gamma_dot_c) # Structure evolution + diffusion local_rate = jax.vmap( lambda li, gi: structure_evolution(li, gi, t_eq, a_param, c_param) )(lam_c, gamma_dot_c) diffusion = D_lambda * _laplacian_pure(lam_c) lam_new = jnp.clip(lam_c + dt * (local_rate + diffusion), 0.0, 1.0) # Velocity reconstruction from uniform stress gamma_dot_target = stress / jnp.maximum(eta, 1e-10) v_raw = jnp.concatenate( [jnp.array([0.0]), jnp.cumsum(gamma_dot_target[:-1]) * dy] ) # Guard: if v_raw[-1] is near-zero, scaling would amplify noise; # fall back to V_wall denominator (linear profile) instead. v_denom = jnp.where(v_raw[-1] > 1e-6 * V_wall, v_raw[-1], V_wall) v_new = v_raw * V_wall / v_denom outputs = (stress, lam_c, v_c, gamma_dot_c) return (lam_new, v_new), outputs # Run scan — single JIT-compiled loop, no host transfers (lam_final, v_final), (stress_all, lam_all, v_all, gd_all) = jax.lax.scan( scan_step, (lam, v), None, length=n_steps ) # Single vectorized transfer at the end t_out = np.arange(n_steps) * dt return { "t": t_out, "y": self.y, "lam": np.asarray(lam_all), "gamma_dot": np.asarray(gd_all), "velocity": np.asarray(v_all), "stress": np.asarray(stress_all), }
# ========================================================================= # Banding Detection # =========================================================================
[docs] def detect_banding( self, result: dict, threshold: float = 0.1, ) -> dict: """Detect shear banding from steady-state profiles. A shear band is detected when the shear rate profile shows significant spatial variation (standard deviation / mean > threshold). Parameters ---------- result : dict Result from simulate_steady_shear() threshold : float Relative variation threshold for banding detection Returns ------- dict 'is_banding': bool 'band_contrast': max/min shear rate ratio 'band_width': approximate band width [m] 'band_location': center of high-shear band [m] 'gamma_dot_profile': final shear rate profile 'lam_profile': final structure profile """ # Use final profiles gamma_dot_final = result["gamma_dot"][-1] lam_final = result["lam"][-1] # Compute variation metrics mean_gd = np.mean(gamma_dot_final) std_gd = np.std(gamma_dot_final) relative_variation = std_gd / max(mean_gd, 1e-10) is_banding = relative_variation > threshold # Band contrast band_contrast = np.max(gamma_dot_final) / max(np.min(gamma_dot_final), 1e-10) # Find band location and width # High-shear band is where γ̇ > mean + std high_shear_mask = gamma_dot_final > mean_gd + std_gd if np.any(high_shear_mask): band_indices = np.where(high_shear_mask)[0] band_width = (band_indices[-1] - band_indices[0]) * self.dy band_location = self.y[band_indices].mean() else: band_width = self.gap_width band_location = self.gap_width / 2 return { "is_banding": is_banding, "relative_variation": relative_variation, "band_contrast": band_contrast, "band_width": band_width, "band_width_fraction": band_width / self.gap_width, "band_location": band_location, "gamma_dot_profile": gamma_dot_final, "lam_profile": lam_final, }
# ========================================================================= # Flow Curve # ========================================================================= def _fit_flow_curve( self, gamma_dot: np.ndarray, stress: np.ndarray, **kwargs ) -> DMTNonlocal: """Fit to steady-state flow curve. For nonlocal model, this fits to the apparent (average) flow curve. """ # Use local model fit as approximation from rheojax.models.dmt.local import DMTLocal local_model = DMTLocal( closure=self.closure, include_elasticity=self.include_elasticity ) local_model._fit_flow_curve(gamma_dot, stress, **kwargs) # Copy parameters for name in local_model.parameters.keys(): if name in self.parameters.keys(): self.parameters.set_value(name, local_model.parameters.get_value(name)) self._fitted = True return self def _predict_flow_curve(self, gamma_dot: np.ndarray, **kwargs) -> np.ndarray: """Predict steady-state flow curve. Can either use: - Homogeneous (local) approximation - Full nonlocal simulation at each point (slow) """ use_local = kwargs.get("use_local_approximation", True) if use_local: # Use local approximation from rheojax.models.dmt.local import DMTLocal local_model = DMTLocal( closure=self.closure, include_elasticity=self.include_elasticity ) # Copy parameters for name in self.parameters.keys(): if name in local_model.parameters.keys(): local_model.parameters.set_value( name, self.parameters.get_value(name) ) return local_model._predict_flow_curve(gamma_dot) else: # Full nonlocal simulation stress = [] for gd in gamma_dot: result = self.simulate_steady_shear( gamma_dot_avg=float(gd), t_end=kwargs.get("t_equilibrate", 1000.0), dt=kwargs.get("dt", 0.1), ) stress.append(result["stress"][-1]) return np.array(stress) def _fit_startup(self, t: np.ndarray, stress: np.ndarray, **kwargs) -> DMTNonlocal: """Fit to startup transient.""" raise NotImplementedError("Startup fitting for nonlocal model not implemented") # ========================================================================= # Visualization Helpers # =========================================================================
[docs] def plot_profiles( self, result: dict, time_indices: list[int] | None = None, figsize: tuple = (12, 4), ): """Plot structure and shear rate profiles. Parameters ---------- result : dict Result from simulate_steady_shear() time_indices : list, optional Indices of time points to plot (default: [0, -1]) figsize : tuple Figure size Returns ------- fig, axes Matplotlib figure and axes """ import matplotlib.pyplot as plt if time_indices is None: time_indices = [0, len(result["t"]) // 2, -1] fig, axes = plt.subplots(1, 3, figsize=figsize) y_mm = result["y"] * 1000 # Convert to mm for i, idx in enumerate(time_indices): t_val = result["t"][idx] color = plt.cm.viridis(i / (len(time_indices) - 1)) # Structure profile axes[0].plot( y_mm, result["lam"][idx], color=color, label=f"t = {t_val:.1f} s" ) # Shear rate profile axes[1].plot( y_mm, result["gamma_dot"][idx], color=color, label=f"t = {t_val:.1f} s" ) # Velocity profile axes[2].plot( y_mm, result["velocity"][idx] * 1000, # Convert to mm/s color=color, label=f"t = {t_val:.1f} s", ) axes[0].set_xlabel("y [mm]") axes[0].set_ylabel("λ [-]") axes[0].set_title("Structure Profile") axes[0].legend() axes[0].set_ylim(0, 1) axes[1].set_xlabel("y [mm]") axes[1].set_ylabel("γ̇ [1/s]") axes[1].set_title("Shear Rate Profile") axes[1].legend() axes[2].set_xlabel("y [mm]") axes[2].set_ylabel("v [mm/s]") axes[2].set_title("Velocity Profile") axes[2].legend() plt.tight_layout() return fig, axes