"""Strain-Rate Frequency Superposition (SRFS) transform.
This module implements SRFS for collapsing flow curves at different shear rates
onto a master curve, analogous to time-temperature superposition (TTS) but based
on shear rate rather than temperature.
SRFS is particularly useful for soft glassy materials where the SGR model predicts
a power-law relationship between shift factor and shear rate:
a(gamma_dot) ~ (gamma_dot)^m
where m = (2 - x) depends on the noise temperature x.
Thixotropy kinetics and shear banding detection are also implemented for
complete characterization of complex flow behavior in soft glassy materials.
Physical Background:
- SRFS exploits the fact that flow curves at different reference shear rates
can be collapsed via horizontal shifting
- For SGR materials, the shift factor has power-law form determined by x
- Thixotropy arises from microstructure build-up (at rest) and breakdown (under shear)
- Shear banding occurs when the constitutive curve becomes non-monotonic
References:
- P. Sollich, Rheological constitutive equation for a model of soft glassy
materials, Physical Review E, 1998, 58(1), 738-759
- M. Wyss et al., Strain-rate frequency superposition: A rheological probe
of structural relaxation in soft materials, Physical Review Letters, 2007
"""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
import numpy as np
from rheojax.core.base import BaseTransform
from rheojax.core.data import RheoData
from rheojax.core.inventory import TransformType
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import TransformRegistry
from rheojax.logging import get_logger
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
# Module logger
logger = get_logger(__name__)
if TYPE_CHECKING:
import jax.numpy as jnp_typing
else: # pragma: no cover - typing fallback
jnp_typing = np
type JaxArray = jnp_typing.ndarray
type ScalarOrArray = float | JaxArray
[docs]
@TransformRegistry.register("srfs", type=TransformType.SUPERPOSITION)
class SRFS(BaseTransform):
"""Strain-Rate Frequency Superposition (SRFS) transform.
SRFS collapses flow curves measured at different shear rates onto a master
curve by applying horizontal shift factors. This is analogous to time-temperature
superposition (TTS) but uses shear rate rather than temperature.
For SGR (Soft Glassy Rheology) materials, the shift factor follows:
a(gamma_dot) = (gamma_dot / gamma_dot_ref)^m
where m = (2 - x) and x is the noise temperature.
Parameters
----------
reference_gamma_dot : float, default=1.0
Reference shear rate for the master curve (1/s)
auto_shift : bool, default=False
If True, automatically compute optimal shift factors from data overlap
Attributes
----------
reference_gamma_dot : float
Reference shear rate
shift_factors_ : dict[float, float] or None
Computed shift factors after transform
Examples
--------
>>> from rheojax.transforms.srfs import SRFS
>>> from rheojax.core.data import RheoData
>>>
>>> # Create flow curve datasets at different reference shear rates
>>> datasets = [
... RheoData(x=gamma_dots_1, y=eta_1, metadata={'reference_gamma_dot': 0.1}),
... RheoData(x=gamma_dots_2, y=eta_2, metadata={'reference_gamma_dot': 1.0}),
... RheoData(x=gamma_dots_3, y=eta_3, metadata={'reference_gamma_dot': 10.0}),
... ]
>>>
>>> # Create SRFS transform
>>> srfs = SRFS(reference_gamma_dot=1.0)
>>>
>>> # Apply SRFS shift (requires SGR parameters)
>>> mastercurve, shift_factors = srfs.transform(datasets, x=1.5, tau0=1e-3)
Notes
-----
- Shift factors depend on SGR noise temperature x
- For x < 1 (glass), shift behavior changes near yield stress
- For x >= 2 (Newtonian), shift factor approaches 1
"""
[docs]
def __init__(
self,
reference_gamma_dot: float = 1.0,
auto_shift: bool = False,
):
"""Initialize SRFS transform.
Parameters
----------
reference_gamma_dot : float
Reference shear rate for the master curve
auto_shift : bool
Whether to automatically compute optimal shift factors
"""
if reference_gamma_dot <= 0.0:
raise ValueError(
f"reference_gamma_dot must be positive, got {reference_gamma_dot}. "
"The reference shear rate is used as a divisor in shift-factor "
"computation and cannot be zero or negative."
)
super().__init__()
self.reference_gamma_dot = reference_gamma_dot
self._auto_shift = auto_shift
self.shift_factors_: dict[float, float] | None = None
[docs]
def compute_shift_factor(
self,
gamma_dot: float,
x: float,
tau0: float,
) -> float:
"""Compute SRFS shift factor from SGR theory.
For SGR materials, the shift factor follows a power-law:
a(gamma_dot) = (gamma_dot / gamma_dot_ref)^m
where m = (2 - x) for the power-law fluid regime (1 < x < 2).
Parameters
----------
gamma_dot : float
Shear rate to compute shift for (1/s)
x : float
SGR noise temperature (dimensionless)
tau0 : float
SGR attempt time (s)
Returns
-------
float
Shift factor a(gamma_dot)
Notes
-----
- For x = 1.5, exponent m = 0.5
- For x = 2 (Newtonian), m = 0, shift factor = 1
- For x < 1 (glass), behavior near yield stress is different
"""
logger.debug(
"Computing shift factor",
gamma_dot=gamma_dot,
x=x,
tau0=tau0,
reference_gamma_dot=self.reference_gamma_dot,
)
# Compute shift exponent from SGR theory
# In power-law regime: a ~ gamma_dot^(2-x)
# This comes from the scaling of viscosity eta ~ gamma_dot^(x-2)
# and the requirement that shifted curves collapse
# Exponent for shift factor
m = 2.0 - x
# Handle special cases
if abs(gamma_dot - self.reference_gamma_dot) < 1e-12:
logger.debug("Shear rate equals reference, shift factor = 1.0")
return 1.0
# Compute shift factor
# a(gamma_dot) = (gamma_dot * tau0)^m / (gamma_dot_ref * tau0)^m
# = (gamma_dot / gamma_dot_ref)^m
ratio = gamma_dot / self.reference_gamma_dot
a_gamma_dot = ratio**m
logger.debug(
"Shift factor computed",
exponent_m=m,
ratio=ratio,
shift_factor=float(a_gamma_dot),
)
return float(a_gamma_dot)
def _transform_single(
self,
data: RheoData,
x: float,
tau0: float,
) -> RheoData:
"""Apply SRFS shift to a single dataset.
Parameters
----------
data : RheoData
Single flow curve dataset
x : float
SGR noise temperature
tau0 : float
SGR attempt time
Returns
-------
RheoData
Shifted dataset
"""
# Get reference shear rate from metadata
_meta = data.metadata or {}
if "reference_gamma_dot" not in _meta:
logger.error(
"Missing reference_gamma_dot in metadata",
available_keys=list(_meta.keys()),
)
raise ValueError(
"reference_gamma_dot must be in metadata for SRFS shifting"
)
gamma_dot_ref = _meta["reference_gamma_dot"]
logger.debug(
"Applying SRFS shift to single dataset",
gamma_dot_ref=gamma_dot_ref,
data_points=len(data.x), # type: ignore[arg-type]
)
# Compute shift factor
a_gamma_dot = self.compute_shift_factor(gamma_dot_ref, x, tau0)
# Apply horizontal shift to shear rate axis
x_shifted = jnp.asarray(data.x) * a_gamma_dot
# Create shifted dataset
new_metadata = _meta.copy()
new_metadata.update(
{
"transform": "srfs",
"reference_gamma_dot_master": self.reference_gamma_dot,
"shift_factor": float(a_gamma_dot),
"sgr_x": x,
"sgr_tau0": tau0,
}
)
logger.debug(
"Single dataset shifted",
shift_factor=float(a_gamma_dot),
original_x_range=(float(data.x[0]), float(data.x[-1])), # type: ignore[index]
shifted_x_range=(float(x_shifted[0]), float(x_shifted[-1])),
)
return RheoData(
x=x_shifted,
y=data.y,
x_units=data.x_units,
y_units=data.y_units,
domain=data.domain,
metadata=new_metadata,
validate=False,
)
def _transform(
self,
data: RheoData | list[RheoData],
x: float | None = None,
tau0: float | None = None,
return_shifts: bool = False,
) -> RheoData | list[RheoData] | tuple[RheoData, dict[float, float]]:
"""Apply SRFS transformation.
Parameters
----------
data : RheoData or list of RheoData
Single dataset or list of datasets to transform
x : float, optional
SGR noise temperature (required if not using auto_shift)
tau0 : float, optional
SGR attempt time (required if not using auto_shift)
return_shifts : bool, default=False
If True, return shift factors dict along with mastercurve
Returns
-------
RheoData or tuple
If data is single RheoData: shifted dataset
If data is list and return_shifts=True: (mastercurve, shift_factors)
If data is list and return_shifts=False: mastercurve
"""
is_list = not isinstance(data, RheoData)
logger.info(
"Starting SRFS transformation",
is_list=is_list,
n_datasets=len(data) if is_list else 1, # type: ignore[arg-type]
reference_gamma_dot=self.reference_gamma_dot,
sgr_x=x,
sgr_tau0=tau0,
)
# Handle single dataset
if isinstance(data, RheoData):
if x is None or tau0 is None:
logger.error("Missing required SGR parameters for SRFS transformation")
raise ValueError("x and tau0 are required for SRFS transformation")
return self._transform_single(data, x, tau0)
# Handle list of datasets
if x is None or tau0 is None:
logger.error("Missing required SGR parameters for SRFS transformation")
raise ValueError("x and tau0 are required for SRFS transformation")
return self.create_mastercurve(data, x, tau0, return_shifts=return_shifts)
[docs]
def create_mastercurve(
self,
datasets: list[RheoData],
x: float,
tau0: float,
merge: bool = True,
return_shifts: bool = False,
) -> RheoData | list[RheoData] | tuple[RheoData, dict[float, float]]:
"""Create SRFS master curve from multiple flow curve datasets.
Parameters
----------
datasets : list of RheoData
Flow curves at different reference shear rates
x : float
SGR noise temperature
tau0 : float
SGR attempt time
merge : bool, default=True
If True, merge all shifted data into single RheoData
return_shifts : bool, default=False
If True, return shift factors dict with mastercurve
Returns
-------
RheoData or list or tuple
Master curve or list of shifted datasets, optionally with shifts
"""
logger.info(
"Creating SRFS master curve",
n_datasets=len(datasets),
sgr_x=x,
sgr_tau0=tau0,
merge=merge,
)
# Extract reference shear rates and sort
ref_gamma_dots = []
for data in datasets:
_dmeta = data.metadata or {}
if "reference_gamma_dot" not in _dmeta:
logger.error(
"Missing reference_gamma_dot in dataset metadata",
available_keys=list(_dmeta.keys()),
)
raise ValueError(
"All datasets must have 'reference_gamma_dot' in metadata"
)
ref_gamma_dots.append(_dmeta["reference_gamma_dot"])
logger.debug(
"Reference shear rates extracted",
ref_gamma_dots=ref_gamma_dots,
)
# Sort by reference shear rate
sorted_indices = np.argsort(ref_gamma_dots)
datasets = [datasets[i] for i in sorted_indices]
ref_gamma_dots = [ref_gamma_dots[i] for i in sorted_indices]
# Compute shift factors
logger.debug("Computing shift factors for all datasets")
shift_factors = {}
for gamma_dot_ref in ref_gamma_dots:
a_gamma_dot = self.compute_shift_factor(gamma_dot_ref, x, tau0)
shift_factors[gamma_dot_ref] = a_gamma_dot
logger.debug("Shift factors computed", shift_factors=shift_factors)
# Apply shifts
logger.debug("Applying shifts to all datasets")
shifted_datasets = []
for data, _gamma_dot_ref in zip(datasets, ref_gamma_dots, strict=False):
shifted = self._transform_single(data, x, tau0)
shifted_datasets.append(shifted)
# Store shift factors
self.shift_factors_ = shift_factors
if not merge:
logger.info(
"SRFS transformation completed (no merge)",
n_shifted_datasets=len(shifted_datasets),
)
return shifted_datasets
# Merge all shifted data
all_x = []
all_y = []
all_refs = []
for data, ref in zip(shifted_datasets, ref_gamma_dots, strict=False):
x_data = np.asarray(data.x)
y_data = np.asarray(data.y)
all_x.append(x_data)
all_y.append(y_data)
all_refs.extend([ref] * len(x_data))
# Concatenate and sort
merged_x = np.concatenate(all_x)
merged_y = np.concatenate(all_y)
merged_refs = np.array(all_refs)
sort_idx = np.argsort(merged_x)
merged_x = merged_x[sort_idx]
merged_y = merged_y[sort_idx]
merged_refs = merged_refs[sort_idx]
# Create mastercurve
mastercurve_metadata = {
"transform": "srfs",
"reference_gamma_dot": self.reference_gamma_dot,
"source_gamma_dots": ref_gamma_dots,
"n_datasets": len(datasets),
"source_refs": merged_refs,
"shift_factors": shift_factors,
"sgr_x": x,
"sgr_tau0": tau0,
}
mastercurve = RheoData(
x=merged_x,
y=merged_y,
x_units=datasets[0].x_units if datasets else None,
y_units=datasets[0].y_units if datasets else None,
domain=datasets[0].domain if datasets else "shear_rate",
metadata=mastercurve_metadata,
validate=False,
)
logger.info(
"SRFS master curve created",
total_points=len(merged_x),
n_datasets=len(datasets),
x_range=(float(merged_x[0]), float(merged_x[-1])),
)
if return_shifts:
return mastercurve, shift_factors
return mastercurve
[docs]
def get_shift_factors_array(
self,
gamma_dots: list[float] | np.ndarray | None = None,
x: float | None = None,
tau0: float | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Get shift factors as arrays for plotting.
Parameters
----------
gamma_dots : list or ndarray, optional
Shear rates to compute shifts for. If None, uses stored values.
x : float, optional
SGR noise temperature (required if computing new shifts)
tau0 : float, optional
SGR attempt time (required if computing new shifts)
Returns
-------
gamma_dots : ndarray
Array of shear rates (sorted)
shift_factors : ndarray
Array of corresponding shift factors
"""
if gamma_dots is None:
if self.shift_factors_ is None:
raise ValueError(
"No shift factors available. Either provide gamma_dots or "
"create a mastercurve first."
)
gamma_dots_arr = np.array(sorted(self.shift_factors_.keys()))
shifts_arr = np.array([self.shift_factors_[gd] for gd in gamma_dots_arr])
else:
if x is None or tau0 is None:
raise ValueError("x and tau0 required to compute shift factors")
gamma_dots_arr = np.array(gamma_dots)
sort_idx = np.argsort(gamma_dots_arr)
gamma_dots_arr = gamma_dots_arr[sort_idx]
shifts_arr = np.array(
[self.compute_shift_factor(float(gd), x, tau0) for gd in gamma_dots_arr]
)
return gamma_dots_arr, shifts_arr
# ============================================================================
# Shear Banding Detection Functions
# ============================================================================
def detect_shear_banding(
gamma_dot: np.ndarray,
sigma: np.ndarray,
warn: bool = False,
threshold: float = -0.01,
) -> tuple[bool, dict | None]:
"""Detect shear banding from non-monotonic constitutive curve.
Shear banding occurs when the derivative d(sigma)/d(gamma_dot) < 0,
indicating a region of mechanical instability where the material
splits into bands with different local shear rates.
Parameters
----------
gamma_dot : ndarray
Shear rate array (1/s)
sigma : ndarray
Stress array (Pa)
warn : bool, default=False
If True, issue a warning when shear banding is detected
threshold : float, default=-0.01
Threshold for detecting negative slope (allows for numerical noise)
Returns
-------
is_banding : bool
True if shear banding is detected
banding_info : dict or None
Information about the banding region if detected:
- 'gamma_dot_low': Lower shear rate of banding region
- 'gamma_dot_high': Upper shear rate of banding region
- 'sigma_range': Stress range in banding region
- 'negative_slope_fraction': Fraction of curve with negative slope
Examples
--------
>>> gamma_dot = np.logspace(-2, 2, 100)
>>> sigma = gamma_dot ** 0.5 # Monotonic power-law
>>> is_banding, info = detect_shear_banding(gamma_dot, sigma)
>>> print(is_banding) # False
>>> # Non-monotonic curve
>>> sigma_nm = sigma * (1 - 0.3 * np.exp(-((gamma_dot - 1)**2) / 0.1))
>>> is_banding, info = detect_shear_banding(gamma_dot, sigma_nm)
>>> print(is_banding) # True
"""
logger.debug(
"Detecting shear banding",
n_points=len(gamma_dot),
threshold=threshold,
)
if len(gamma_dot) < 2:
logger.debug("Insufficient data for banding detection (need >= 2 points)")
return False, None
# Sort by shear rate
sort_idx = np.argsort(gamma_dot)
gamma_dot = gamma_dot[sort_idx]
sigma = sigma[sort_idx]
# Compute derivative d(sigma)/d(gamma_dot) using finite differences
d_sigma = np.diff(sigma)
d_gamma_dot = np.diff(gamma_dot)
# Avoid division by zero
d_gamma_dot = np.maximum(d_gamma_dot, 1e-20)
derivative = d_sigma / d_gamma_dot
# Detect regions with negative slope
negative_slope_mask = derivative < threshold
# Check if any negative slope regions exist
is_banding = np.any(negative_slope_mask)
if not is_banding:
logger.debug("No shear banding detected (monotonic flow curve)")
return False, None
# Find the banding region bounds
negative_indices = np.where(negative_slope_mask)[0]
if len(negative_indices) == 0:
return False, None
# Get bounds of non-monotonic region
first_neg_idx = negative_indices[0]
last_neg_idx = negative_indices[-1]
gamma_dot_low = gamma_dot[first_neg_idx]
gamma_dot_high = gamma_dot[min(last_neg_idx + 1, len(gamma_dot) - 1)]
# Get stress range in banding region
sigma_low = sigma[first_neg_idx]
sigma_high = sigma[min(last_neg_idx + 1, len(sigma) - 1)]
# Compute fraction of curve with negative slope
neg_fraction = np.sum(negative_slope_mask) / len(derivative)
banding_info = {
"gamma_dot_low": float(gamma_dot_low),
"gamma_dot_high": float(gamma_dot_high),
"sigma_low": float(sigma_low),
"sigma_high": float(sigma_high),
"sigma_range": (
float(min(sigma_low, sigma_high)),
float(max(sigma_low, sigma_high)),
),
"negative_slope_fraction": float(neg_fraction),
}
logger.info(
"Shear banding detected",
gamma_dot_low=float(gamma_dot_low),
gamma_dot_high=float(gamma_dot_high),
negative_slope_fraction=float(neg_fraction),
)
if warn:
warnings.warn(
f"Shear banding detected in flow curve. "
f"Non-monotonic region: gamma_dot = [{gamma_dot_low:.3g}, {gamma_dot_high:.3g}] 1/s. "
f"This may indicate mechanical instability.",
UserWarning,
stacklevel=2,
)
return True, banding_info
def compute_shear_band_coexistence(
gamma_dot: np.ndarray,
sigma: np.ndarray,
gamma_dot_applied: float,
) -> dict | None:
"""Compute shear band coexistence using lever rule.
When shear banding occurs, the material splits into bands with different
local shear rates (gamma_dot_low and gamma_dot_high) that coexist at
a common stress plateau. The fraction of each band is determined by
the lever rule from the applied average shear rate.
Parameters
----------
gamma_dot : ndarray
Shear rate array (1/s)
sigma : ndarray
Stress array (Pa)
gamma_dot_applied : float
Applied (average) shear rate (1/s)
Returns
-------
coexistence : dict or None
Coexistence information if banding detected:
- 'gamma_dot_low': Shear rate in low-shear band
- 'gamma_dot_high': Shear rate in high-shear band
- 'fraction_low': Volume fraction of low-shear band
- 'fraction_high': Volume fraction of high-shear band
- 'stress_plateau': Common stress in banding regime
Returns None if no banding or applied rate outside banding region.
Notes
-----
The lever rule states:
gamma_dot_applied = f_low * gamma_dot_low + f_high * gamma_dot_high
where f_low + f_high = 1.
The stress plateau is found by equal area construction (Maxwell rule)
or by finding the stress at which both bands coexist stably.
"""
logger.debug(
"Computing shear band coexistence",
gamma_dot_applied=gamma_dot_applied,
n_points=len(gamma_dot),
)
# First detect if banding exists
is_banding, banding_info = detect_shear_banding(gamma_dot, sigma)
if not is_banding or banding_info is None:
logger.debug("No shear banding detected, cannot compute coexistence")
return None
# Get banding region bounds
gamma_dot_low_bound = banding_info["gamma_dot_low"]
gamma_dot_high_bound = banding_info["gamma_dot_high"]
# Check if applied shear rate is in banding region
if (
gamma_dot_applied < gamma_dot_low_bound
or gamma_dot_applied > gamma_dot_high_bound
):
logger.debug(
"Applied shear rate outside banding region",
gamma_dot_applied=gamma_dot_applied,
banding_region=(gamma_dot_low_bound, gamma_dot_high_bound),
)
return None
# Find stress plateau using simplified approach
# (In practice, would use equal area Maxwell construction)
# Sort data
sort_idx = np.argsort(gamma_dot)
gamma_dot_sorted = gamma_dot[sort_idx]
sigma_sorted = sigma[sort_idx]
# Find indices bounding the banding region
low_idx = np.searchsorted(gamma_dot_sorted, gamma_dot_low_bound)
high_idx = np.searchsorted(gamma_dot_sorted, gamma_dot_high_bound)
# Estimate stress plateau as average in banding region
banding_slice = sigma_sorted[low_idx : high_idx + 1]
if len(banding_slice) == 0:
return None
stress_plateau = np.mean(banding_slice)
# Find coexisting shear rates at stress plateau
# These are the intersections of horizontal line at stress_plateau
# with the constitutive curve (on the stable branches)
# Left branch (before banding onset)
left_mask = gamma_dot_sorted < gamma_dot_low_bound
if np.any(left_mask):
gamma_dot_left = gamma_dot_sorted[left_mask]
sigma_left = sigma_sorted[left_mask]
# Interpolate to find gamma_dot at stress_plateau
if len(gamma_dot_left) > 1:
gamma_dot_low = np.interp(stress_plateau, sigma_left, gamma_dot_left)
else:
gamma_dot_low = gamma_dot_low_bound
else:
gamma_dot_low = gamma_dot_low_bound
# Right branch (after banding ends)
right_mask = gamma_dot_sorted > gamma_dot_high_bound
if np.any(right_mask):
gamma_dot_right = gamma_dot_sorted[right_mask]
sigma_right = sigma_sorted[right_mask]
# Interpolate
if len(gamma_dot_right) > 1:
gamma_dot_high = np.interp(stress_plateau, sigma_right, gamma_dot_right)
else:
gamma_dot_high = gamma_dot_high_bound
else:
gamma_dot_high = gamma_dot_high_bound
# Lever rule for band fractions
# gamma_dot_applied = f_low * gamma_dot_low + (1 - f_low) * gamma_dot_high
# f_low = (gamma_dot_high - gamma_dot_applied) / (gamma_dot_high - gamma_dot_low)
delta_gamma = gamma_dot_high - gamma_dot_low
if abs(delta_gamma) < 1e-12:
return None
f_low = (gamma_dot_high - gamma_dot_applied) / delta_gamma
f_high = 1.0 - f_low
# Clamp fractions to [0, 1]
f_low = np.clip(f_low, 0, 1)
f_high = np.clip(f_high, 0, 1)
logger.info(
"Shear band coexistence computed",
gamma_dot_low=float(gamma_dot_low),
gamma_dot_high=float(gamma_dot_high),
fraction_low=float(f_low),
fraction_high=float(f_high),
stress_plateau=float(stress_plateau),
)
return {
"gamma_dot_low": float(gamma_dot_low),
"gamma_dot_high": float(gamma_dot_high),
"fraction_low": float(f_low),
"fraction_high": float(f_high),
"stress_plateau": float(stress_plateau),
}
# ============================================================================
# Thixotropy Kinetics Functions
# ============================================================================
@jax.jit
def thixotropy_lambda_derivative(
lambda_val: float,
gamma_dot: float,
k_build: float,
k_break: float,
) -> float:
"""Compute time derivative of structural parameter lambda.
The structural parameter lambda represents the state of internal
microstructure, with lambda = 1 being fully built and lambda = 0
being fully broken.
Evolution equation:
d(lambda)/dt = k_build * (1 - lambda) - k_break * gamma_dot * lambda
Parameters
----------
lambda_val : float
Current structural parameter value [0, 1]
gamma_dot : float
Current shear rate (1/s)
k_build : float
Structure build-up rate (1/s)
k_break : float
Structure breakdown rate (dimensionless)
Returns
-------
float
Time derivative d(lambda)/dt
"""
# Build-up term: drives lambda toward 1 at rest
build_up = k_build * (1.0 - lambda_val)
# Breakdown term: shear destroys structure
breakdown = k_break * gamma_dot * lambda_val
return build_up - breakdown
@jax.jit
def _thixotropy_scan_step(
lambda_prev: float,
inputs: tuple[float, float, float, float],
) -> tuple[float, float]:
"""Single step of thixotropy evolution for jax.lax.scan.
This is JIT-compiled and fused into a single kernel when used with scan,
eliminating per-iteration Python overhead.
Parameters
----------
lambda_prev : float
Previous structural parameter value
inputs : tuple
(gamma_dot_i, dt_i, k_build, k_break) for this timestep
Returns
-------
tuple
(lambda_new, lambda_new) - carry and output are the same
"""
gamma_dot_i, dt_i, k_build, k_break = inputs
# Compute derivative using inlined logic (avoids function call overhead)
build_up = k_build * (1.0 - lambda_prev)
breakdown = k_break * gamma_dot_i * lambda_prev
dlambda_dt = build_up - breakdown
# Euler step with clamping
lambda_new = lambda_prev + dlambda_dt * dt_i
lambda_new = jnp.clip(lambda_new, 0.0, 1.0)
return lambda_new, lambda_new
def evolve_thixotropy_lambda(
t: np.ndarray,
gamma_dot: np.ndarray,
lambda_initial: float,
k_build: float,
k_break: float,
) -> np.ndarray:
"""Evolve structural parameter lambda(t) for given shear history.
Integrates the thixotropy kinetics equation:
d(lambda)/dt = k_build * (1 - lambda) - k_break * gamma_dot * lambda
Uses JAX's lax.scan for efficient vectorized integration, compiling the
entire loop into a single fused kernel. This provides 2-5x speedup over
Python loops by eliminating per-iteration dispatch overhead.
Parameters
----------
t : ndarray
Time array (s)
gamma_dot : ndarray
Shear rate array (1/s), same shape as t
lambda_initial : float
Initial structural parameter [0, 1]
k_build : float
Structure build-up rate (1/s)
k_break : float
Structure breakdown rate (dimensionless)
Returns
-------
lambda_t : ndarray
Structural parameter evolution, same shape as t
"""
logger.debug(
"Evolving thixotropy lambda",
n_points=len(t),
lambda_initial=lambda_initial,
k_build=k_build,
k_break=k_break,
)
# T-24: Forward Euler stability check for thixotropy ODE.
# dt must be < 2 / max_eigenvalue to avoid oscillatory blow-up.
if len(t) > 1:
dt_arr = np.diff(t)
max_eigenvalue = k_build + k_break * np.max(np.abs(gamma_dot))
max_stable_dt = 2.0 / max(max_eigenvalue, 1e-30)
max_dt = float(np.max(dt_arr))
if max_dt > max_stable_dt:
warnings.warn(
f"Forward Euler may be unstable for thixotropy ODE: "
f"max(dt)={max_dt:.3g} > stability limit={max_stable_dt:.3g}. "
f"Consider using finer time steps or an implicit integrator.",
stacklevel=2,
)
if t.shape != gamma_dot.shape:
logger.error(
"Shape mismatch between time and shear rate arrays",
t_shape=t.shape,
gamma_dot_shape=gamma_dot.shape,
)
raise ValueError(
f"Time and shear rate arrays must have same shape: "
f"t.shape={t.shape}, gamma_dot.shape={gamma_dot.shape}"
)
# Convert to JAX arrays for scan
t_jax = jnp.asarray(t, dtype=jnp.float64)
gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64)
# Compute time steps (dt[0] is not used, but we need consistent shapes)
dt = jnp.diff(t_jax)
# Prepare inputs for scan: (gamma_dot[1:], dt, k_build, k_break)
# We broadcast k_build and k_break to match the sequence length
n_steps = len(dt)
k_build_arr = jnp.full(n_steps, k_build, dtype=jnp.float64)
k_break_arr = jnp.full(n_steps, k_break, dtype=jnp.float64)
# Stack inputs for scan: each element is (gamma_dot_i, dt_i, k_build, k_break)
scan_inputs = (gamma_dot_jax[1:], dt, k_build_arr, k_break_arr)
# Run vectorized integration using lax.scan
# This compiles the entire loop into a single fused kernel
_, lambda_history = jax.lax.scan(
_thixotropy_scan_step,
jnp.float64(lambda_initial), # Initial carry
scan_inputs, # Sequence of inputs
)
# Prepend initial value to get full history
lambda_t = jnp.concatenate([jnp.array([lambda_initial]), lambda_history])
# Convert back to numpy for compatibility
lambda_t_np = np.asarray(lambda_t, dtype=np.float64)
logger.debug(
"Thixotropy evolution completed",
lambda_final=float(lambda_t_np[-1]),
lambda_min=float(np.min(lambda_t_np)),
lambda_max=float(np.max(lambda_t_np)),
)
return lambda_t_np
def compute_thixotropic_stress(
t: np.ndarray,
gamma_dot: np.ndarray,
lambda_t: np.ndarray,
G0: float,
tau0: float,
x: float,
n_struct: float = 2.0,
) -> np.ndarray:
"""Compute stress response with thixotropic modulus.
The effective modulus is coupled to the structural parameter:
G_eff(t) = G0 * lambda(t)^n_struct
Parameters
----------
t : ndarray
Time array (s)
gamma_dot : ndarray
Shear rate array (1/s)
lambda_t : ndarray
Structural parameter array [0, 1]
G0 : float
Base modulus scale (Pa)
tau0 : float
Attempt time (s)
x : float
SGR noise temperature
n_struct : float, default=2.0
Structural coupling exponent
Returns
-------
sigma : ndarray
Stress response (Pa)
"""
logger.debug(
"Computing thixotropic stress",
n_points=len(t),
G0=G0,
tau0=tau0,
x=x,
n_struct=n_struct,
)
# Effective modulus from structure
G_eff = G0 * np.power(lambda_t, n_struct)
# Viscosity from power-law (SGR-like)
gamma_dot_safe = np.maximum(np.abs(gamma_dot), 1e-12)
exponent = np.clip(x - 2.0, -10.0, 10.0)
eta_factor = np.power(gamma_dot_safe * tau0, exponent)
eta_factor = np.clip(eta_factor, 1e-30, 1e30)
# Stress = G_eff * gamma_dot * tau0 * eta_factor
sigma = G_eff * gamma_dot * tau0 * eta_factor
logger.debug(
"Thixotropic stress computed",
sigma_min=float(np.min(sigma)),
sigma_max=float(np.max(sigma)),
)
return sigma
__all__ = [
"SRFS",
"detect_shear_banding",
"compute_shear_band_coexistence",
"thixotropy_lambda_derivative",
"evolve_thixotropy_lambda",
"compute_thixotropic_stress",
]