Source code for rheojax.models.vlb.nonlocal_model

"""Nonlocal VLB model with tensor diffusion for shear banding.

This module implements `VLBNonlocal`, a spatially-resolved (1D) extension
of the VLB framework where the distribution tensor mu varies across the
gap of a Couette geometry.

The PDE governing the distribution tensor is:

    dmu/dt = k_d(I - mu) + L·mu + mu·L^T + D_mu * nabla^2(mu)

where D_mu is the distribution tensor diffusivity (m²/s), a material
constant that sets the cooperativity length xi = sqrt(D_mu / k_d_0).

Shear banding arises when the Bell breakage rate creates a non-monotonic
constitutive curve (S-shaped sigma vs gamma_dot). The nonlocal diffusion
term regularizes the banding interface and sets its width.

Parameters
----------
breakage : str
    "constant" or "bell"
stress_type : str
    "linear" or "fene"
n_points : int
    Spatial grid points across gap (default 51)
gap_width : float
    Gap width in meters (default 1e-3)

References
----------
- Vernerey, F.J., Long, R. & Brighenti, R. (2017). JMPS 107, 1-20.
- Dhont, J.K.G. (1999). PRE 60, 4534.
"""

from __future__ import annotations

import logging
from typing import Literal

import numpy as np

from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax

diffrax = lazy_import("diffrax")
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.vlb._base import VLBBase
from rheojax.models.vlb._kernels import (
    laplacian_1d_neumann_vlb,
    vlb_breakage_bell,
    vlb_stress_fene_xy,
)

jax, jnp = safe_import_jax()

logger = logging.getLogger(__name__)

BreakageType = Literal["constant", "bell"]
StressType = Literal["linear", "fene"]


