"""HVMLocal: Local (0D) Hybrid Vitrimer Model.
Single-point constitutive model for vitrimers with three subnetworks:
1. Permanent (P): covalent crosslinks, neo-Hookean elastic
2. Exchangeable (E): associative vitrimer bonds with BER kinetics
3. Dissociative (D): physical reversible bonds, standard Maxwell
Supports 6 rheological protocols:
- Flow curve (analytical: sigma_E → 0 at steady state)
- SAOS (analytical: two Maxwell modes + permanent plateau)
- Startup shear (analytical or ODE with TST feedback)
- Stress relaxation (analytical or ODE with TST feedback)
- Creep (ODE: implicit gamma_dot solve)
- LAOS (ODE: nonlinear oscillatory response)
Limiting Cases
--------------
- G_E=0, G_D=0: Neo-Hookean elastic solid
- G_P=0, G_E=0: Single Maxwell fluid (VLB)
- G_E=0: Zener/SLS (spring + dashpot)
- G_D=0, G_P=0: Pure vitrimer
- G_D=0: Partial vitrimer (Meng et al. 2019)
- Full: 3-network HVM
References
----------
- Vernerey, Long, & Brighenti (2017). JMPS 107, 1-20.
- Meng, Simon, Niu, McKenna, & Hallinan (2019). Macromolecules 52, 8.
"""
from __future__ import annotations
import logging
from typing import Literal
import numpy as np
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.hvm._base import HVMBase
from rheojax.models.hvm._kernels import (
hvm_ber_rate_constant,
hvm_creep_compliance_linear_vec,
hvm_normal_stress_1,
hvm_relaxation_modulus_vec,
hvm_saos_moduli_vec,
hvm_startup_stress_linear_vec,
hvm_steady_shear_stress_vec,
hvm_total_stress_shear,
)
from rheojax.models.hvm._kernels_diffrax import (
_mask_failed_solution_ys,
hvm_solve_creep,
hvm_solve_laos,
hvm_solve_relaxation,
hvm_solve_startup,
)
jax, jnp = safe_import_jax()
logger = logging.getLogger(__name__)
_MISSING = object()
[docs]
@ModelRegistry.register(
"hvm_local",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
@ModelRegistry.register(
"hvm",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class HVMLocal(HVMBase):
"""Local (0D) Hybrid Vitrimer Model.
A constitutive model for vitrimers combining:
- Permanent network (P): covalent crosslinks, elastic
- Exchangeable network (E): vitrimer bonds with TST kinetics
- Dissociative network (D): physical bonds, Maxwell relaxation
Parameters
----------
kinetics : {'stress', 'stretch'}, default 'stress'
TST coupling mechanism for bond exchange rate
include_damage : bool, default False
Enable cooperative damage shielding
include_dissociative : bool, default True
Include dissociative (D) network
Examples
--------
>>> from rheojax.models import HVMLocal
>>> model = HVMLocal()
>>> omega = np.logspace(-2, 2, 50)
>>> G_prime, G_double_prime = model.predict_saos(omega)
>>> # Partial vitrimer (Meng 2019)
>>> model = HVMLocal(include_dissociative=False)
>>> # With TST stress feedback
>>> model = HVMLocal(kinetics='stress')
>>> t = np.linspace(0, 10, 200)
>>> result = model.simulate_startup(t, gamma_dot=1.0, return_full=True)
"""
[docs]
def __init__(
self,
kinetics: Literal["stress", "stretch"] = "stress",
include_damage: bool = False,
include_dissociative: bool = True,
):
super().__init__(
kinetics=kinetics,
include_damage=include_damage,
include_dissociative=include_dissociative,
)
self._setup_parameters()
self._test_mode = None
# Protocol kwargs cached by simulate_*/fit for model_function fallback.
# Must be initialized to None to prevent AttributeError on fresh instances
# (e.g., when BayesianService creates a model without replaying simulate_*).
self._gamma_dot_applied: float | None = None
self._sigma_applied: float | None = None
self._gamma_0: float | None = None
self._omega_laos: float | None = None
logger.info(
"HVMLocal initialized",
extra={
"kinetics": kinetics,
"include_damage": include_damage,
"include_dissociative": include_dissociative,
},
)
# =========================================================================
# Internal Helpers
# =========================================================================
def _resolve_test_mode(self, kwarg_value) -> str:
"""Resolve test_mode from kwarg, cached value, or default."""
if kwarg_value is not None:
return kwarg_value
cached = getattr(self, "_test_mode", None)
if cached is not None:
return cached
return "flow_curve"
# =========================================================================
# Parameter Helpers
# =========================================================================
def _get_params_dict(self) -> dict[str, float]:
"""Get parameters as dict with defaults for optional params."""
d = self.get_parameter_dict()
d.setdefault("G_D", 0.0)
d.setdefault("k_d_D", 1.0)
d.setdefault("Gamma_0", 0.0)
d.setdefault("lambda_crit", 10.0)
return d
def _get_k_ber_0(self) -> float:
"""Compute zero-stress BER rate from current parameters."""
return float(hvm_ber_rate_constant(self.nu_0, self.E_a, self.T))
# =========================================================================
# Flow Curve (Analytical)
# =========================================================================
[docs]
def predict_flow_curve(
self, gamma_dot: np.ndarray, return_components: bool = False
) -> np.ndarray | dict[str, np.ndarray]:
"""Predict steady-state flow curve.
At steady state, mu^E → mu^E_nat, so sigma_E → 0.
Only the D-network contributes viscous stress.
Parameters
----------
gamma_dot : array-like
Shear rate array (1/s)
return_components : bool, default False
If True, return dict with subnetwork contributions
Returns
-------
np.ndarray or dict
Steady-state stress (Pa) or component dict
"""
gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64)
sigma = hvm_steady_shear_stress_vec(
gamma_dot_jax, self.G_P, self.G_D, self.k_d_D
)
if return_components:
eta_D = self.G_D / jnp.maximum(self.k_d_D, 1e-30)
sigma_D = eta_D * gamma_dot_jax
return {
"stress": np.asarray(sigma),
"sigma_P": np.zeros_like(np.asarray(gamma_dot)), # Elastic, not viscous
"sigma_E": np.zeros_like(np.asarray(gamma_dot)), # Relaxed at SS
"sigma_D": np.asarray(sigma_D),
"eta_eff": np.asarray(sigma / jnp.maximum(gamma_dot_jax, 1e-30)),
}
return np.asarray(sigma)
# =========================================================================
# SAOS (Analytical)
# =========================================================================
[docs]
def predict_saos(
self, omega: np.ndarray, return_components: bool = True
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
"""Predict SAOS storage and loss moduli.
Two Maxwell modes (E, D) plus permanent plateau (P).
Parameters
----------
omega : array-like
Angular frequency array (rad/s)
return_components : bool, default True
If True, return (G', G''); if False, return |G*|
Returns
-------
tuple of (np.ndarray, np.ndarray) or np.ndarray
(G', G'') or |G*|
"""
omega_jax = jnp.asarray(omega, dtype=jnp.float64)
k_ber_0 = self._get_k_ber_0()
G_prime, G_double_prime = hvm_saos_moduli_vec(
omega_jax, self.G_P, self.G_E, self.G_D, k_ber_0, self.k_d_D
)
if return_components:
return np.asarray(G_prime), np.asarray(G_double_prime)
return np.asarray(jnp.sqrt(jnp.maximum(G_prime**2 + G_double_prime**2, 1e-30)))
# =========================================================================
# Startup Shear
# =========================================================================
[docs]
def simulate_startup(
self,
t: np.ndarray,
gamma_dot: float,
return_full: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
"""Simulate startup shear flow.
Uses analytical solution for constant-rate case, or ODE with
TST feedback for nonlinear regime.
Parameters
----------
t : array-like
Time array (s)
gamma_dot : float
Applied shear rate (1/s)
return_full : bool, default False
If True, return dict with all trajectories
Returns
-------
np.ndarray or dict
Stress array or full trajectory dict
"""
self._gamma_dot_applied = gamma_dot
t_jax = jnp.asarray(t, dtype=jnp.float64)
# Use ODE solver for TST feedback
params = self._get_params_dict()
if params is None:
raise ValueError("params dict must not be None")
sol = hvm_solve_startup(
t_jax,
gamma_dot,
params,
kinetics=self._kinetics,
include_damage=self._include_damage,
include_dissociative=self._include_dissociative,
)
ys = sol.ys # (n_times, 11)
if ys is None:
raise ValueError("ODE solver returned None for ys")
# Compute stress from state
stress = jax.vmap(
lambda y: hvm_total_stress_shear(
y[9],
y[2],
y[5],
y[8],
params["G_P"],
params["G_E"],
params.get("G_D", 0.0),
y[10],
)
)(ys)
# Handle solver failure
stress = jnp.where(
sol.result == diffrax.RESULTS.successful,
stress,
jnp.nan * jnp.ones_like(stress),
)
if return_full:
return {
"time": np.asarray(t),
"stress": np.asarray(stress),
"strain": np.asarray(ys[:, 9]),
"mu_E_xx": np.asarray(ys[:, 0]),
"mu_E_yy": np.asarray(ys[:, 1]),
"mu_E_xy": np.asarray(ys[:, 2]),
"mu_E_nat_xx": np.asarray(ys[:, 3]),
"mu_E_nat_yy": np.asarray(ys[:, 4]),
"mu_E_nat_xy": np.asarray(ys[:, 5]),
"mu_D_xx": np.asarray(ys[:, 6]),
"mu_D_yy": np.asarray(ys[:, 7]),
"mu_D_xy": np.asarray(ys[:, 8]),
"damage": np.asarray(ys[:, 10]),
"N1": np.asarray(
jax.vmap(
lambda y: hvm_normal_stress_1(
y[0],
y[1],
y[3],
y[4],
y[6],
y[7],
params["G_E"],
params.get("G_D", 0.0),
)
)(ys)
),
}
return np.asarray(stress)
# =========================================================================
# Stress Relaxation
# =========================================================================
[docs]
def simulate_relaxation(
self,
t: np.ndarray,
gamma_step: float = 1.0,
return_full: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
"""Simulate stress relaxation after step strain.
Parameters
----------
t : array-like
Time array after step (s)
gamma_step : float, default 1.0
Applied step strain
return_full : bool, default False
If True, return full trajectory dict
Returns
-------
np.ndarray or dict
G(t) relaxation modulus or trajectory dict
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
params = self._get_params_dict()
if params is None:
raise ValueError("params dict must not be None")
sol = hvm_solve_relaxation(
t_jax,
gamma_step,
params,
kinetics=self._kinetics,
include_damage=self._include_damage,
include_dissociative=self._include_dissociative,
)
ys = sol.ys
if ys is None:
raise ValueError("ODE solver returned None for ys")
# G(t) = sigma(t) / gamma_step
stress = jax.vmap(
lambda y: hvm_total_stress_shear(
y[9],
y[2],
y[5],
y[8],
params["G_P"],
params["G_E"],
params.get("G_D", 0.0),
y[10],
)
)(ys)
G_t = stress / jnp.maximum(jnp.abs(gamma_step), 1e-30)
G_t = jnp.where(
sol.result == diffrax.RESULTS.successful,
G_t,
jnp.nan * jnp.ones_like(G_t),
)
if return_full:
return {
"time": np.asarray(t),
"G_t": np.asarray(G_t),
"stress": np.asarray(stress),
"mu_E_xy": np.asarray(ys[:, 2]),
"mu_E_nat_xy": np.asarray(ys[:, 5]),
"mu_D_xy": np.asarray(ys[:, 8]),
"damage": np.asarray(ys[:, 10]),
}
return np.asarray(G_t)
# =========================================================================
# Creep
# =========================================================================
[docs]
def simulate_creep(
self,
t: np.ndarray,
sigma_0: float,
return_full: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
"""Simulate creep under constant stress.
Parameters
----------
t : array-like
Time array (s)
sigma_0 : float
Applied constant stress (Pa)
return_full : bool, default False
If True, return full trajectory dict
Returns
-------
np.ndarray or dict
Strain gamma(t) or trajectory dict
"""
self._sigma_applied = sigma_0
t_jax = jnp.asarray(t, dtype=jnp.float64)
params = self._get_params_dict()
if params is None:
raise ValueError("params dict must not be None")
sol = hvm_solve_creep(
t_jax,
sigma_0,
params,
kinetics=self._kinetics,
include_damage=self._include_damage,
include_dissociative=self._include_dissociative,
)
ys = sol.ys
if ys is None:
raise ValueError("ODE solver returned None for ys")
gamma = ys[:, 9]
gamma = jnp.where(
sol.result == diffrax.RESULTS.successful,
gamma,
jnp.nan * jnp.ones_like(gamma),
)
if return_full:
J_t = gamma / jnp.maximum(jnp.abs(sigma_0), 1e-30)
return {
"time": np.asarray(t),
"strain": np.asarray(gamma),
"compliance": np.asarray(J_t),
"mu_E_xy": np.asarray(ys[:, 2]),
"mu_E_nat_xy": np.asarray(ys[:, 5]),
"mu_D_xy": np.asarray(ys[:, 8]),
"damage": np.asarray(ys[:, 10]),
}
return np.asarray(gamma)
# =========================================================================
# LAOS
# =========================================================================
[docs]
def simulate_laos(
self,
t: np.ndarray,
gamma_0: float,
omega: float,
) -> dict[str, np.ndarray]:
"""Simulate LAOS (Large Amplitude Oscillatory Shear).
Parameters
----------
t : array-like
Time array (s)
gamma_0 : float
Strain amplitude
omega : float
Angular frequency (rad/s)
Returns
-------
dict
Keys: time, strain, stress, gamma_dot, N1,
mu_E_xy, mu_E_nat_xy, mu_D_xy, damage
"""
self._gamma_0 = gamma_0
self._omega_laos = omega
t_jax = jnp.asarray(t, dtype=jnp.float64)
params = self._get_params_dict()
if params is None:
raise ValueError("params dict must not be None")
sol = hvm_solve_laos(
t_jax,
gamma_0,
omega,
params,
kinetics=self._kinetics,
include_damage=self._include_damage,
include_dissociative=self._include_dissociative,
)
ys = sol.ys
if ys is None:
raise ValueError("ODE solver returned None for ys")
strain = gamma_0 * jnp.sin(omega * t_jax)
gamma_dot_arr = gamma_0 * omega * jnp.cos(omega * t_jax)
stress = jax.vmap(
lambda y: hvm_total_stress_shear(
y[9],
y[2],
y[5],
y[8],
params["G_P"],
params["G_E"],
params.get("G_D", 0.0),
y[10],
)
)(ys)
N1 = jax.vmap(
lambda y: hvm_normal_stress_1(
y[0],
y[1],
y[3],
y[4],
y[6],
y[7],
params["G_E"],
params.get("G_D", 0.0),
)
)(ys)
# Handle solver failure
failed = sol.result != diffrax.RESULTS.successful
stress = jnp.where(failed, jnp.nan, stress)
N1 = jnp.where(failed, jnp.nan, N1)
return {
"time": np.asarray(t),
"strain": np.asarray(strain),
"stress": np.asarray(stress),
"gamma_dot": np.asarray(gamma_dot_arr),
"N1": np.asarray(N1),
"mu_E_xy": np.asarray(ys[:, 2]),
"mu_E_nat_xy": np.asarray(ys[:, 5]),
"mu_D_xy": np.asarray(ys[:, 8]),
"damage": np.asarray(ys[:, 10]),
}
# =========================================================================
# Normal Stresses
# =========================================================================
[docs]
def predict_normal_stresses(
self, gamma_dot: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""Predict steady-state normal stress differences.
At steady state, E-network contributes zero normal stress
(mu^E → mu^E_nat). Only D-network contributes N1.
N1 = 2 * G_D * (gamma_dot / k_d_D)^2
N2 = 0 (upper-convected Maxwell)
Parameters
----------
gamma_dot : array-like
Shear rate array (1/s)
Returns
-------
tuple of (np.ndarray, np.ndarray)
(N1, N2) arrays (Pa)
"""
gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64)
Wi_D = gamma_dot_jax / jnp.maximum(self.k_d_D, 1e-30)
N1 = 2.0 * self.G_D * Wi_D**2
N2 = jnp.zeros_like(N1)
return np.asarray(N1), np.asarray(N2)
# =========================================================================
# LAOS Harmonic Extraction
# =========================================================================
# =========================================================================
# Fitting (NLSQ)
# =========================================================================
def _fit(self, x, y, **kwargs):
"""Fit model to data using protocol-aware optimization.
Parameters
----------
x : array-like
Independent variable
y : array-like
Dependent variable
**kwargs
test_mode, gamma_dot, sigma_applied, gamma_0, omega, etc.
Returns
-------
self
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
test_mode = self._resolve_test_mode(kwargs.get("test_mode"))
self._test_mode = test_mode
x_jax = jnp.asarray(x, dtype=jnp.float64)
# Preserve complex dtype for oscillation data (G* = G' + iG'')
y_arr = np.asarray(y)
if np.iscomplexobj(y_arr):
y_jax = jnp.asarray(y_arr, dtype=jnp.complex128)
else:
y_jax = jnp.asarray(y_arr, dtype=jnp.float64)
# Store protocol-specific inputs
self._gamma_dot_applied = kwargs.get("gamma_dot")
self._sigma_applied = kwargs.get("sigma_applied")
self._gamma_0 = kwargs.get("gamma_0")
self._omega_laos = kwargs.get("omega")
# Filter out fitting-specific and BaseModel kwargs
fwd_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in (
"test_mode",
"deformation_mode",
"poisson_ratio",
"use_log_residuals",
"use_jax",
"method",
"max_iter",
"use_multi_start",
"n_starts",
"perturb_factor",
)
}
def model_fn(x_fit, params):
return self.model_function(x_fit, params, test_mode=test_mode, **fwd_kwargs)
objective = create_least_squares_objective(
model_fn,
x_jax,
y_jax,
use_log_residuals=kwargs.get(
"use_log_residuals", test_mode == "flow_curve"
),
)
# ODE-based protocols use diffrax with custom_vjp, incompatible with
# NLSQ forward-mode AD. Default to scipy to avoid failed attempt overhead.
_ode_protocols = {"laos"}
_default_method = "scipy" if test_mode in _ode_protocols else "auto"
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", _default_method),
max_iter=kwargs.get("max_iter", 2000),
)
self.fitted_ = True
self._nlsq_result = result
logger.info(
f"Fitted HVMLocal: G_P={self.G_P:.2e}, G_E={self.G_E:.2e}, "
f"G_D={self.G_D:.2e}"
)
return self
def _predict(self, X, **kwargs):
"""Predict response using fitted parameters.
Parameters
----------
X : array-like
Independent variable
**kwargs
test_mode and protocol-specific parameters
Returns
-------
np.ndarray
Predicted response
"""
test_mode = self._resolve_test_mode(kwargs.get("test_mode"))
param_values = jnp.array(
[self.parameters.get_value(n) for n in self.parameters.keys()],
dtype=jnp.float64,
)
fwd_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ("test_mode", "deformation_mode", "poisson_ratio")
}
result = np.asarray(
self.model_function(X, param_values, test_mode=test_mode, **fwd_kwargs)
)
# model_function returns (N,2) [G', G''] for oscillation;
# convert to complex G* for consistent API
if test_mode == "oscillation" and result.ndim == 2 and result.shape[1] == 2:
result = result[:, 0] + 1j * result[:, 1]
return result
# =========================================================================
# Model Function (NLSQ / NumPyro)
# =========================================================================
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""NumPyro/BayesianMixin model function for HVM.
Routes to appropriate JAX-traceable prediction based on test_mode.
Required by BayesianMixin for NumPyro NUTS sampling.
Parameters
----------
X : array-like
Independent variable
params : array-like
Parameter values in ParameterSet order
test_mode : str, optional
Override stored test mode
**kwargs
Protocol-specific: gamma_dot, sigma_applied, gamma_0, omega
Returns
-------
jnp.ndarray
Predicted response
"""
# Unpack parameters by position (cache key list to avoid repeated iteration)
if not hasattr(self, "_param_names"):
self._param_names = list(self.parameters.keys())
p_dict = dict(zip(self._param_names, params, strict=True))
G_P = p_dict["G_P"]
G_E = p_dict["G_E"]
nu_0 = p_dict["nu_0"]
E_a = p_dict["E_a"]
V_act = p_dict["V_act"]
T = p_dict["T"]
G_D = p_dict.get("G_D", 0.0)
k_d_D = p_dict.get("k_d_D", 1.0)
mode = self._resolve_test_mode(test_mode)
X_jax = jnp.asarray(X, dtype=jnp.float64)
# Use sentinel pattern to avoid swallowing falsy values (e.g. gamma_dot=0.0)
_gd = kwargs.get("gamma_dot", _MISSING)
gamma_dot = (
_gd if _gd is not _MISSING else getattr(self, "_gamma_dot_applied", None)
)
_sa = kwargs.get("sigma_applied", _MISSING)
sigma_applied = (
_sa if _sa is not _MISSING else getattr(self, "_sigma_applied", None)
)
_g0 = kwargs.get("gamma_0", _MISSING)
gamma_0 = _g0 if _g0 is not _MISSING else getattr(self, "_gamma_0", None)
_om = kwargs.get("omega", _MISSING)
omega = _om if _om is not _MISSING else getattr(self, "_omega_laos", None)
k_ber_0 = hvm_ber_rate_constant(nu_0, E_a, T)
if mode in ["flow_curve", "steady_shear", "rotation"]:
return hvm_steady_shear_stress_vec(X_jax, G_P, G_D, k_d_D)
elif mode == "oscillation":
G_prime, G_double_prime = hvm_saos_moduli_vec(
X_jax, G_P, G_E, G_D, k_ber_0, k_d_D
)
return jnp.column_stack([G_prime, G_double_prime])
elif mode == "startup":
if gamma_dot is None:
raise ValueError("startup mode requires gamma_dot")
return hvm_startup_stress_linear_vec(
X_jax, gamma_dot, G_P, G_E, G_D, k_ber_0, k_d_D
)
elif mode == "relaxation":
D_val = 0.0 # No damage in linear model_function
return hvm_relaxation_modulus_vec(
X_jax, G_P, G_E, G_D, k_ber_0, k_d_D, D_val
)
elif mode == "creep":
if sigma_applied is None:
raise ValueError("creep mode requires sigma_applied")
J = hvm_creep_compliance_linear_vec(X_jax, G_P, G_E, G_D, k_ber_0, k_d_D)
return sigma_applied * J
elif mode == "laos":
if gamma_0 is None or omega is None:
raise ValueError("LAOS mode requires gamma_0 and omega")
# Extract time from (2, N) stacked [time, strain] input
t_arr = X_jax[0] if X_jax.ndim == 2 else X_jax
# Use ODE solver for LAOS (cannot use analytical)
params_dict = {
"G_P": G_P,
"G_E": G_E,
"G_D": G_D,
"k_d_D": k_d_D,
"nu_0": nu_0,
"E_a": E_a,
"V_act": V_act,
"T": T,
"Gamma_0": 0.0,
"lambda_crit": 10.0,
}
sol = hvm_solve_laos(
t_arr,
gamma_0,
omega,
params_dict,
kinetics=self._kinetics,
include_damage=False,
include_dissociative=self._include_dissociative,
)
# Mask failed ODE solutions with NaN so Bayesian NaN guard rejects them
ys = _mask_failed_solution_ys(sol)
stress = jax.vmap(
lambda y: hvm_total_stress_shear(
y[9],
y[2],
y[5],
y[8],
G_P,
G_E,
G_D,
y[10],
)
)(ys)
return stress
else:
logger.warning(f"Unknown test_mode '{mode}', defaulting to flow_curve")
return hvm_steady_shear_stress_vec(X_jax, G_P, G_D, k_d_D)
# =========================================================================
# Factory Methods (Limiting Cases)
# =========================================================================
[docs]
@classmethod
def neo_hookean(cls, G_P: float = 1e4) -> HVMLocal:
"""Create neo-Hookean elastic solid (G_E=0, G_D=0).
Parameters
----------
G_P : float
Permanent network modulus (Pa)
Returns
-------
HVMLocal
Model with only P-network active
"""
model = cls(include_dissociative=False)
model.parameters.set_value("G_P", G_P)
model.parameters.set_value("G_E", 0.0)
return model
[docs]
@classmethod
def maxwell(cls, G_D: float = 1e4, k_d_D: float = 1.0) -> HVMLocal:
"""Create single Maxwell fluid (G_P=0, G_E=0).
Parameters
----------
G_D : float
Network modulus (Pa)
k_d_D : float
Dissociation rate (1/s)
Returns
-------
HVMLocal
Model with only D-network active
"""
model = cls(include_dissociative=True)
model.parameters.set_value("G_P", 0.0)
model.parameters.set_value("G_E", 0.0)
model.parameters.set_value("G_D", G_D)
model.parameters.set_value("k_d_D", k_d_D)
return model
[docs]
@classmethod
def zener(cls, G_P: float = 1e4, G_D: float = 1e4, k_d_D: float = 1.0) -> HVMLocal:
"""Create Zener/SLS model (G_E=0).
Parameters
----------
G_P : float
Permanent network modulus (Pa)
G_D : float
Dissociative network modulus (Pa)
k_d_D : float
Dissociation rate (1/s)
Returns
-------
HVMLocal
Model with P + D networks (no vitrimer exchange)
"""
model = cls(include_dissociative=True)
model.parameters.set_value("G_P", G_P)
model.parameters.set_value("G_E", 0.0)
model.parameters.set_value("G_D", G_D)
model.parameters.set_value("k_d_D", k_d_D)
return model
[docs]
@classmethod
def pure_vitrimer(
cls,
G_E: float = 1e4,
nu_0: float = 1e10,
E_a: float = 80e3,
V_act: float = 1e-5,
T: float = 300.0,
) -> HVMLocal:
"""Create pure vitrimer (G_P=0, G_D=0).
Parameters
----------
G_E : float
Exchangeable network modulus (Pa)
nu_0, E_a, V_act, T : float
TST parameters
Returns
-------
HVMLocal
Model with only E-network active
"""
model = cls(include_dissociative=False)
model.parameters.set_value("G_P", 0.0)
model.parameters.set_value("G_E", G_E)
model.parameters.set_value("nu_0", nu_0)
model.parameters.set_value("E_a", E_a)
model.parameters.set_value("V_act", V_act)
model.parameters.set_value("T", T)
return model
[docs]
@classmethod
def partial_vitrimer(
cls,
G_P: float = 1e4,
G_E: float = 1e4,
nu_0: float = 1e10,
E_a: float = 80e3,
V_act: float = 1e-5,
T: float = 300.0,
) -> HVMLocal:
"""Create partial vitrimer (G_D=0, Meng 2019).
Parameters
----------
G_P : float
Permanent network modulus (Pa)
G_E : float
Exchangeable network modulus (Pa)
nu_0, E_a, V_act, T : float
TST parameters
Returns
-------
HVMLocal
Model with P + E networks (no dissociative bonds)
"""
model = cls(include_dissociative=False)
model.parameters.set_value("G_P", G_P)
model.parameters.set_value("G_E", G_E)
model.parameters.set_value("nu_0", nu_0)
model.parameters.set_value("E_a", E_a)
model.parameters.set_value("V_act", V_act)
model.parameters.set_value("T", T)
return model