"""Soft Glassy Rheology (SGR) Conventional Model.
This module implements the conventional SGR model from Sollich 1998, describing
the rheology of soft glassy materials including foams, emulsions, pastes, and
colloidal suspensions through a statistical mechanics framework.
The SGR model is characterized by:
- Exponential distribution of energy traps: rho(E) ~ exp(-E)
- Effective noise temperature x controlling material phase:
* x < 1: Glass phase with yield stress (solid-like)
* 1 < x < 2: Power-law viscoelastic fluid with G' ~ G'' ~ omega^(x-1)
* x >= 2: Newtonian viscous liquid
- Mesoscopic elements undergoing thermally activated hopping between traps
- Characteristic power-law rheological responses
Constitutive Equations:
Equilibrium modulus:
G0(x) = integral_0^inf rho(E) * E * (1 - exp(-E/x)) dE
Frequency-dependent complex modulus:
G*(omega) = G_scale * Gp(x, omega*tau0)
where Gp(x, z) = integral_0^inf rho(E) * E * [i*z] / [i*z + exp(E/x)] dE
Model Parameters:
x (dimensionless): Effective noise temperature, range (0.5, 3.0), default 1.5
Controls material phase transition and power-law exponent
G0 (Pa): Modulus scale, range (1e-3, 1e9), default 1e3
Sets absolute magnitude of elastic response
tau0 (s): Attempt time, range (1e-9, 1e3), default 1e-3
Characteristic microscopic relaxation timescale
Physical Interpretation:
The noise temperature x quantifies the ratio of thermal fluctuations to
the energy scale of structural rearrangements. At x=1, the material undergoes
a glass transition from solid-like (x<1) to fluid-like (x>1) behavior. The
power-law exponent (x-1) directly reflects the breadth of the distribution
of relaxation times arising from the exponential trap distribution.
References:
- P. Sollich, Rheological constitutive equation for a model of soft glassy
materials, Physical Review E, 1998, 58(1), 738-759
- P. Sollich et al., Rheology of Soft Glassy Materials, Physical Review
Letters, 1997, 78(10), 2020-2023
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode, TestMode
from rheojax.logging import get_logger, log_fit
from rheojax.utils.sgr_kernels import G0, Gp
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
if TYPE_CHECKING: # pragma: no cover
import jax.numpy as jnp_typing
else:
jnp_typing = np
# Module logger
logger = get_logger(__name__)
[docs]
@ModelRegistry.register(
"sgr_conventional",
protocols=[
Protocol.FLOW_CURVE,
Protocol.CREEP,
Protocol.RELAXATION,
Protocol.STARTUP,
Protocol.OSCILLATION,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class SGRConventional(BaseModel):
"""Soft Glassy Rheology (SGR) Conventional Model.
Statistical mechanics model for soft glassy materials (foams, emulsions,
pastes, colloidal suspensions) based on Sollich 1998. The model describes
rheological behavior through a population of mesoscopic elements trapped
in energy wells with exponential distribution rho(E) ~ exp(-E), undergoing
thermally activated hopping.
The effective noise temperature x controls material phase:
- x < 1: Glass with yield stress (solid-like)
- 1 < x < 2: Power-law viscoelastic fluid (G' ~ G'' ~ omega^(x-1))
- x >= 2: Newtonian liquid
Parameters:
x: Effective noise temperature (dimensionless), controls phase transition
G0: Modulus scale (Pa), sets absolute magnitude of elastic response
tau0: Attempt time (s), characteristic microscopic relaxation timescale
Attributes:
parameters: ParameterSet containing x, G0, tau0
Example:
>>> from rheojax.models.sgr_conventional import SGRConventional
>>> import numpy as np
>>> model = SGRConventional()
>>> omega = np.logspace(-2, 2, 50)
>>> # Fitting oscillation data
>>> model.fit(omega, G_star, test_mode='oscillation')
>>> G_star_pred = model.predict(omega)
>>> # Bayesian inference
>>> result = model.fit_bayesian(omega, G_star, num_samples=2000)
>>> intervals = model.get_credible_intervals(result.posterior_samples)
Notes:
- Inherits from BaseModel (includes BayesianMixin for NumPyro NUTS)
- Mode-aware Bayesian inference (stores test_mode for correct predictions)
- JAX-accelerated kernel functions for GPU computation
- Float64 precision critical for numerical stability near x=1
"""
[docs]
def __init__(self, dynamic_x: bool = False):
"""Initialize SGR Conventional Model.
Creates ParameterSet with:
- x (noise temperature): bounds (0.5, 3.0), default 1.5
- G0 (modulus scale): bounds (1e-3, 1e9), default 1e3
- tau0 (attempt time): bounds (1e-9, 1e3), default 1e-3
When dynamic_x=True, additional parameters for x(t) evolution:
- x_eq: Equilibrium noise temperature at rest
- alpha_aging: Aging rate (drives x toward x_eq)
- beta_rejuv: Rejuvenation rate (shear-induced increase in x)
- x_ss_A: Steady-state amplitude parameter
- x_ss_n: Steady-state power-law exponent
Args:
dynamic_x: If True, x evolves via dx/dt equation. If False, x is constant.
The model is ready for fitting after instantiation.
"""
super().__init__()
# Create parameter set
self.parameters = ParameterSet()
# x: Effective noise temperature (dimensionless)
# Range: 0.5 to 3.0 covers glass -> power-law fluid -> Newtonian
# Default: 1.5 (middle of power-law regime)
self.parameters.add(
name="x",
value=1.5,
bounds=(0.5, 3.0),
units="dimensionless",
description="Effective noise temperature (glass transition at x=1)",
)
# G0: Modulus scale (Pa)
# Range: 1e-3 to 1e9 Pa covers soft gels to stiff pastes
# Default: 1e3 Pa (typical soft gel)
self.parameters.add(
name="G0",
value=1e3,
bounds=(1e-3, 1e9),
units="Pa",
description="Modulus scale (absolute magnitude of elastic response)",
)
# tau0: Attempt time (s)
# Range: 1e-9 to 1e3 s covers molecular to macroscopic timescales
# Default: 1e-3 s (millisecond, typical for colloidal relaxation)
self.parameters.add(
name="tau0",
value=1e-3,
bounds=(1e-9, 1e3),
units="s",
description="Attempt time (microscopic relaxation timescale)",
)
# Dynamic x(t) mode flag
self.dynamic_x = dynamic_x
# Add dynamic x parameters if enabled
if dynamic_x:
# x_eq: Equilibrium noise temperature at rest
self.parameters.add(
name="x_eq",
value=1.0,
bounds=(0.5, 2.5),
units="dimensionless",
description="Equilibrium noise temperature at rest (aging target)",
)
# alpha_aging: Aging rate coefficient
self.parameters.add(
name="alpha_aging",
value=0.1,
bounds=(0.0, 10.0),
units="1/s",
description="Aging rate coefficient (drives x toward x_eq)",
)
# beta_rejuv: Rejuvenation rate coefficient
self.parameters.add(
name="beta_rejuv",
value=0.5,
bounds=(0.0, 10.0),
units="dimensionless",
description="Rejuvenation rate coefficient (shear-induced x increase)",
)
# x_ss_A: Steady-state amplitude parameter
self.parameters.add(
name="x_ss_A",
value=0.5,
bounds=(0.0, 2.0),
units="dimensionless",
description="Steady-state amplitude: x_ss = x_eq + A*(gamma_dot*tau0)^n",
)
# x_ss_n: Steady-state power-law exponent
self.parameters.add(
name="x_ss_n",
value=0.3,
bounds=(0.0, 1.0),
units="dimensionless",
description="Steady-state power-law exponent",
)
# Storage for x(t) trajectory in dynamic mode
self._x_trajectory: np.ndarray | None = None
self._t_trajectory: np.ndarray | None = None
# Store test mode for mode-aware Bayesian inference
self._test_mode: TestMode | str | None = None
# LAOS attributes
self._gamma_0: float | None = None
self._omega_laos: float | None = None
def _fit(
self,
X: np.ndarray,
y: np.ndarray,
**kwargs,
) -> SGRConventional:
"""Fit SGR model to data using NLSQ optimization.
Routes to appropriate fitting method based on test_mode. This is the
internal implementation required by BaseModel.
Args:
X: Independent variable (frequency for oscillation, time for relaxation)
y: Dependent variable (complex modulus, relaxation modulus, etc.)
**kwargs: NLSQ optimizer arguments (max_iter, ftol, xtol, gtol).
Must include test_mode ('oscillation', 'relaxation', 'creep', 'steady_shear').
Raises:
ValueError: If test_mode not provided or invalid
RuntimeError: If optimization fails to converge
Note:
NLSQ provides 5-270x speedup over scipy. Fitted parameters serve
as warm-start for Bayesian inference via fit_bayesian().
"""
# Detect test mode
test_mode = kwargs.pop("test_mode", None)
if test_mode is None:
raise ValueError("test_mode must be specified for SGR fitting")
with log_fit(logger, model="SGRConventional", data_shape=X.shape) as ctx:
try:
logger.info(
"Starting SGR Conventional model fit",
test_mode=test_mode,
n_points=len(X),
dynamic_x=self.dynamic_x,
)
logger.debug(
"Input data statistics",
x_range=(float(np.min(X)), float(np.max(X))),
y_range=(float(np.min(np.abs(y))), float(np.max(np.abs(y)))),
)
# Store test mode for mode-aware Bayesian inference
self._test_mode = test_mode
ctx["test_mode"] = test_mode
# Route to appropriate fitting method
if test_mode == "oscillation":
self._fit_oscillation_mode(X, y, **kwargs)
elif test_mode == "relaxation":
self._fit_relaxation_mode(X, y, **kwargs)
elif test_mode == "creep":
self._fit_creep_mode(X, y, **kwargs)
elif test_mode in ("steady_shear", "flow_curve"):
self._fit_steady_shear_mode(X, y, **kwargs)
elif test_mode == "laos":
self._fit_laos_mode(X, y, **kwargs)
elif test_mode == "startup":
self._fit_startup_mode(X, y, **kwargs)
else:
raise ValueError(
f"Unsupported test_mode: {test_mode}. "
f"SGR model supports 'oscillation', 'relaxation', 'creep', "
f"'steady_shear', 'laos', and 'startup'."
)
# Log final parameters
x_val = self.parameters.get_value("x")
G0_val = self.parameters.get_value("G0")
tau0_val = self.parameters.get_value("tau0")
ctx["x"] = x_val
ctx["G0"] = G0_val
ctx["tau0"] = tau0_val
ctx["phase_regime"] = self.get_phase_regime()
logger.info(
"SGR Conventional model fit completed",
x=x_val,
G0=G0_val,
tau0=tau0_val,
phase_regime=self.get_phase_regime(),
)
except Exception as e:
logger.error(
"SGR Conventional model fit failed",
test_mode=test_mode,
error=str(e),
exc_info=True,
)
raise
return self
def _fit_oscillation_mode(
self,
omega: np.ndarray,
G_star: np.ndarray,
**kwargs,
) -> None:
"""Fit SGR to complex modulus data (oscillation mode).
Uses NLSQ-accelerated optimization (5-270x faster than scipy) to fit
SGR parameters [x, G0, tau0] to complex modulus data G*(omega).
Args:
omega: Angular frequency array (rad/s)
G_star: Complex modulus data. Accepted formats:
- Complex array (M,) where G* = G' + i*G''
- Real array (M, 2) where columns are [G', G'']
**kwargs: NLSQ optimizer arguments:
- max_iter: Maximum iterations (default: 1000)
- ftol: Function tolerance (default: 1e-6)
- xtol: Parameter tolerance (default: 1e-6)
- gtol: Gradient tolerance (default: 1e-6)
Raises:
RuntimeError: If optimization fails to converge
Note:
After fitting, use fit_bayesian() for uncertainty quantification.
The fitted parameters serve as warm-start for NUTS sampling.
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Convert inputs to JAX arrays
omega_jax = jnp.asarray(omega, dtype=jnp.float64)
# Handle G_star format: can be (M, 2) or complex (M,)
G_star_np = np.asarray(G_star)
if np.iscomplexobj(G_star_np):
# Convert complex to (M, 2) [G', G'']
G_star_2d = np.column_stack([np.real(G_star_np), np.imag(G_star_np)])
elif G_star_np.ndim == 2 and G_star_np.shape[1] == 2:
G_star_2d = G_star_np
elif G_star_np.ndim == 2 and G_star_np.shape[0] == 2:
# Transposed format (2, M) -> (M, 2)
G_star_2d = G_star_np.T
else:
raise ValueError(
f"G_star must be complex (M,) or real (M, 2), got shape {G_star_np.shape}"
)
G_star_jax = jnp.asarray(G_star_2d, dtype=jnp.float64)
# Create model function for NLSQ: takes (x_data, params) -> predictions
def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
"""Stateless model function for optimization."""
x_param = params[0]
G0_param = params[1]
tau0_param = params[2]
return self._predict_oscillation_jit(x_data, x_param, G0_param, tau0_param)
# Create residual function
objective = create_least_squares_objective(
model_fn,
omega_jax,
G_star_jax,
normalize=True,
use_log_residuals=kwargs.get("use_log_residuals", False),
)
# Run NLSQ optimization
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
max_iter=kwargs.get("max_iter", 1000),
ftol=kwargs.get("ftol", 1e-6),
xtol=kwargs.get("xtol", 1e-6),
gtol=kwargs.get("gtol", 1e-6),
)
if not result.success:
raise RuntimeError(
f"SGR oscillation fitting failed: {result.message}. "
"Try adjusting initial values, bounds, or use use_log_residuals=True."
)
logger.debug(
f"SGR oscillation fit converged: x={self.parameters.get_value('x'):.4f}, "
f"G0={self.parameters.get_value('G0'):.2e}, "
f"tau0={self.parameters.get_value('tau0'):.2e}, "
f"cost={result.fun:.3e}"
)
self.fitted_ = True
def _fit_relaxation_mode(
self,
t: np.ndarray,
G_t: np.ndarray,
**kwargs,
) -> None:
"""Fit SGR to relaxation modulus data (relaxation mode).
Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0]
to relaxation modulus data G(t).
Args:
t: Time array (s)
G_t: Relaxation modulus array (Pa)
**kwargs: NLSQ optimizer arguments:
- max_iter: Maximum iterations (default: 1000)
- ftol: Function tolerance (default: 1e-6)
- xtol: Parameter tolerance (default: 1e-6)
- gtol: Gradient tolerance (default: 1e-6)
- use_log_residuals: Use log-space residuals (default: True)
Raises:
RuntimeError: If optimization fails to converge
Note:
Log-space residuals are recommended for relaxation data due to
the power-law decay spanning many orders of magnitude.
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Convert inputs to JAX arrays
t_jax = jnp.asarray(t, dtype=jnp.float64)
G_t_jax = jnp.asarray(G_t, dtype=jnp.float64)
# Create model function for NLSQ
def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
"""Stateless model function for optimization."""
x_param = params[0]
G0_param = params[1]
tau0_param = params[2]
return self._predict_relaxation_jit(x_data, x_param, G0_param, tau0_param)
# Create residual function (log-space recommended for power-law decay)
objective = create_least_squares_objective(
model_fn,
t_jax,
G_t_jax,
normalize=True,
use_log_residuals=kwargs.get("use_log_residuals", True),
)
# Run NLSQ optimization
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
max_iter=kwargs.get("max_iter", 1000),
ftol=kwargs.get("ftol", 1e-6),
xtol=kwargs.get("xtol", 1e-6),
gtol=kwargs.get("gtol", 1e-6),
)
if not result.success:
raise RuntimeError(
f"SGR relaxation fitting failed: {result.message}. "
"Try adjusting initial values or bounds."
)
logger.debug(
f"SGR relaxation fit converged: x={self.parameters.get_value('x'):.4f}, "
f"G0={self.parameters.get_value('G0'):.2e}, "
f"tau0={self.parameters.get_value('tau0'):.2e}, "
f"cost={result.fun:.3e}"
)
self.fitted_ = True
def _fit_creep_mode(
self,
t: np.ndarray,
J_t: np.ndarray,
**kwargs,
) -> None:
"""Fit SGR to creep compliance data (creep mode).
Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0]
to creep compliance data J(t).
Args:
t: Time array (s)
J_t: Creep compliance array (1/Pa)
**kwargs: NLSQ optimizer arguments
Raises:
RuntimeError: If optimization fails to converge
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Convert inputs to JAX arrays
t_jax = jnp.asarray(t, dtype=jnp.float64)
J_t_jax = jnp.asarray(J_t, dtype=jnp.float64)
# Create model function for NLSQ
def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
"""Stateless model function for optimization."""
x_param = params[0]
G0_param = params[1]
tau0_param = params[2]
return self._predict_creep_jit(x_data, x_param, G0_param, tau0_param)
# Create residual function (log-space for compliance spanning decades)
objective = create_least_squares_objective(
model_fn,
t_jax,
J_t_jax,
normalize=True,
use_log_residuals=kwargs.get("use_log_residuals", True),
)
# Run NLSQ optimization
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
max_iter=kwargs.get("max_iter", 1000),
ftol=kwargs.get("ftol", 1e-6),
xtol=kwargs.get("xtol", 1e-6),
gtol=kwargs.get("gtol", 1e-6),
)
if not result.success:
raise RuntimeError(
f"SGR creep fitting failed: {result.message}. "
"Try adjusting initial values or bounds."
)
logger.debug(
f"SGR creep fit converged: x={self.parameters.get_value('x'):.4f}, "
f"G0={self.parameters.get_value('G0'):.2e}, "
f"tau0={self.parameters.get_value('tau0'):.2e}, "
f"cost={result.fun:.3e}"
)
self.fitted_ = True
def _fit_steady_shear_mode(
self,
gamma_dot: np.ndarray,
eta: np.ndarray,
**kwargs,
) -> None:
"""Fit SGR to steady shear viscosity data (steady_shear mode).
Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0]
to flow curve data eta(gamma_dot).
Args:
gamma_dot: Shear rate array (1/s)
eta: Viscosity array (Pa.s)
**kwargs: NLSQ optimizer arguments
Raises:
RuntimeError: If optimization fails to converge
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Convert inputs to JAX arrays
gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64)
eta_jax = jnp.asarray(eta, dtype=jnp.float64)
# Create model function for NLSQ
def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
"""Stateless model function for optimization."""
x_param = params[0]
G0_param = params[1]
tau0_param = params[2]
return self._predict_steady_shear_jit(x_data, x_param, G0_param, tau0_param)
# Create residual function (log-space for viscosity spanning decades)
objective = create_least_squares_objective(
model_fn,
gamma_dot_jax,
eta_jax,
normalize=True,
use_log_residuals=kwargs.get("use_log_residuals", True),
)
# Run NLSQ optimization
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
max_iter=kwargs.get("max_iter", 1000),
ftol=kwargs.get("ftol", 1e-6),
xtol=kwargs.get("xtol", 1e-6),
gtol=kwargs.get("gtol", 1e-6),
)
if not result.success:
raise RuntimeError(
f"SGR steady shear fitting failed: {result.message}. "
"Try adjusting initial values or bounds."
)
logger.debug(
f"SGR steady shear fit converged: x={self.parameters.get_value('x'):.4f}, "
f"G0={self.parameters.get_value('G0'):.2e}, "
f"tau0={self.parameters.get_value('tau0'):.2e}, "
f"cost={result.fun:.3e}"
)
self.fitted_ = True
def _fit_laos_mode(
self,
X: np.ndarray,
y: np.ndarray,
**kwargs,
) -> None:
"""Fit SGR to LAOS data.
Uses Monte Carlo simulation for time-domain stress prediction,
then optimizes parameters to match measured stress.
Args:
X: Time array (s)
y: Stress array (Pa)
**kwargs: Required kwargs:
- gamma_0: Strain amplitude
- omega: Angular frequency (rad/s)
Optional kwargs:
- n_particles: Monte Carlo particle count (default 5000)
- max_iter: Optimizer max iterations (default 50)
- seed: Random seed for reproducibility (default 42)
Raises:
ValueError: If gamma_0 or omega not provided
"""
gamma_0 = kwargs.get("gamma_0")
omega = kwargs.get("omega")
if gamma_0 is None or omega is None:
raise ValueError("LAOS fitting requires gamma_0 and omega in kwargs")
if gamma_0 <= 0:
raise ValueError(f"gamma_0 must be positive, got {gamma_0}")
if gamma_0 is not None:
self._gamma_0 = gamma_0
if omega is not None:
self._omega_laos = omega
n_particles = kwargs.get("n_particles", 5000)
# Use SAOS approximation for small amplitude
if gamma_0 < 0.1:
logger.info(
f"Small strain amplitude gamma_0={gamma_0}. Using SAOS approximation."
)
# Use JAX-native FFT for JAX-First compliance
sigma = y
t = X
sigma_fft = jnp.fft.fft(jnp.asarray(sigma))
n = len(sigma)
fundamental_idx = int(omega * (t[-1] - t[0]) / (2 * np.pi))
fundamental_idx = max(1, min(fundamental_idx, n // 2 - 1))
G_star_amplitude = (
2.0 * float(jnp.abs(sigma_fft[fundamental_idx])) / (n * gamma_0)
)
phase = float(jnp.angle(sigma_fft[fundamental_idx]))
G_prime = G_star_amplitude * np.cos(phase)
G_double_prime = G_star_amplitude * np.sin(phase)
omega_single = np.array([omega])
G_star_single = np.array([[G_prime, G_double_prime]])
self._fit_oscillation_mode(omega_single, G_star_single, **kwargs)
else:
# Full MC-based LAOS fitting
self._fit_laos_mc(X, y, gamma_0, omega, n_particles, **kwargs)
def _fit_laos_mc(
self,
t: np.ndarray,
sigma: np.ndarray,
gamma_0: float,
omega: float,
n_particles: int,
**kwargs,
) -> None:
"""Full Monte Carlo-based LAOS fitting.
Runs MC simulations within optimization loop to match time-domain stress.
Args:
t: Time array (s)
sigma: Measured stress array (Pa)
gamma_0: Strain amplitude
omega: Angular frequency (rad/s)
n_particles: Number of MC particles
**kwargs: Optimizer arguments
Note:
Uses scipy.optimize.minimize (L-BFGS-B) because the objective function
calls Monte Carlo simulations which are stochastic and not JAX-traceable.
This is acceptable per Technical Guidelines as it's used only for large-
amplitude LAOS fitting, not the primary oscillation/relaxation modes.
"""
from scipy.optimize import minimize
from rheojax.utils.sgr_monte_carlo import simulate_oscillatory
logger.info(
f"Full MC-based LAOS fitting: {n_particles} particles, "
f"gamma_0={gamma_0}, omega={omega:.3f} rad/s"
)
# Determine simulation parameters from data
period = 2.0 * np.pi / omega
t_total = t[-1] - t[0]
n_cycles = max(1, int(t_total / period))
points_per_cycle = max(10, len(t) // n_cycles)
# Warm-start: estimate parameters from stress amplitude
sigma_max = np.max(np.abs(sigma))
G0_init = sigma_max / gamma_0
x_init = self.parameters.get_value("x")
tau0_init = self.parameters.get_value("tau0")
assert x_init is not None
assert tau0_init is not None
# Normalize target stress for residual calculation
sigma_norm = sigma / (sigma_max + 1e-12)
# Fixed random seed for reproducibility
seed = kwargs.get("seed", 42)
def objective(params):
"""Compute residual between MC stress and measured stress."""
x_val, log_G0, log_tau0 = params
G0_val = np.exp(log_G0)
tau0_val = np.exp(log_tau0)
x_val = np.clip(x_val, 0.5, 2.5)
try:
key = jax.random.PRNGKey(seed)
_, _, sigma_mc = simulate_oscillatory(
key=key,
gamma_0=gamma_0,
omega=omega,
n_cycles=n_cycles,
points_per_cycle=points_per_cycle,
x=x_val,
n_particles=n_particles,
k=G0_val,
Gamma0=1.0 / tau0_val,
xg=1.0,
)
t_mc = np.linspace(0, t_total, len(sigma_mc))
sigma_mc_interp = np.interp(t - t[0], t_mc, np.array(sigma_mc))
sigma_mc_max = np.max(np.abs(sigma_mc_interp)) + 1e-12
sigma_mc_norm = sigma_mc_interp / sigma_mc_max
residual = np.sum((sigma_mc_norm - sigma_norm) ** 2)
return residual
except Exception as e:
logger.warning(f"MC simulation failed: {e}")
return 1e10
x0 = np.array([x_init, np.log(G0_init), np.log(tau0_init)])
bounds = [
(0.5, 2.5),
(np.log(1e-3), np.log(1e9)),
(np.log(1e-9), np.log(1e3)),
]
max_iter = kwargs.get("max_iter", 50)
logger.info(f"Starting MC-LAOS optimization (max {max_iter} iterations)...")
result = minimize(
objective,
x0,
method="L-BFGS-B",
bounds=bounds,
options={"maxiter": max_iter, "disp": False},
)
x_opt, log_G0_opt, log_tau0_opt = result.x
self.parameters.set_value("x", float(x_opt))
self.parameters.set_value("G0", float(np.exp(log_G0_opt)))
self.parameters.set_value("tau0", float(np.exp(log_tau0_opt)))
if result.success:
logger.info(
f"MC-LAOS fit converged: x={x_opt:.4f}, "
f"G0={np.exp(log_G0_opt):.2e}, tau0={np.exp(log_tau0_opt):.2e}, "
f"cost={result.fun:.3e}"
)
else:
logger.warning(
f"MC-LAOS fit did not fully converge: {result.message}. "
f"Best: x={x_opt:.4f}, G0={np.exp(log_G0_opt):.2e}"
)
self.fitted_ = True
def _fit_startup_mode(
self,
t: np.ndarray,
eta_plus: np.ndarray,
**kwargs,
) -> None:
"""Fit SGR to startup flow data (stress growth coefficient).
Uses NLSQ-accelerated optimization to fit SGR parameters [x, G0, tau0]
to stress growth coefficient η⁺(t) data from startup flow experiments.
Args:
t: Time array (s)
eta_plus: Stress growth coefficient array (Pa·s).
If stress data is provided, must also pass gamma_dot in kwargs.
**kwargs: NLSQ optimizer arguments, plus:
- gamma_dot: Applied shear rate (required if y is stress, not η⁺)
- is_stress: If True, treat y as stress and divide by gamma_dot
Raises:
RuntimeError: If optimization fails to converge
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# Check if we need to convert stress to eta_plus
gamma_dot = kwargs.get("gamma_dot", 1.0)
is_stress = kwargs.get("is_stress", False)
if is_stress:
# Convert stress to stress growth coefficient
eta_plus_data = eta_plus / gamma_dot
else:
eta_plus_data = eta_plus
# Store gamma_dot for prediction
self._startup_gamma_dot = gamma_dot
# Convert inputs to JAX arrays
t_jax = jnp.asarray(t, dtype=jnp.float64)
eta_plus_jax = jnp.asarray(eta_plus_data, dtype=jnp.float64)
# Create model function for NLSQ
def model_fn(x_data: jnp.ndarray, params: jnp.ndarray) -> jnp.ndarray:
"""Stateless model function for optimization."""
x_param = params[0]
G0_param = params[1]
tau0_param = params[2]
return self._predict_startup_jit(
x_data, x_param, G0_param, tau0_param, gamma_dot
)
# Create residual function (log-space for spanning decades)
objective = create_least_squares_objective(
model_fn,
t_jax,
eta_plus_jax,
normalize=True,
use_log_residuals=kwargs.get("use_log_residuals", True),
)
# Run NLSQ optimization
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
max_iter=kwargs.get("max_iter", 1000),
ftol=kwargs.get("ftol", 1e-6),
xtol=kwargs.get("xtol", 1e-6),
gtol=kwargs.get("gtol", 1e-6),
)
if not result.success:
raise RuntimeError(
f"SGR startup fitting failed: {result.message}. "
"Try adjusting initial values or bounds."
)
logger.debug(
f"SGR startup fit converged: x={self.parameters.get_value('x'):.4f}, "
f"G0={self.parameters.get_value('G0'):.2e}, "
f"tau0={self.parameters.get_value('tau0'):.2e}, "
f"cost={result.fun:.3e}"
)
self.fitted_ = True
@staticmethod
@jax.jit
def _predict_oscillation_jit(
omega: jnp.ndarray,
x: jax.Array | float,
G0_scale: jax.Array | float,
tau0: jax.Array | float,
) -> jnp.ndarray:
"""JIT-compiled oscillation prediction: G'(omega), G''(omega).
Computes complex modulus from SGR kernels:
G'(omega) = G0_scale * G_prime(x, omega*tau0)
G''(omega) = G0_scale * G_double_prime(x, omega*tau0)
Args:
omega: Angular frequency array (rad/s)
x: Effective noise temperature (dimensionless)
G0_scale: Modulus scale (Pa)
tau0: Attempt time (s)
Returns:
Complex modulus [G', G''] with shape (M, 2)
Notes:
- Uses Gp kernel from rheojax.utils.sgr_kernels
- Power-law scaling G' ~ G'' ~ omega^(x-1) for 1 < x < 2
"""
# Compute dimensionless frequency
omega_tau0 = omega * tau0
# Call Gp kernel (returns G_prime, G_double_prime)
G_prime, G_double_prime = Gp(x, omega_tau0)
# Scale by G0
G_prime_scaled = G0_scale * G_prime
G_double_prime_scaled = G0_scale * G_double_prime
# Stack into (M, 2) array
G_star = jnp.stack([G_prime_scaled, G_double_prime_scaled], axis=1)
return G_star
@staticmethod
@jax.jit
def _predict_relaxation_jit(
t: jnp.ndarray,
x: jax.Array | float,
G0_scale: jax.Array | float,
tau0: jax.Array | float,
) -> jnp.ndarray:
"""JIT-compiled relaxation prediction: G(t).
Computes relaxation modulus with power-law decay:
G(t) ~ G0 at short times (t << tau0)
G(t) ~ t^(1-x) at long times (t >> tau0) for 1 < x < 2
The relaxation modulus is related to the frequency-domain response
via Fourier transform. For SGR, we use the inverse relationship:
G(t) = G0 * integral of G_prime(omega) * cos(omega*t) d(omega)
For computational efficiency, we use an analytical approximation
based on the power-law scaling.
Args:
t: Time array (s)
x: Effective noise temperature (dimensionless)
G0_scale: Modulus scale (Pa)
tau0: Attempt time (s)
Returns:
Relaxation modulus G(t) with shape (M,)
Notes:
- Power-law decay G(t) ~ t^(1-x) for large t
- Plateau G(t) = G0 at short times
"""
# Dimensionless time
t_scaled = t / tau0
# Compute equilibrium modulus factor (dimensionless)
G0_dim = G0(x)
# Relaxation modulus using power-law form.
# Decay exponent (1-x): G(t) ~ t^(1-x) for 1 < x < 2 — the theoretical
# SGR result (negative of the creep exponent x-1; Fourier-consistent
# with SAOS G', G'' ~ omega^(x-1)). Earlier versions used x-2, which
# only coincides with 1-x at x=1.5.
# At short times (t << tau0): G(t) -> G0_dim * G0_scale (plateau)
# At long times (t >> tau0): G(t) ~ (t/tau0)^(1-x) * G0_scale
epsilon = 1e-12 # Prevent division by zero
t_safe = jnp.maximum(t_scaled, epsilon)
# Smooth crossover: (1 + t_scaled)^(1-x) = 1 / (1 + t_scaled)^(x-1)
G_t = G0_scale * G0_dim / jnp.power(1.0 + t_safe, x - 1.0)
return G_t
@staticmethod
@jax.jit
def _predict_creep_jit(
t: jnp.ndarray,
x: jax.Array | float,
G0_scale: jax.Array | float,
tau0: jax.Array | float,
) -> jnp.ndarray:
"""JIT-compiled creep prediction: J(t).
Computes creep compliance as approximate inverse of G(t):
J(t) = 1 / G(t)
For more accurate results, uses proper compliance formula:
J(t) ~ t^(x-1) / G0 for large t (1 < x < 2)
Args:
t: Time array (s)
x: Effective noise temperature (dimensionless)
G0_scale: Modulus scale (Pa)
tau0: Attempt time (s)
Returns:
Creep compliance J(t) with shape (M,)
Notes:
- Monotonicity enforced via jnp.maximum.accumulate()
- Power-law growth J(t) ~ t^(x-1) for large t
"""
# Dimensionless time
t_scaled = t / tau0
# Compute equilibrium modulus factor
G0_dim = G0(x)
# Creep compliance power-law exponent: J(t) ~ t^(x-1) for 1 < x < 2.
# This is the theoretical SGR creep exponent — complementary to the
# relaxation decay t^(1-x) (exponents sum to zero) and Fourier-
# consistent with the canonical SAOS scaling G', G'' ~ omega^(x-1).
# See utils.sgr_kernels.power_law_exponent(x) = x - 1.
# (Earlier versions used 2-x, which only coincides with x-1 at x=1.5.)
growth_exp = x - 1.0
epsilon = 1e-12
t_safe = jnp.maximum(t_scaled, epsilon)
# Creep compliance with power-law growth
# J(t) = (1 / (G0_scale * G0_dim)) * (1 + t_scaled)^(x-1)
J_t = jnp.power(1.0 + t_safe, growth_exp) / (G0_scale * G0_dim)
# Monotonicity enforced by physical parameter bounds, not in NUTS path
return J_t
# NOTE: SGR FLOW_CURVE returns steady-state viscosity eta(gamma_dot), not stress sigma.
# This is the physically meaningful SGR quantity. For stress, multiply by gamma_dot.
@staticmethod
@jax.jit
def _predict_steady_shear_jit(
gamma_dot: jnp.ndarray,
x: jax.Array | float,
G0_scale: jax.Array | float,
tau0: jax.Array | float,
) -> jnp.ndarray:
"""JIT-compiled steady shear prediction: eta(gamma_dot).
Computes viscosity as function of shear rate:
eta ~ gamma_dot^(x-2) for 1 < x < 2 (shear-thinning)
eta = const for x >= 2 (Newtonian)
sigma_y > 0 for x < 1 (yield stress, glass phase)
Args:
gamma_dot: Shear rate array (1/s)
x: Effective noise temperature (dimensionless)
G0_scale: Modulus scale (Pa)
tau0: Attempt time (s)
Returns:
Viscosity eta(gamma_dot) with shape (M,)
Notes:
- Shear-thinning exponent: x - 2
- Uses relationship: eta ~ G0 * tau0 * (gamma_dot * tau0)^(x-2)
"""
# Dimensionless shear rate
gamma_dot_scaled = gamma_dot * tau0
epsilon = 1e-12
gamma_dot_safe = jnp.maximum(gamma_dot_scaled, epsilon)
# Compute equilibrium modulus factor
G0_dim = G0(x)
# Viscosity power-law exponent
visc_exp = x - 2.0
# Viscosity formula
# eta = G0_scale * tau0 * G0_dim * (gamma_dot * tau0)^(x-2)
# For x = 2: eta = const (Newtonian)
# For x < 2: eta decreases with gamma_dot (shear-thinning)
eta = G0_scale * tau0 * G0_dim * jnp.power(gamma_dot_safe, visc_exp)
return eta
def _predict(self, X: np.ndarray, **kwargs) -> np.ndarray:
"""Predict based on fitted test mode.
Routes to appropriate prediction method based on stored test_mode.
Args:
X: Independent variable (frequency or time)
**kwargs: Additional arguments including optional test_mode override
Returns:
Predicted values (complex modulus or relaxation modulus)
Raises:
ValueError: If test_mode not set (model not fitted)
"""
# Get test_mode from kwargs or instance attribute
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode if _kw_mode is not None else getattr(self, "_test_mode", None)
)
if test_mode is None:
raise ValueError("test_mode must be specified for prediction")
# Extract LAOS parameters from kwargs if provided
if test_mode == "laos":
gamma_0 = kwargs.get("gamma_0")
omega = kwargs.get("omega")
if gamma_0 is not None:
self._gamma_0 = gamma_0
if omega is not None:
self._omega_laos = omega
# Route to appropriate prediction method
if test_mode == "oscillation":
return self._predict_oscillation(X)
elif test_mode == "relaxation":
return self._predict_relaxation(X)
elif test_mode == "creep":
return self._predict_creep(X)
elif test_mode in ("steady_shear", "flow_curve"):
return self._predict_steady_shear(X)
elif test_mode == "laos":
return self._predict_laos(X)
elif test_mode == "startup":
return self._predict_startup(X)
else:
raise ValueError(f"Unknown test_mode: {test_mode}")
def _predict_oscillation(self, omega: np.ndarray) -> np.ndarray:
"""Predict complex modulus in oscillation mode.
Args:
omega: Angular frequency array (rad/s)
Returns:
Complex modulus G* = G' + iG'' with shape (M,)
"""
# Get parameters
x = self.parameters.get_value("x")
G0_scale = self.parameters.get_value("G0")
tau0 = self.parameters.get_value("tau0")
# Convert to JAX arrays
omega_jax = jnp.asarray(omega)
# Call JIT-compiled prediction
G_star_jax = self._predict_oscillation_jit(omega_jax, x, G0_scale, tau0)
# Convert (N,2) [G', G''] to complex G* for consistent API
result = np.array(G_star_jax)
return result[:, 0] + 1j * result[:, 1]
def _predict_relaxation(self, t: np.ndarray) -> np.ndarray:
"""Predict relaxation modulus in relaxation mode.
Args:
t: Time array (s)
Returns:
Relaxation modulus array (Pa)
"""
# Get parameters
x = self.parameters.get_value("x")
G0_scale = self.parameters.get_value("G0")
tau0 = self.parameters.get_value("tau0")
# Convert to JAX arrays
t_jax = jnp.asarray(t)
# Call JIT-compiled prediction
G_t_jax = self._predict_relaxation_jit(t_jax, x, G0_scale, tau0)
# Convert back to numpy
return np.array(G_t_jax)
def _predict_creep(self, t: np.ndarray) -> np.ndarray:
"""Predict creep compliance in creep mode.
Args:
t: Time array (s)
Returns:
Creep compliance array (1/Pa)
"""
# Get parameters
x = self.parameters.get_value("x")
G0_scale = self.parameters.get_value("G0")
tau0 = self.parameters.get_value("tau0")
# Convert to JAX arrays
t_jax = jnp.asarray(t)
# Call JIT-compiled prediction
J_t_jax = self._predict_creep_jit(t_jax, x, G0_scale, tau0)
# Convert back to numpy
return np.array(J_t_jax)
def _predict_steady_shear(self, gamma_dot: np.ndarray) -> np.ndarray:
"""Predict viscosity in steady shear mode.
Args:
gamma_dot: Shear rate array (1/s)
Returns:
Viscosity array (Pa.s)
"""
# Get parameters
x = self.parameters.get_value("x")
G0_scale = self.parameters.get_value("G0")
tau0 = self.parameters.get_value("tau0")
# Convert to JAX arrays
gamma_dot_jax = jnp.asarray(gamma_dot)
# Call JIT-compiled prediction
eta_jax = self._predict_steady_shear_jit(gamma_dot_jax, x, G0_scale, tau0)
# Convert back to numpy
return np.array(eta_jax)
def _predict_laos(self, X: np.ndarray) -> np.ndarray:
"""Predict LAOS response.
Args:
X: Time array or strain array
Returns:
Stress response array
"""
if self._gamma_0 is None or self._omega_laos is None:
raise ValueError(
"LAOS prediction requires gamma_0 and omega. "
"Set via fit() with test_mode='laos' or simulate_laos()."
)
strain, stress = self.simulate_laos(
self._gamma_0, self._omega_laos, n_cycles=1, n_points_per_cycle=len(X)
)
return stress
@staticmethod
@jax.jit
def _predict_startup_jit(
t: jnp.ndarray,
x: jax.Array | float,
G0_scale: jax.Array | float,
tau0: jax.Array | float,
gamma_dot: jax.Array | float,
) -> jnp.ndarray:
"""JIT-compiled startup flow prediction: eta_plus(t).
Computes stress growth coefficient η⁺(t) = σ(t)/γ̇ = ∫₀ᵗ G(s) ds.
For SGR's power-law relaxation G(t) ~ G₀ · (1 + t/τ₀)^(1-x):
η⁺(t) = ∫₀ᵗ G(s) ds = G₀ · τ₀ · G₀(x) · [(1+t/τ₀)^(2-x) - 1] / (2-x)
Special case for x=2:
η⁺(t) = G₀ · τ₀ · G₀(x) · ln(1 + t/τ₀)
Args:
t: Time array (s)
x: Effective noise temperature (dimensionless)
G0_scale: Modulus scale (Pa)
tau0: Attempt time (s)
gamma_dot: Applied shear rate (1/s)
Returns:
Stress growth coefficient η⁺(t) with shape (M,)
Multiply by gamma_dot to get stress σ(t).
Notes:
- At short times: η⁺(t) → G₀ · t (elastic response)
- At long times: η⁺(t) → η₀ (zero-shear viscosity)
- For x < 1 (glass): η⁺ → ∞ as t → ∞ (no steady state)
"""
from rheojax.utils.sgr_kernels import G0 as G0_func
# Dimensionless time
t_scaled = t / tau0
# Compute equilibrium modulus factor (dimensionless)
G0_dim = G0_func(x)
epsilon = 1e-12
t_safe = jnp.maximum(t_scaled, epsilon)
# Stress growth exponent. With G(s) ~ (1+s/tau0)^(1-x), the integral
# eta_plus = INT_0^t G ds ~ (1+t/tau0)^(2-x), so the exponent is (2-x):
# for 1<x<2 it grows (no finite zero-shear viscosity); for x>2 it
# saturates to eta_0 = G0*G0(x)*tau0/(x-2). (Earlier versions used x-1.)
exp = 2.0 - x
# Integral of G(t) from 0 to t
# INT G0*G0(x)*(1+s/tau0)^(1-x) ds = G0*G0(x)*tau0 * [(1+t/tau0)^(2-x) - 1]/(2-x)
# Handle the singular case exp ~ 0 (x ~ 2) separately
def exp_near_zero(_):
# INT (1+s/tau0)^(-1) ds = tau0 * ln(1 + t/tau0) (x = 2)
return G0_scale * G0_dim * tau0 * jnp.log(1.0 + t_safe)
def exp_nonzero(_):
# [(1+t/tau0)^(2-x) - 1] / (2-x)
return (
G0_scale * G0_dim * tau0 * ((jnp.power(1.0 + t_safe, exp) - 1.0) / exp)
)
# Use lax.cond for JIT-compatible branching
eta_plus = jax.lax.cond(
jnp.abs(exp) < 1e-6,
exp_near_zero,
exp_nonzero,
operand=None,
)
return eta_plus
def _predict_startup(self, t: np.ndarray) -> np.ndarray:
"""Predict stress growth coefficient in startup flow mode.
Args:
t: Time array (s)
Returns:
Stress growth coefficient η⁺(t) array (Pa·s)
To get stress: σ(t) = γ̇ · η⁺(t)
"""
# Get parameters
x = self.parameters.get_value("x")
G0_scale = self.parameters.get_value("G0")
tau0 = self.parameters.get_value("tau0")
# Get stored shear rate (set during fit)
gamma_dot = getattr(self, "_startup_gamma_dot", None)
if gamma_dot is None:
raise RuntimeError(
"SGRConventional._predict_startup: gamma_dot not provided and "
"_startup_gamma_dot not cached. Call fit() with startup data first."
)
# Convert to JAX arrays
t_jax = jnp.asarray(t)
# Call JIT-compiled prediction
eta_plus_jax = self._predict_startup_jit(t_jax, x, G0_scale, tau0, gamma_dot)
# Convert back to numpy
return np.array(eta_plus_jax)
[docs]
def get_phase_regime(self) -> str:
"""Determine material phase regime from noise temperature x.
The SGR model exhibits three distinct phase regimes based on x:
- Glass phase (x < 1): Solid-like with yield stress
- Power-law fluid (1 <= x < 2): Viscoelastic with G' ~ G'' ~ omega^(x-1)
- Newtonian liquid (x >= 2): Viscous with constant viscosity
Returns:
Phase regime string: 'glass', 'power-law', or 'newtonian'
Example:
>>> model = SGRConventional()
>>> model.parameters.set_value('x', 0.8)
>>> model.get_phase_regime()
'glass'
"""
x = self.parameters.get_value("x")
assert x is not None
if x < 1.0:
return "glass"
elif x < 2.0:
return "power-law"
else:
return "newtonian"
def _compute_x_ss(self, gamma_dot: float, tau0: float) -> float:
"""Compute steady-state effective temperature x_ss(gamma_dot).
The steady-state noise temperature increases with shear rate following:
x_ss(gamma_dot) = x_eq + A * (gamma_dot * tau0)^n
This represents the balance between aging and rejuvenation at constant
shear rate, following Fielding et al.
Args:
gamma_dot: Shear rate (1/s)
tau0: Attempt time (s)
Returns:
Steady-state noise temperature x_ss
Raises:
ValueError: If dynamic_x mode is not enabled
"""
if not self.dynamic_x:
raise ValueError("x_ss computation requires dynamic_x=True mode")
x_eq = self.parameters.get_value("x_eq")
A = self.parameters.get_value("x_ss_A")
n = self.parameters.get_value("x_ss_n")
assert x_eq is not None
assert A is not None
assert n is not None
# Dimensionless shear rate
gamma_dot_dim = gamma_dot * tau0
# Steady-state: x_ss = x_eq + A * (gamma_dot * tau0)^n
x_ss = x_eq + A * (gamma_dot_dim**n)
return x_ss
@staticmethod
@jax.jit
def _dx_dt_jit(
x: jax.Array | float,
gamma_dot: jax.Array | float,
x_eq: jax.Array | float,
x_ss: jax.Array | float,
alpha_aging: jax.Array | float,
beta_rejuv: jax.Array | float,
) -> jax.Array | float:
"""JIT-compiled evolution equation for dx/dt.
The effective temperature x evolves according to:
dx/dt = -alpha_aging * (x - x_eq) + beta_rejuv * gamma_dot * (x_ss - x)
Terms:
- Aging term: -alpha_aging * (x - x_eq)
Drives x toward equilibrium x_eq at rest (gamma_dot = 0)
- Rejuvenation term: beta_rejuv * gamma_dot * (x_ss - x)
Shear-induced increase in x toward steady-state x_ss
Args:
x: Current effective temperature
gamma_dot: Current shear rate (1/s)
x_eq: Equilibrium temperature at rest
x_ss: Steady-state temperature at given shear rate
alpha_aging: Aging rate coefficient
beta_rejuv: Rejuvenation rate coefficient
Returns:
Time derivative dx/dt
"""
# Aging term (drives x -> x_eq)
aging = -alpha_aging * (x - x_eq)
# Rejuvenation term (shear drives x -> x_ss)
rejuvenation = beta_rejuv * gamma_dot * (x_ss - x)
return aging + rejuvenation
[docs]
def evolve_x(
self,
t: np.ndarray,
gamma_dot: np.ndarray,
x_initial: float,
) -> np.ndarray:
"""Evolve effective temperature x(t) via ODE integration.
Integrates the evolution equation:
dx/dt = -alpha_aging * (x - x_eq) + beta_rejuv * gamma_dot(t) * (x_ss(t) - x)
Uses JAX ODE integration (jax.experimental.ode.odeint) for stability
and compatibility with JAX transformations.
Args:
t: Time array (s)
gamma_dot: Shear rate array (1/s), same length as t
x_initial: Initial effective temperature
Returns:
x_trajectory: Effective temperature x(t) at each time point
Raises:
ValueError: If dynamic_x mode is not enabled or array shapes mismatch
Example:
>>> model = SGRConventional(dynamic_x=True)
>>> t = np.linspace(0, 10, 100)
>>> gamma_dot = np.ones_like(t) * 5.0 # Constant shear
>>> x_t = model.evolve_x(t, gamma_dot, x_initial=1.0)
"""
if not self.dynamic_x:
raise ValueError("evolve_x requires dynamic_x=True mode")
if t.shape != gamma_dot.shape:
raise ValueError(
f"Time and shear rate arrays must have same shape: "
f"t.shape={t.shape}, gamma_dot.shape={gamma_dot.shape}"
)
# Get parameters
x_eq = self.parameters.get_value("x_eq")
alpha_aging = self.parameters.get_value("alpha_aging")
beta_rejuv = self.parameters.get_value("beta_rejuv")
tau0 = self.parameters.get_value("tau0")
# Read x_ss_A and x_ss_n once here so they are captured as concrete values
# in the closure, not re-read via self.parameters inside the JIT-traced
# vector field (which would produce stale floats during NUTS sampling).
x_ss_A = self.parameters.get_value("x_ss_A")
x_ss_n = self.parameters.get_value("x_ss_n")
# Convert to JAX arrays
t_jax = jnp.asarray(t)
gamma_dot_jax = jnp.asarray(gamma_dot)
# Build args dict so that x_ss_A and x_ss_n travel through diffrax
# as traced JAX values rather than being re-fetched from self.parameters.
ode_args = {"x_ss_A": x_ss_A, "x_ss_n": x_ss_n}
# Define ODE vector field for diffrax
# Signature: vector_field(t, y, args) -> dy/dt
def vector_field(t_val, x_val, args):
"""ODE vector field: dx/dt = f(t, x)."""
# Interpolate gamma_dot at current time t_val
# Use linear interpolation (jnp.interp is acceptable for
# piecewise-linear interpolation; interpax would be preferred
# for JIT-safe differentiable interpolation if imported).
gamma_dot_current = jnp.interp(t_val, t_jax, gamma_dot_jax)
# Compute x_ss at current shear rate using args, not self.parameters,
# so that traced values flow correctly during NUTS (Fix JAX-002).
x_ss_A = args["x_ss_A"]
x_ss_n = args["x_ss_n"]
gamma_dot_dim = gamma_dot_current * tau0
x_ss_current = x_eq + x_ss_A * (gamma_dot_dim**x_ss_n)
# Compute dx/dt
# Ensure x_val is scalar extract if needed, though diffrax passes arrays
return self._dx_dt_jit(
x_val, gamma_dot_current, x_eq, x_ss_current, alpha_aging, beta_rejuv
)
# Solve ODE using Diffrax
# Use Tsit5 (Runge-Kutta 5(4)) which is generally efficient for non-stiff problems
# Use PIDController for adaptive step size (similar to odeint)
# Wrap with checkpoint to reduce VJP memory during NUTS reverse-mode AD
term = diffrax.ODETerm(jax.checkpoint(vector_field))
solver = diffrax.Tsit5()
t0 = t_jax[0]
t1 = t_jax[-1]
dt0 = (t1 - t0) / len(t_jax) if len(t_jax) > 1 else 0.001
# Save solution at specified time points
saveat = diffrax.SaveAt(ts=t_jax)
# Step size controller
stepsize_controller = diffrax.PIDController(
rtol=1.4e-8, atol=1.4e-8
) # Match standard precision
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0=x_initial,
args=ode_args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=100000, # Safety limit
throw=False,
)
# Extract trajectory
# sol.ys has shape (num_save_points, num_states) -> (N, 1) or (N,) depending on y0
x_trajectory_jax = sol.ys
# Ensure correct shape (N,)
if x_trajectory_jax.ndim > 1:
x_trajectory_jax = x_trajectory_jax.squeeze()
# Handle solver failures
x_trajectory_jax = jnp.where(
sol.result == diffrax.RESULTS.successful,
x_trajectory_jax,
jnp.nan * jnp.ones_like(x_trajectory_jax),
)
# Convert back to numpy
x_trajectory = np.array(x_trajectory_jax)
# Store trajectory for analysis
self._x_trajectory = x_trajectory
self._t_trajectory = t
return x_trajectory
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""Model function for Bayesian inference with NumPyro NUTS.
This method is required by BayesianMixin for NumPyro NUTS sampling.
It computes SGR predictions given input X and a parameter array.
Args:
X: Independent variable (frequency for oscillation, time for relaxation)
params: Array of parameter values [x, G0, tau0]
Length: 3
test_mode: Optional test mode override. If None, uses stored self._test_mode
**kwargs: Protocol-specific arguments (gamma_dot, sigma_applied, etc.)
Returns:
Model predictions as JAX array
- For oscillation: Complex modulus [G', G''] with shape (M, 2)
- For relaxation: Relaxation modulus with shape (M,)
- For creep: Creep compliance with shape (M,)
- For steady_shear: Viscosity with shape (M,)
- For startup: Stress growth coefficient with shape (M,)
Note:
Uses stored test_mode from last fit() call for mode-aware inference.
This ensures correct kernel functions are used during NUTS sampling.
Example:
>>> # During Bayesian inference, NumPyro calls:
>>> predictions = model.model_function(omega, params_sample)
>>> # Where params_sample = [x_sample, G0_sample, tau0_sample]
"""
# Extract parameters from array
x = params[0]
G0_scale = params[1]
tau0 = params[2]
# Use stored test mode from last fit (critical for mode-aware inference)
mode = test_mode if test_mode is not None else self._test_mode
if mode is None:
mode = "oscillation" # Default fallback
# Convert X to JAX array
X_jax = jnp.asarray(X)
# Route to appropriate prediction method based on test mode
if mode == "oscillation":
return self._predict_oscillation_jit(X_jax, x, G0_scale, tau0)
elif mode == "relaxation":
return self._predict_relaxation_jit(X_jax, x, G0_scale, tau0)
elif mode == "creep":
return self._predict_creep_jit(X_jax, x, G0_scale, tau0)
elif mode in ("steady_shear", "flow_curve"):
return self._predict_steady_shear_jit(X_jax, x, G0_scale, tau0)
elif mode == "laos":
# For LAOS Bayesian inference, use oscillation response
return self._predict_oscillation_jit(X_jax, x, G0_scale, tau0)
elif mode == "startup":
# Get gamma_dot from kwargs or instance attribute
# Priority: explicit kwarg > instance attr > default
# Use None sentinel (not `or`) to avoid swallowing gamma_dot=0.0
gamma_dot = kwargs.get("gamma_dot")
if gamma_dot is None:
gamma_dot = getattr(self, "_startup_gamma_dot", None)
if gamma_dot is None:
# R11-SGR-001: Require explicit gamma_dot — silent 1.0 default masks bugs
raise RuntimeError(
"SGRConventional.model_function: gamma_dot not provided and "
"_startup_gamma_dot not cached. Call fit() with startup data first."
)
return self._predict_startup_jit(X_jax, x, G0_scale, tau0, gamma_dot)
else:
raise ValueError(f"Unsupported test mode: {mode}")
# =========================================================================
# LAOS (Large Amplitude Oscillatory Shear) Methods - Task Group 6
# =========================================================================
[docs]
def simulate_laos(
self,
gamma_0: float,
omega: float,
n_cycles: int = 2,
n_points_per_cycle: int = 256,
) -> tuple[np.ndarray, np.ndarray]:
"""Simulate LAOS response for given strain amplitude and frequency.
Generates time-domain stress response to sinusoidal strain input:
gamma(t) = gamma_0 * sin(omega * t)
For SGR model, the stress response is computed using the complex modulus
in the linear viscoelastic approximation, with nonlinearity arising from
strain-dependent softening at large amplitudes.
Args:
gamma_0: Strain amplitude (dimensionless)
omega: Angular frequency (rad/s)
n_cycles: Number of oscillation cycles to simulate
n_points_per_cycle: Number of time points per cycle
Returns:
strain: Strain array gamma(t)
stress: Stress array sigma(t)
Example:
>>> model = SGRConventional()
>>> model.parameters.set_value("x", 1.5)
>>> strain, stress = model.simulate_laos(gamma_0=0.1, omega=1.0)
"""
# Store LAOS parameters
self._gamma_0 = gamma_0
self._omega_laos = omega
# Get model parameters
x = self.parameters.get_value("x")
G0_scale = self.parameters.get_value("G0")
tau0 = self.parameters.get_value("tau0")
# Time array
period = 2.0 * np.pi / omega
t_max = n_cycles * period
n_points = n_cycles * n_points_per_cycle
t = np.linspace(0, t_max, n_points, endpoint=False)
# Strain: gamma(t) = gamma_0 * sin(omega * t)
strain = gamma_0 * np.sin(omega * t)
# Strain rate: gamma_dot(t) = gamma_0 * omega * cos(omega * t)
strain_rate = gamma_0 * omega * np.cos(omega * t)
# Get complex modulus at this frequency
omega_arr = np.array([omega])
G_star = self._predict_oscillation_jit(
jnp.asarray(omega_arr), x, G0_scale, tau0
)
G_prime = float(G_star[0, 0])
G_double_prime = float(G_star[0, 1])
# In linear viscoelastic regime:
# sigma(t) = G' * gamma(t) + (G'' / omega) * gamma_dot(t)
# sigma(t) = G' * gamma_0 * sin(omega*t) + G'' * gamma_0 * cos(omega*t)
stress = G_prime * strain + (G_double_prime / omega) * strain_rate
# Add weak nonlinearity based on SGR physics for large amplitudes
# Large strains can cause local yielding events in SGR
# Approximate this by adding strain-softening at large |gamma|
if gamma_0 > 0.1: # Only for larger amplitudes
# Strain-softening factor: reduces stress at high strains
softening = 1.0 - 0.1 * (np.abs(strain) / gamma_0) ** 2
stress = stress * softening
return strain, stress
[docs]
def compute_chebyshev_coefficients(
self,
strain: np.ndarray,
stress: np.ndarray,
gamma_0: float,
omega: float,
n_points_per_cycle: int = 256,
) -> dict:
"""Compute Chebyshev decomposition of LAOS response.
Decomposes stress into elastic and viscous Chebyshev contributions:
sigma(gamma, gamma_dot) = sum_n e_n * T_n(gamma/gamma_0)
+ sum_n v_n * T_n(gamma_dot/gamma_dot_0)
where T_n are Chebyshev polynomials of the first kind.
Physical interpretation:
- e_n: Elastic (in-phase with strain) Chebyshev coefficients
- v_n: Viscous (out-of-phase with strain) Chebyshev coefficients
- e_3/e_1 > 0: Strain stiffening
- e_3/e_1 < 0: Strain softening
- v_3/v_1 > 0: Shear thickening
- v_3/v_1 < 0: Shear thinning
Args:
strain: Strain array from simulate_laos()
stress: Stress array from simulate_laos()
gamma_0: Strain amplitude
omega: Angular frequency
n_points_per_cycle: Points per oscillation cycle
Returns:
Dictionary containing:
- e_1, e_3, e_5: Elastic Chebyshev coefficients
- v_1, v_3, v_5: Viscous Chebyshev coefficients
- e_3_e_1, v_3_v_1: Normalized coefficients
Example:
>>> strain, stress = model.simulate_laos(gamma_0=0.5, omega=1.0)
>>> chebyshev = model.compute_chebyshev_coefficients(
... strain, stress, gamma_0=0.5, omega=1.0
... )
>>> print(f"Strain stiffening ratio: {chebyshev['e_3_e_1']:.4f}")
"""
# Use last complete cycle
strain_cycle = strain[-n_points_per_cycle:]
stress_cycle = stress[-n_points_per_cycle:]
# Normalize strain to [-1, 1] for Chebyshev basis
gamma_norm = strain_cycle / gamma_0
# Compute strain rate
dt = 2.0 * np.pi / (omega * n_points_per_cycle)
gamma_dot = np.gradient(strain_cycle, dt)
gamma_dot_0 = gamma_0 * omega
gamma_dot_norm = gamma_dot / gamma_dot_0
# Chebyshev polynomials T_n(x)
# T_0(x) = 1
# T_1(x) = x
# T_2(x) = 2x^2 - 1
# T_3(x) = 4x^3 - 3x
# T_5(x) = 16x^5 - 20x^3 + 5x
def T_1(x):
return x
def T_3(x):
return 4 * x**3 - 3 * x
def T_5(x):
return 16 * x**5 - 20 * x**3 + 5 * x
# Decompose stress into elastic (in-phase) and viscous (out-of-phase) parts
# Using orthogonality of Chebyshev polynomials
# Elastic part: stress component in-phase with strain
# sigma_elastic proportional to strain
# Viscous part: stress component in-phase with strain rate
# sigma_viscous proportional to strain rate (90 degrees out of phase)
# Project stress onto Chebyshev basis
# Use numerical integration (trapezoidal rule)
# Weight function for Chebyshev: w(x) = 1/sqrt(1-x^2)
# But for LAOS, we use uniform weighting over the cycle
# Elastic coefficients (project onto strain-dependent basis)
e_1 = 2.0 * np.mean(stress_cycle * T_1(gamma_norm))
e_3 = 2.0 * np.mean(stress_cycle * T_3(gamma_norm))
e_5 = 2.0 * np.mean(stress_cycle * T_5(gamma_norm))
# Viscous coefficients (project onto strain-rate-dependent basis)
v_1 = 2.0 * np.mean(stress_cycle * T_1(gamma_dot_norm))
v_3 = 2.0 * np.mean(stress_cycle * T_3(gamma_dot_norm))
v_5 = 2.0 * np.mean(stress_cycle * T_5(gamma_dot_norm))
# Build result dictionary
chebyshev = {
"e_1": e_1,
"e_3": e_3,
"e_5": e_5,
"v_1": v_1,
"v_3": v_3,
"v_5": v_5,
}
# Normalized coefficients (standard LAOS metrics)
if abs(e_1) > 1e-12:
chebyshev["e_3_e_1"] = e_3 / e_1
chebyshev["e_5_e_1"] = e_5 / e_1
else:
chebyshev["e_3_e_1"] = 0.0
chebyshev["e_5_e_1"] = 0.0
if abs(v_1) > 1e-12:
chebyshev["v_3_v_1"] = v_3 / v_1
chebyshev["v_5_v_1"] = v_5 / v_1
else:
chebyshev["v_3_v_1"] = 0.0
chebyshev["v_5_v_1"] = 0.0
return chebyshev
[docs]
def get_lissajous_curve(
self,
gamma_0: float,
omega: float,
n_points: int = 256,
normalized: bool = False,
) -> tuple[np.ndarray, np.ndarray]:
"""Generate Lissajous curve (stress vs strain) for LAOS.
Args:
gamma_0: Strain amplitude
omega: Angular frequency (rad/s)
n_points: Number of points in curve
normalized: If True, normalize strain and stress
Returns:
strain: Strain array (one period)
stress: Stress array (one period)
Example:
>>> strain, stress = model.get_lissajous_curve(gamma_0=0.1, omega=1.0)
>>> plt.plot(strain, stress) # Elastic Lissajous
"""
# Simulate two cycles
strain, stress = self.simulate_laos(
gamma_0, omega, n_cycles=2, n_points_per_cycle=n_points
)
# Use last cycle for steady-state
strain_cycle = strain[-n_points:]
stress_cycle = stress[-n_points:]
if normalized:
strain_cycle = strain_cycle / gamma_0
stress_max = np.max(np.abs(stress_cycle))
if stress_max > 0:
stress_cycle = stress_cycle / stress_max
return strain_cycle, stress_cycle
# =========================================================================
# Thixotropy Extension Methods
# =========================================================================
[docs]
def enable_thixotropy(
self,
k_build: float = 0.1,
k_break: float = 0.5,
n_struct: float = 2.0,
) -> None:
"""Enable thixotropy modeling with structural parameter lambda(t).
Adds thixotropy kinetics parameters to the model. The structural
parameter lambda represents the state of internal microstructure:
- lambda = 1: Fully built structure
- lambda = 0: Fully broken structure
Evolution equation:
d(lambda)/dt = k_build * (1 - lambda) - k_break * gamma_dot * lambda
The effective modulus is coupled to lambda:
G_eff(t) = G0 * lambda(t)^n_struct
Args:
k_build: Structure build-up rate (1/s), default 0.1
k_break: Structure breakdown rate (dimensionless), default 0.5
n_struct: Structural coupling exponent, default 2.0
Example:
>>> model = SGRConventional()
>>> model.enable_thixotropy(k_build=0.1, k_break=0.5, n_struct=2.0)
>>> # Now model can predict stress transients with thixotropy
"""
# Add thixotropy parameters if not already present
if "k_build" not in self.parameters.keys():
self.parameters.add(
name="k_build",
value=k_build,
bounds=(0.0, 10.0),
units="1/s",
description="Structure build-up rate (1/s)",
)
if "k_break" not in self.parameters.keys():
self.parameters.add(
name="k_break",
value=k_break,
bounds=(0.0, 10.0),
units="dimensionless",
description="Structure breakdown rate (shear-dependent)",
)
if "n_struct" not in self.parameters.keys():
self.parameters.add(
name="n_struct",
value=n_struct,
bounds=(0.1, 5.0),
units="dimensionless",
description="Structural coupling exponent",
)
# Flag for thixotropy mode
self._thixotropy_enabled = True
# Storage for lambda trajectory
self._lambda_trajectory: np.ndarray | None = None
[docs]
def evolve_lambda(
self,
t: np.ndarray,
gamma_dot: np.ndarray,
lambda_initial: float = 1.0,
) -> np.ndarray:
"""Evolve structural parameter lambda(t) for given shear history.
Integrates the thixotropy kinetics equation:
d(lambda)/dt = k_build * (1 - lambda) - k_break * gamma_dot * lambda
Uses JAX lax.scan for vectorized time-stepping (replaces Python for-loop).
Args:
t: Time array (s)
gamma_dot: Shear rate array (1/s), same shape as t
lambda_initial: Initial structural parameter [0, 1], default 1.0
Returns:
lambda_t: Structural parameter evolution, same shape as t
Raises:
ValueError: If thixotropy not enabled or array shapes mismatch
Example:
>>> model = SGRConventional()
>>> model.enable_thixotropy()
>>> t = np.linspace(0, 10, 100)
>>> gamma_dot = np.ones_like(t) * 10.0 # Constant shear
>>> lambda_t = model.evolve_lambda(t, gamma_dot, lambda_initial=1.0)
"""
if not getattr(self, "_thixotropy_enabled", False):
raise ValueError("Thixotropy not enabled. Call enable_thixotropy() first.")
if t.shape != gamma_dot.shape:
raise ValueError(
f"Time and shear rate arrays must have same shape: "
f"t.shape={t.shape}, gamma_dot.shape={gamma_dot.shape}"
)
# Get thixotropy parameters
k_build = self.parameters.get_value("k_build")
k_break = self.parameters.get_value("k_break")
# Compute dt array
t_jax = jnp.asarray(t)
dt = jnp.diff(t_jax, prepend=t_jax[0])
gamma_dot_abs = jnp.abs(jnp.asarray(gamma_dot))
# lax.scan step function
def step(lambda_prev, inputs):
dt_i, gdot_i = inputs
# dy/dt = A - B*y with exponential integrator
A = k_build
B = k_build + k_break * gdot_i
# Exact exponential integration for B > 0; linear for B ≈ 0
lambda_ss = A / jnp.maximum(B, 1e-30)
decay = jnp.exp(-B * dt_i)
lambda_exp = lambda_ss + (lambda_prev - lambda_ss) * decay
lambda_lin = lambda_prev + A * dt_i
lambda_new = jnp.where(B > 1e-12, lambda_exp, lambda_lin)
lambda_new = jnp.clip(lambda_new, 0.0, 1.0)
return lambda_new, lambda_new
# Scan over time steps (skip first step where dt=0)
_, lambda_steps = jax.lax.scan(
step, jnp.float64(lambda_initial), (dt[1:], gamma_dot_abs[1:])
)
# Prepend initial value
lambda_t = np.asarray(
jnp.concatenate([jnp.array([lambda_initial]), lambda_steps])
)
# Store trajectory
self._lambda_trajectory = lambda_t
return lambda_t
[docs]
def predict_thixotropic_stress(
self,
t: np.ndarray,
gamma_dot: np.ndarray,
lambda_t: np.ndarray | None = None,
lambda_initial: float = 1.0,
) -> np.ndarray:
"""Predict stress response with thixotropic modulus.
The effective modulus is coupled to the structural parameter:
G_eff(t) = G0 * lambda(t)^n_struct
Args:
t: Time array (s)
gamma_dot: Shear rate array (1/s)
lambda_t: Pre-computed lambda trajectory, or None to compute
lambda_initial: Initial lambda if computing [0, 1], default 1.0
Returns:
sigma: Stress response (Pa)
Example:
>>> model = SGRConventional()
>>> model.enable_thixotropy()
>>> t = np.linspace(0, 10, 100)
>>> gamma_dot = np.ones_like(t) * 10.0
>>> sigma = model.predict_thixotropic_stress(t, gamma_dot)
"""
if not getattr(self, "_thixotropy_enabled", False):
raise ValueError("Thixotropy not enabled. Call enable_thixotropy() first.")
# Compute lambda trajectory if not provided
if lambda_t is None:
lambda_t = self.evolve_lambda(t, gamma_dot, lambda_initial)
# Get parameters
G0_val = self.parameters.get_value("G0")
tau0 = self.parameters.get_value("tau0")
x = self.parameters.get_value("x")
n_struct = self.parameters.get_value("n_struct")
assert G0_val is not None
assert tau0 is not None
assert x is not None
assert n_struct is not None
# Effective modulus from structure
G_eff = G0_val * np.power(lambda_t, n_struct)
# Viscosity from power-law (SGR-like)
gamma_dot_safe = np.maximum(np.abs(gamma_dot), 1e-12)
eta_factor = np.power(gamma_dot_safe * tau0, x - 2.0)
# Stress = G_eff * gamma_dot * tau0 * eta_factor
sigma = G_eff * gamma_dot * tau0 * eta_factor
return sigma
[docs]
def predict_stress_transient(
self,
t: np.ndarray,
gamma_dot: np.ndarray,
lambda_initial: float = 1.0,
) -> tuple[np.ndarray, np.ndarray]:
"""Predict stress transient (overshoot/undershoot) for shear step protocol.
For step-up in shear rate: Initially high stress (intact structure)
followed by decay as structure breaks down (overshoot).
For step-down in shear rate: Initially low stress (broken structure)
followed by increase as structure rebuilds (undershoot).
Args:
t: Time array (s)
gamma_dot: Shear rate array (1/s), can include steps
lambda_initial: Initial structural parameter [0, 1]
Returns:
sigma: Stress response (Pa)
lambda_t: Structural parameter evolution
Example:
>>> model = SGRConventional()
>>> model.enable_thixotropy()
>>> t = np.linspace(0, 10, 100)
>>> gamma_dot = np.ones_like(t)
>>> gamma_dot[t >= 5] = 10.0 # Step up at t=5
>>> sigma, lambda_t = model.predict_stress_transient(t, gamma_dot)
"""
# Evolve lambda
lambda_t = self.evolve_lambda(t, gamma_dot, lambda_initial)
# Compute stress
sigma = self.predict_thixotropic_stress(t, gamma_dot, lambda_t)
return sigma, lambda_t
# =========================================================================
# Shear Banding Detection Methods
# =========================================================================
[docs]
def detect_shear_banding(
self,
gamma_dot: np.ndarray | None = None,
sigma: np.ndarray | None = None,
n_points: int = 100,
gamma_dot_range: tuple[float, float] = (1e-2, 1e2),
) -> tuple[bool, dict | None]:
"""Detect shear banding from constitutive curve.
Computes the steady-state flow curve and checks for non-monotonicity
(d sigma / d gamma_dot < 0) which indicates shear banding instability.
Args:
gamma_dot: Shear rate array (1/s). If None, uses gamma_dot_range.
sigma: Stress array (Pa). If None, computes from model.
n_points: Number of points if computing flow curve
gamma_dot_range: Range for computing flow curve if gamma_dot is None
Returns:
is_banding: True if shear banding detected
banding_info: Dict with banding region info, or None
Example:
>>> model = SGRConventional()
>>> model.parameters.set_value("x", 0.8) # Glass regime
>>> is_banding, info = model.detect_shear_banding()
"""
# Import detection function
from rheojax.transforms.srfs import detect_shear_banding as _detect_banding
# Compute flow curve if not provided
if gamma_dot is None:
gamma_dot = np.logspace(
np.log10(gamma_dot_range[0]),
np.log10(gamma_dot_range[1]),
n_points,
)
if sigma is None:
# Compute viscosity from model
self._test_mode = "steady_shear"
eta = self.predict(gamma_dot)
sigma = eta * gamma_dot
# Detect shear banding
is_banding, banding_info = _detect_banding(gamma_dot, sigma, warn=True)
return is_banding, banding_info
[docs]
def predict_banded_flow(
self,
gamma_dot_applied: float,
gamma_dot: np.ndarray | None = None,
sigma: np.ndarray | None = None,
n_points: int = 100,
) -> dict | None:
"""Predict flow in shear banding regime with lever rule.
When shear banding occurs, the material splits into bands with
different local shear rates. This method computes the band
fractions and the composite stress.
Args:
gamma_dot_applied: Applied average shear rate (1/s)
gamma_dot: Shear rate array for flow curve. If None, computed.
sigma: Stress array for flow curve. If None, computed.
n_points: Number of points if computing flow curve
Returns:
coexistence: Dict with band coexistence info, or None
Example:
>>> model = SGRConventional()
>>> model.parameters.set_value("x", 0.8)
>>> coex = model.predict_banded_flow(gamma_dot_applied=1.0)
>>> if coex:
... print(f"Low band: {coex['fraction_low']:.2%}")
... print(f"High band: {coex['fraction_high']:.2%}")
"""
from rheojax.transforms.srfs import compute_shear_band_coexistence
# Compute flow curve if not provided
if gamma_dot is None:
gamma_dot = np.logspace(-2, 3, n_points)
if sigma is None:
self._test_mode = "steady_shear"
eta = self.predict(gamma_dot)
sigma = eta * gamma_dot
# Compute coexistence
coexistence = compute_shear_band_coexistence(
gamma_dot, sigma, gamma_dot_applied
)
return coexistence