"""Multi-species Transient Network Theory (TNT) model.
This module implements `TNTMultiSpecies`, a constitutive model for networks
with N distinct bond types that relax independently.
Key Physics
-----------
Multi-species TNT describes networks with heterogeneous physical crosslinks:
- N independent bond populations (species)
- Each species i has its own modulus G_i and lifetime τ_b_i
- Each species has its own conformation tensor S_i that evolves independently
- Total stress is the sum over all species: σ = Σ G_i·S_xy_i + η_s·γ̇
The constitutive equations for each species i::
dS_i/dt = L·S_i + S_i·L^T + (1/τ_b_i)·I - (1/τ_b_i)·S_i
This represents a superposition of N Maxwell modes (multi-mode UCM),
commonly used to model:
- Polydisperse systems with broad relaxation spectra
- Multicomponent associative networks
- Complex hierarchical structures
Supported Protocols
-------------------
- FLOW_CURVE: Steady shear (analytical)
- OSCILLATION: Small-amplitude oscillatory shear (analytical)
- STARTUP: Transient stress growth (ODE)
- RELAXATION: Stress decay after cessation (analytical/ODE)
- CREEP: Strain evolution under constant stress (ODE)
- LAOS: Large-amplitude oscillatory shear (ODE)
Example
-------
>>> from rheojax.models.tnt import TNTMultiSpecies
>>> import numpy as np
>>>
>>> # Two-species network
>>> model = TNTMultiSpecies(n_species=2)
>>>
>>> # Flow curve (analytical superposition)
>>> gamma_dot = np.logspace(-2, 2, 50)
>>> sigma = model.predict(gamma_dot, test_mode='flow_curve')
>>>
>>> # SAOS (analytical Maxwell superposition)
>>> omega = np.logspace(-2, 2, 50)
>>> G_star = model.predict(omega, test_mode='oscillation',
>>> return_components=True)
>>>
>>> # Startup (ODE with 2N state variables)
>>> t = np.linspace(0, 10, 200)
>>> model._gamma_dot_applied = 1.0
>>> sigma_t = model.predict(t, test_mode='startup')
References
----------
- Likhtman, A.E. & Graham, R.S. (2003). J. Non-Newt. Fluid Mech. 114, 1-12.
- Graham, R.S., Likhtman, A.E., McLeish, T.C.B. & Milner, S.T. (2003).
J. Rheol. 47, 1171-1200.
"""
from __future__ import annotations
import logging
import numpy as np
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.tnt._base import TNTBase
from rheojax.models.tnt._kernels import (
tnt_multimode_ode_rhs,
tnt_multimode_relaxation_vec,
tnt_multimode_saos_moduli_vec,
)
jax, jnp = safe_import_jax()
logger = logging.getLogger(__name__)
_MISSING = object()
[docs]
@ModelRegistry.register(
"tnt_multi_species",
protocols=["flow_curve", "oscillation", "startup", "relaxation", "creep", "laos"],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class TNTMultiSpecies(TNTBase):
"""Multi-species Transient Network Theory model.
Implements a network with N independent bond populations, each with
its own modulus G_i and lifetime τ_b_i. The total stress is a
superposition of N Maxwell modes.
This is equivalent to a generalized Maxwell model where each mode
represents a distinct physical crosslink species rather than a
mathematical decomposition.
Parameters
----------
n_species : int, default 2
Number of distinct bond species (N ≥ 1)
Attributes
----------
parameters : ParameterSet
Model parameters: [G_0, tau_b_0, G_1, tau_b_1, ..., G_{N-1},
tau_b_{N-1}, eta_s]
fitted_ : bool
Whether the model has been fitted
_n_species : int
Number of species
Notes
-----
The state vector has 4*N components:
[S_xx_0, S_yy_0, S_zz_0, S_xy_0, ..., S_xy_{N-1}]
Each species evolves independently via the upper-convected derivative
with constant breakage/creation rates.
See Also
--------
TNTSingleMode : Single-mode TNT with variant breakage rates
GeneralizedMaxwell : Mathematical multi-mode decomposition
"""
[docs]
def __init__(self, n_species: int = 2):
"""Initialize multi-species TNT model.
Parameters
----------
n_species : int, default 2
Number of distinct bond species (must be ≥ 1)
"""
if n_species < 1:
raise ValueError(f"n_species must be ≥ 1, got {n_species}")
self._n_species = n_species
super().__init__()
self._setup_parameters()
self._test_mode = None
# =========================================================================
# Parameter Setup
# =========================================================================
def _setup_parameters(self):
"""Initialize ParameterSet with 2N+1 parameters.
Parameters are organized as:
[G_0, tau_b_0, G_1, tau_b_1, ..., G_{N-1}, tau_b_{N-1}, eta_s]
Default values:
- G_i = 1000.0 / N (equal distribution)
- tau_b_i = 10^(-i) (logarithmic spacing: 1.0, 0.1, 0.01, ...)
- eta_s = 0.0 (no solvent viscosity)
"""
self.parameters = ParameterSet()
for i in range(self._n_species):
# Mode modulus
self.parameters.add(
name=f"G_{i}",
value=1000.0 / self._n_species,
bounds=(1e0, 1e8),
units="Pa",
description=f"Network modulus for bond species {i}",
)
# Mode lifetime (logarithmically spaced by default)
default_tau = 10.0 ** (-i)
self.parameters.add(
name=f"tau_b_{i}",
value=default_tau,
bounds=(1e-6, 1e4),
units="s",
description=f"Bond lifetime for species {i}",
)
# Global solvent viscosity
self.parameters.add(
name="eta_s",
value=0.0,
bounds=(0.0, 1e4),
units="Pa·s",
description="Solvent viscosity (Newtonian background)",
)
[docs]
def initialize_from_creep(
self,
t: np.ndarray,
strain: np.ndarray,
sigma_applied: float = 1.0,
) -> None:
"""Initialize parameters from creep data for numerical stability.
For multi-species, sets eta_s to prevent infinite initial shear rate.
Parameters
----------
t : np.ndarray
Time array (s)
strain : np.ndarray
Strain array
sigma_applied : float
Applied constant stress (Pa)
"""
t = np.asarray(t)
strain = np.asarray(strain)
# Estimate strain rate from data (slope at long times)
if len(t) > 10:
n_late = max(3, len(t) // 3)
t_late = t[-n_late:]
strain_late = strain[-n_late:]
gamma_dot_est = np.polyfit(t_late, strain_late, 1)[0]
else:
gamma_dot_est = (strain[-1] - strain[0]) / (t[-1] - t[0])
# Estimate zero-shear viscosity
eta_0_est = sigma_applied / max(abs(gamma_dot_est), 1e-10)
# Set eta_s for numerical stability
eta_s_est = max(0.01 * eta_0_est, 1e-6 * sigma_applied)
# Estimate characteristic time
t_char = t[len(t) // 4] if len(t) > 4 else t[-1] / 4
tau_b_est = max(t_char, 0.1)
# G_total = η₀ / τ_b
G_total_est = max(eta_0_est / tau_b_est, 10.0)
# Distribute G and tau across modes
for i in range(self._n_species):
G_i = G_total_est / self._n_species
tau_i = tau_b_est * (10.0 ** (-i)) # Logarithmic spacing
self.parameters.set_value(f"G_{i}", np.clip(G_i, 1e0, 1e8))
self.parameters.set_value(f"tau_b_{i}", np.clip(tau_i, 1e-6, 1e4))
self.parameters.set_value("eta_s", np.clip(eta_s_est, 1e-10, 1e4))
logger.debug(
f"Creep initialization: G_total={G_total_est:.3e} Pa, "
f"tau_b_0={tau_b_est:.3e} s, η_s={eta_s_est:.3e} Pa·s"
)
[docs]
def initialize_from_relaxation(
self,
t: np.ndarray,
modulus: np.ndarray,
) -> None:
"""Initialize parameters from stress relaxation data.
Parameters
----------
t : np.ndarray
Time array (s)
modulus : np.ndarray
Relaxation modulus G(t) (Pa)
"""
t = np.asarray(t)
modulus = np.asarray(modulus)
# Sort by time
sort_idx = np.argsort(t)
t = t[sort_idx]
modulus = modulus[sort_idx]
# Estimate initial modulus
G_0_est = modulus[0]
# Estimate τ_b from decay to 1/e
target = G_0_est / np.e
crossings = np.where(modulus < target)[0]
if len(crossings) > 0:
tau_b_est = t[crossings[0]]
else:
if len(t) > 5:
log_modulus = np.log(np.maximum(modulus, 1e-20))
slope = np.polyfit(t[: len(t) // 2], log_modulus[: len(t) // 2], 1)[0]
tau_b_est = -1.0 / slope if slope < 0 else t[-1]
else:
tau_b_est = t[-1] / 3
# Distribute across modes
for i in range(self._n_species):
G_i = G_0_est / self._n_species
tau_i = tau_b_est * (10.0 ** (-i))
self.parameters.set_value(f"G_{i}", np.clip(G_i, 1e0, 1e8))
self.parameters.set_value(f"tau_b_{i}", np.clip(tau_i, 1e-6, 1e4))
logger.debug(
f"Relaxation initialization: G_0={G_0_est:.3e} Pa, τ_b_0={tau_b_est:.3e} s"
)
# =========================================================================
# Property Accessors
# =========================================================================
@property
def n_species(self) -> int:
"""Get number of bond species N."""
return self._n_species
def _get_mode_arrays(self) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Get G_modes and tau_modes as JAX arrays.
Returns
-------
tuple[jnp.ndarray, jnp.ndarray]
(G_modes, tau_modes) with shape (N,)
"""
G_vals = []
for i in range(self._n_species):
val = self.parameters.get_value(f"G_{i}")
G_vals.append(float(val) if val is not None else 0.0)
G_modes = jnp.array(G_vals, dtype=jnp.float64)
tau_vals = []
for i in range(self._n_species):
val = self.parameters.get_value(f"tau_b_{i}")
tau_vals.append(float(val) if val is not None else 0.0)
tau_modes = jnp.array(tau_vals, dtype=jnp.float64)
return G_modes, tau_modes
@property
def eta_s(self) -> float:
"""Get solvent viscosity η_s (Pa·s)."""
val = self.parameters.get_value("eta_s")
return float(val) if val is not None else 0.0
@property
def G_total(self) -> float:
"""Get total modulus G_total = Σ G_i (Pa)."""
G_modes, _ = self._get_mode_arrays()
return float(jnp.sum(G_modes))
@property
def eta_0(self) -> float:
"""Get zero-shear viscosity η₀ = Σ G_i·τ_b_i + η_s (Pa·s)."""
G_modes, tau_modes = self._get_mode_arrays()
return float(jnp.sum(G_modes * tau_modes) + self.eta_s)
# =========================================================================
# Equilibrium State
# =========================================================================
# =========================================================================
# Core Interface Methods
# =========================================================================
def _fit(self, x, y, **kwargs):
"""Fit model to data using protocol-aware optimization.
Parameters
----------
x : array-like
Independent variable (shear rate, frequency, or time)
y : array-like
Dependent variable (stress, modulus, or strain)
**kwargs
Additional arguments including test_mode
Returns
-------
self
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
self._test_mode = test_mode
x_jax = jnp.asarray(x, dtype=jnp.float64)
# Preserve complex dtype for oscillation data (G* = G' + iG'')
y_arr = np.asarray(y)
if np.iscomplexobj(y_arr):
y_jax = jnp.asarray(y_arr, dtype=jnp.complex128)
else:
y_jax = jnp.asarray(y_arr, dtype=jnp.float64)
# Store protocol-specific inputs
self._gamma_dot_applied = kwargs.get("gamma_dot")
self._sigma_applied = kwargs.get("sigma_applied")
self._gamma_0 = kwargs.get("gamma_0")
self._omega_laos = kwargs.get("omega")
# Smart initialization based on protocol
if test_mode in ["flow_curve", "steady_shear", "rotation"]:
# Initialize from flow curve (only sets G_0, tau_b_0, eta_s)
self.initialize_from_flow_curve(np.asarray(x), np.asarray(y))
elif test_mode == "oscillation":
# Initialize from SAOS (only sets G_0, tau_b_0, eta_s)
self.initialize_from_saos(
np.asarray(x), np.real(np.asarray(y)), np.imag(np.asarray(y))
)
elif test_mode == "creep":
self.initialize_from_creep(
np.asarray(x),
np.asarray(y),
sigma_applied=kwargs.get("sigma_applied", 1.0),
)
elif test_mode == "relaxation":
self.initialize_from_relaxation(np.asarray(x), np.asarray(y))
# Define model function for fitting
def model_fn(x_fit, params):
return self.model_function(x_fit, params, test_mode=test_mode)
# Create objective and optimize
objective = create_least_squares_objective(
model_fn,
x_jax,
y_jax,
use_log_residuals=kwargs.get(
"use_log_residuals", test_mode == "flow_curve"
),
)
# ODE-based protocols use diffrax with custom_vjp, incompatible with
# NLSQ forward-mode AD. Default to scipy to avoid failed attempt overhead.
_ode_protocols = {"startup", "creep", "laos"}
_default_method = "scipy" if test_mode in _ode_protocols else "auto"
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", _default_method),
max_iter=kwargs.get("max_iter", 2000),
)
self.fitted_ = True
self._nlsq_result = result
G_modes, tau_modes = self._get_mode_arrays()
logger.info(
f"Fitted TNTMultiSpecies (N={self._n_species}): "
f"G_total={float(jnp.sum(G_modes)):.2e} Pa, "
f"tau_range=[{float(jnp.min(tau_modes)):.2e}, {float(jnp.max(tau_modes)):.2e}] s, "
f"η_s={self.eta_s:.2e} Pa·s"
)
return self
def _predict(self, x, **kwargs):
"""Predict response using protocol-aware dispatch.
Parameters
----------
x : array-like
Independent variable
**kwargs
Additional arguments including test_mode, gamma_dot, sigma_applied,
gamma_0, omega
Returns
-------
jnp.ndarray
Predicted response
"""
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
x_jax = jnp.asarray(x, dtype=jnp.float64)
# Extract and store protocol-specific parameters from kwargs
if "gamma_dot" in kwargs:
self._gamma_dot_applied = kwargs["gamma_dot"]
if "sigma_applied" in kwargs:
self._sigma_applied = kwargs["sigma_applied"]
if "gamma_0" in kwargs:
self._gamma_0 = kwargs["gamma_0"]
if "omega" in kwargs:
self._omega_laos = kwargs["omega"]
# Build parameter array from ParameterSet (ordering: G_0, tau_b_0, ..., eta_s)
param_values = [
float(self.parameters.get_value(name)) for name in self.parameters.keys()
]
params = jnp.array(param_values)
result = self.model_function(x_jax, params, test_mode=test_mode)
# model_function returns (N,2) [G', G''] for oscillation;
# convert to complex G* for consistent API
if test_mode == "oscillation" and result.ndim == 2 and result.shape[1] == 2:
result = result[:, 0] + 1j * result[:, 1]
return result
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""NumPyro/BayesianMixin model function.
Routes to appropriate prediction based on test_mode. This is the
stateless function used for both NLSQ optimization and Bayesian
inference.
Parameters
----------
X : array-like
Independent variable
params : array-like
Parameter values: [G_0, tau_b_0, G_1, tau_b_1, ..., G_{N-1},
tau_b_{N-1}, eta_s]
Total length: 2*N + 1
test_mode : str, optional
Override stored test mode
Returns
-------
jnp.ndarray
Predicted response
"""
# Unpack parameters
N = self._n_species
G_modes = params[0 : 2 * N : 2] # G_0, G_1, ..., G_{N-1}
tau_modes = params[1 : 2 * N : 2] # tau_b_0, tau_b_1, ..., tau_b_{N-1}
eta_s = params[2 * N]
mode = (
test_mode
if test_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
# Use sentinel pattern to avoid swallowing falsy values (e.g. gamma_dot=0.0)
_gd = kwargs.get("gamma_dot", _MISSING)
gamma_dot = (
_gd if _gd is not _MISSING else getattr(self, "_gamma_dot_applied", None)
)
_sa = kwargs.get("sigma_applied", _MISSING)
sigma_applied = (
_sa if _sa is not _MISSING else getattr(self, "_sigma_applied", None)
)
_g0 = kwargs.get("gamma_0", _MISSING)
gamma_0 = _g0 if _g0 is not _MISSING else getattr(self, "_gamma_0", None)
_om = kwargs.get("omega", _MISSING)
omega = _om if _om is not _MISSING else getattr(self, "_omega_laos", None)
X_jax = jnp.asarray(X, dtype=jnp.float64)
if mode in ["flow_curve", "steady_shear", "rotation"]:
return self._predict_flow_curve_internal(X_jax, G_modes, tau_modes, eta_s)
elif mode == "oscillation":
G_prime, G_double_prime = tnt_multimode_saos_moduli_vec(
X_jax, G_modes, tau_modes, eta_s
)
return jnp.column_stack([G_prime, G_double_prime])
elif mode == "startup":
if gamma_dot is None:
raise ValueError("startup mode requires gamma_dot")
return self._simulate_startup_internal(
X_jax, G_modes, tau_modes, eta_s, gamma_dot
)
elif mode == "relaxation":
if gamma_dot is None:
raise ValueError("relaxation mode requires gamma_dot (pre-shear rate)")
return self._simulate_relaxation_internal(
X_jax, G_modes, tau_modes, eta_s, gamma_dot
)
elif mode == "creep":
if sigma_applied is None:
raise ValueError("creep mode requires sigma_applied")
return self._simulate_creep_internal(
X_jax, G_modes, tau_modes, eta_s, sigma_applied
)
elif mode == "laos":
if gamma_0 is None or omega is None:
raise ValueError("LAOS mode requires gamma_0 and omega")
_, stress = self._simulate_laos_internal(
X_jax, G_modes, tau_modes, eta_s, gamma_0, omega
)
return stress
else:
logger.warning(f"Unknown test_mode '{mode}', defaulting to flow_curve")
return self._predict_flow_curve_internal(X_jax, G_modes, tau_modes, eta_s)
# =========================================================================
# Analytical Predictions
# =========================================================================
def _predict_flow_curve_internal(
self,
gamma_dot: jnp.ndarray,
G_modes: jnp.ndarray,
tau_modes: jnp.ndarray,
eta_s: float,
) -> jnp.ndarray:
"""Analytical steady shear stress for multi-species TNT.
σ = Σ G_i·τ_b_i·γ̇ + η_s·γ̇ = η₀·γ̇
For constant breakage (UCM), the conformation tensor steady state
gives S_xy = τ·γ̇, hence σ_xy = G·τ·γ̇. This is Newtonian
(no shear thinning), consistent with single-mode TNT.
Parameters
----------
gamma_dot : jnp.ndarray
Shear rate array (1/s)
G_modes : jnp.ndarray
Mode moduli, shape (N,)
tau_modes : jnp.ndarray
Mode relaxation times, shape (N,)
eta_s : float
Solvent viscosity (Pa·s)
Returns
-------
jnp.ndarray
Shear stress σ (Pa)
"""
# η₀ = Σ G_i·τ_b_i + η_s
eta_0 = jnp.sum(G_modes * tau_modes) + eta_s
return eta_0 * gamma_dot
[docs]
def predict_flow_curve(
self,
gamma_dot: np.ndarray,
return_components: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Predict steady shear stress and viscosity.
Analytical superposition: σ = Σ G_i·τ_b_i·γ̇ / (1 + (τ_b_i·γ̇)²) + η_s·γ̇
Parameters
----------
gamma_dot : np.ndarray
Shear rate array (1/s)
return_components : bool, default False
If True, return (sigma, eta, N1)
Returns
-------
np.ndarray or tuple
Shear stress σ (Pa), or (σ, η, N₁) if return_components=True
"""
gd = jnp.asarray(gamma_dot, dtype=jnp.float64)
G_modes, tau_modes = self._get_mode_arrays()
sigma = self._predict_flow_curve_internal(gd, G_modes, tau_modes, self.eta_s)
if return_components:
eta = sigma / jnp.maximum(gd, 1e-20)
# N1 = Σ 2·G_i·(τ_b_i·γ̇)² / (1 + (τ_b_i·γ̇)²)²
wi = tau_modes[:, None] * gd[None, :] # (N, M)
wi2 = wi * wi
denom2 = (1.0 + wi2) ** 2
N1 = jnp.sum(2.0 * G_modes[:, None] * wi2 / denom2, axis=0)
return np.asarray(sigma), np.asarray(eta), np.asarray(N1)
return np.asarray(sigma)
[docs]
def predict_saos(
self,
omega: np.ndarray,
return_components: bool = True,
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
"""Predict SAOS storage and loss moduli.
Analytical superposition:
G'(ω) = Σ G_i·(ωτ_i)² / (1 + (ωτ_i)²)
G''(ω) = Σ G_i·(ωτ_i) / (1 + (ωτ_i)²) + η_s·ω
Parameters
----------
omega : np.ndarray
Angular frequency array (rad/s)
return_components : bool, default True
If True, return (G', G'')
Returns
-------
tuple or np.ndarray
(G', G'') if return_components=True, else |G*|
"""
w = jnp.asarray(omega, dtype=jnp.float64)
G_modes, tau_modes = self._get_mode_arrays()
G_prime, G_double_prime = tnt_multimode_saos_moduli_vec(
w, G_modes, tau_modes, self.eta_s
)
if return_components:
return np.asarray(G_prime), np.asarray(G_double_prime)
G_star_mag = jnp.sqrt(jnp.maximum(G_prime**2 + G_double_prime**2, 1e-30))
return np.asarray(G_star_mag)
[docs]
def predict_normal_stresses(
self,
gamma_dot: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""Predict first and second normal stress differences.
For multi-mode UCM (conformation tensor):
N₁ = Σ 2·G_i·(τ_b_i·γ̇)²
N₂ = 0 (upper-convected derivative)
Parameters
----------
gamma_dot : np.ndarray
Shear rate array (1/s)
Returns
-------
tuple[np.ndarray, np.ndarray]
(N₁, N₂) in Pa
"""
gd = jnp.asarray(gamma_dot, dtype=jnp.float64)
G_modes, tau_modes = self._get_mode_arrays()
# N1 = Σ 2·G_i·(τ_b_i·γ̇)² (UCM conformation: S_xx - S_yy = 2(τγ̇)²)
wi = tau_modes[:, None] * gd[None, :] # (N, M)
wi2 = wi * wi
N1 = jnp.sum(2.0 * G_modes[:, None] * wi2, axis=0)
N2 = jnp.zeros_like(N1)
return np.asarray(N1), np.asarray(N2)
# =========================================================================
# ODE-Based Internal Simulations (for model_function)
# =========================================================================
def _simulate_startup_internal(
self,
t: jnp.ndarray,
G_modes: jnp.ndarray,
tau_modes: jnp.ndarray,
eta_s: float,
gamma_dot: float,
) -> jnp.ndarray:
"""Internal startup simulation for model_function.
Returns total shear stress σ_xy(t) = Σ G_i·S_xy_i + η_s·γ̇.
Parameters
----------
t : jnp.ndarray
Time array (s)
G_modes : jnp.ndarray
Mode moduli (Pa), shape (N,)
tau_modes : jnp.ndarray
Mode relaxation times (s), shape (N,)
eta_s : float
Solvent viscosity (Pa·s)
gamma_dot : float
Applied shear rate (1/s)
Returns
-------
jnp.ndarray
Shear stress σ(t) (Pa)
"""
def ode_fn(ti, yi, args):
return tnt_multimode_ode_rhs(
ti, yi, args["gamma_dot"], args["G_modes"], args["tau_modes"]
)
args = {"gamma_dot": gamma_dot, "G_modes": G_modes, "tau_modes": tau_modes}
# Initial state: all modes at equilibrium [1,1,1,0, ...]
y0 = self.get_equilibrium_conformation_multimode()
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
# Extract S_xy_i for each mode (index 3, 7, 11, ...)
S_xy_modes = sol.ys[:, 3::4] # (T, N)
# Total stress: σ = Σ G_i·S_xy_i + η_s·γ̇
sigma = jnp.sum(G_modes[None, :] * S_xy_modes, axis=1) + eta_s * gamma_dot
sigma = jnp.where(
sol.result == diffrax.RESULTS.successful,
sigma,
jnp.nan * jnp.ones_like(sigma),
)
return sigma
def _simulate_relaxation_internal(
self,
t: jnp.ndarray,
G_modes: jnp.ndarray,
tau_modes: jnp.ndarray,
eta_s: float,
gamma_dot_preshear: float,
) -> jnp.ndarray:
"""Internal relaxation simulation for model_function.
Analytical multi-mode relaxation:
σ(t) = Σ σ₀_i·exp(-t/τ_b_i)
where σ₀_i = G_i·τ_b_i·γ̇ / (1 + (τ_b_i·γ̇)²)
Parameters
----------
t : jnp.ndarray
Time array (s)
G_modes : jnp.ndarray
Mode moduli, shape (N,)
tau_modes : jnp.ndarray
Mode relaxation times, shape (N,)
eta_s : float
Solvent viscosity (Pa·s)
gamma_dot_preshear : float
Pre-shear rate (1/s)
Returns
-------
jnp.ndarray
Relaxing stress σ(t) (Pa)
"""
# Initial stress per mode at steady state
wi = tau_modes * gamma_dot_preshear
sigma_0_modes = G_modes * wi / (1.0 + wi * wi)
return tnt_multimode_relaxation_vec(t, sigma_0_modes, tau_modes)
def _simulate_creep_internal(
self,
t: jnp.ndarray,
G_modes: jnp.ndarray,
tau_modes: jnp.ndarray,
eta_s: float,
sigma_applied: float,
) -> jnp.ndarray:
"""Internal creep simulation for model_function.
State: [S_xx_0, S_yy_0, S_zz_0, S_xy_0, ..., S_xy_{N-1}, γ] = 4N+1
The applied stress is held constant:
σ = Σ G_i·S_xy_i + η_s·γ̇
so the shear rate is:
γ̇ = (σ - Σ G_i·S_xy_i) / η_s
Parameters
----------
t : jnp.ndarray
Time array (s)
G_modes : jnp.ndarray
Mode moduli, shape (N,)
tau_modes : jnp.ndarray
Mode relaxation times, shape (N,)
eta_s : float
Solvent viscosity (Pa·s)
sigma_applied : float
Applied constant stress (Pa)
Returns
-------
jnp.ndarray
Strain γ(t)
"""
def ode_fn(ti, yi, args):
# Unpack state: [S_modes..., gamma]
N = args["G_modes"].shape[0]
S_state = yi[: 4 * N]
_gamma = yi[4 * N]
# Compute elastic stress contribution from each mode
S_xy_modes = S_state[3::4] # Extract S_xy_i
sigma_elastic = jnp.sum(args["G_modes"] * S_xy_modes)
# Compute shear rate from stress constraint
eta_s_reg = jnp.maximum(
args["eta_s"], 1e-10 * jnp.max(args["G_modes"] * args["tau_modes"])
)
gamma_dot = (args["sigma_applied"] - sigma_elastic) / eta_s_reg
# Conformation evolution (multimode)
d_S = tnt_multimode_ode_rhs(
ti, S_state, gamma_dot, args["G_modes"], args["tau_modes"]
)
# Strain evolution
d_gamma = gamma_dot
return jnp.concatenate([d_S, jnp.array([d_gamma])])
args = {
"sigma_applied": sigma_applied,
"G_modes": G_modes,
"tau_modes": tau_modes,
"eta_s": eta_s,
}
# Initial state: all modes at equilibrium + zero strain
y0_conf = self.get_equilibrium_conformation_multimode()
y0 = jnp.concatenate([y0_conf, jnp.array([0.0])])
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
result = sol.ys[:, -1] # γ (last component)
result = jnp.where(
sol.result == diffrax.RESULTS.successful,
result,
jnp.nan * jnp.ones_like(result),
)
return result
def _simulate_laos_internal(
self,
t: jnp.ndarray,
G_modes: jnp.ndarray,
tau_modes: jnp.ndarray,
eta_s: float,
gamma_0: float,
omega: float,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Internal LAOS simulation for model_function.
State: [S_xx_0, S_yy_0, S_zz_0, S_xy_0, ..., S_xy_{N-1}] = 4N
Oscillatory shear: γ̇(t) = γ₀·ω·cos(ωt)
Parameters
----------
t : jnp.ndarray
Time array (s)
G_modes : jnp.ndarray
Mode moduli, shape (N,)
tau_modes : jnp.ndarray
Mode relaxation times, shape (N,)
eta_s : float
Solvent viscosity (Pa·s)
gamma_0 : float
Strain amplitude
omega : float
Angular frequency (rad/s)
Returns
-------
tuple[jnp.ndarray, jnp.ndarray]
(strain, stress) arrays
"""
def ode_fn(ti, yi, args):
gamma_dot = args["gamma_0"] * args["omega"] * jnp.cos(args["omega"] * ti)
return tnt_multimode_ode_rhs(
ti, yi, gamma_dot, args["G_modes"], args["tau_modes"]
)
args = {
"gamma_0": gamma_0,
"omega": omega,
"G_modes": G_modes,
"tau_modes": tau_modes,
}
y0 = self.get_equilibrium_conformation_multimode()
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
# Extract S_xy_i for each mode
S_xy_modes = sol.ys[:, 3::4] # (T, N)
# Compute strain and stress
strain = gamma_0 * jnp.sin(omega * t)
gamma_dot_t = gamma_0 * omega * jnp.cos(omega * t)
stress = jnp.sum(G_modes[None, :] * S_xy_modes, axis=1) + eta_s * gamma_dot_t
# Handle solver failures
strain = jnp.where(
sol.result == diffrax.RESULTS.successful,
strain,
jnp.nan * jnp.ones_like(strain),
)
stress = jnp.where(
sol.result == diffrax.RESULTS.successful,
stress,
jnp.nan * jnp.ones_like(stress),
)
return strain, stress
# =========================================================================
# Public Simulation Methods (return numpy arrays)
# =========================================================================
[docs]
def simulate_startup(
self,
t: np.ndarray,
gamma_dot: float,
return_full: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
"""Simulate startup flow at constant shear rate.
Parameters
----------
t : np.ndarray
Time array (s)
gamma_dot : float
Applied shear rate (1/s)
return_full : bool, default False
If True, return full conformation tensor for all modes
Returns
-------
np.ndarray or dict
Shear stress σ(t), or dict with 'S_xx', 'S_yy', 'S_xy', 'S_zz'
(each shape (T, N)) if return_full=True
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
G_modes, tau_modes = self._get_mode_arrays()
def ode_fn(ti, yi, args):
return tnt_multimode_ode_rhs(
ti, yi, args["gamma_dot"], args["G_modes"], args["tau_modes"]
)
args = {"gamma_dot": gamma_dot, "G_modes": G_modes, "tau_modes": tau_modes}
y0 = self.get_equilibrium_conformation_multimode()
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t_jax[0]
t1 = t_jax[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t_jax)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
# Extract mode components: S_xx_i at indices 0,4,8,...
S_xx_modes = sol.ys[:, 0::4] # (T, N)
S_yy_modes = sol.ys[:, 1::4]
S_zz_modes = sol.ys[:, 2::4]
S_xy_modes = sol.ys[:, 3::4]
# Handle solver failures
S_xx_modes = jnp.where(
sol.result == diffrax.RESULTS.successful,
S_xx_modes,
jnp.nan * jnp.ones_like(S_xx_modes),
)
S_yy_modes = jnp.where(
sol.result == diffrax.RESULTS.successful,
S_yy_modes,
jnp.nan * jnp.ones_like(S_yy_modes),
)
S_zz_modes = jnp.where(
sol.result == diffrax.RESULTS.successful,
S_zz_modes,
jnp.nan * jnp.ones_like(S_zz_modes),
)
S_xy_modes = jnp.where(
sol.result == diffrax.RESULTS.successful,
S_xy_modes,
jnp.nan * jnp.ones_like(S_xy_modes),
)
self._trajectory = {
"t": np.asarray(t_jax),
"S_xx": np.asarray(S_xx_modes),
"S_yy": np.asarray(S_yy_modes),
"S_zz": np.asarray(S_zz_modes),
"S_xy": np.asarray(S_xy_modes),
}
if return_full:
return {
"S_xx": np.asarray(S_xx_modes),
"S_yy": np.asarray(S_yy_modes),
"S_xy": np.asarray(S_xy_modes),
"S_zz": np.asarray(S_zz_modes),
}
# Total stress: σ = Σ G_i·S_xy_i + η_s·γ̇
sigma = jnp.sum(G_modes[None, :] * S_xy_modes, axis=1) + self.eta_s * gamma_dot
return np.asarray(sigma)
[docs]
def simulate_relaxation(
self,
t: np.ndarray,
gamma_dot_preshear: float,
) -> np.ndarray:
"""Simulate stress relaxation after cessation of steady shear.
Analytical multi-mode relaxation:
σ(t) = Σ σ₀_i·exp(-t/τ_b_i)
Parameters
----------
t : np.ndarray
Time array (s), starting from t=0 (cessation)
gamma_dot_preshear : float
Shear rate before cessation (1/s)
Returns
-------
np.ndarray
Relaxing stress σ(t) (Pa)
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
G_modes, tau_modes = self._get_mode_arrays()
# Initial stress per mode from steady-state
wi = tau_modes * gamma_dot_preshear
sigma_0_modes = G_modes * wi / (1.0 + wi * wi)
sigma = tnt_multimode_relaxation_vec(t_jax, sigma_0_modes, tau_modes)
self._trajectory = {
"t": np.asarray(t_jax),
"sigma": np.asarray(sigma),
}
return np.asarray(sigma)
[docs]
def simulate_creep(
self,
t: np.ndarray,
sigma_applied: float,
return_rate: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Simulate creep deformation under constant stress.
Parameters
----------
t : np.ndarray
Time array (s)
sigma_applied : float
Applied constant stress (Pa)
return_rate : bool, default False
If True, also return shear rate γ̇(t)
Returns
-------
np.ndarray or tuple
Strain γ(t), or (γ, γ̇) if return_rate=True
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
G_modes, tau_modes = self._get_mode_arrays()
def ode_fn(ti, yi, args):
N = args["G_modes"].shape[0]
S_state = yi[: 4 * N]
_gamma = yi[4 * N]
# Elastic stress
S_xy_modes = S_state[3::4]
sigma_elastic = jnp.sum(args["G_modes"] * S_xy_modes)
# Shear rate
eta_s_reg = jnp.maximum(
args["eta_s"], 1e-10 * jnp.max(args["G_modes"] * args["tau_modes"])
)
gamma_dot = (args["sigma_applied"] - sigma_elastic) / eta_s_reg
# Conformation evolution
d_S = tnt_multimode_ode_rhs(
ti, S_state, gamma_dot, args["G_modes"], args["tau_modes"]
)
return jnp.concatenate([d_S, jnp.array([gamma_dot])])
args = {
"sigma_applied": sigma_applied,
"G_modes": G_modes,
"tau_modes": tau_modes,
"eta_s": self.eta_s,
}
y0_conf = self.get_equilibrium_conformation_multimode()
y0 = jnp.concatenate([y0_conf, jnp.array([0.0])])
term = diffrax.ODETerm(ode_fn)
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t_jax[0]
t1 = t_jax[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t_jax)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=500_000,
throw=False,
)
gamma_jax = sol.ys[:, -1]
gamma_jax = jnp.where(
sol.result == diffrax.RESULTS.successful,
gamma_jax,
jnp.nan * jnp.ones_like(gamma_jax),
)
gamma = np.asarray(gamma_jax)
self._trajectory = {
"t": np.asarray(t_jax),
"gamma": gamma,
}
if return_rate:
S_xy_modes = sol.ys[:, 3::4]
S_xy_modes = jnp.where(
sol.result == diffrax.RESULTS.successful,
S_xy_modes,
jnp.nan * jnp.ones_like(S_xy_modes),
)
sigma_elastic = jnp.sum(G_modes[None, :] * S_xy_modes, axis=1)
eta_s_reg = max(self.eta_s, 1e-10 * float(jnp.max(G_modes * tau_modes)))
gamma_dot = (sigma_applied - sigma_elastic) / eta_s_reg
return gamma, np.asarray(gamma_dot)
return gamma
[docs]
def simulate_laos(
self,
t: np.ndarray,
gamma_0: float,
omega: float,
n_cycles: int | None = None,
) -> dict[str, np.ndarray]:
"""Simulate Large-Amplitude Oscillatory Shear (LAOS).
Parameters
----------
t : np.ndarray
Time array (s), or None to auto-generate
gamma_0 : float
Strain amplitude (dimensionless)
omega : float
Angular frequency (rad/s)
n_cycles : int, optional
Number of oscillation cycles (overrides t)
Returns
-------
dict
Dictionary with keys: 't', 'strain', 'stress', 'strain_rate'
"""
if n_cycles is not None:
T = 2 * np.pi / omega
t = np.linspace(0, n_cycles * T, n_cycles * 200)
t_jax = jnp.asarray(t, dtype=jnp.float64)
G_modes, tau_modes = self._get_mode_arrays()
strain, stress = self._simulate_laos_internal(
t_jax, G_modes, tau_modes, self.eta_s, gamma_0, omega
)
strain_rate = gamma_0 * omega * jnp.cos(omega * t_jax)
self._trajectory = {
"t": np.asarray(t_jax),
"strain": np.asarray(strain),
"stress": np.asarray(stress),
"strain_rate": np.asarray(strain_rate),
}
return {
"t": np.asarray(t_jax),
"strain": np.asarray(strain),
"stress": np.asarray(stress),
"strain_rate": np.asarray(strain_rate),
}
# =========================================================================
# Analysis Methods
# =========================================================================
[docs]
def get_relaxation_spectrum(
self,
t: np.ndarray | None = None,
n_points: int = 100,
) -> tuple[np.ndarray, np.ndarray]:
"""Get relaxation modulus G(t).
For multi-species TNT: G(t) = Σ G_i·exp(-t/τ_b_i)
Parameters
----------
t : np.ndarray, optional
Time array (default: logspace from 0.01·min(τ) to 100·max(τ))
n_points : int, default 100
Number of points if t not provided
Returns
-------
tuple[np.ndarray, np.ndarray]
(t, G(t))
"""
G_modes, tau_modes = self._get_mode_arrays()
if t is None:
t_min = 0.01 * float(jnp.min(tau_modes))
t_max = 100 * float(jnp.max(tau_modes))
t = np.logspace(np.log10(t_min), np.log10(t_max), n_points)
t_jax = jnp.asarray(t, dtype=jnp.float64)
# G(t) = Σ G_i·exp(-t/τ_i)
G_t = jnp.sum(
G_modes[:, None] * jnp.exp(-t_jax[None, :] / tau_modes[:, None]), axis=0
)
return t, np.asarray(G_t)
# =========================================================================
# String Representation
# =========================================================================
[docs]
def __repr__(self) -> str:
"""Return string representation."""
G_modes, tau_modes = self._get_mode_arrays()
G_total = float(jnp.sum(G_modes))
tau_min = float(jnp.min(tau_modes))
tau_max = float(jnp.max(tau_modes))
return (
f"TNTMultiSpecies(n_species={self._n_species}, "
f"G_total={G_total:.2e} Pa, "
f"tau_range=[{tau_min:.2e}, {tau_max:.2e}] s, "
f"η_s={self.eta_s:.2e} Pa·s)"
)