Source code for rheojax.transforms.srfs

"""Strain-Rate Frequency Superposition (SRFS) transform.

This module implements SRFS for collapsing flow curves at different shear rates
onto a master curve, analogous to time-temperature superposition (TTS) but based
on shear rate rather than temperature.

SRFS is particularly useful for soft glassy materials where the SGR model predicts
a power-law relationship between shift factor and shear rate:
    a(gamma_dot) ~ (gamma_dot)^m
where m = (2 - x) depends on the noise temperature x.

Thixotropy kinetics and shear banding detection are also implemented for
complete characterization of complex flow behavior in soft glassy materials.

Physical Background:
    - SRFS exploits the fact that flow curves at different reference shear rates
      can be collapsed via horizontal shifting
    - For SGR materials, the shift factor has power-law form determined by x
    - Thixotropy arises from microstructure build-up (at rest) and breakdown (under shear)
    - Shear banding occurs when the constitutive curve becomes non-monotonic

References:
    - P. Sollich, Rheological constitutive equation for a model of soft glassy
      materials, Physical Review E, 1998, 58(1), 738-759
    - M. Wyss et al., Strain-rate frequency superposition: A rheological probe
      of structural relaxation in soft materials, Physical Review Letters, 2007
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

import numpy as np

from rheojax.core.base import BaseTransform
from rheojax.core.data import RheoData
from rheojax.core.inventory import TransformType
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import TransformRegistry
from rheojax.logging import get_logger

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()

# Module logger
logger = get_logger(__name__)

if TYPE_CHECKING:
    import jax.numpy as jnp_typing
else:  # pragma: no cover - typing fallback
    jnp_typing = np

type JaxArray = jnp_typing.ndarray
type ScalarOrArray = float | JaxArray


[docs] @TransformRegistry.register("srfs", type=TransformType.SUPERPOSITION) class SRFS(BaseTransform): """Strain-Rate Frequency Superposition (SRFS) transform. SRFS collapses flow curves measured at different shear rates onto a master curve by applying horizontal shift factors. This is analogous to time-temperature superposition (TTS) but uses shear rate rather than temperature. For SGR (Soft Glassy Rheology) materials, the shift factor follows: a(gamma_dot) = (gamma_dot / gamma_dot_ref)^m where m = (2 - x) and x is the noise temperature. Parameters ---------- reference_gamma_dot : float, default=1.0 Reference shear rate for the master curve (1/s) auto_shift : bool, default=False If True, automatically compute optimal shift factors from data overlap Attributes ---------- reference_gamma_dot : float Reference shear rate shift_factors_ : dict[float, float] or None Computed shift factors after transform Examples -------- >>> from rheojax.transforms.srfs import SRFS >>> from rheojax.core.data import RheoData >>> >>> # Create flow curve datasets at different reference shear rates >>> datasets = [ ... RheoData(x=gamma_dots_1, y=eta_1, metadata={'reference_gamma_dot': 0.1}), ... RheoData(x=gamma_dots_2, y=eta_2, metadata={'reference_gamma_dot': 1.0}), ... RheoData(x=gamma_dots_3, y=eta_3, metadata={'reference_gamma_dot': 10.0}), ... ] >>> >>> # Create SRFS transform >>> srfs = SRFS(reference_gamma_dot=1.0) >>> >>> # Apply SRFS shift (requires SGR parameters) >>> mastercurve, shift_factors = srfs.transform(datasets, x=1.5, tau0=1e-3) Notes ----- - Shift factors depend on SGR noise temperature x - For x < 1 (glass), shift behavior changes near yield stress - For x >= 2 (Newtonian), shift factor approaches 1 """
[docs] def __init__( self, reference_gamma_dot: float = 1.0, auto_shift: bool = False, ): """Initialize SRFS transform. Parameters ---------- reference_gamma_dot : float Reference shear rate for the master curve auto_shift : bool Whether to automatically compute optimal shift factors """ if reference_gamma_dot <= 0.0: raise ValueError( f"reference_gamma_dot must be positive, got {reference_gamma_dot}. " "The reference shear rate is used as a divisor in shift-factor " "computation and cannot be zero or negative." ) super().__init__() self.reference_gamma_dot = reference_gamma_dot self._auto_shift = auto_shift self.shift_factors_: dict[float, float] | None = None
[docs] def compute_shift_factor( self, gamma_dot: float, x: float, tau0: float, ) -> float: """Compute SRFS shift factor from SGR theory. For SGR materials, the shift factor follows a power-law: a(gamma_dot) = (gamma_dot / gamma_dot_ref)^m where m = (2 - x) for the power-law fluid regime (1 < x < 2). Parameters ---------- gamma_dot : float Shear rate to compute shift for (1/s) x : float SGR noise temperature (dimensionless) tau0 : float SGR attempt time (s) Returns ------- float Shift factor a(gamma_dot) Notes ----- - For x = 1.5, exponent m = 0.5 - For x = 2 (Newtonian), m = 0, shift factor = 1 - For x < 1 (glass), behavior near yield stress is different """ logger.debug( "Computing shift factor", gamma_dot=gamma_dot, x=x, tau0=tau0, reference_gamma_dot=self.reference_gamma_dot, ) # Compute shift exponent from SGR theory # In power-law regime: a ~ gamma_dot^(2-x) # This comes from the scaling of viscosity eta ~ gamma_dot^(x-2) # and the requirement that shifted curves collapse # Exponent for shift factor m = 2.0 - x # Handle special cases if abs(gamma_dot - self.reference_gamma_dot) < 1e-12: logger.debug("Shear rate equals reference, shift factor = 1.0") return 1.0 # Compute shift factor # a(gamma_dot) = (gamma_dot * tau0)^m / (gamma_dot_ref * tau0)^m # = (gamma_dot / gamma_dot_ref)^m ratio = gamma_dot / self.reference_gamma_dot a_gamma_dot = ratio**m logger.debug( "Shift factor computed", exponent_m=m, ratio=ratio, shift_factor=float(a_gamma_dot), ) return float(a_gamma_dot)
def _transform_single( self, data: RheoData, x: float, tau0: float, ) -> RheoData: """Apply SRFS shift to a single dataset. Parameters ---------- data : RheoData Single flow curve dataset x : float SGR noise temperature tau0 : float SGR attempt time Returns ------- RheoData Shifted dataset """ # Get reference shear rate from metadata _meta = data.metadata or {} if "reference_gamma_dot" not in _meta: logger.error( "Missing reference_gamma_dot in metadata", available_keys=list(_meta.keys()), ) raise ValueError( "reference_gamma_dot must be in metadata for SRFS shifting" ) gamma_dot_ref = _meta["reference_gamma_dot"] logger.debug( "Applying SRFS shift to single dataset", gamma_dot_ref=gamma_dot_ref, data_points=len(data.x), # type: ignore[arg-type] ) # Compute shift factor a_gamma_dot = self.compute_shift_factor(gamma_dot_ref, x, tau0) # Apply horizontal shift to shear rate axis x_shifted = jnp.asarray(data.x) * a_gamma_dot # Create shifted dataset new_metadata = _meta.copy() new_metadata.update( { "transform": "srfs", "reference_gamma_dot_master": self.reference_gamma_dot, "shift_factor": float(a_gamma_dot), "sgr_x": x, "sgr_tau0": tau0, } ) logger.debug( "Single dataset shifted", shift_factor=float(a_gamma_dot), original_x_range=(float(data.x[0]), float(data.x[-1])), # type: ignore[index] shifted_x_range=(float(x_shifted[0]), float(x_shifted[-1])), ) return RheoData( x=x_shifted, y=data.y, x_units=data.x_units, y_units=data.y_units, domain=data.domain, metadata=new_metadata, validate=False, ) def _transform( self, data: RheoData | list[RheoData], x: float | None = None, tau0: float | None = None, return_shifts: bool = False, ) -> RheoData | list[RheoData] | tuple[RheoData, dict[float, float]]: """Apply SRFS transformation. Parameters ---------- data : RheoData or list of RheoData Single dataset or list of datasets to transform x : float, optional SGR noise temperature (required if not using auto_shift) tau0 : float, optional SGR attempt time (required if not using auto_shift) return_shifts : bool, default=False If True, return shift factors dict along with mastercurve Returns ------- RheoData or tuple If data is single RheoData: shifted dataset If data is list and return_shifts=True: (mastercurve, shift_factors) If data is list and return_shifts=False: mastercurve """ is_list = not isinstance(data, RheoData) logger.info( "Starting SRFS transformation", is_list=is_list, n_datasets=len(data) if is_list else 1, # type: ignore[arg-type] reference_gamma_dot=self.reference_gamma_dot, sgr_x=x, sgr_tau0=tau0, ) # Handle single dataset if isinstance(data, RheoData): if x is None or tau0 is None: logger.error("Missing required SGR parameters for SRFS transformation") raise ValueError("x and tau0 are required for SRFS transformation") return self._transform_single(data, x, tau0) # Handle list of datasets if x is None or tau0 is None: logger.error("Missing required SGR parameters for SRFS transformation") raise ValueError("x and tau0 are required for SRFS transformation") return self.create_mastercurve(data, x, tau0, return_shifts=return_shifts)
[docs] def transform( self, data: RheoData | list[RheoData], x: float | None = None, tau0: float | None = None, return_shifts: bool = False, ) -> RheoData | list[RheoData] | tuple[RheoData, dict[float, float]]: """Apply SRFS transformation (public interface). Parameters ---------- data : RheoData or list of RheoData Single dataset or list of datasets to transform x : float, optional SGR noise temperature tau0 : float, optional SGR attempt time return_shifts : bool, default=False If True, return shift factors dict along with mastercurve Returns ------- RheoData or tuple Transformed data, optionally with shift factors """ return self._transform(data, x=x, tau0=tau0, return_shifts=return_shifts)
[docs] def create_mastercurve( self, datasets: list[RheoData], x: float, tau0: float, merge: bool = True, return_shifts: bool = False, ) -> RheoData | list[RheoData] | tuple[RheoData, dict[float, float]]: """Create SRFS master curve from multiple flow curve datasets. Parameters ---------- datasets : list of RheoData Flow curves at different reference shear rates x : float SGR noise temperature tau0 : float SGR attempt time merge : bool, default=True If True, merge all shifted data into single RheoData return_shifts : bool, default=False If True, return shift factors dict with mastercurve Returns ------- RheoData or list or tuple Master curve or list of shifted datasets, optionally with shifts """ logger.info( "Creating SRFS master curve", n_datasets=len(datasets), sgr_x=x, sgr_tau0=tau0, merge=merge, ) # Extract reference shear rates and sort ref_gamma_dots = [] for data in datasets: _dmeta = data.metadata or {} if "reference_gamma_dot" not in _dmeta: logger.error( "Missing reference_gamma_dot in dataset metadata", available_keys=list(_dmeta.keys()), ) raise ValueError( "All datasets must have 'reference_gamma_dot' in metadata" ) ref_gamma_dots.append(_dmeta["reference_gamma_dot"]) logger.debug( "Reference shear rates extracted", ref_gamma_dots=ref_gamma_dots, ) # Sort by reference shear rate sorted_indices = np.argsort(ref_gamma_dots) datasets = [datasets[i] for i in sorted_indices] ref_gamma_dots = [ref_gamma_dots[i] for i in sorted_indices] # Compute shift factors logger.debug("Computing shift factors for all datasets") shift_factors = {} for gamma_dot_ref in ref_gamma_dots: a_gamma_dot = self.compute_shift_factor(gamma_dot_ref, x, tau0) shift_factors[gamma_dot_ref] = a_gamma_dot logger.debug("Shift factors computed", shift_factors=shift_factors) # Apply shifts logger.debug("Applying shifts to all datasets") shifted_datasets = [] for data, _gamma_dot_ref in zip(datasets, ref_gamma_dots, strict=False): shifted = self._transform_single(data, x, tau0) shifted_datasets.append(shifted) # Store shift factors self.shift_factors_ = shift_factors if not merge: logger.info( "SRFS transformation completed (no merge)", n_shifted_datasets=len(shifted_datasets), ) return shifted_datasets # Merge all shifted data all_x = [] all_y = [] all_refs = [] for data, ref in zip(shifted_datasets, ref_gamma_dots, strict=False): x_data = np.asarray(data.x) y_data = np.asarray(data.y) all_x.append(x_data) all_y.append(y_data) all_refs.extend([ref] * len(x_data)) # Concatenate and sort merged_x = np.concatenate(all_x) merged_y = np.concatenate(all_y) merged_refs = np.array(all_refs) sort_idx = np.argsort(merged_x) merged_x = merged_x[sort_idx] merged_y = merged_y[sort_idx] merged_refs = merged_refs[sort_idx] # Create mastercurve mastercurve_metadata = { "transform": "srfs", "reference_gamma_dot": self.reference_gamma_dot, "source_gamma_dots": ref_gamma_dots, "n_datasets": len(datasets), "source_refs": merged_refs, "shift_factors": shift_factors, "sgr_x": x, "sgr_tau0": tau0, } mastercurve = RheoData( x=merged_x, y=merged_y, x_units=datasets[0].x_units if datasets else None, y_units=datasets[0].y_units if datasets else None, domain=datasets[0].domain if datasets else "shear_rate", metadata=mastercurve_metadata, validate=False, ) logger.info( "SRFS master curve created", total_points=len(merged_x), n_datasets=len(datasets), x_range=(float(merged_x[0]), float(merged_x[-1])), ) if return_shifts: return mastercurve, shift_factors return mastercurve
[docs] def get_shift_factors_array( self, gamma_dots: list[float] | np.ndarray | None = None, x: float | None = None, tau0: float | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Get shift factors as arrays for plotting. Parameters ---------- gamma_dots : list or ndarray, optional Shear rates to compute shifts for. If None, uses stored values. x : float, optional SGR noise temperature (required if computing new shifts) tau0 : float, optional SGR attempt time (required if computing new shifts) Returns ------- gamma_dots : ndarray Array of shear rates (sorted) shift_factors : ndarray Array of corresponding shift factors """ if gamma_dots is None: if self.shift_factors_ is None: raise ValueError( "No shift factors available. Either provide gamma_dots or " "create a mastercurve first." ) gamma_dots_arr = np.array(sorted(self.shift_factors_.keys())) shifts_arr = np.array([self.shift_factors_[gd] for gd in gamma_dots_arr]) else: if x is None or tau0 is None: raise ValueError("x and tau0 required to compute shift factors") gamma_dots_arr = np.array(gamma_dots) sort_idx = np.argsort(gamma_dots_arr) gamma_dots_arr = gamma_dots_arr[sort_idx] shifts_arr = np.array( [self.compute_shift_factor(float(gd), x, tau0) for gd in gamma_dots_arr] ) return gamma_dots_arr, shifts_arr
# ============================================================================ # Shear Banding Detection Functions # ============================================================================ def detect_shear_banding( gamma_dot: np.ndarray, sigma: np.ndarray, warn: bool = False, threshold: float = -0.01, ) -> tuple[bool, dict | None]: """Detect shear banding from non-monotonic constitutive curve. Shear banding occurs when the derivative d(sigma)/d(gamma_dot) < 0, indicating a region of mechanical instability where the material splits into bands with different local shear rates. Parameters ---------- gamma_dot : ndarray Shear rate array (1/s) sigma : ndarray Stress array (Pa) warn : bool, default=False If True, issue a warning when shear banding is detected threshold : float, default=-0.01 Threshold for detecting negative slope (allows for numerical noise) Returns ------- is_banding : bool True if shear banding is detected banding_info : dict or None Information about the banding region if detected: - 'gamma_dot_low': Lower shear rate of banding region - 'gamma_dot_high': Upper shear rate of banding region - 'sigma_range': Stress range in banding region - 'negative_slope_fraction': Fraction of curve with negative slope Examples -------- >>> gamma_dot = np.logspace(-2, 2, 100) >>> sigma = gamma_dot ** 0.5 # Monotonic power-law >>> is_banding, info = detect_shear_banding(gamma_dot, sigma) >>> print(is_banding) # False >>> # Non-monotonic curve >>> sigma_nm = sigma * (1 - 0.3 * np.exp(-((gamma_dot - 1)**2) / 0.1)) >>> is_banding, info = detect_shear_banding(gamma_dot, sigma_nm) >>> print(is_banding) # True """ logger.debug( "Detecting shear banding", n_points=len(gamma_dot), threshold=threshold, ) if len(gamma_dot) < 2: logger.debug("Insufficient data for banding detection (need >= 2 points)") return False, None # Sort by shear rate sort_idx = np.argsort(gamma_dot) gamma_dot = gamma_dot[sort_idx] sigma = sigma[sort_idx] # Compute derivative d(sigma)/d(gamma_dot) using finite differences d_sigma = np.diff(sigma) d_gamma_dot = np.diff(gamma_dot) # Avoid division by zero d_gamma_dot = np.maximum(d_gamma_dot, 1e-20) derivative = d_sigma / d_gamma_dot # Detect regions with negative slope negative_slope_mask = derivative < threshold # Check if any negative slope regions exist is_banding = np.any(negative_slope_mask) if not is_banding: logger.debug("No shear banding detected (monotonic flow curve)") return False, None # Find the banding region bounds negative_indices = np.where(negative_slope_mask)[0] if len(negative_indices) == 0: return False, None # Get bounds of non-monotonic region first_neg_idx = negative_indices[0] last_neg_idx = negative_indices[-1] gamma_dot_low = gamma_dot[first_neg_idx] gamma_dot_high = gamma_dot[min(last_neg_idx + 1, len(gamma_dot) - 1)] # Get stress range in banding region sigma_low = sigma[first_neg_idx] sigma_high = sigma[min(last_neg_idx + 1, len(sigma) - 1)] # Compute fraction of curve with negative slope neg_fraction = np.sum(negative_slope_mask) / len(derivative) banding_info = { "gamma_dot_low": float(gamma_dot_low), "gamma_dot_high": float(gamma_dot_high), "sigma_low": float(sigma_low), "sigma_high": float(sigma_high), "sigma_range": ( float(min(sigma_low, sigma_high)), float(max(sigma_low, sigma_high)), ), "negative_slope_fraction": float(neg_fraction), } logger.info( "Shear banding detected", gamma_dot_low=float(gamma_dot_low), gamma_dot_high=float(gamma_dot_high), negative_slope_fraction=float(neg_fraction), ) if warn: warnings.warn( f"Shear banding detected in flow curve. " f"Non-monotonic region: gamma_dot = [{gamma_dot_low:.3g}, {gamma_dot_high:.3g}] 1/s. " f"This may indicate mechanical instability.", UserWarning, stacklevel=2, ) return True, banding_info def compute_shear_band_coexistence( gamma_dot: np.ndarray, sigma: np.ndarray, gamma_dot_applied: float, ) -> dict | None: """Compute shear band coexistence using lever rule. When shear banding occurs, the material splits into bands with different local shear rates (gamma_dot_low and gamma_dot_high) that coexist at a common stress plateau. The fraction of each band is determined by the lever rule from the applied average shear rate. Parameters ---------- gamma_dot : ndarray Shear rate array (1/s) sigma : ndarray Stress array (Pa) gamma_dot_applied : float Applied (average) shear rate (1/s) Returns ------- coexistence : dict or None Coexistence information if banding detected: - 'gamma_dot_low': Shear rate in low-shear band - 'gamma_dot_high': Shear rate in high-shear band - 'fraction_low': Volume fraction of low-shear band - 'fraction_high': Volume fraction of high-shear band - 'stress_plateau': Common stress in banding regime Returns None if no banding or applied rate outside banding region. Notes ----- The lever rule states: gamma_dot_applied = f_low * gamma_dot_low + f_high * gamma_dot_high where f_low + f_high = 1. The stress plateau is found by equal area construction (Maxwell rule) or by finding the stress at which both bands coexist stably. """ logger.debug( "Computing shear band coexistence", gamma_dot_applied=gamma_dot_applied, n_points=len(gamma_dot), ) # First detect if banding exists is_banding, banding_info = detect_shear_banding(gamma_dot, sigma) if not is_banding or banding_info is None: logger.debug("No shear banding detected, cannot compute coexistence") return None # Get banding region bounds gamma_dot_low_bound = banding_info["gamma_dot_low"] gamma_dot_high_bound = banding_info["gamma_dot_high"] # Check if applied shear rate is in banding region if ( gamma_dot_applied < gamma_dot_low_bound or gamma_dot_applied > gamma_dot_high_bound ): logger.debug( "Applied shear rate outside banding region", gamma_dot_applied=gamma_dot_applied, banding_region=(gamma_dot_low_bound, gamma_dot_high_bound), ) return None # Find stress plateau using simplified approach # (In practice, would use equal area Maxwell construction) # Sort data sort_idx = np.argsort(gamma_dot) gamma_dot_sorted = gamma_dot[sort_idx] sigma_sorted = sigma[sort_idx] # Find indices bounding the banding region low_idx = np.searchsorted(gamma_dot_sorted, gamma_dot_low_bound) high_idx = np.searchsorted(gamma_dot_sorted, gamma_dot_high_bound) # Estimate stress plateau as average in banding region banding_slice = sigma_sorted[low_idx : high_idx + 1] if len(banding_slice) == 0: return None stress_plateau = np.mean(banding_slice) # Find coexisting shear rates at stress plateau # These are the intersections of horizontal line at stress_plateau # with the constitutive curve (on the stable branches) # Left branch (before banding onset) left_mask = gamma_dot_sorted < gamma_dot_low_bound if np.any(left_mask): gamma_dot_left = gamma_dot_sorted[left_mask] sigma_left = sigma_sorted[left_mask] # Interpolate to find gamma_dot at stress_plateau if len(gamma_dot_left) > 1: gamma_dot_low = np.interp(stress_plateau, sigma_left, gamma_dot_left) else: gamma_dot_low = gamma_dot_low_bound else: gamma_dot_low = gamma_dot_low_bound # Right branch (after banding ends) right_mask = gamma_dot_sorted > gamma_dot_high_bound if np.any(right_mask): gamma_dot_right = gamma_dot_sorted[right_mask] sigma_right = sigma_sorted[right_mask] # Interpolate if len(gamma_dot_right) > 1: gamma_dot_high = np.interp(stress_plateau, sigma_right, gamma_dot_right) else: gamma_dot_high = gamma_dot_high_bound else: gamma_dot_high = gamma_dot_high_bound # Lever rule for band fractions # gamma_dot_applied = f_low * gamma_dot_low + (1 - f_low) * gamma_dot_high # f_low = (gamma_dot_high - gamma_dot_applied) / (gamma_dot_high - gamma_dot_low) delta_gamma = gamma_dot_high - gamma_dot_low if abs(delta_gamma) < 1e-12: return None f_low = (gamma_dot_high - gamma_dot_applied) / delta_gamma f_high = 1.0 - f_low # Clamp fractions to [0, 1] f_low = np.clip(f_low, 0, 1) f_high = np.clip(f_high, 0, 1) logger.info( "Shear band coexistence computed", gamma_dot_low=float(gamma_dot_low), gamma_dot_high=float(gamma_dot_high), fraction_low=float(f_low), fraction_high=float(f_high), stress_plateau=float(stress_plateau), ) return { "gamma_dot_low": float(gamma_dot_low), "gamma_dot_high": float(gamma_dot_high), "fraction_low": float(f_low), "fraction_high": float(f_high), "stress_plateau": float(stress_plateau), } # ============================================================================ # Thixotropy Kinetics Functions # ============================================================================ @jax.jit def thixotropy_lambda_derivative( lambda_val: float, gamma_dot: float, k_build: float, k_break: float, ) -> float: """Compute time derivative of structural parameter lambda. The structural parameter lambda represents the state of internal microstructure, with lambda = 1 being fully built and lambda = 0 being fully broken. Evolution equation: d(lambda)/dt = k_build * (1 - lambda) - k_break * gamma_dot * lambda Parameters ---------- lambda_val : float Current structural parameter value [0, 1] gamma_dot : float Current shear rate (1/s) k_build : float Structure build-up rate (1/s) k_break : float Structure breakdown rate (dimensionless) Returns ------- float Time derivative d(lambda)/dt """ # Build-up term: drives lambda toward 1 at rest build_up = k_build * (1.0 - lambda_val) # Breakdown term: shear destroys structure breakdown = k_break * gamma_dot * lambda_val return build_up - breakdown @jax.jit def _thixotropy_scan_step( lambda_prev: float, inputs: tuple[float, float, float, float], ) -> tuple[float, float]: """Single step of thixotropy evolution for jax.lax.scan. This is JIT-compiled and fused into a single kernel when used with scan, eliminating per-iteration Python overhead. Parameters ---------- lambda_prev : float Previous structural parameter value inputs : tuple (gamma_dot_i, dt_i, k_build, k_break) for this timestep Returns ------- tuple (lambda_new, lambda_new) - carry and output are the same """ gamma_dot_i, dt_i, k_build, k_break = inputs # Compute derivative using inlined logic (avoids function call overhead) build_up = k_build * (1.0 - lambda_prev) breakdown = k_break * gamma_dot_i * lambda_prev dlambda_dt = build_up - breakdown # Euler step with clamping lambda_new = lambda_prev + dlambda_dt * dt_i lambda_new = jnp.clip(lambda_new, 0.0, 1.0) return lambda_new, lambda_new def evolve_thixotropy_lambda( t: np.ndarray, gamma_dot: np.ndarray, lambda_initial: float, k_build: float, k_break: float, ) -> np.ndarray: """Evolve structural parameter lambda(t) for given shear history. Integrates the thixotropy kinetics equation: d(lambda)/dt = k_build * (1 - lambda) - k_break * gamma_dot * lambda Uses JAX's lax.scan for efficient vectorized integration, compiling the entire loop into a single fused kernel. This provides 2-5x speedup over Python loops by eliminating per-iteration dispatch overhead. Parameters ---------- t : ndarray Time array (s) gamma_dot : ndarray Shear rate array (1/s), same shape as t lambda_initial : float Initial structural parameter [0, 1] k_build : float Structure build-up rate (1/s) k_break : float Structure breakdown rate (dimensionless) Returns ------- lambda_t : ndarray Structural parameter evolution, same shape as t """ logger.debug( "Evolving thixotropy lambda", n_points=len(t), lambda_initial=lambda_initial, k_build=k_build, k_break=k_break, ) # T-24: Forward Euler stability check for thixotropy ODE. # dt must be < 2 / max_eigenvalue to avoid oscillatory blow-up. if len(t) > 1: dt_arr = np.diff(t) max_eigenvalue = k_build + k_break * np.max(np.abs(gamma_dot)) max_stable_dt = 2.0 / max(max_eigenvalue, 1e-30) max_dt = float(np.max(dt_arr)) if max_dt > max_stable_dt: warnings.warn( f"Forward Euler may be unstable for thixotropy ODE: " f"max(dt)={max_dt:.3g} > stability limit={max_stable_dt:.3g}. " f"Consider using finer time steps or an implicit integrator.", stacklevel=2, ) if t.shape != gamma_dot.shape: logger.error( "Shape mismatch between time and shear rate arrays", t_shape=t.shape, gamma_dot_shape=gamma_dot.shape, ) raise ValueError( f"Time and shear rate arrays must have same shape: " f"t.shape={t.shape}, gamma_dot.shape={gamma_dot.shape}" ) # Convert to JAX arrays for scan t_jax = jnp.asarray(t, dtype=jnp.float64) gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64) # Compute time steps (dt[0] is not used, but we need consistent shapes) dt = jnp.diff(t_jax) # Prepare inputs for scan: (gamma_dot[1:], dt, k_build, k_break) # We broadcast k_build and k_break to match the sequence length n_steps = len(dt) k_build_arr = jnp.full(n_steps, k_build, dtype=jnp.float64) k_break_arr = jnp.full(n_steps, k_break, dtype=jnp.float64) # Stack inputs for scan: each element is (gamma_dot_i, dt_i, k_build, k_break) scan_inputs = (gamma_dot_jax[1:], dt, k_build_arr, k_break_arr) # Run vectorized integration using lax.scan # This compiles the entire loop into a single fused kernel _, lambda_history = jax.lax.scan( _thixotropy_scan_step, jnp.float64(lambda_initial), # Initial carry scan_inputs, # Sequence of inputs ) # Prepend initial value to get full history lambda_t = jnp.concatenate([jnp.array([lambda_initial]), lambda_history]) # Convert back to numpy for compatibility lambda_t_np = np.asarray(lambda_t, dtype=np.float64) logger.debug( "Thixotropy evolution completed", lambda_final=float(lambda_t_np[-1]), lambda_min=float(np.min(lambda_t_np)), lambda_max=float(np.max(lambda_t_np)), ) return lambda_t_np def compute_thixotropic_stress( t: np.ndarray, gamma_dot: np.ndarray, lambda_t: np.ndarray, G0: float, tau0: float, x: float, n_struct: float = 2.0, ) -> np.ndarray: """Compute stress response with thixotropic modulus. The effective modulus is coupled to the structural parameter: G_eff(t) = G0 * lambda(t)^n_struct Parameters ---------- t : ndarray Time array (s) gamma_dot : ndarray Shear rate array (1/s) lambda_t : ndarray Structural parameter array [0, 1] G0 : float Base modulus scale (Pa) tau0 : float Attempt time (s) x : float SGR noise temperature n_struct : float, default=2.0 Structural coupling exponent Returns ------- sigma : ndarray Stress response (Pa) """ logger.debug( "Computing thixotropic stress", n_points=len(t), G0=G0, tau0=tau0, x=x, n_struct=n_struct, ) # Effective modulus from structure G_eff = G0 * np.power(lambda_t, n_struct) # Viscosity from power-law (SGR-like) gamma_dot_safe = np.maximum(np.abs(gamma_dot), 1e-12) exponent = np.clip(x - 2.0, -10.0, 10.0) eta_factor = np.power(gamma_dot_safe * tau0, exponent) eta_factor = np.clip(eta_factor, 1e-30, 1e30) # Stress = G_eff * gamma_dot * tau0 * eta_factor sigma = G_eff * gamma_dot * tau0 * eta_factor logger.debug( "Thixotropic stress computed", sigma_min=float(np.min(sigma)), sigma_max=float(np.max(sigma)), ) return sigma __all__ = [ "SRFS", "detect_shear_banding", "compute_shear_band_coexistence", "thixotropy_lambda_derivative", "evolve_thixotropy_lambda", "compute_thixotropic_stress", ]