"""Single-mode Transient Network Theory (TNT) model.
This module implements `TNTSingleMode`, a composable constitutive model for
associative polymers and physical gels with reversible crosslinks.
Composability
-------------
The single-mode TNT model supports multiple variant combinations via
constructor parameters:
- ``breakage``: Breakage rate function ("constant", "bell", "power_law",
"stretch_creation")
- ``stress_type``: Stress formula ("linear", "fene")
- ``xi``: Slip parameter (0 = upper-convected, >0 = Gordon-Schowalter)
These can be combined freely, e.g.::
TNTSingleMode(breakage="bell", stress_type="fene", xi=0.3)
Supported Protocols
-------------------
- FLOW_CURVE: Steady shear (analytical for constant, numerical otherwise)
- OSCILLATION: Small-amplitude oscillatory shear (analytical, Maxwell-like)
- STARTUP: Transient stress growth at constant rate (ODE)
- RELAXATION: Stress decay after cessation (analytical or ODE)
- CREEP: Strain evolution under constant stress (ODE)
- LAOS: Large-amplitude oscillatory shear (ODE)
Example
-------
>>> from rheojax.models.tnt import TNTSingleMode
>>> import numpy as np
>>>
>>> # Basic (Tanaka-Edwards) model
>>> model = TNTSingleMode()
>>>
>>> # Flow curve (analytical)
>>> gamma_dot = np.logspace(-2, 2, 50)
>>> sigma = model.predict(gamma_dot, test_mode='flow_curve')
>>>
>>> # Bell force-dependent variant
>>> model_bell = TNTSingleMode(breakage="bell")
>>>
>>> # FENE + non-affine composition
>>> model_full = TNTSingleMode(stress_type="fene", xi=0.2)
References
----------
- Green, M.S. & Tobolsky, A.V. (1946). J. Chem. Phys. 14, 80-92.
- Tanaka, F. & Edwards, S.F. (1992). Macromolecules 25, 1516-1523.
- Bell, G.I. (1978). Science 200, 618-627.
"""
from __future__ import annotations
import logging
from typing import Literal
import numpy as np
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.tnt._base import TNTBase
from rheojax.models.tnt._kernels import (
build_tnt_creep_ode_rhs,
build_tnt_laos_ode_rhs,
build_tnt_ode_rhs,
build_tnt_relaxation_ode_rhs,
tnt_base_relaxation_vec,
tnt_base_steady_conformation,
tnt_base_steady_n1_vec,
tnt_base_steady_stress_vec,
tnt_saos_moduli_vec,
)
jax, jnp = safe_import_jax()
logger = logging.getLogger(__name__)
_MISSING = object()
# Breakage type alias
BreakageType = Literal["constant", "bell", "power_law", "stretch_creation"]
StressType = Literal["linear", "fene"]
[docs]
@ModelRegistry.register(
"tnt_single_mode",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
@ModelRegistry.register(
"tnt",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class TNTSingleMode(TNTBase):
"""Single-mode Transient Network Theory model.
Implements the Green-Tobolsky / Tanaka-Edwards transient network model
with composable physics variants. The conformation tensor S tracks the
average chain configuration between reversible crosslinks.
The constitutive equation is::
dS/dt = L·S + S·L^T + g₀·I - β(S)·S
Stress is computed from S via σ = G·f(S)·(S - I) + η_s·γ̇.
Parameters
----------
breakage : str, default "constant"
Breakage rate function:
- "constant": β = 1/τ_b (Tanaka-Edwards, UCM-like)
- "bell": β = (1/τ_b)·exp(ν·(stretch-1)) (force-dependent)
- "power_law": β = (1/τ_b)·stretch^m
- "stretch_creation": β = (1/τ_b), g₀ = (1+κ·stretch)/τ_b
stress_type : str, default "linear"
Stress formula:
- "linear": σ = G·(S - I) (Gaussian chains)
- "fene": σ = G·f(tr(S))·(S - I) (finitely extensible)
xi : float, default 0.0
Gordon-Schowalter slip parameter (0=upper-convected, 1=corotational)
Attributes
----------
parameters : ParameterSet
Model parameters
fitted_ : bool
Whether the model has been fitted
Examples
--------
Basic Tanaka-Edwards model:
>>> model = TNTSingleMode()
>>> gamma_dot = np.logspace(-2, 2, 50)
>>> sigma = model.predict(gamma_dot, test_mode='flow_curve')
Bell force-dependent breakage:
>>> model = TNTSingleMode(breakage="bell")
>>> # Now has additional parameter 'nu' (force sensitivity)
See Also
--------
TNTLoopBridge : Two-species loop-bridge kinetics
TNTCates : Living polymer (wormlike micelle) model
"""
[docs]
def __init__(
self,
breakage: BreakageType = "constant",
stress_type: StressType = "linear",
xi: float = 0.0,
):
"""Initialize single-mode TNT model.
Parameters
----------
breakage : str, default "constant"
Breakage rate function type
stress_type : str, default "linear"
Stress formula type
xi : float, default 0.0
Slip parameter for Gordon-Schowalter derivative
"""
# Store variant flags before calling super().__init__
self._breakage = breakage
self._stress_type = stress_type
self._xi = xi
super().__init__()
self._setup_parameters()
self._build_variant_ode_functions()
self._test_mode = None
# =========================================================================
# Parameter Setup
# =========================================================================
def _setup_parameters(self):
"""Initialize ParameterSet with TNT parameters.
Core parameters (always present):
- G: Network modulus (Pa)
- tau_b: Bond lifetime (s)
- eta_s: Solvent viscosity (Pa·s)
Conditional parameters (variant-dependent):
- nu: Force sensitivity (Bell breakage)
- m_break: Breakage exponent (power_law breakage)
- kappa: Creation enhancement (stretch_creation breakage)
- L_max: Maximum extensibility (FENE stress)
"""
self.parameters = ParameterSet()
# Core parameters
self.parameters.add(
name="G",
value=1e3,
bounds=(1e0, 1e8),
units="Pa",
description="Network modulus (elastic contribution from active chains)",
)
self.parameters.add(
name="tau_b",
value=1.0,
bounds=(1e-6, 1e4),
units="s",
description="Bond lifetime (mean time between detachment events)",
)
self.parameters.add(
name="eta_s",
value=0.0,
bounds=(0.0, 1e4),
units="Pa·s",
description="Solvent viscosity (Newtonian background contribution)",
)
# Conditional: Bell breakage
if self._breakage == "bell":
self.parameters.add(
name="nu",
value=1.0,
bounds=(0.01, 20.0),
units="dimensionless",
description="Force sensitivity (Bell model, higher = more shear-thinning)",
)
# Conditional: Power-law breakage
if self._breakage == "power_law":
self.parameters.add(
name="m_break",
value=2.0,
bounds=(0.5, 10.0),
units="dimensionless",
description="Breakage power-law exponent",
)
# Conditional: Stretch-creation breakage
if self._breakage == "stretch_creation":
self.parameters.add(
name="kappa",
value=0.5,
bounds=(0.0, 5.0),
units="dimensionless",
description="Creation rate enhancement from chain stretch",
)
# Conditional: FENE stress
if self._stress_type == "fene":
self.parameters.add(
name="L_max",
value=10.0,
bounds=(2.0, 100.0),
units="dimensionless",
description="Maximum chain extensibility (FENE-P spring)",
)
# =========================================================================
# Property Accessors
# =========================================================================
@property
def G(self) -> float:
"""Get network modulus G (Pa)."""
val = self.parameters.get_value("G")
return float(val) if val is not None else 0.0
@property
def tau_b(self) -> float:
"""Get bond lifetime τ_b (s)."""
val = self.parameters.get_value("tau_b")
return float(val) if val is not None else 0.0
@property
def eta_s(self) -> float:
"""Get solvent viscosity η_s (Pa·s)."""
val = self.parameters.get_value("eta_s")
return float(val) if val is not None else 0.0
@property
def eta_0(self) -> float:
"""Get zero-shear viscosity η₀ = G·τ_b + η_s (Pa·s)."""
return self.G * self.tau_b + self.eta_s
@property
def breakage(self) -> str:
"""Get breakage type."""
return self._breakage
@property
def stress_type(self) -> str:
"""Get stress type."""
return self._stress_type
@property
def xi(self) -> float:
"""Get slip parameter ξ."""
return self._xi
@property
def _is_basic(self) -> bool:
"""Whether this is the basic (constant/linear/UC) variant."""
return (
self._breakage == "constant"
and self._stress_type == "linear"
and self._xi == 0.0
)
# =========================================================================
# Variant ODE Infrastructure
# =========================================================================
def _build_variant_ode_functions(self):
"""Build and cache variant-specific ODE RHS functions.
Called once in __init__. Each variant combination traces to a
separate JAX-compiled function.
"""
use_fene = self._stress_type == "fene"
use_gs = self._xi > 0
self._variant_ode = build_tnt_ode_rhs(self._breakage, use_fene, use_gs)
self._variant_creep_ode = build_tnt_creep_ode_rhs(
self._breakage, use_fene, use_gs
)
self._variant_laos_ode = build_tnt_laos_ode_rhs(
self._breakage, use_fene, use_gs
)
self._variant_relax_ode = build_tnt_relaxation_ode_rhs(
self._breakage, use_fene, use_gs
)
def _get_variant_args(self) -> dict:
"""Build args dict with variant parameters from self.parameters.
Used by public simulation methods that read parameter values
from the fitted ParameterSet.
"""
args = {
"nu": 0.0,
"m_break": 0.0,
"kappa": 0.0,
"L_max": 10.0,
"xi": self._xi,
}
if self._breakage == "bell":
val = self.parameters.get_value("nu")
args["nu"] = float(val) if val is not None else 0.0
elif self._breakage == "power_law":
val = self.parameters.get_value("m_break")
args["m_break"] = float(val) if val is not None else 0.0
elif self._breakage == "stretch_creation":
val = self.parameters.get_value("kappa")
args["kappa"] = float(val) if val is not None else 0.0
if self._stress_type == "fene":
val = self.parameters.get_value("L_max")
args["L_max"] = float(val) if val is not None else 10.0
return args
def _unpack_variant_params(self, params) -> dict:
"""Unpack variant parameters from a JAX params array.
Used by model_function where G, tau_b, eta_s and variant params
come from the traced parameter array (for NLSQ/NUTS).
Returns dict with all variant param values (dummy values for
inactive variants).
"""
result = {
"nu": 0.0,
"m_break": 0.0,
"kappa": 0.0,
"L_max": 10.0,
"xi": self._xi,
}
idx = 3 # After G, tau_b, eta_s
if self._breakage == "bell":
result["nu"] = params[idx]
idx += 1
elif self._breakage == "power_law":
result["m_break"] = params[idx]
idx += 1
elif self._breakage == "stretch_creation":
result["kappa"] = params[idx]
idx += 1
if self._stress_type == "fene":
result["L_max"] = params[idx]
idx += 1
return result
def _get_ode_solver(self):
"""Return appropriate ODE solver for this variant.
Uses Dopri5 (5th-order Dormand-Prince) for non-basic variants
which may be mildly stiff, Tsit5 for basic (constant breakage).
Note: Implicit solvers (Kvaerno5) are incompatible with current
diffrax/lineax versions due to TracerBoolConversionError in LU.
"""
if self._is_basic:
return diffrax.Tsit5()
return diffrax.Dopri5()
def _compute_stress_from_conformation(
self, S_xx, S_yy, S_zz, S_xy, G, eta_s, gamma_dot, vp
):
"""Compute total shear stress from conformation tensor components.
Handles both linear and FENE-P stress types.
"""
if self._stress_type == "fene":
tr_S = S_xx + S_yy + S_zz
L2 = vp["L_max"] * vp["L_max"]
f = L2 / jnp.maximum(L2 - tr_S, 1e-10)
sigma_el = G * f * S_xy
else:
sigma_el = G * S_xy
return sigma_el + eta_s * gamma_dot
# =========================================================================
# Core Interface Methods
# =========================================================================
def _fit(self, x, y, **kwargs):
"""Fit model to data using protocol-aware optimization.
Parameters
----------
x : array-like
Independent variable (shear rate, frequency, or time)
y : array-like
Dependent variable (stress, modulus, or strain).
For oscillation, accepts complex G* = G' + iG'',
(N, 2) array [G', G''], or real |G*|.
**kwargs
Additional arguments including test_mode
Returns
-------
self
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
self._test_mode = test_mode
x_jax = jnp.asarray(x, dtype=jnp.float64)
# Preserve complex dtype for oscillation data (G* = G' + iG'')
y_arr = np.asarray(y)
if np.iscomplexobj(y_arr):
y_jax = jnp.asarray(y_arr, dtype=jnp.complex128)
else:
y_jax = jnp.asarray(y_arr, dtype=jnp.float64)
# Store protocol-specific inputs
self._gamma_dot_applied = kwargs.get("gamma_dot")
self._sigma_applied = kwargs.get("sigma_applied")
self._gamma_0 = kwargs.get("gamma_0")
self._omega_laos = kwargs.get("omega")
# Smart initialization based on protocol
if test_mode in ["flow_curve", "steady_shear", "rotation"]:
self.initialize_from_flow_curve(np.asarray(x), np.asarray(y))
elif test_mode == "oscillation":
self.initialize_from_saos(
np.asarray(x), np.real(np.asarray(y)), np.imag(np.asarray(y))
)
elif test_mode == "relaxation":
self.initialize_from_relaxation(np.asarray(x), np.asarray(y))
elif test_mode in ("startup", "creep", "laos"):
self.initialize_from_time_domain(
np.asarray(x),
np.real(np.asarray(y)),
gamma_dot=kwargs.get("gamma_dot"),
sigma_applied=kwargs.get("sigma_applied"),
)
# Define model function for fitting
def model_fn(x_fit, params):
return self.model_function(x_fit, params, test_mode=test_mode)
# Create objective and optimize
objective = create_least_squares_objective(
model_fn,
x_jax,
y_jax,
use_log_residuals=kwargs.get(
"use_log_residuals", test_mode == "flow_curve"
),
)
# ODE-based protocols use diffrax with custom_vjp, incompatible with
# NLSQ forward-mode AD. Default to scipy to avoid failed attempt overhead.
_ode_protocols = {"startup", "relaxation", "creep", "laos"}
_default_method = "scipy" if test_mode in _ode_protocols else "auto"
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", _default_method),
max_iter=kwargs.get("max_iter", 2000),
)
self.fitted_ = True
self._nlsq_result = result
logger.info(
f"Fitted TNTSingleMode ({self._breakage}): "
f"G={self.G:.2e}, τ_b={self.tau_b:.2e}, η_s={self.eta_s:.2e}"
)
return self
def _predict(self, x, **kwargs):
"""Predict response using protocol-aware dispatch.
Parameters
----------
x : array-like
Independent variable
**kwargs
Additional arguments including test_mode, gamma_dot, sigma_applied,
gamma_0, omega
Returns
-------
jnp.ndarray
Predicted response
"""
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
x_jax = jnp.asarray(x, dtype=jnp.float64)
# Extract and store protocol-specific parameters from kwargs
if "gamma_dot" in kwargs:
self._gamma_dot_applied = kwargs["gamma_dot"]
if "sigma_applied" in kwargs:
self._sigma_applied = kwargs["sigma_applied"]
if "gamma_0" in kwargs:
self._gamma_0 = kwargs["gamma_0"]
if "omega" in kwargs:
self._omega_laos = kwargs["omega"]
# Build parameter array from ParameterSet (ordering matters)
param_values = [
float(self.parameters.get_value(name)) for name in self.parameters.keys()
]
params = jnp.array(param_values)
result = self.model_function(x_jax, params, test_mode=test_mode)
# model_function returns (N,2) [G', G''] for oscillation;
# convert to complex G* for consistent API
if test_mode == "oscillation" and result.ndim == 2 and result.shape[1] == 2:
result = result[:, 0] + 1j * result[:, 1]
return result
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""NumPyro/BayesianMixin model function.
Routes to appropriate prediction based on test_mode. This is the
stateless function used for both NLSQ optimization and Bayesian
inference.
Parameters
----------
X : array-like
Independent variable
params : array-like
Parameter values in ParameterSet order: [G, tau_b, eta_s, ...]
test_mode : str, optional
Override stored test mode
**kwargs
Protocol-specific parameters: gamma_dot, sigma_applied, gamma_0, omega
Returns
-------
jnp.ndarray
Predicted response
"""
# Unpack core parameters (always first 3)
G = params[0]
tau_b = params[1]
eta_s = params[2]
# Unpack variant parameters (dummy values for inactive variants)
vp = self._unpack_variant_params(params)
mode = (
test_mode
if test_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
X_jax = jnp.asarray(X, dtype=jnp.float64)
# Use sentinel pattern to avoid swallowing falsy values (e.g. gamma_dot=0.0)
_gd = kwargs.get("gamma_dot", _MISSING)
gamma_dot = (
_gd if _gd is not _MISSING else getattr(self, "_gamma_dot_applied", None)
)
_sa = kwargs.get("sigma_applied", _MISSING)
sigma_applied = (
_sa if _sa is not _MISSING else getattr(self, "_sigma_applied", None)
)
_g0 = kwargs.get("gamma_0", _MISSING)
gamma_0 = _g0 if _g0 is not _MISSING else getattr(self, "_gamma_0", None)
_om = kwargs.get("omega", _MISSING)
omega = _om if _om is not _MISSING else getattr(self, "_omega_laos", None)
if mode in ["flow_curve", "steady_shear", "rotation"]:
if self._is_basic:
return tnt_base_steady_stress_vec(X_jax, G, tau_b, eta_s)
return self._variant_flow_curve_internal(X_jax, G, tau_b, eta_s, vp)
elif mode == "oscillation":
# All TNT variants linearize to Maxwell in SAOS
G_prime, G_double_prime = tnt_saos_moduli_vec(X_jax, G, tau_b, eta_s)
return jnp.column_stack([G_prime, G_double_prime])
elif mode == "startup":
if gamma_dot is None:
raise ValueError("startup mode requires gamma_dot")
return self._simulate_startup_internal(
X_jax, G, tau_b, eta_s, gamma_dot, vp
)
elif mode == "relaxation":
if gamma_dot is None:
raise ValueError("relaxation mode requires gamma_dot (pre-shear rate)")
return self._simulate_relaxation_internal(
X_jax, G, tau_b, eta_s, gamma_dot, vp
)
elif mode == "creep":
if sigma_applied is None:
raise ValueError("creep mode requires sigma_applied")
return self._simulate_creep_internal(
X_jax, G, tau_b, eta_s, sigma_applied, vp
)
elif mode == "laos":
if gamma_0 is None or omega is None:
raise ValueError("LAOS mode requires gamma_0 and omega")
_, stress = self._simulate_laos_internal(
X_jax, G, tau_b, eta_s, gamma_0, omega, vp
)
return stress
else:
logger.warning(f"Unknown test_mode '{mode}', defaulting to flow_curve")
if self._is_basic:
return tnt_base_steady_stress_vec(X_jax, G, tau_b, eta_s)
return self._variant_flow_curve_internal(X_jax, G, tau_b, eta_s, vp)
# =========================================================================
# Variant Flow Curve (ODE-to-steady-state)
# =========================================================================
def _variant_flow_curve_internal(
self,
gamma_dot_arr: jnp.ndarray,
G: float,
tau_b: float,
eta_s: float,
vp: dict,
) -> jnp.ndarray:
"""Compute variant flow curve by running ODE to steady state.
For non-constant breakage, the steady-state conformation cannot be
solved analytically. Instead, we integrate the ODE for ~50·τ_b at
each shear rate and extract the final stress.
"""
variant_ode = self._variant_ode
is_fene = self._stress_type == "fene"
def solve_single(gdot):
def ode_fn(ti, yi, args):
return variant_ode(
ti,
yi,
args["gdot"],
args["G"],
args["tau_b"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {"gdot": gdot, "G": G, "tau_b": tau_b, **vp}
y0 = jnp.array([1.0, 1.0, 1.0, 0.0], dtype=jnp.float64)
t_end = 50.0 * tau_b
dt0 = tau_b / 20.0
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
controller = diffrax.PIDController(rtol=1e-5, atol=1e-7)
saveat = diffrax.SaveAt(ts=jnp.array([t_end]))
sol = diffrax.diffeqsolve(
term,
solver,
0.0,
t_end,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=controller,
max_steps=500_000,
throw=False,
)
S_final = sol.ys[0]
if is_fene:
tr_S = S_final[0] + S_final[1] + S_final[2]
L2 = vp["L_max"] * vp["L_max"]
f = L2 / jnp.maximum(L2 - tr_S, 1e-10)
sigma_el = G * f * S_final[3]
else:
sigma_el = G * S_final[3]
result = sigma_el + eta_s * gdot
result = jnp.where(
sol.result == diffrax.RESULTS.successful, result, jnp.nan
)
return result
return jax.vmap(solve_single)(gamma_dot_arr)
def _variant_steady_conformation(
self,
gamma_dot_arr: jnp.ndarray,
G: float,
tau_b: float,
vp: dict,
) -> jnp.ndarray:
"""Compute variant steady-state conformation via ODE.
Returns array of shape (N, 4) with [S_xx, S_yy, S_zz, S_xy]
for each shear rate.
"""
variant_ode = self._variant_ode
def solve_single(gdot):
def ode_fn(ti, yi, args):
return variant_ode(
ti,
yi,
args["gdot"],
args["G"],
args["tau_b"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {"gdot": gdot, "G": G, "tau_b": tau_b, **vp}
y0 = jnp.array([1.0, 1.0, 1.0, 0.0], dtype=jnp.float64)
t_end = 50.0 * tau_b
dt0 = tau_b / 20.0
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
controller = diffrax.PIDController(rtol=1e-5, atol=1e-7)
saveat = diffrax.SaveAt(ts=jnp.array([t_end]))
sol = diffrax.diffeqsolve(
term,
solver,
0.0,
t_end,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=controller,
max_steps=500_000,
throw=False,
)
result = sol.ys[0]
result = jnp.where(
sol.result == diffrax.RESULTS.successful,
result,
jnp.nan * jnp.ones_like(result),
)
return result
return jax.vmap(solve_single)(gamma_dot_arr)
# =========================================================================
# Analytical / Hybrid Predictions
# =========================================================================
[docs]
def predict_flow_curve(
self,
gamma_dot: np.ndarray,
return_components: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Predict steady shear stress and viscosity.
For constant breakage: analytical (UCM-like, no shear thinning).
For non-constant breakage: ODE-to-steady-state (shear thinning).
Parameters
----------
gamma_dot : np.ndarray
Shear rate array (1/s)
return_components : bool, default False
If True, return (sigma, eta, N1)
Returns
-------
np.ndarray or tuple
Shear stress σ (Pa), or (σ, η, N₁) if return_components=True
"""
gd = jnp.asarray(gamma_dot, dtype=jnp.float64)
if self._is_basic:
sigma = tnt_base_steady_stress_vec(gd, self.G, self.tau_b, self.eta_s)
if return_components:
eta = sigma / jnp.maximum(gd, 1e-20)
N1 = tnt_base_steady_n1_vec(gd, self.G, self.tau_b)
return np.asarray(sigma), np.asarray(eta), np.asarray(N1)
return np.asarray(sigma)
# Variant: compute via ODE-to-steady-state
vp = self._get_variant_args()
sigma = self._variant_flow_curve_internal(
gd, self.G, self.tau_b, self.eta_s, vp
)
if return_components:
eta = sigma / jnp.maximum(gd, 1e-20)
# N1 from steady-state conformation
S_ss = self._variant_steady_conformation(gd, self.G, self.tau_b, vp)
if self._stress_type == "fene":
tr_S = S_ss[:, 0] + S_ss[:, 1] + S_ss[:, 2]
L2 = vp["L_max"] ** 2
f = L2 / jnp.maximum(L2 - tr_S, 1e-10)
N1 = self.G * f * (S_ss[:, 0] - S_ss[:, 1])
else:
N1 = self.G * (S_ss[:, 0] - S_ss[:, 1])
return np.asarray(sigma), np.asarray(eta), np.asarray(N1)
return np.asarray(sigma)
[docs]
def predict_saos(
self,
omega: np.ndarray,
return_components: bool = True,
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
"""Predict SAOS storage and loss moduli.
In the linear regime, TNT reduces to single-mode Maxwell:
G'(ω) = G·(ωτ_b)²/(1+(ωτ_b)²)
G''(ω) = G·(ωτ_b)/(1+(ωτ_b)²) + η_s·ω
Parameters
----------
omega : np.ndarray
Angular frequency array (rad/s)
return_components : bool, default True
If True, return (G', G'')
Returns
-------
tuple or np.ndarray
(G', G'') if return_components=True, else |G*|
"""
w = jnp.asarray(omega, dtype=jnp.float64)
G_prime, G_double_prime = tnt_saos_moduli_vec(w, self.G, self.tau_b, self.eta_s)
if return_components:
return np.asarray(G_prime), np.asarray(G_double_prime)
G_star_mag = jnp.sqrt(jnp.maximum(G_prime**2 + G_double_prime**2, 1e-30))
return np.asarray(G_star_mag)
[docs]
def predict_normal_stresses(
self,
gamma_dot: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Predict first and second normal stress differences.
Basic TNT (constant/linear/UC): N₁ = 2G·(τ_b·γ̇)², N₂ = 0.
Gordon-Schowalter (ξ > 0): N₂ ≠ 0.
FENE-P: N₁ enhanced by Peterlin factor f(trS).
Parameters
----------
gamma_dot : np.ndarray
Shear rate array (1/s)
Returns
-------
tuple[np.ndarray, np.ndarray]
(N₁, N₂) in Pa
"""
gd = jnp.asarray(gamma_dot, dtype=jnp.float64)
if self._is_basic:
N1 = tnt_base_steady_n1_vec(gd, self.G, self.tau_b)
N2 = jnp.zeros_like(N1)
return np.asarray(N1), np.asarray(N2)
# Variant: compute from steady-state conformation
vp = self._get_variant_args()
S_ss = self._variant_steady_conformation(gd, self.G, self.tau_b, vp)
if self._stress_type == "fene":
tr_S = S_ss[:, 0] + S_ss[:, 1] + S_ss[:, 2]
L2 = vp["L_max"] ** 2
f = L2 / jnp.maximum(L2 - tr_S, 1e-10)
N1 = self.G * f * (S_ss[:, 0] - S_ss[:, 1])
N2 = self.G * f * (S_ss[:, 1] - S_ss[:, 2])
else:
N1 = self.G * (S_ss[:, 0] - S_ss[:, 1])
N2 = self.G * (S_ss[:, 1] - S_ss[:, 2])
return np.asarray(N1), np.asarray(N2)
# =========================================================================
# ODE-Based Internal Simulations (for model_function)
# =========================================================================
def _simulate_startup_internal(
self,
t: jnp.ndarray,
G: float,
tau_b: float,
eta_s: float,
gamma_dot: float,
vp: dict | None = None,
) -> jnp.ndarray:
"""Internal startup simulation for model_function.
Returns total shear stress σ_xy(t).
"""
if vp is None:
vp = {"nu": 0.0, "m_break": 0.0, "kappa": 0.0, "L_max": 10.0, "xi": 0.0}
variant_ode = self._variant_ode
def ode_fn(ti, yi, args):
return variant_ode(
ti,
yi,
args["gamma_dot"],
args["G"],
args["tau_b"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {"gamma_dot": gamma_dot, "G": G, "tau_b": tau_b, **vp}
y0 = jnp.array([1.0, 1.0, 1.0, 0.0], dtype=jnp.float64)
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
result = self._compute_stress_from_conformation(
sol.ys[:, 0],
sol.ys[:, 1],
sol.ys[:, 2],
sol.ys[:, 3],
G,
eta_s,
gamma_dot,
vp,
)
result = jnp.where(
sol.result == diffrax.RESULTS.successful,
result,
jnp.nan * jnp.ones_like(result),
)
return result
def _simulate_relaxation_internal(
self,
t: jnp.ndarray,
G: float,
tau_b: float,
eta_s: float,
gamma_dot_preshear: float,
vp: dict | None = None,
) -> jnp.ndarray:
"""Internal relaxation simulation for model_function.
For constant breakage, relaxation is analytical (single exponential).
For non-constant breakage, uses ODE integration.
"""
if vp is None:
vp = {"nu": 0.0, "m_break": 0.0, "kappa": 0.0, "L_max": 10.0, "xi": 0.0}
if self._breakage == "constant" and self._stress_type == "linear":
# Analytical: σ(t) = G·τ_b·γ̇·exp(-t/τ_b)
sigma_0 = G * tau_b * gamma_dot_preshear
return tnt_base_relaxation_vec(t, sigma_0, tau_b)
# ODE-based relaxation for non-constant breakage
# First find steady-state conformation from pre-shear
S_xx_0, S_yy_0, S_zz_0, S_xy_0 = tnt_base_steady_conformation(
gamma_dot_preshear, tau_b
)
y0 = jnp.array([S_xx_0, S_yy_0, S_zz_0, S_xy_0], dtype=jnp.float64)
variant_relax_ode = self._variant_relax_ode
def ode_fn(ti, yi, args):
return variant_relax_ode(
ti,
yi,
args["G"],
args["tau_b"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {"G": G, "tau_b": tau_b, **vp}
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
result = self._compute_stress_from_conformation(
sol.ys[:, 0],
sol.ys[:, 1],
sol.ys[:, 2],
sol.ys[:, 3],
G,
eta_s,
0.0,
vp,
)
result = jnp.where(
sol.result == diffrax.RESULTS.successful,
result,
jnp.nan * jnp.ones_like(result),
)
return result
def _simulate_creep_internal(
self,
t: jnp.ndarray,
G: float,
tau_b: float,
eta_s: float,
sigma_applied: float,
vp: dict | None = None,
) -> jnp.ndarray:
"""Internal creep simulation for model_function.
Returns accumulated strain γ(t).
"""
if vp is None:
vp = {"nu": 0.0, "m_break": 0.0, "kappa": 0.0, "L_max": 10.0, "xi": 0.0}
variant_creep_ode = self._variant_creep_ode
def ode_fn(ti, yi, args):
return variant_creep_ode(
ti,
yi,
args["sigma_applied"],
args["G"],
args["tau_b"],
args["eta_s"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {
"sigma_applied": sigma_applied,
"G": G,
"tau_b": tau_b,
"eta_s": eta_s,
**vp,
}
y0 = jnp.array([1.0, 1.0, 1.0, 0.0, 0.0], dtype=jnp.float64)
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
result = sol.ys[:, 4] # γ (strain)
result = jnp.where(
sol.result == diffrax.RESULTS.successful,
result,
jnp.nan * jnp.ones_like(result),
)
return result
def _simulate_laos_internal(
self,
t: jnp.ndarray,
G: float,
tau_b: float,
eta_s: float,
gamma_0: float,
omega: float,
vp: dict | None = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Internal LAOS simulation for model_function.
Returns (strain, stress) arrays.
"""
if vp is None:
vp = {"nu": 0.0, "m_break": 0.0, "kappa": 0.0, "L_max": 10.0, "xi": 0.0}
variant_laos_ode = self._variant_laos_ode
def ode_fn(ti, yi, args):
return variant_laos_ode(
ti,
yi,
args["gamma_0"],
args["omega"],
args["G"],
args["tau_b"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {
"gamma_0": gamma_0,
"omega": omega,
"G": G,
"tau_b": tau_b,
**vp,
}
y0 = jnp.array([1.0, 1.0, 1.0, 0.0], dtype=jnp.float64)
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
strain = gamma_0 * jnp.sin(omega * t)
gamma_dot_t = gamma_0 * omega * jnp.cos(omega * t)
stress = self._compute_stress_from_conformation(
sol.ys[:, 0],
sol.ys[:, 1],
sol.ys[:, 2],
sol.ys[:, 3],
G,
eta_s,
gamma_dot_t,
vp,
)
stress = jnp.where(
sol.result == diffrax.RESULTS.successful,
stress,
jnp.nan * jnp.ones_like(stress),
)
return strain, stress
# =========================================================================
# Public Simulation Methods (return numpy arrays)
# =========================================================================
[docs]
def simulate_startup(
self,
t: np.ndarray,
gamma_dot: float,
return_full: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Simulate startup flow at constant shear rate.
Parameters
----------
t : np.ndarray
Time array (s)
gamma_dot : float
Applied shear rate (1/s)
return_full : bool, default False
If True, return full conformation tensor components
Returns
-------
np.ndarray or tuple
Shear stress σ(t), or (S_xx, S_yy, S_xy, S_zz) if return_full
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
vp = self._get_variant_args()
variant_ode = self._variant_ode
def ode_fn(ti, yi, args):
return variant_ode(
ti,
yi,
args["gamma_dot"],
args["G"],
args["tau_b"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {
"gamma_dot": gamma_dot,
"G": self.G,
"tau_b": self.tau_b,
**vp,
}
y0 = jnp.array([1.0, 1.0, 1.0, 0.0], dtype=jnp.float64)
term = diffrax.ODETerm(ode_fn)
solver = self._get_ode_solver()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t_jax[0]
t1 = t_jax[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t_jax)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
S_xx = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 0],
jnp.nan * jnp.ones_like(sol.ys[:, 0]),
)
S_yy = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 1],
jnp.nan * jnp.ones_like(sol.ys[:, 1]),
)
S_zz = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 2],
jnp.nan * jnp.ones_like(sol.ys[:, 2]),
)
S_xy = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 3],
jnp.nan * jnp.ones_like(sol.ys[:, 3]),
)
self._trajectory = {
"t": np.asarray(t_jax),
"S_xx": np.asarray(S_xx),
"S_yy": np.asarray(S_yy),
"S_zz": np.asarray(S_zz),
"S_xy": np.asarray(S_xy),
}
if return_full:
return (
np.asarray(S_xx),
np.asarray(S_yy),
np.asarray(S_xy),
np.asarray(S_zz),
)
# Total stress: σ = G·f(S)·S_xy + η_s·γ̇ (f=1 for linear, FENE-P otherwise)
sigma = self._compute_stress_from_conformation(
S_xx,
S_yy,
S_zz,
S_xy,
self.G,
self.eta_s,
gamma_dot,
vp,
)
return np.asarray(sigma)
[docs]
def simulate_relaxation(
self,
t: np.ndarray,
gamma_dot_preshear: float,
return_full: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Simulate stress relaxation after cessation of steady shear.
For constant breakage + linear stress, relaxation is analytical:
σ(t) = G·S_xy(0)·exp(-t/τ_b).
For non-constant breakage or FENE stress, ODE integration is used.
Parameters
----------
t : np.ndarray
Time array (s), starting from t=0 (cessation)
gamma_dot_preshear : float
Shear rate before cessation (1/s)
return_full : bool, default False
If True, return full conformation tensor components
Returns
-------
np.ndarray or tuple
Relaxing stress σ(t), or (S_xx, S_yy, S_xy, S_zz) if return_full
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
vp = self._get_variant_args()
use_ode = return_full or not self._is_basic
if not use_ode:
# Analytical relaxation for constant breakage + linear stress
sigma_0 = self.G * self.tau_b * gamma_dot_preshear
sigma = tnt_base_relaxation_vec(t_jax, sigma_0, self.tau_b)
return np.asarray(sigma)
# ODE-based relaxation (variant or return_full)
# Find steady-state conformation from pre-shear
if self._is_basic:
S_xx_0, S_yy_0, S_zz_0, S_xy_0 = tnt_base_steady_conformation(
gamma_dot_preshear, self.tau_b
)
else:
# Run startup ODE to steady state for variant breakage
t_ss = jnp.linspace(0.0, 50.0 * self.tau_b, 2000)
_ = self._simulate_startup_internal(
t_ss,
self.G,
self.tau_b,
self.eta_s,
gamma_dot_preshear,
vp,
)
# Re-run to get conformation (use trajectory if available)
S_xx_0, S_yy_0, S_zz_0, S_xy_0 = tnt_base_steady_conformation(
gamma_dot_preshear, self.tau_b
)
# Override with ODE steady state via _variant_steady_conformation
ss = self._variant_steady_conformation(
jnp.array([gamma_dot_preshear]), self.G, self.tau_b, vp
)
S_xx_0, S_yy_0, S_zz_0, S_xy_0 = (ss[0, 0], ss[0, 1], ss[0, 2], ss[0, 3])
y0 = jnp.array([S_xx_0, S_yy_0, S_zz_0, S_xy_0], dtype=jnp.float64)
variant_relax_ode = self._variant_relax_ode
def ode_fn(ti, yi, args):
return variant_relax_ode(
ti,
yi,
args["G"],
args["tau_b"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {"G": self.G, "tau_b": self.tau_b, **vp}
term = diffrax.ODETerm(ode_fn)
solver = self._get_ode_solver()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t_jax[0]
t1 = t_jax[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t_jax)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
S_xx = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 0],
jnp.nan * jnp.ones_like(sol.ys[:, 0]),
)
S_yy = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 1],
jnp.nan * jnp.ones_like(sol.ys[:, 1]),
)
S_zz = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 2],
jnp.nan * jnp.ones_like(sol.ys[:, 2]),
)
S_xy = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 3],
jnp.nan * jnp.ones_like(sol.ys[:, 3]),
)
self._trajectory = {
"t": np.asarray(t_jax),
"S_xx": np.asarray(S_xx),
"S_yy": np.asarray(S_yy),
"S_zz": np.asarray(S_zz),
"S_xy": np.asarray(S_xy),
}
if return_full:
return (
np.asarray(S_xx),
np.asarray(S_yy),
np.asarray(S_xy),
np.asarray(S_zz),
)
# Stress from conformation: σ = G·f(S)·S_xy + η_s·γ̇ (γ̇=0 in relaxation)
sigma = self._compute_stress_from_conformation(
S_xx,
S_yy,
S_zz,
S_xy,
self.G,
self.eta_s,
0.0,
vp,
)
return np.asarray(sigma)
[docs]
def simulate_creep(
self,
t: np.ndarray,
sigma_applied: float,
return_rate: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Simulate creep deformation under constant stress.
Parameters
----------
t : np.ndarray
Time array (s)
sigma_applied : float
Applied constant stress (Pa)
return_rate : bool, default False
If True, also return shear rate γ̇(t)
Returns
-------
np.ndarray or tuple
Strain γ(t), or (γ, γ̇) if return_rate=True
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
vp = self._get_variant_args()
variant_creep_ode = self._variant_creep_ode
def ode_fn(ti, yi, args):
return variant_creep_ode(
ti,
yi,
args["sigma_applied"],
args["G"],
args["tau_b"],
args["eta_s"],
args["nu"],
args["m_break"],
args["kappa"],
args["L_max"],
args["xi"],
)
args = {
"sigma_applied": sigma_applied,
"G": self.G,
"tau_b": self.tau_b,
"eta_s": self.eta_s,
**vp,
}
y0 = jnp.array([1.0, 1.0, 1.0, 0.0, 0.0], dtype=jnp.float64)
term = diffrax.ODETerm(ode_fn)
solver = self._get_ode_solver()
# Creep with nonlinear breakage/stress is stiff — use relaxed
# tolerances and very small initial dt for non-basic variants.
if self._is_basic:
rtol, atol = 1e-6, 1e-8
else:
rtol, atol = 1e-4, 1e-6
stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol)
t0 = t_jax[0]
t1 = t_jax[-1]
n_steps_hint = max(len(t), 1000)
if not self._is_basic:
n_steps_hint *= 10 # Smaller initial dt for stiff variants
dt0 = (t1 - t0) / n_steps_hint
saveat = diffrax.SaveAt(ts=t_jax)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=1_000_000,
throw=False,
)
gamma_result = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 4],
jnp.nan * jnp.ones_like(sol.ys[:, 4]),
)
S_xy_result = jnp.where(
sol.result == diffrax.RESULTS.successful,
sol.ys[:, 3],
jnp.nan * jnp.ones_like(sol.ys[:, 3]),
)
gamma = np.asarray(gamma_result)
S_xy = np.asarray(S_xy_result)
self._trajectory = {
"t": np.asarray(t_jax),
"gamma": gamma,
"S_xy": S_xy,
}
if return_rate:
eta_s_reg = max(self.eta_s, 1e-10 * self.G * self.tau_b)
# For FENE: σ_elastic = G·f(trS)·S_xy, for linear: G·S_xy
if self._stress_type == "fene":
L2 = vp["L_max"] ** 2
tr_S = (
np.asarray(sol.ys[:, 0])
+ np.asarray(sol.ys[:, 1])
+ np.asarray(sol.ys[:, 2])
)
f_pet = L2 / np.maximum(L2 - tr_S, 1e-10)
sigma_elastic = self.G * f_pet * S_xy
else:
sigma_elastic = self.G * S_xy
gamma_dot = (sigma_applied - sigma_elastic) / eta_s_reg
return gamma, gamma_dot
return gamma
[docs]
def simulate_laos(
self,
t: np.ndarray,
gamma_0: float,
omega: float,
n_cycles: int | None = None,
) -> dict[str, np.ndarray]:
"""Simulate Large-Amplitude Oscillatory Shear (LAOS).
Parameters
----------
t : np.ndarray
Time array (s), or None to auto-generate
gamma_0 : float
Strain amplitude (dimensionless)
omega : float
Angular frequency (rad/s)
n_cycles : int, optional
Number of oscillation cycles (overrides t)
Returns
-------
dict
Dictionary with keys: 't', 'strain', 'stress', 'strain_rate'
"""
if n_cycles is not None:
T = 2 * np.pi / omega
t = np.linspace(0, n_cycles * T, n_cycles * 200)
t_jax = jnp.asarray(t, dtype=jnp.float64)
vp = self._get_variant_args()
strain, stress = self._simulate_laos_internal(
t_jax, self.G, self.tau_b, self.eta_s, gamma_0, omega, vp
)
strain_rate = gamma_0 * omega * jnp.cos(omega * t_jax)
self._trajectory = {
"t": np.asarray(t_jax),
"strain": np.asarray(strain),
"stress": np.asarray(stress),
"strain_rate": np.asarray(strain_rate),
}
return {
"t": np.asarray(t_jax),
"strain": np.asarray(strain),
"stress": np.asarray(stress),
"strain_rate": np.asarray(strain_rate),
}
# =========================================================================
# Analysis Methods
# =========================================================================
[docs]
def get_overshoot_ratio(
self,
gamma_dot: float,
t_max: float | None = None,
) -> tuple[float, float]:
"""Compute stress overshoot ratio in startup flow.
For constant breakage, there is no overshoot (UCM behavior).
Overshoot requires Bell or stretch-dependent breakage.
Parameters
----------
gamma_dot : float
Shear rate (1/s)
t_max : float, optional
Maximum simulation time (default: 20·τ_b)
Returns
-------
tuple[float, float]
(overshoot_ratio, strain_at_overshoot)
"""
if t_max is None:
t_max = 20 * self.tau_b
t = np.linspace(0, t_max, 1000)
sigma = self.simulate_startup(t, gamma_dot)
peak_idx = np.argmax(sigma)
sigma_max = sigma[peak_idx]
strain_at_peak = gamma_dot * t[peak_idx]
sigma_ss = sigma[-1]
overshoot_ratio = sigma_max / sigma_ss if sigma_ss > 0 else 1.0
return overshoot_ratio, strain_at_peak
[docs]
def get_relaxation_spectrum(
self,
t: np.ndarray | None = None,
n_points: int = 100,
) -> tuple[np.ndarray, np.ndarray]:
"""Get relaxation modulus G(t).
For single-mode TNT: G(t) = G·exp(-t/τ_b)
Parameters
----------
t : np.ndarray, optional
Time array (default: logspace from 0.01·τ_b to 100·τ_b)
n_points : int, default 100
Number of points if t not provided
Returns
-------
tuple[np.ndarray, np.ndarray]
(t, G(t))
"""
if t is None:
t = np.logspace(
np.log10(0.01 * self.tau_b),
np.log10(100 * self.tau_b),
n_points,
)
G_t = self.G * np.exp(-t / self.tau_b)
return t, G_t