[docs] @ModelRegistry.register( "vlb_nonlocal", protocols=[ Protocol.FLOW_CURVE, Protocol.STARTUP, Protocol.CREEP, ], deformation_modes=[DeformationMode.SHEAR], ) class VLBNonlocal(VLBBase): """Nonlocal VLB with tensor diffusion for shear banding. Solves a 1D PDE across the gap of a Couette cell. The state at each spatial point is (mu_xx, mu_yy, mu_zz, mu_xy), plus a single wall stress Sigma (spatially uniform at low Reynolds number). Shear banding occurs when the Bell breakage rate creates a non-monotonic flow curve. The diffusion term D_mu * nabla^2(mu) regularizes the interface with width ~ xi = sqrt(D_mu / k_d_0). Parameters ---------- breakage : str, default "constant" "constant" or "bell" stress_type : str, default "linear" "linear" or "fene" n_points : int, default 51 Spatial grid resolution gap_width : float, default 1e-3 Gap width (m) """
[docs] def __init__( self, breakage: BreakageType = "constant", stress_type: StressType = "linear", n_points: int = 51, gap_width: float = 1e-3, ): """Initialize VLBNonlocal model.""" self._breakage = breakage self._stress_type = stress_type self.n_points = n_points self.gap_width = gap_width super().__init__() self._setup_parameters() # Spatial grid self.y = np.linspace(0, gap_width, n_points) self.dy = gap_width / (n_points - 1) self._test_mode = None logger.info( f"VLBNonlocal initialized: breakage={breakage}, " f"stress={stress_type}, n_points={n_points}" )
# ========================================================================= # Parameters # ========================================================================= def _setup_parameters(self): """Initialize ParameterSet for nonlocal model.""" self.parameters = ParameterSet() # Core parameters self.parameters.add( name="G0", value=1e3, bounds=(1e0, 1e8), units="Pa", description="Network modulus", ) self.parameters.add( name="k_d_0", value=1.0, bounds=(1e-6, 1e6), units="1/s", description="Unstressed dissociation rate", ) self.parameters.add( name="eta_s", value=0.0, bounds=(0.0, 1e4), units="Pa·s", description="Solvent viscosity", ) # Nonlocal parameter self.parameters.add( name="D_mu", value=1e-8, bounds=(1e-14, 1e-4), units="m²/s", description="Distribution tensor diffusivity", ) # Bell parameter if self._breakage == "bell": self.parameters.add( name="nu", value=3.0, bounds=(0.0, 20.0), units="dimensionless", description="Force sensitivity (Bell model)", ) # FENE parameter if self._stress_type == "fene": self.parameters.add( name="L_max", value=10.0, bounds=(1.5, 1000.0), units="dimensionless", description="Maximum chain extensibility (FENE-P spring)", ) # ========================================================================= # Properties # ========================================================================= @property def G0(self) -> float: val = self.parameters.get_value("G0") return float(val) if val is not None else 1e3 @property def k_d_0(self) -> float: val = self.parameters.get_value("k_d_0") return float(val) if val is not None else 1.0
[docs] def get_cooperativity_length(self) -> float: """Cooperativity length xi = sqrt(D_mu / k_d_0). This sets the shear band interface width. Returns ------- float Cooperativity length (m) """ # P3-VLB-001: Use explicit None-check instead of or-sentinel so that a # legitimately small D_mu=0.0 (no diffusion) is not silently replaced. _d_mu = self.parameters.get_value("D_mu") D_mu = float(_d_mu if _d_mu is not None else 1e-8) return np.sqrt(D_mu / self.k_d_0)
# ========================================================================= # PDE Integration Core # ========================================================================= def _build_pde_rhs(self): """Build PDE RHS function for diffrax integration. State: [Sigma, mu_xx[0:N], mu_yy[0:N], mu_zz[0:N], mu_xy[0:N]] Total state size: 1 + 4*N """ breakage = self._breakage stress_type = self._stress_type n = self.n_points # Closed over as static (required for JAX tracing) @jax.jit def pde_rhs(t, state, args): G0 = args["G0"] k_d_0 = args["k_d_0"] eta_s = args["eta_s"] D_mu = args["D_mu"] nu = args["nu"] L_max = args["L_max"] dy = args["dy"] gamma_dot_avg = args["gamma_dot_avg"] # Unpack state (n is closed over as static) Sigma = state[0] mu_xx = state[1 : 1 + n] mu_yy = state[1 + n : 1 + 2 * n] mu_zz = state[1 + 2 * n : 1 + 3 * n] mu_xy = state[1 + 3 * n : 1 + 4 * n] # Local dissociation rate if breakage == "bell": k_d = jax.vmap( lambda xx, yy, zz: vlb_breakage_bell(xx, yy, zz, k_d_0, nu) )(mu_xx, mu_yy, mu_zz) else: k_d = jnp.full(n, k_d_0) # Local elastic stress if stress_type == "fene": sigma_elastic = jax.vmap( lambda xx, yy, zz, xy: vlb_stress_fene_xy(xx, yy, zz, xy, G0, L_max) )(mu_xx, mu_yy, mu_zz, mu_xy) else: sigma_elastic = G0 * mu_xy # Local shear rate from stress balance # Sigma = sigma_elastic + eta_s * gamma_dot # For eta_s = 0, regularize with small fraction of network viscosity eta_eff = jnp.maximum(eta_s, 1e-2 * G0 / jnp.maximum(k_d_0, 1e-30)) gamma_dot = (Sigma - sigma_elastic) / eta_eff # mu evolution (local kinetics) dmu_xx = k_d * (1.0 - mu_xx) + 2.0 * gamma_dot * mu_xy dmu_yy = k_d * (1.0 - mu_yy) dmu_zz = k_d * (1.0 - mu_zz) dmu_xy = -k_d * mu_xy + gamma_dot * mu_yy # Add diffusion dmu_xx = dmu_xx + D_mu * laplacian_1d_neumann_vlb(mu_xx, dy) dmu_yy = dmu_yy + D_mu * laplacian_1d_neumann_vlb(mu_yy, dy) dmu_zz = dmu_zz + D_mu * laplacian_1d_neumann_vlb(mu_zz, dy) dmu_xy = dmu_xy + D_mu * laplacian_1d_neumann_vlb(mu_xy, dy) # Stress feedback: enforce average shear rate = imposed value K = 10.0 * G0 mean_gd = jnp.mean(gamma_dot) dSigma = K * (gamma_dot_avg - mean_gd) return jnp.concatenate( [ jnp.array([dSigma]), dmu_xx, dmu_yy, dmu_zz, dmu_xy, ] ) return jax.checkpoint(pde_rhs) def _build_initial_state(self, perturbation: float = 0.01) -> jnp.ndarray: """Build initial state with small perturbation for symmetry breaking. Parameters ---------- perturbation : float Amplitude of spatial noise (relative) Returns ------- jnp.ndarray Initial state vector, shape (1 + 4*n_points,) """ n = self.n_points # Initial guess for wall stress Sigma_0 = self.G0 * 1.0 # Arbitrary initial stress # Uniform equilibrium with small noise for symmetry breaking key = jax.random.PRNGKey(42) noise = perturbation * jax.random.normal(key, shape=(n,)) mu_xx_0 = jnp.ones(n) + noise mu_yy_0 = jnp.ones(n) mu_zz_0 = jnp.ones(n) mu_xy_0 = jnp.zeros(n) return jnp.concatenate( [ jnp.array([Sigma_0]), mu_xx_0, mu_yy_0, mu_zz_0, mu_xy_0, ] ) def _unpack_state(self, state: jnp.ndarray) -> dict: """Unpack state vector into named fields.""" n = self.n_points return { "Sigma": state[0], "mu_xx": state[1 : 1 + n], "mu_yy": state[1 + n : 1 + 2 * n], "mu_zz": state[1 + 2 * n : 1 + 3 * n], "mu_xy": state[1 + 3 * n : 1 + 4 * n], } def _compute_gamma_dot_profile(self, state_fields: dict) -> jnp.ndarray: """Compute local shear rate profile from state.""" G0 = self.G0 k_d_0 = self.k_d_0 _eta_s = self.parameters.get_value("eta_s") eta_s = float(_eta_s if _eta_s is not None else 0.0) Sigma = state_fields["Sigma"] mu_xy = state_fields["mu_xy"] if self._stress_type == "fene": _l_max = self.parameters.get_value("L_max") L_max = float(_l_max if _l_max is not None else 10.0) sigma_elastic = jax.vmap( lambda xx, yy, zz, xy: vlb_stress_fene_xy(xx, yy, zz, xy, G0, L_max) )( state_fields["mu_xx"], state_fields["mu_yy"], state_fields["mu_zz"], mu_xy, ) else: sigma_elastic = G0 * mu_xy eta_eff = jnp.maximum(eta_s, 1e-2 * G0 / jnp.maximum(k_d_0, 1e-30)) return (Sigma - sigma_elastic) / eta_eff # ========================================================================= # Simulation Methods # =========================================================================
[docs] def simulate_steady_shear( self, gamma_dot_avg: float, t_end: float = 100.0, dt: float = 0.1, perturbation: float = 0.01, ) -> dict: """Simulate approach to steady state under imposed average shear rate. Parameters ---------- gamma_dot_avg : float Imposed average shear rate (1/s) t_end : float Simulation end time (s) dt : float Output time step (s) perturbation : float Initial spatial noise amplitude Returns ------- dict 't': time array 'y': spatial grid 'mu_xy': mu_xy profiles (N_t, N_y) 'gamma_dot': shear rate profiles (N_t, N_y) 'stress': wall stress Sigma(t) """ n = self.n_points params = self.get_parameter_dict() nu = params.get("nu", 0.0) L_max = params.get("L_max", 10.0) args = { "G0": jnp.float64(params["G0"]), "k_d_0": jnp.float64(params["k_d_0"]), "eta_s": jnp.float64(params["eta_s"]), "D_mu": jnp.float64(params["D_mu"]), "nu": jnp.float64(nu), "L_max": jnp.float64(L_max), "dy": jnp.float64(self.dy), "gamma_dot_avg": jnp.float64(gamma_dot_avg), } pde_rhs = self._build_pde_rhs() y0 = self._build_initial_state(perturbation) # Set initial stress to expected level eta_0 = params["G0"] / params["k_d_0"] y0 = y0.at[0].set(eta_0 * gamma_dot_avg) n_steps = int(t_end / dt) t_save = jnp.linspace(0.0, t_end, n_steps + 1) # pde_rhs is already wrapped with jax.checkpoint by _build_pde_rhs() term = diffrax.ODETerm(pde_rhs) solver = diffrax.Tsit5() controller = diffrax.PIDController(rtol=1e-4, atol=1e-6) sol = diffrax.diffeqsolve( term, solver, 0.0, t_end, dt / 10.0, y0, args=args, saveat=diffrax.SaveAt(ts=t_save), stepsize_controller=controller, max_steps=5_000_000, throw=False, ) # Extract profiles at each time t_out = np.asarray(t_save) stress_out = np.asarray(sol.ys[:, 0]) mu_xy_profiles = np.asarray(sol.ys[:, 1 + 3 * n : 1 + 4 * n]) # Compute gamma_dot profiles vectorized over time axis n = self.n_points G0 = self.G0 k_d_0 = self.k_d_0 _eta_s = self.parameters.get_value("eta_s") eta_s = float(_eta_s if _eta_s is not None else 0.0) eta_eff = jnp.maximum(eta_s, 1e-2 * G0 / jnp.maximum(k_d_0, 1e-30)) if self._stress_type == "fene": _l_max = self.parameters.get_value("L_max") L_max = float(_l_max if _l_max is not None else 10.0) def _gamma_dot_single(state_i): Sigma_i = state_i[0] mu_xx_i = state_i[1 : 1 + n] mu_yy_i = state_i[1 + n : 1 + 2 * n] mu_zz_i = state_i[1 + 2 * n : 1 + 3 * n] mu_xy_i = state_i[1 + 3 * n : 1 + 4 * n] sigma_el = jax.vmap( lambda xx, yy, zz, xy: vlb_stress_fene_xy(xx, yy, zz, xy, G0, L_max) )(mu_xx_i, mu_yy_i, mu_zz_i, mu_xy_i) return (Sigma_i - sigma_el) / eta_eff else: def _gamma_dot_single(state_i): Sigma_i = state_i[0] mu_xy_i = state_i[1 + 3 * n : 1 + 4 * n] return (Sigma_i - G0 * mu_xy_i) / eta_eff gamma_dot_all = jax.vmap(_gamma_dot_single)(sol.ys) return { "t": t_out, "y": self.y, "mu_xy": mu_xy_profiles, "gamma_dot": np.asarray(gamma_dot_all), "stress": stress_out, }
[docs] def simulate_startup( self, gamma_dot_avg: float, t_end: float = 50.0, dt: float = 0.05, ) -> dict: """Simulate startup from rest with banding evolution. Parameters ---------- gamma_dot_avg : float Imposed average shear rate (1/s) t_end : float End time (s) dt : float Output time step (s) Returns ------- dict Same format as simulate_steady_shear """ return self.simulate_steady_shear( gamma_dot_avg, t_end=t_end, dt=dt, perturbation=0.01 )
[docs] def simulate_creep( self, sigma_0: float, t_end: float = 100.0, dt: float = 0.1, ) -> dict: """Simulate stress-controlled creep with spatial resolution. In creep, the stress Sigma is held fixed (no feedback). Parameters ---------- sigma_0 : float Applied stress (Pa) t_end : float End time (s) dt : float Output time step (s) Returns ------- dict 't', 'y', 'gamma_dot', 'mu_xy', 'velocity' """ n = self.n_points params = self.get_parameter_dict() nu = params.get("nu", 0.0) L_max = params.get("L_max", 10.0) breakage = self._breakage stress_type = self._stress_type dy = self.dy # Creep PDE: Sigma is constant, no feedback @jax.jit def creep_rhs(t, state, args): G0 = args["G0"] k_d_0 = args["k_d_0"] eta_s = args["eta_s"] D_mu = args["D_mu"] nu_val = args["nu"] L_max_val = args["L_max"] dy_val = args["dy"] Sigma = args["Sigma"] mu_xx = state[0:n] mu_yy = state[n : 2 * n] mu_zz = state[2 * n : 3 * n] mu_xy = state[3 * n : 4 * n] if breakage == "bell": k_d = jax.vmap( lambda xx, yy, zz: vlb_breakage_bell(xx, yy, zz, k_d_0, nu_val) )(mu_xx, mu_yy, mu_zz) else: k_d = jnp.full(n, k_d_0) if stress_type == "fene": sigma_elastic = jax.vmap( lambda xx, yy, zz, xy: vlb_stress_fene_xy( xx, yy, zz, xy, G0, L_max_val ) )(mu_xx, mu_yy, mu_zz, mu_xy) else: sigma_elastic = G0 * mu_xy eta_eff = jnp.maximum(eta_s, 1e-2 * G0 / jnp.maximum(k_d_0, 1e-30)) gamma_dot = (Sigma - sigma_elastic) / eta_eff dmu_xx = k_d * (1.0 - mu_xx) + 2.0 * gamma_dot * mu_xy dmu_yy = k_d * (1.0 - mu_yy) dmu_zz = k_d * (1.0 - mu_zz) dmu_xy = -k_d * mu_xy + gamma_dot * mu_yy dmu_xx = dmu_xx + D_mu * laplacian_1d_neumann_vlb(mu_xx, dy_val) dmu_yy = dmu_yy + D_mu * laplacian_1d_neumann_vlb(mu_yy, dy_val) dmu_zz = dmu_zz + D_mu * laplacian_1d_neumann_vlb(mu_zz, dy_val) dmu_xy = dmu_xy + D_mu * laplacian_1d_neumann_vlb(mu_xy, dy_val) return jnp.concatenate([dmu_xx, dmu_yy, dmu_zz, dmu_xy]) args = { "G0": jnp.float64(params["G0"]), "k_d_0": jnp.float64(params["k_d_0"]), "eta_s": jnp.float64(params["eta_s"]), "D_mu": jnp.float64(params["D_mu"]), "nu": jnp.float64(nu), "L_max": jnp.float64(L_max), "dy": jnp.float64(dy), "Sigma": jnp.float64(sigma_0), } # Initial state: equilibrium + noise key = jax.random.PRNGKey(42) noise = 0.01 * jax.random.normal(key, shape=(n,)) y0 = jnp.concatenate( [ jnp.ones(n) + noise, # mu_xx jnp.ones(n), # mu_yy jnp.ones(n), # mu_zz jnp.zeros(n), # mu_xy ] ) n_steps = int(t_end / dt) t_save = jnp.linspace(0.0, t_end, n_steps + 1) # Wrap with checkpoint to reduce VJP memory during NUTS reverse-mode AD term = diffrax.ODETerm(jax.checkpoint(creep_rhs)) solver = diffrax.Tsit5() controller = diffrax.PIDController(rtol=1e-4, atol=1e-6) sol = diffrax.diffeqsolve( term, solver, 0.0, t_end, dt / 10.0, y0, args=args, saveat=diffrax.SaveAt(ts=t_save), stepsize_controller=controller, max_steps=5_000_000, throw=False, ) t_out = np.asarray(t_save) mu_xy_profiles = np.asarray(sol.ys[:, 3 * n : 4 * n]) # Compute gamma_dot and velocity profiles vectorized over time axis G0_val = params["G0"] eta_s_val = params["eta_s"] k_d_0_val = params["k_d_0"] eta_eff = jnp.maximum(eta_s_val, 1e-2 * G0_val / jnp.maximum(k_d_0_val, 1e-30)) dy = self.dy if self._stress_type == "fene": def _creep_gamma_dot_single(state_i): mu_xx_i = state_i[:n] mu_yy_i = state_i[n : 2 * n] mu_zz_i = state_i[2 * n : 3 * n] mu_xy_i = state_i[3 * n : 4 * n] sigma_el = jax.vmap( lambda xx, yy, zz, xy: vlb_stress_fene_xy( xx, yy, zz, xy, G0_val, L_max ) )(mu_xx_i, mu_yy_i, mu_zz_i, mu_xy_i) return (sigma_0 - sigma_el) / eta_eff else: def _creep_gamma_dot_single(state_i): mu_xy_i = state_i[3 * n : 4 * n] return (sigma_0 - G0_val * mu_xy_i) / eta_eff gamma_dot_all = jax.vmap(_creep_gamma_dot_single)(sol.ys) # Velocity from integrating shear rate: v(y) = integral(gamma_dot, dy) def _velocity_single(gd): return jnp.concatenate([jnp.array([0.0]), jnp.cumsum(gd[:-1]) * dy]) velocity_all = jax.vmap(_velocity_single)(gamma_dot_all) return { "t": t_out, "y": self.y, "mu_xy": mu_xy_profiles, "gamma_dot": np.asarray(gamma_dot_all), "velocity": np.asarray(velocity_all), "stress": np.full(len(t_out), sigma_0), }
# ========================================================================= # Banding Detection # =========================================================================
[docs] def detect_banding(self, result: dict, threshold: float = 0.1) -> dict: """Detect shear banding from steady-state profiles. Parameters ---------- result : dict Result from simulate_steady_shear() threshold : float Relative variation threshold for banding detection Returns ------- dict 'is_banding': bool 'band_contrast': max/min shear rate ratio 'band_width': approximate band width (m) 'band_location': center of high-shear band (m) """ gamma_dot_final = result["gamma_dot"][-1] mean_gd = np.mean(gamma_dot_final) std_gd = np.std(gamma_dot_final) relative_variation = std_gd / max(mean_gd, 1e-10) is_banding = relative_variation > threshold band_contrast = np.max(gamma_dot_final) / max(np.min(gamma_dot_final), 1e-10) # Find band location high_shear_mask = gamma_dot_final > mean_gd + std_gd if np.any(high_shear_mask): band_indices = np.where(high_shear_mask)[0] band_width = (band_indices[-1] - band_indices[0]) * self.dy band_location = self.y[band_indices].mean() else: band_width = self.gap_width band_location = self.gap_width / 2 return { "is_banding": is_banding, "relative_variation": relative_variation, "band_contrast": band_contrast, "band_width": band_width, "band_width_fraction": band_width / self.gap_width, "band_location": band_location, "gamma_dot_profile": gamma_dot_final, }
[docs] def get_velocity_profile(self, result: dict) -> np.ndarray: """Compute velocity profile from final shear rate profile. v(y) = integral_0^y gamma_dot(y') dy' Parameters ---------- result : dict Result from simulate_steady_shear() Returns ------- np.ndarray Velocity profile v(y) """ gamma_dot_final = result["gamma_dot"][-1] v = np.concatenate([[0.0], np.cumsum(gamma_dot_final[:-1]) * self.dy]) return v
# ========================================================================= # Visualization # =========================================================================
[docs] def plot_profiles(self, result: dict, ax=None): """Plot spatial profiles (shear rate and mu_xy). Parameters ---------- result : dict Result from simulate_steady_shear() ax : matplotlib axes, optional If None, creates new figure Returns ------- matplotlib figure """ import matplotlib.pyplot as plt if ax is None: fig, axes = plt.subplots(1, 3, figsize=(14, 4)) else: fig = ax[0].get_figure() axes = ax y_mm = self.y * 1e3 # Convert to mm # Shear rate profile axes[0].plot(y_mm, result["gamma_dot"][-1]) axes[0].set_xlabel("Position y (mm)") axes[0].set_ylabel("Shear rate (1/s)") axes[0].set_title("Shear Rate Profile") # mu_xy profile axes[1].plot(y_mm, result["mu_xy"][-1]) axes[1].set_xlabel("Position y (mm)") axes[1].set_ylabel(r"$\mu_{xy}$") axes[1].set_title("Distribution Tensor Profile") # Stress evolution axes[2].plot(result["t"], result["stress"]) axes[2].set_xlabel("Time (s)") axes[2].set_ylabel("Stress (Pa)") axes[2].set_title("Stress Evolution") plt.tight_layout() return fig
# ========================================================================= # Fit/Predict (minimal implementation) # ========================================================================= def _fit(self, x, y, **kwargs): """Fit is not supported for nonlocal models (use simulate methods).""" raise NotImplementedError( "VLBNonlocal does not support _fit(). Use simulate_steady_shear() " "or simulate_startup() for direct simulation." ) def _predict(self, X, **kwargs): """Predict is not directly supported for nonlocal models.""" raise NotImplementedError( "VLBNonlocal does not support _predict(). Use simulate_steady_shear() " "for flow curve predictions." ) # ========================================================================= # Repr # ========================================================================= def __repr__(self) -> str: return ( f"VLBNonlocal(breakage={self._breakage!r}, " f"stress={self._stress_type!r}, " f"n_points={self.n_points}, gap={self.gap_width:.1e}m)" )