"""HVNMLocal: Local (0D) Hybrid Vitrimer Nanocomposite Model.
Single-point constitutive model for NP-filled vitrimers with four subnetworks:
1. Permanent (P): covalent crosslinks, amplified by Guth-Gold X(phi)
2. Exchangeable (E): associative vitrimer bonds with matrix BER kinetics
3. Dissociative (D): physical reversible bonds, standard Maxwell
4. Interphase (I): NP-bound chains with distinct interfacial BER kinetics
Supports 6 rheological protocols:
- Flow curve (analytical: sigma_E -> 0, sigma_I -> 0 at steady state)
- SAOS (analytical: three Maxwell modes + amplified permanent plateau)
- Startup shear (ODE with dual TST feedback)
- Stress relaxation (ODE: quad-exponential + amplified plateau)
- Creep (ODE: implicit gamma_dot solve, 4-network stress balance)
- LAOS (ODE: nonlinear oscillatory response with Payne effect)
Limiting Cases
--------------
- phi=0: Recovers HVM exactly (primary validation criterion)
- G_E=0, G_D=0, G_I=0: Amplified neo-Hookean
- G_P=0, G_E=0, G_I=0: Maxwell fluid
- k_BER^int -> 0: Frozen interphase (dead layer)
- Full: 4-network HVNM
References
----------
- Vernerey, Long, & Brighenti (2017). JMPS 107, 1-20.
- Li, Zhao, Duan, Zhang, Liu (2024). Langmuir 40, 7550-7560.
- Karim, Vernerey, Sain (2025). Macromolecules 58, 4899-4912.
"""
from __future__ import annotations
import logging
from typing import Literal
import numpy as np
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.hvnm._base import HVNMBase
from rheojax.models.hvnm._kernels import (
hvnm_ber_rate_constant_interphase,
hvnm_ber_rate_constant_matrix,
hvnm_creep_compliance_linear_vec,
hvnm_effective_phi,
hvnm_guth_gold,
hvnm_interphase_fraction,
hvnm_interphase_modulus,
hvnm_relaxation_modulus_vec,
hvnm_relaxation_modulus_with_diffusion_vec,
hvnm_saos_moduli_vec,
hvnm_startup_stress_linear_vec,
hvnm_steady_shear_stress_vec,
hvnm_total_normal_stress_1,
hvnm_total_stress_shear,
)
from rheojax.models.hvnm._kernels_diffrax import (
_mask_failed_solution_ys,
hvnm_solve_creep,
hvnm_solve_laos,
hvnm_solve_relaxation,
hvnm_solve_startup,
)
jax, jnp = safe_import_jax()
logger = logging.getLogger(__name__)
_MISSING = object()
[docs]
@ModelRegistry.register(
"hvnm_local",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
@ModelRegistry.register(
"hvnm",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
Protocol.RELAXATION,
Protocol.CREEP,
Protocol.LAOS,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class HVNMLocal(HVNMBase):
"""Local (0D) Hybrid Vitrimer Nanocomposite Model.
A constitutive model for NP-filled vitrimers combining:
- Permanent network (P): covalent crosslinks, amplified by X(phi)
- Exchangeable network (E): vitrimer bonds with matrix TST kinetics
- Dissociative network (D): physical bonds, Maxwell relaxation
- Interphase network (I): NP-bound chains with interfacial TST kinetics
Parameters
----------
kinetics : {'stress', 'stretch'}, default 'stress'
TST coupling mechanism for bond exchange rates
include_damage : bool, default False
Enable matrix cooperative damage shielding
include_dissociative : bool, default True
Include dissociative (D) network
include_interfacial_damage : bool, default False
Enable interfacial damage with self-healing
include_diffusion : bool, default False
Enable diffusion-limited relaxation tails
Examples
--------
>>> from rheojax.models import HVNMLocal
>>> model = HVNMLocal()
>>> model.parameters.set_value("phi", 0.1)
>>> omega = np.logspace(-2, 2, 50)
>>> G_prime, G_double_prime = model.predict_saos(omega)
>>> # Unfilled limit (recovers HVM)
>>> model = HVNMLocal()
>>> model.parameters.set_value("phi", 0.0)
"""
[docs]
def __init__(
self,
kinetics: Literal["stress", "stretch"] = "stress",
include_damage: bool = False,
include_dissociative: bool = True,
include_interfacial_damage: bool = False,
include_diffusion: bool = False,
):
super().__init__(
kinetics=kinetics,
include_damage=include_damage,
include_dissociative=include_dissociative,
include_interfacial_damage=include_interfacial_damage,
include_diffusion=include_diffusion,
)
self._setup_parameters()
self._test_mode = None
self._gamma_dot_applied = None
self._sigma_applied = None
self._gamma_0 = None
self._omega_laos = None
logger.info(
"HVNMLocal initialized",
extra={
"kinetics": kinetics,
"include_damage": include_damage,
"include_dissociative": include_dissociative,
"include_interfacial_damage": include_interfacial_damage,
"include_diffusion": include_diffusion,
"n_params": len(self.parameters),
},
)
# =========================================================================
# Parameter Helpers
# =========================================================================
def _get_params_dict(self) -> dict[str, float]:
"""Get parameters as dict with defaults for optional params."""
d = self.get_parameter_dict()
d.setdefault("G_D", 0.0)
d.setdefault("k_d_D", 1.0)
d.setdefault("Gamma_0", 0.0)
d.setdefault("lambda_crit", 10.0)
d.setdefault("Gamma_0_int", 0.0)
d.setdefault("lambda_crit_int", 10.0)
d.setdefault("h_0", 0.0)
d.setdefault("E_a_heal", 100e3)
d.setdefault("n_h", 1.0)
return d
def _get_derived_params(self, p_dict: dict) -> dict[str, float]:
"""Compute derived NP geometry quantities from parameter dict.
Parameters
----------
p_dict : dict
Parameter name → value mapping
Returns
-------
dict
Derived quantities: G_I_eff, X_phi, X_I, k_BER_mat_0, k_BER_int_0, phi_I
"""
phi = p_dict.get("phi", 0.0)
R_NP = p_dict.get("R_NP", 20e-9)
delta_m = p_dict.get("delta_m", 10e-9)
G_E = p_dict.get("G_E", 0.0)
beta_I = p_dict.get("beta_I", 3.0)
nu_0 = p_dict.get("nu_0", 1e10)
E_a = p_dict.get("E_a", 80e3)
nu_0_int = p_dict.get("nu_0_int", 1e10)
E_a_int = p_dict.get("E_a_int", 90e3)
T = p_dict.get("T", 300.0)
delta_g = 1e-9 # Default glassy layer thickness
phi_I = hvnm_interphase_fraction(phi, R_NP, delta_g, delta_m)
G_I_eff = hvnm_interphase_modulus(G_E, beta_I, phi_I)
X_phi = hvnm_guth_gold(phi)
# Effective phi for interphase amplification
phi_eff = hvnm_effective_phi(phi, R_NP, delta_g)
X_I = hvnm_guth_gold(phi_eff)
k_BER_mat_0 = hvnm_ber_rate_constant_matrix(nu_0, E_a, T)
k_BER_int_0 = hvnm_ber_rate_constant_interphase(nu_0_int, E_a_int, T)
return {
"G_I_eff": G_I_eff,
"X_phi": X_phi,
"X_I": X_I,
"phi_I": phi_I,
"k_BER_mat_0": k_BER_mat_0,
"k_BER_int_0": k_BER_int_0,
}
def _get_ode_args(self, p_dict: dict | None = None) -> dict:
"""Build complete ODE args dict with derived quantities."""
if p_dict is None:
p_dict = self._get_params_dict()
derived = self._get_derived_params(p_dict)
args = {**p_dict, **derived}
# Ensure all interphase params have defaults
args.setdefault("nu_0_int", 1e10)
args.setdefault("E_a_int", 90e3)
args.setdefault("V_act_int", 5e-6)
args.setdefault("Gamma_0_int", 0.0)
args.setdefault("lambda_crit_int", 10.0)
args.setdefault("h_0", 0.0)
args.setdefault("E_a_heal", 100e3)
args.setdefault("n_h", 1.0)
return args
# =========================================================================
# Flow Curve (Analytical)
# =========================================================================
[docs]
def predict_flow_curve(
self, gamma_dot: np.ndarray, return_components: bool = False
) -> np.ndarray | dict[str, np.ndarray]:
"""Predict steady-state flow curve.
At steady state, mu^E -> mu^E_nat and mu^I -> mu^I_nat,
so sigma_E -> 0 and sigma_I -> 0.
Only the D-network contributes viscous stress: sigma_D = eta_D * gamma_dot.
Parameters
----------
gamma_dot : array-like
Shear rate array (1/s)
return_components : bool, default False
If True, return dict with subnetwork contributions
Returns
-------
np.ndarray or dict
Steady-state stress (Pa) or component dict
"""
G_D = self.G_D
k_d_D = self.k_d_D
if G_D is None:
raise ValueError("G_D must not be None")
if k_d_D is None:
raise ValueError("k_d_D must not be None")
gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64)
sigma = hvnm_steady_shear_stress_vec(gamma_dot_jax, G_D, k_d_D)
if return_components:
eta_D = G_D / jnp.maximum(k_d_D, 1e-30)
sigma_D = eta_D * gamma_dot_jax
return {
"stress": np.asarray(sigma),
"sigma_P": np.zeros_like(np.asarray(gamma_dot)),
"sigma_E": np.zeros_like(np.asarray(gamma_dot)),
"sigma_D": np.asarray(sigma_D),
"sigma_I": np.zeros_like(np.asarray(gamma_dot)),
"eta_eff": np.asarray(sigma / jnp.maximum(gamma_dot_jax, 1e-30)),
}
return np.asarray(sigma)
# =========================================================================
# SAOS (Analytical)
# =========================================================================
[docs]
def predict_saos(
self, omega: np.ndarray, return_components: bool = True
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
"""Predict SAOS storage and loss moduli.
Three Maxwell modes (E, D, I) plus amplified permanent plateau (P).
Parameters
----------
omega : array-like
Angular frequency array (rad/s)
return_components : bool, default True
If True, return (G', G''); if False, return |G*|
Returns
-------
tuple of (np.ndarray, np.ndarray) or np.ndarray
(G', G'') or |G*|
"""
G_P = self.G_P
G_E = self.G_E
G_D = self.G_D
k_d_D = self.k_d_D
if G_P is None or G_E is None:
raise ValueError("G_P, G_E must not be None")
if G_D is None:
raise ValueError("G_D must not be None")
if k_d_D is None:
raise ValueError("k_d_D must not be None")
omega_jax = jnp.asarray(omega, dtype=jnp.float64)
p = self._get_params_dict()
d = self._get_derived_params(p)
G_prime, G_double_prime = hvnm_saos_moduli_vec(
omega_jax,
G_P,
G_E,
G_D,
d["G_I_eff"],
d["X_phi"],
d["X_I"],
d["k_BER_mat_0"],
k_d_D,
d["k_BER_int_0"],
0.0,
0.0, # D=0, D_int=0
)
if return_components:
return np.asarray(G_prime), np.asarray(G_double_prime)
return np.asarray(jnp.sqrt(jnp.maximum(G_prime**2 + G_double_prime**2, 1e-30)))
# =========================================================================
# Startup Shear
# =========================================================================
[docs]
def simulate_startup(
self,
t: np.ndarray,
gamma_dot: float,
return_full: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
"""Simulate startup shear flow with dual TST feedback.
Parameters
----------
t : array-like
Time array (s)
gamma_dot : float
Applied shear rate (1/s)
return_full : bool, default False
If True, return dict with all trajectories
Returns
-------
np.ndarray or dict
Stress array or full trajectory dict
"""
G_P = self.G_P
G_E = self.G_E
G_D = self.G_D
if G_P is None or G_E is None or G_D is None:
raise ValueError("G_P, G_E, G_D must not be None")
self._gamma_dot_applied = gamma_dot
t_jax = jnp.asarray(t, dtype=jnp.float64)
args = self._get_ode_args()
sol = hvnm_solve_startup(
t_jax,
gamma_dot,
args,
kinetics=self._kinetics,
include_damage=self._include_damage,
include_dissociative=self._include_dissociative,
include_interfacial_damage=self._include_interfacial_damage,
)
ys = sol.ys
if ys is None:
raise ValueError("ODE solver returned None for ys")
D_int_col = ys[
:, 17
] # Always 18-component state; zero when damage not included
# Compute total stress from state
X_phi = args["X_phi"]
X_I = args["X_I"]
G_I_eff = args["G_I_eff"]
stress = jax.vmap(
lambda y_D_int: hvnm_total_stress_shear(
y_D_int[0][9], # gamma
y_D_int[0][2], # mu_E_xy
y_D_int[0][5], # mu_E_nat_xy
y_D_int[0][8], # mu_D_xy
y_D_int[0][13], # mu_I_xy
y_D_int[0][16], # mu_I_nat_xy
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
y_D_int[0][10], # D
y_D_int[1], # D_int
)
)((ys, D_int_col))
# Handle solver failure
stress = jnp.where(
sol.result == diffrax.RESULTS.successful,
stress,
jnp.nan * jnp.ones_like(stress),
)
if return_full:
return {
"time": np.asarray(t),
"stress": np.asarray(stress),
"strain": np.asarray(ys[:, 9]),
"mu_E_xx": np.asarray(ys[:, 0]),
"mu_E_yy": np.asarray(ys[:, 1]),
"mu_E_xy": np.asarray(ys[:, 2]),
"mu_E_nat_xx": np.asarray(ys[:, 3]),
"mu_E_nat_yy": np.asarray(ys[:, 4]),
"mu_E_nat_xy": np.asarray(ys[:, 5]),
"mu_D_xx": np.asarray(ys[:, 6]),
"mu_D_yy": np.asarray(ys[:, 7]),
"mu_D_xy": np.asarray(ys[:, 8]),
"damage": np.asarray(ys[:, 10]),
"mu_I_xx": np.asarray(ys[:, 11]),
"mu_I_yy": np.asarray(ys[:, 12]),
"mu_I_xy": np.asarray(ys[:, 13]),
"mu_I_nat_xx": np.asarray(ys[:, 14]),
"mu_I_nat_yy": np.asarray(ys[:, 15]),
"mu_I_nat_xy": np.asarray(ys[:, 16]),
"damage_int": np.asarray(D_int_col),
}
return np.asarray(stress)
# =========================================================================
# Stress Relaxation
# =========================================================================
[docs]
def simulate_relaxation(
self,
t: np.ndarray,
gamma_step: float = 1.0,
return_full: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
"""Simulate stress relaxation after step strain.
Parameters
----------
t : array-like
Time array after step (s)
gamma_step : float, default 1.0
Applied step strain
return_full : bool, default False
If True, return full trajectory dict
Returns
-------
np.ndarray or dict
G(t) relaxation modulus or trajectory dict
"""
G_P = self.G_P
G_E = self.G_E
G_D = self.G_D
if G_P is None or G_E is None or G_D is None:
raise ValueError("G_P, G_E, G_D must not be None")
t_jax = jnp.asarray(t, dtype=jnp.float64)
args = self._get_ode_args()
sol = hvnm_solve_relaxation(
t_jax,
gamma_step,
args,
kinetics=self._kinetics,
include_damage=self._include_damage,
include_dissociative=self._include_dissociative,
include_interfacial_damage=self._include_interfacial_damage,
)
ys = sol.ys
if ys is None:
raise ValueError("ODE solver returned None for ys")
D_int_col = ys[
:, 17
] # Always 18-component state; zero when damage not included
X_phi = args["X_phi"]
X_I = args["X_I"]
G_I_eff = args["G_I_eff"]
stress = jax.vmap(
lambda y_D_int: hvnm_total_stress_shear(
y_D_int[0][9],
y_D_int[0][2],
y_D_int[0][5],
y_D_int[0][8],
y_D_int[0][13],
y_D_int[0][16],
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
y_D_int[0][10],
y_D_int[1],
)
)((ys, D_int_col))
G_t = stress / jnp.maximum(jnp.abs(gamma_step), 1e-30)
G_t = jnp.where(
sol.result == diffrax.RESULTS.successful,
G_t,
jnp.nan * jnp.ones_like(G_t),
)
if return_full:
return {
"time": np.asarray(t),
"G_t": np.asarray(G_t),
"stress": np.asarray(stress),
"mu_E_xy": np.asarray(ys[:, 2]),
"mu_E_nat_xy": np.asarray(ys[:, 5]),
"mu_D_xy": np.asarray(ys[:, 8]),
"mu_I_xy": np.asarray(ys[:, 13]),
"mu_I_nat_xy": np.asarray(ys[:, 16]),
"damage": np.asarray(ys[:, 10]),
"damage_int": np.asarray(D_int_col),
}
return np.asarray(G_t)
# =========================================================================
# Creep
# =========================================================================
[docs]
def simulate_creep(
self,
t: np.ndarray,
sigma_0: float,
return_full: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
"""Simulate creep under constant stress.
Parameters
----------
t : array-like
Time array (s)
sigma_0 : float
Applied constant stress (Pa)
return_full : bool, default False
If True, return full trajectory dict
Returns
-------
np.ndarray or dict
Strain gamma(t) or trajectory dict
"""
self._sigma_applied = sigma_0
t_jax = jnp.asarray(t, dtype=jnp.float64)
args = self._get_ode_args()
if args is None:
raise ValueError("ODE args dict must not be None")
sol = hvnm_solve_creep(
t_jax,
sigma_0,
args,
kinetics=self._kinetics,
include_damage=self._include_damage,
include_dissociative=self._include_dissociative,
include_interfacial_damage=self._include_interfacial_damage,
)
ys = sol.ys
if ys is None:
raise ValueError("ODE solver returned None for ys")
gamma = ys[:, 9]
gamma = jnp.where(
sol.result == diffrax.RESULTS.successful,
gamma,
jnp.nan * jnp.ones_like(gamma),
)
if return_full:
J_t = gamma / jnp.maximum(jnp.abs(sigma_0), 1e-30)
D_int_col = ys[
:, 17
] # Always 18-component state; zero when damage not included
return {
"time": np.asarray(t),
"strain": np.asarray(gamma),
"compliance": np.asarray(J_t),
"mu_E_xy": np.asarray(ys[:, 2]),
"mu_E_nat_xy": np.asarray(ys[:, 5]),
"mu_D_xy": np.asarray(ys[:, 8]),
"mu_I_xy": np.asarray(ys[:, 13]),
"mu_I_nat_xy": np.asarray(ys[:, 16]),
"damage": np.asarray(ys[:, 10]),
"damage_int": np.asarray(D_int_col),
}
return np.asarray(gamma)
# =========================================================================
# LAOS
# =========================================================================
[docs]
def simulate_laos(
self,
t: np.ndarray,
gamma_0: float,
omega: float,
) -> dict[str, np.ndarray]:
"""Simulate LAOS (Large Amplitude Oscillatory Shear).
Parameters
----------
t : array-like
Time array (s)
gamma_0 : float
Strain amplitude
omega : float
Angular frequency (rad/s)
Returns
-------
dict
Keys: time, strain, stress, gamma_dot, N1,
mu_E_xy, mu_E_nat_xy, mu_D_xy, mu_I_xy, mu_I_nat_xy,
damage, damage_int
"""
G_P = self.G_P
G_E = self.G_E
G_D = self.G_D
if G_P is None or G_E is None or G_D is None:
raise ValueError("G_P, G_E, G_D must not be None")
self._gamma_0 = gamma_0
self._omega_laos = omega
t_jax = jnp.asarray(t, dtype=jnp.float64)
args = self._get_ode_args()
sol = hvnm_solve_laos(
t_jax,
gamma_0,
omega,
args,
kinetics=self._kinetics,
include_damage=self._include_damage,
include_dissociative=self._include_dissociative,
include_interfacial_damage=self._include_interfacial_damage,
)
ys = sol.ys
if ys is None:
raise ValueError("ODE solver returned None for ys")
D_int_col = ys[
:, 17
] # Always 18-component state; zero when damage not included
strain = gamma_0 * jnp.sin(omega * t_jax)
gamma_dot_arr = gamma_0 * omega * jnp.cos(omega * t_jax)
X_phi = args["X_phi"]
X_I = args["X_I"]
G_I_eff = args["G_I_eff"]
stress = jax.vmap(
lambda y_D_int: hvnm_total_stress_shear(
y_D_int[0][9],
y_D_int[0][2],
y_D_int[0][5],
y_D_int[0][8],
y_D_int[0][13],
y_D_int[0][16],
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
y_D_int[0][10],
y_D_int[1],
)
)((ys, D_int_col))
N1 = jax.vmap(
lambda y_D_int: hvnm_total_normal_stress_1(
y_D_int[0][0],
y_D_int[0][1],
y_D_int[0][3],
y_D_int[0][4],
y_D_int[0][6],
y_D_int[0][7],
y_D_int[0][11],
y_D_int[0][12],
y_D_int[0][14],
y_D_int[0][15],
G_E,
G_D,
G_I_eff,
X_I,
y_D_int[1],
)
)((ys, D_int_col))
failed = sol.result != diffrax.RESULTS.successful
stress = jnp.where(failed, jnp.nan, stress)
N1 = jnp.where(failed, jnp.nan, N1)
return {
"time": np.asarray(t),
"strain": np.asarray(strain),
"stress": np.asarray(stress),
"gamma_dot": np.asarray(gamma_dot_arr),
"N1": np.asarray(N1),
"mu_E_xy": np.asarray(ys[:, 2]),
"mu_E_nat_xy": np.asarray(ys[:, 5]),
"mu_D_xy": np.asarray(ys[:, 8]),
"mu_I_xy": np.asarray(ys[:, 13]),
"mu_I_nat_xy": np.asarray(ys[:, 16]),
"damage": np.asarray(ys[:, 10]),
"damage_int": np.asarray(D_int_col),
}
# =========================================================================
# Normal Stresses
# =========================================================================
[docs]
def predict_normal_stresses(
self, gamma_dot: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
"""Predict steady-state normal stress differences.
At steady state, E and I networks contribute zero normal stress.
Only D-network contributes N1.
N1 = 2 * G_D * (gamma_dot / k_d_D)^2
N2 = 0
Parameters
----------
gamma_dot : array-like
Shear rate array (1/s)
Returns
-------
tuple of (np.ndarray, np.ndarray)
(N1, N2) arrays (Pa)
"""
G_D = self.G_D
k_d_D = self.k_d_D
if G_D is None:
raise ValueError("G_D must not be None")
if k_d_D is None:
raise ValueError("k_d_D must not be None")
gamma_dot_jax = jnp.asarray(gamma_dot, dtype=jnp.float64)
Wi_D = gamma_dot_jax / jnp.maximum(k_d_D, 1e-30)
N1 = 2.0 * G_D * Wi_D**2
N2 = jnp.zeros_like(N1)
return np.asarray(N1), np.asarray(N2)
# =========================================================================
# LAOS Harmonic Extraction
# =========================================================================
# =========================================================================
# Payne Effect Parameters
# =========================================================================
[docs]
def get_payne_parameters(self) -> dict[str, float]:
"""Extract Payne effect parameters from model.
The Payne effect manifests as modulus drop with increasing
strain amplitude, driven by interphase softening.
Returns
-------
dict
G_0: zero-strain modulus
G_inf: high-strain modulus (X*G_P only)
gamma_c: approximate critical strain (1/X_I)
"""
G_P = self.G_P
G_E = self.G_E
G_D = self.G_D
if G_P is None or G_E is None or G_D is None:
raise ValueError("G_P, G_E, G_D must not be None")
d = self._get_derived_params(self._get_params_dict())
G_I_amp = d["G_I_eff"] * d["X_I"]
G_0 = G_P * d["X_phi"] + G_E + G_D + G_I_amp
G_inf = G_P * d["X_phi"] # Only permanent plateau at large strain
gamma_c = 1.0 / jnp.maximum(d["X_I"], 1.0) # Critical strain ~ 1/X_I
return {"G_0": float(G_0), "G_inf": float(G_inf), "gamma_c": float(gamma_c)}
# =========================================================================
# Fitting (NLSQ)
# =========================================================================
def _fit(self, x, y, **kwargs):
"""Fit model to data using protocol-aware optimization.
Parameters
----------
x : array-like
Independent variable
y : array-like
Dependent variable
**kwargs
test_mode, gamma_dot, sigma_applied, gamma_0, omega, etc.
Returns
-------
self
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
self._test_mode = test_mode
x_jax = jnp.asarray(x, dtype=jnp.float64)
# Preserve complex dtype for oscillation data (G* = G' + iG'')
y_arr = np.asarray(y)
if np.iscomplexobj(y_arr):
y_jax = jnp.asarray(y_arr, dtype=jnp.complex128)
else:
y_jax = jnp.asarray(y_arr, dtype=jnp.float64)
self._gamma_dot_applied = kwargs.get("gamma_dot")
self._sigma_applied = kwargs.get("sigma_applied")
self._gamma_0 = kwargs.get("gamma_0")
self._omega_laos = kwargs.get("omega")
# Filter out fitting-specific and BaseModel kwargs
fwd_kwargs = {
k: v
for k, v in kwargs.items()
if k
not in (
"test_mode",
"deformation_mode",
"poisson_ratio",
"use_log_residuals",
"use_jax",
"method",
"max_iter",
"use_multi_start",
"n_starts",
"perturb_factor",
)
}
def model_fn(x_fit, params):
return self.model_function(x_fit, params, test_mode=test_mode, **fwd_kwargs)
objective = create_least_squares_objective(
model_fn,
x_jax,
y_jax,
use_log_residuals=kwargs.get(
"use_log_residuals", test_mode == "flow_curve"
),
)
# ODE-based protocols use diffrax with custom_vjp, incompatible with
# NLSQ forward-mode AD. Default to scipy to avoid failed attempt overhead.
_ode_protocols = {"laos"}
_default_method = "scipy" if test_mode in _ode_protocols else "auto"
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", _default_method),
max_iter=kwargs.get("max_iter", 2000),
)
self.fitted_ = True
self._nlsq_result = result
logger.info(
f"Fitted HVNMLocal: G_P={self.G_P:.2e}, G_E={self.G_E:.2e}, "
f"phi={self.phi:.3f}"
)
return self
def _predict(self, X, **kwargs):
"""Predict response using fitted parameters."""
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
param_values = jnp.array(
[self.parameters.get_value(n) for n in self.parameters.keys()],
dtype=jnp.float64,
)
fwd_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ("test_mode", "deformation_mode", "poisson_ratio")
}
result = np.asarray(
self.model_function(X, param_values, test_mode=test_mode, **fwd_kwargs)
)
# model_function returns (N,2) [G', G''] for oscillation;
# convert to complex G* for consistent API
if test_mode == "oscillation" and result.ndim == 2 and result.shape[1] == 2:
result = result[:, 0] + 1j * result[:, 1]
return result
# =========================================================================
# Model Function (NLSQ / NumPyro)
# =========================================================================
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""NumPyro/BayesianMixin model function for HVNM.
Routes to appropriate JAX-traceable prediction based on test_mode.
Required by BayesianMixin for NumPyro NUTS sampling.
Parameters
----------
X : array-like
Independent variable
params : array-like
Parameter values in ParameterSet order
test_mode : str, optional
Override stored test mode
**kwargs
Protocol-specific: gamma_dot, sigma_applied, gamma_0, omega
Returns
-------
jnp.ndarray
Predicted response
"""
# Unpack parameters by position
p_names = list(self.parameters.keys())
p_dict = dict(zip(p_names, params, strict=True))
G_P = p_dict["G_P"]
G_E = p_dict["G_E"]
nu_0 = p_dict["nu_0"]
E_a = p_dict["E_a"]
V_act = p_dict["V_act"]
T = p_dict["T"]
G_D = p_dict.get("G_D", 0.0)
k_d_D = p_dict.get("k_d_D", 1.0)
# Interphase params
beta_I = p_dict.get("beta_I", 3.0)
nu_0_int = p_dict.get("nu_0_int", 1e10)
E_a_int = p_dict.get("E_a_int", 90e3)
V_act_int = p_dict.get("V_act_int", 5e-6)
phi = p_dict.get("phi", 0.0)
R_NP = p_dict.get("R_NP", 20e-9)
delta_m = p_dict.get("delta_m", 10e-9)
mode = (
test_mode
if test_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "flow_curve"
)
)
X_jax = jnp.asarray(X, dtype=jnp.float64)
# Use sentinel pattern to avoid swallowing falsy values (e.g. gamma_dot=0.0)
_gd = kwargs.get("gamma_dot", _MISSING)
gamma_dot = (
_gd if _gd is not _MISSING else getattr(self, "_gamma_dot_applied", None)
)
_sa = kwargs.get("sigma_applied", _MISSING)
sigma_applied = (
_sa if _sa is not _MISSING else getattr(self, "_sigma_applied", None)
)
_g0 = kwargs.get("gamma_0", _MISSING)
gamma_0 = _g0 if _g0 is not _MISSING else getattr(self, "_gamma_0", None)
_om = kwargs.get("omega", _MISSING)
omega = _om if _om is not _MISSING else getattr(self, "_omega_laos", None)
# Compute derived quantities (JAX-traceable)
delta_g = 1e-9
phi_I = hvnm_interphase_fraction(phi, R_NP, delta_g, delta_m)
G_I_eff = hvnm_interphase_modulus(G_E, beta_I, phi_I)
X_phi = hvnm_guth_gold(phi)
phi_eff = hvnm_effective_phi(phi, R_NP, delta_g)
X_I = hvnm_guth_gold(phi_eff)
k_BER_mat_0 = hvnm_ber_rate_constant_matrix(nu_0, E_a, T)
k_BER_int_0 = hvnm_ber_rate_constant_interphase(nu_0_int, E_a_int, T)
if mode in ["flow_curve", "steady_shear", "rotation"]:
return hvnm_steady_shear_stress_vec(X_jax, G_D, k_d_D)
elif mode == "oscillation":
G_prime, G_double_prime = hvnm_saos_moduli_vec(
X_jax,
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
k_BER_mat_0,
k_d_D,
k_BER_int_0,
0.0,
0.0, # D=0, D_int=0
)
return jnp.column_stack([G_prime, G_double_prime])
elif mode == "startup":
if gamma_dot is None:
raise ValueError("startup mode requires gamma_dot")
return hvnm_startup_stress_linear_vec(
X_jax,
gamma_dot,
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
k_BER_mat_0,
k_d_D,
k_BER_int_0,
0.0, # D_int=0
)
elif mode == "relaxation":
if self._include_diffusion:
k_diff_mat = p_dict.get("k_diff_0_mat", 0.0)
k_diff_int = p_dict.get("k_diff_0_int", 0.0)
return hvnm_relaxation_modulus_with_diffusion_vec(
X_jax,
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
k_BER_mat_0,
k_d_D,
k_BER_int_0,
k_diff_mat,
k_diff_int,
0.0,
0.0, # D=0, D_int=0
)
return hvnm_relaxation_modulus_vec(
X_jax,
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
k_BER_mat_0,
k_d_D,
k_BER_int_0,
0.0,
0.0, # D=0, D_int=0
)
elif mode == "creep":
if sigma_applied is None:
raise ValueError("creep mode requires sigma_applied")
J = hvnm_creep_compliance_linear_vec(
X_jax,
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
k_BER_mat_0,
k_d_D,
k_BER_int_0,
)
return sigma_applied * J
elif mode == "laos":
if gamma_0 is None or omega is None:
raise ValueError("LAOS mode requires gamma_0 and omega")
# Extract time from (2, N) stacked [time, strain] input
t_arr = X_jax[0] if X_jax.ndim == 2 else X_jax
params_dict = {
"G_P": G_P,
"G_E": G_E,
"G_D": G_D,
"k_d_D": k_d_D,
"nu_0": nu_0,
"E_a": E_a,
"V_act": V_act,
"T": T,
"G_I_eff": G_I_eff,
"X_phi": X_phi,
"X_I": X_I,
"nu_0_int": nu_0_int,
"E_a_int": E_a_int,
"V_act_int": V_act_int,
"Gamma_0": 0.0,
"lambda_crit": 10.0,
"Gamma_0_int": 0.0,
"lambda_crit_int": 10.0,
"h_0": 0.0,
"E_a_heal": 100e3,
"n_h": 1.0,
}
sol = hvnm_solve_laos(
t_arr,
gamma_0,
omega,
params_dict,
kinetics=self._kinetics,
include_damage=False,
include_dissociative=self._include_dissociative,
include_interfacial_damage=False,
)
# Mask failed ODE solutions with NaN so Bayesian NaN guard rejects them
ys = _mask_failed_solution_ys(sol)
stress = jax.vmap(
lambda y: hvnm_total_stress_shear(
y[9],
y[2],
y[5],
y[8],
y[13],
y[16],
G_P,
G_E,
G_D,
G_I_eff,
X_phi,
X_I,
y[10],
0.0,
)
)(ys)
return stress
else:
logger.warning(f"Unknown test_mode '{mode}', defaulting to flow_curve")
return hvnm_steady_shear_stress_vec(X_jax, G_D, k_d_D)
# =========================================================================
# Factory Methods (Limiting Cases)
# =========================================================================
[docs]
@classmethod
def unfilled_vitrimer(
cls,
G_P: float = 1e4,
G_E: float = 1e4,
G_D: float = 1e3,
nu_0: float = 1e10,
E_a: float = 80e3,
V_act: float = 1e-5,
T: float = 300.0,
k_d_D: float = 1.0,
) -> HVNMLocal:
"""Create unfilled vitrimer (phi=0, recovers HVM exactly).
Parameters
----------
G_P, G_E, G_D : float
Subnetwork moduli (Pa)
nu_0, E_a, V_act, T : float
TST parameters
k_d_D : float
Dissociative rate (1/s)
Returns
-------
HVNMLocal
Model with phi=0 (no interphase contribution)
"""
model = cls(include_dissociative=True)
model.parameters.set_value("G_P", G_P)
model.parameters.set_value("G_E", G_E)
model.parameters.set_value("G_D", G_D)
model.parameters.set_value("nu_0", nu_0)
model.parameters.set_value("E_a", E_a)
model.parameters.set_value("V_act", V_act)
model.parameters.set_value("T", T)
model.parameters.set_value("k_d_D", k_d_D)
model.parameters.set_value("phi", 0.0)
return model
[docs]
@classmethod
def filled_elastomer(
cls,
G_P: float = 1e4,
phi: float = 0.1,
R_NP: float = 20e-9,
delta_m: float = 10e-9,
) -> HVNMLocal:
"""Create filled elastomer (no exchange networks).
Parameters
----------
G_P : float
Permanent network modulus (Pa)
phi : float
NP volume fraction
R_NP : float
NP radius (m)
delta_m : float
Mobile interphase thickness (m)
Returns
-------
HVNMLocal
Model with only amplified P-network (no E, D, or active I)
"""
model = cls(include_dissociative=False)
model.parameters.set_value("G_P", G_P)
model.parameters.set_value("G_E", 0.0)
model.parameters.set_value("phi", phi)
model.parameters.set_value("R_NP", R_NP)
model.parameters.set_value("delta_m", delta_m)
return model
[docs]
@classmethod
def partial_vitrimer_nc(
cls,
G_P: float = 1e4,
G_E: float = 1e4,
phi: float = 0.1,
nu_0: float = 1e10,
E_a: float = 80e3,
V_act: float = 1e-5,
T: float = 300.0,
**nc_kwargs,
) -> HVNMLocal:
"""Create partial vitrimer nanocomposite (G_D=0).
Parameters
----------
G_P, G_E : float
Network moduli (Pa)
phi : float
NP volume fraction
nu_0, E_a, V_act, T : float
TST parameters
**nc_kwargs
NP geometry: R_NP, delta_m, beta_I, nu_0_int, E_a_int, V_act_int
Returns
-------
HVNMLocal
Model with P + E + I networks (no D)
"""
model = cls(include_dissociative=False)
model.parameters.set_value("G_P", G_P)
model.parameters.set_value("G_E", G_E)
model.parameters.set_value("phi", phi)
model.parameters.set_value("nu_0", nu_0)
model.parameters.set_value("E_a", E_a)
model.parameters.set_value("V_act", V_act)
model.parameters.set_value("T", T)
for key, val in nc_kwargs.items():
if key in model.parameters.keys():
model.parameters.set_value(key, val)
return model
[docs]
@classmethod
def conventional_filled_rubber(
cls,
G_P: float = 1e4,
phi: float = 0.1,
R_NP: float = 20e-9,
delta_m: float = 10e-9,
G_D: float = 1e3,
k_d_D: float = 1.0,
) -> HVNMLocal:
"""Create conventional filled rubber (no E-network, frozen interphase).
Parameters
----------
G_P : float
Permanent network modulus (Pa)
phi : float
NP volume fraction
R_NP, delta_m : float
NP geometry (m)
G_D : float
Dissociative modulus (Pa)
k_d_D : float
Dissociative rate (1/s)
Returns
-------
HVNMLocal
Model with P + D + frozen I (no exchange)
"""
model = cls(include_dissociative=True)
model.parameters.set_value("G_P", G_P)
model.parameters.set_value("G_E", 0.0)
model.parameters.set_value("G_D", G_D)
model.parameters.set_value("k_d_D", k_d_D)
model.parameters.set_value("phi", phi)
model.parameters.set_value("R_NP", R_NP)
model.parameters.set_value("delta_m", delta_m)
return model
[docs]
@classmethod
def matrix_only_exchange(
cls,
G_P: float = 1e4,
G_E: float = 1e4,
phi: float = 0.1,
nu_0: float = 1e10,
E_a: float = 80e3,
V_act: float = 1e-5,
T: float = 300.0,
**nc_kwargs,
) -> HVNMLocal:
"""Create model with frozen interphase (k_BER^int=0).
The interphase acts as a dead (non-exchanging) reinforcement layer.
Parameters
----------
G_P, G_E : float
Network moduli (Pa)
phi : float
NP volume fraction
nu_0, E_a, V_act, T : float
Matrix TST parameters
**nc_kwargs
NP geometry: R_NP, delta_m, beta_I
Returns
-------
HVNMLocal
Model with active matrix exchange, frozen interphase
"""
model = cls(include_dissociative=True)
model.parameters.set_value("G_P", G_P)
model.parameters.set_value("G_E", G_E)
model.parameters.set_value("phi", phi)
model.parameters.set_value("nu_0", nu_0)
model.parameters.set_value("E_a", E_a)
model.parameters.set_value("V_act", V_act)
model.parameters.set_value("T", T)
# Freeze interphase: max barrier within bounds → k_BER^int ≈ 0
model.parameters.set_value("E_a_int", 250e3) # Max allowed → negligible rate
for key, val in nc_kwargs.items():
if key in model.parameters.keys():
model.parameters.set_value(key, val)
return model