Source code for rheojax.models.vlb.multi_network

"""Multi-network VLB (Vernerey-Long-Brighenti) transient network model.

This module implements `VLBMultiNetwork`, a constitutive model for polymers
with N distinct populations of dynamic crosslinks, each with its own
modulus G_i and dissociation rate k_d_i, plus an optional permanent network
and solvent viscosity.

Key Physics
-----------
Multi-network VLB describes heterogeneous networks where:

- N independent transient populations with moduli G_i and rates k_d_i
- Each population has its own distribution tensor mu_i evolving independently
- Optional permanent network (k_d = 0, modulus G_e) for cross-linked gels
- Optional Newtonian solvent (viscosity eta_s)

Total stress:

    sigma = sum_i G_i * (mu_i - I) + G_e * (F.F^T - I) + eta_s * D

This represents a superposition of Maxwell elements with molecular basis:
each mode corresponds to a physical population of crosslinks, not a
mathematical decomposition.

Supported Protocols
-------------------
- FLOW_CURVE: Newtonian (constant k_d), analytical superposition
- OSCILLATION: Multi-mode Maxwell G'(omega), G''(omega) (analytical)
- STARTUP: Analytical superposition of transient terms
- RELAXATION: Prony series G(t) = G_e + sum G_i * exp(-k_d_i * t)
- CREEP: ODE for multi-mode; analytical for 1 mode + permanent (SLS)
- LAOS: Multi-mode ODE integration via diffrax

Example
-------
>>> from rheojax.models.vlb import VLBMultiNetwork
>>> import numpy as np
>>>
>>> # Two transient networks + permanent
>>> model = VLBMultiNetwork(n_modes=2, include_permanent=True)
>>>
>>> # SAOS (multi-mode Maxwell)
>>> omega = np.logspace(-2, 2, 50)
>>> G_star = model.predict(omega, test_mode='oscillation')

References
----------
- Vernerey, F.J., Long, R. & Brighenti, R. (2017). JMPS 107, 1-20.
"""

from __future__ import annotations

import logging

import numpy as np

from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax

diffrax = lazy_import("diffrax")
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.vlb._base import VLBBase
from rheojax.models.vlb._kernels import (
    vlb_creep_compliance_dual_vec,
    vlb_multi_relaxation_vec,
    vlb_multi_saos_vec,
    vlb_multi_startup_stress_vec,
    vlb_multi_steady_viscosity,
)

jax, jnp = safe_import_jax()

logger = logging.getLogger(__name__)
_MISSING = object()


[docs] @ModelRegistry.register( "vlb_multi_network", protocols=[ Protocol.FLOW_CURVE, Protocol.OSCILLATION, Protocol.STARTUP, Protocol.RELAXATION, Protocol.CREEP, Protocol.LAOS, ], deformation_modes=[ DeformationMode.SHEAR, DeformationMode.TENSION, DeformationMode.BENDING, DeformationMode.COMPRESSION, ], ) class VLBMultiNetwork(VLBBase): """Multi-network VLB model: M transient + optional permanent + solvent. Implements a network with N independent transient crosslink populations, each with modulus G_i and dissociation rate k_d_i. The total stress is a superposition of N Maxwell modes. Parameters ---------- n_modes : int, default 2 Number of distinct transient network populations (N >= 1) include_permanent : bool, default False Whether to include a permanent (elastic) network (G_e) Attributes ---------- parameters : ParameterSet Model parameters: [G_0, k_d_0, G_1, k_d_1, ..., eta_s, (G_e)] fitted_ : bool Whether the model has been fitted _n_modes : int Number of transient modes Notes ----- Parameter ordering: [G_0, k_d_0, G_1, k_d_1, ..., G_{N-1}, k_d_{N-1}, eta_s, (G_e)] Total parameter count: 2N + 1 (without permanent) or 2N + 2 (with permanent) See Also -------- VLBLocal : Single transient network (2 parameters) """
[docs] def __init__( self, n_modes: int = 2, include_permanent: bool = False, ): """Initialize multi-network VLB model. Parameters ---------- n_modes : int, default 2 Number of transient network populations (must be >= 1) include_permanent : bool, default False Include permanent elastic network """ if n_modes < 1: raise ValueError(f"n_modes must be >= 1, got {n_modes}") self._n_modes = n_modes self._include_permanent = include_permanent super().__init__() self._setup_parameters() self._test_mode = None
# ========================================================================= # Parameter Setup # ========================================================================= def _setup_parameters(self): """Initialize ParameterSet with multi-network parameters. Parameters are organized as: [G_0, k_d_0, G_1, k_d_1, ..., G_{N-1}, k_d_{N-1}, eta_s, (G_e)] Default values: - G_i = 1000.0 / N (equal distribution) - k_d_i = 10^(i) (logarithmic spacing: 1.0, 10.0, 100.0, ...) - eta_s = 0.0 (no solvent viscosity) - G_e = 0.0 (no permanent network, if included) """ self.parameters = ParameterSet() for i in range(self._n_modes): # Mode modulus self.parameters.add( name=f"G_{i}", value=1000.0 / self._n_modes, bounds=(1e0, 1e8), units="Pa", description=f"Network modulus for transient population {i}", ) # Mode dissociation rate (logarithmically spaced) default_kd = 10.0**i self.parameters.add( name=f"k_d_{i}", value=default_kd, bounds=(1e-6, 1e6), units="1/s", description=f"Dissociation rate for population {i}", ) # Solvent viscosity self.parameters.add( name="eta_s", value=0.0, bounds=(0.0, 1e4), units="Pa·s", description="Solvent viscosity (Newtonian background)", ) # Optional permanent network if self._include_permanent: self.parameters.add( name="G_e", value=0.0, bounds=(0.0, 1e8), units="Pa", description="Permanent (elastic) network modulus", ) # ========================================================================= # Properties # ========================================================================= @property def n_modes(self) -> int: """Number of transient network modes.""" return self._n_modes @property def include_permanent(self) -> bool: """Whether a permanent network is included.""" return self._include_permanent @property def G_e(self) -> float: """Permanent network modulus (Pa). 0 if not included.""" if not self._include_permanent: return 0.0 val = self.parameters.get_value("G_e") return float(val) if val is not None else 0.0 @property def eta_s(self) -> float: """Solvent viscosity (Pa*s).""" val = self.parameters.get_value("eta_s") return float(val) if val is not None else 0.0 @property def G_total(self) -> float: """Total modulus: sum G_i + G_e.""" G, _ = self._get_mode_arrays_numpy() return float(np.sum(G)) + self.G_e @property def eta_0(self) -> float: """Zero-shear viscosity: sum G_i/k_d_i + eta_s.""" G, kd = self._get_mode_arrays_numpy() return float(np.sum(G / kd)) + self.eta_s # ========================================================================= # Mode Array Helpers # ========================================================================= def _get_mode_arrays_numpy(self) -> tuple[np.ndarray, np.ndarray]: """Get mode arrays as numpy arrays. Returns ------- tuple of (np.ndarray, np.ndarray) (G_modes, kd_modes) each shape (N,) """ G_arr = np.array( [ float(v) for i in range(self._n_modes) if (v := self.parameters.get_value(f"G_{i}")) is not None ] ) kd_arr = np.array( [ float(v) for i in range(self._n_modes) if (v := self.parameters.get_value(f"k_d_{i}")) is not None ] ) return G_arr, kd_arr def _unpack_mode_params( self, params: jnp.ndarray ) -> tuple[jnp.ndarray, jnp.ndarray, float, float]: """Unpack mode parameters from a flat JAX array. Parameters ---------- params : jnp.ndarray Flat parameter array [G_0, k_d_0, G_1, k_d_1, ..., eta_s, (G_e)] Returns ------- tuple (G_modes, kd_modes, eta_s, G_e) where G_modes and kd_modes are shape (N,) arrays """ N = self._n_modes G_modes = params[0 : 2 * N : 2] # G_0, G_1, ... kd_modes = params[1 : 2 * N : 2] # k_d_0, k_d_1, ... eta_s = params[2 * N] G_e = params[2 * N + 1] if self._include_permanent else 0.0 return G_modes, kd_modes, eta_s, G_e # ========================================================================= # Fitting # ========================================================================= def _fit(self, x, y, **kwargs): """Fit model to data using protocol-aware optimization. Parameters ---------- x : array-like Independent variable y : array-like Dependent variable **kwargs Additional arguments including test_mode Returns ------- self """ from rheojax.utils.optimization import ( create_least_squares_objective, nlsq_optimize, ) _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else ( getattr(self, "_test_mode", None) if getattr(self, "_test_mode", None) is not None else "flow_curve" ) ) self._test_mode = test_mode x_jax = jnp.asarray(x, dtype=jnp.float64) # Preserve complex dtype for oscillation data (G* = G' + iG''). # create_least_squares_objective handles complex y_data + (N,2) y_pred # by fitting G' and G'' independently (stacked residuals). y_np = np.asarray(y) if np.iscomplexobj(y_np): y_jax = jnp.asarray(y_np, dtype=jnp.complex128) else: y_jax = jnp.asarray(y_np, dtype=jnp.float64) # Store protocol-specific inputs self._gamma_dot_applied = kwargs.get("gamma_dot") self._sigma_applied = kwargs.get("sigma_applied") self._gamma_0 = kwargs.get("gamma_0") self._omega_laos = kwargs.get("omega") # Smart initialization if test_mode == "oscillation": self.initialize_from_saos( np.asarray(x), np.real(np.asarray(y)), np.imag(np.asarray(y)) ) # Define model function for fitting (exclude test_mode from kwargs to avoid duplicates) fwd_kwargs = { k: v for k, v in kwargs.items() if k not in ( "test_mode", "use_log_residuals", "use_jax", "method", "max_iter", "use_multi_start", "n_starts", "perturb_factor", ) } def model_fn(x_fit, params): return self.model_function(x_fit, params, test_mode=test_mode, **fwd_kwargs) # Create objective and optimize objective = create_least_squares_objective( model_fn, x_jax, y_jax, use_log_residuals=kwargs.get( "use_log_residuals", test_mode == "flow_curve" ), ) # ODE-based protocols use diffrax with custom_vjp, incompatible with # NLSQ forward-mode AD. Default to scipy to avoid failed attempt overhead. _ode_protocols = {"creep", "laos"} _default_method = "scipy" if test_mode in _ode_protocols else "auto" result = nlsq_optimize( objective, self.parameters, use_jax=kwargs.get("use_jax", True), method=kwargs.get("method", _default_method), max_iter=kwargs.get("max_iter", 2000), ) self.fitted_ = True self._nlsq_result = result logger.info( f"Fitted VLBMultiNetwork ({self._n_modes} modes): " f"eta_0={self.eta_0:.2e}" ) return self # ========================================================================= # Prediction # ========================================================================= def _predict(self, x, **kwargs): """Predict response using protocol-aware dispatch. Parameters ---------- x : array-like Independent variable **kwargs Additional arguments Returns ------- jnp.ndarray """ _kw_mode = kwargs.get("test_mode") test_mode = ( _kw_mode if _kw_mode is not None else ( getattr(self, "_test_mode", None) if getattr(self, "_test_mode", None) is not None else "flow_curve" ) ) x_jax = jnp.asarray(x, dtype=jnp.float64) if "gamma_dot" in kwargs: self._gamma_dot_applied = kwargs["gamma_dot"] if "sigma_applied" in kwargs: self._sigma_applied = kwargs["sigma_applied"] if "gamma_0" in kwargs: self._gamma_0 = kwargs["gamma_0"] if "omega" in kwargs: self._omega_laos = kwargs["omega"] param_values = [ float(self.parameters.get_value(name)) for name in self.parameters.keys() ] params = jnp.array(param_values) # Remove test_mode from kwargs to avoid duplicate fwd_kwargs = { k: v for k, v in kwargs.items() if k not in ("test_mode", "deformation_mode", "poisson_ratio") } result = self.model_function(x_jax, params, test_mode=test_mode, **fwd_kwargs) # model_function returns (N,2) [G', G''] for oscillation; # convert to complex G* for consistent API if test_mode == "oscillation" and result.ndim == 2 and result.shape[1] == 2: result = result[:, 0] + 1j * result[:, 1] return result # ========================================================================= # Model Function (NLSQ / NumPyro) # =========================================================================
[docs] def model_function(self, X, params, test_mode=None, **kwargs): """NumPyro/BayesianMixin model function. Routes to appropriate prediction based on test_mode. Parameters ---------- X : array-like Independent variable params : array-like Parameter values: [G_0, k_d_0, ..., G_{N-1}, k_d_{N-1}, eta_s, (G_e)] test_mode : str, optional Override stored test mode **kwargs Protocol-specific parameters Returns ------- jnp.ndarray Predicted response """ G_modes, kd_modes, eta_s, G_e = self._unpack_mode_params(params) mode = ( test_mode if test_mode is not None else ( getattr(self, "_test_mode", None) if getattr(self, "_test_mode", None) is not None else "flow_curve" ) ) # Use sentinel pattern to avoid swallowing falsy values (e.g. gamma_dot=0.0) _gd = kwargs.get("gamma_dot", _MISSING) gamma_dot = ( _gd if _gd is not _MISSING else getattr(self, "_gamma_dot_applied", None) ) _sa = kwargs.get("sigma_applied", _MISSING) sigma_applied = ( _sa if _sa is not _MISSING else getattr(self, "_sigma_applied", None) ) _g0 = kwargs.get("gamma_0", _MISSING) gamma_0 = _g0 if _g0 is not _MISSING else getattr(self, "_gamma_0", None) _om = kwargs.get("omega", _MISSING) omega = _om if _om is not _MISSING else getattr(self, "_omega_laos", None) X_jax = jnp.asarray(X, dtype=jnp.float64) if mode in ["flow_curve", "steady_shear", "rotation"]: return self._predict_flow_curve_internal(X_jax, G_modes, kd_modes, eta_s) elif mode == "oscillation": G_prime, G_double_prime = vlb_multi_saos_vec( X_jax, G_modes, kd_modes, G_e, eta_s ) return jnp.column_stack([G_prime, G_double_prime]) elif mode == "startup": if gamma_dot is None: raise ValueError("startup mode requires gamma_dot") return vlb_multi_startup_stress_vec( X_jax, gamma_dot, G_modes, kd_modes, G_e, eta_s ) elif mode == "relaxation": return vlb_multi_relaxation_vec(X_jax, G_modes, kd_modes, G_e) elif mode == "creep": if sigma_applied is None: raise ValueError("creep mode requires sigma_applied") return self._simulate_creep_internal( X_jax, G_modes, kd_modes, eta_s, G_e, sigma_applied ) elif mode == "laos": if gamma_0 is None or omega is None: raise ValueError("LAOS mode requires gamma_0 and omega") _, stress = self._simulate_laos_internal( X_jax, G_modes, kd_modes, eta_s, gamma_0, omega ) return stress else: logger.warning(f"Unknown test_mode '{mode}', defaulting to flow_curve") return self._predict_flow_curve_internal(X_jax, G_modes, kd_modes, eta_s)
# ========================================================================= # Analytical Predictions # ========================================================================= def _predict_flow_curve_internal( self, gamma_dot: jnp.ndarray, G_modes: jnp.ndarray, kd_modes: jnp.ndarray, eta_s: float, ) -> jnp.ndarray: """Analytical steady shear stress for multi-network VLB. sigma = (sum G_i / k_d_i + eta_s) * gamma_dot = eta_0 * gamma_dot Newtonian for constant k_d (no shear thinning). """ eta_0 = vlb_multi_steady_viscosity(G_modes, kd_modes, eta_s) return eta_0 * gamma_dot # ========================================================================= # Creep Internal (ODE-based for general case) # ========================================================================= def _simulate_creep_internal( self, t: jnp.ndarray, G_modes: jnp.ndarray, kd_modes: jnp.ndarray, eta_s: float, G_e: float, sigma_0: float, ) -> jnp.ndarray: """Internal creep simulation. For N=1 + permanent (SLS), uses analytical solution. For general case, uses ODE integration. Returns strain array gamma(t). """ N = self._n_modes has_perm = G_e > 1e-30 has_solvent = eta_s > 1e-30 # Special case: 1 mode + permanent, no solvent -> SLS analytical if N == 1 and has_perm and not has_solvent: J = vlb_creep_compliance_dual_vec(t, G_modes[0], kd_modes[0], G_e) return sigma_0 * J # General case: ODE integration # State: [gamma, mu_xy_0, mu_xy_1, ..., mu_xy_{N-1}] # For each mode i: d(mu_xy_i)/dt = -k_d_i * mu_xy_i + gamma_dot * 1 # (mu_yy = 1 since creep is small strain initially) # Stress balance: sigma_0 = sum G_i * mu_xy_i + G_e * gamma + eta_s * gdot def ode_fn(ti, yi, args): gamma = yi[0] mu_xy = yi[1:] # Stress balance -> solve for gamma_dot elastic = jnp.sum(args["G_modes"] * mu_xy) + args["G_e"] * gamma remaining = args["sigma_0"] - elastic # If eta_s > 0: gdot = remaining / eta_s # If eta_s = 0: gdot must come from mode balance directly gdot = jnp.where( args["eta_s"] > 1e-30, remaining / args["eta_s"], remaining / 1e-10, # Fallback, shouldn't reach here if properly set ) dgamma = gdot dmu_xy = -args["kd_modes"] * mu_xy + gdot * jnp.ones_like(mu_xy) return jnp.concatenate([jnp.array([dgamma]), dmu_xy]) # Initial conditions: elastic jump G_total = jnp.sum(G_modes) + G_e gamma_0 = sigma_0 / jnp.maximum(G_total, 1e-30) mu_xy_0 = sigma_0 / jnp.maximum(G_total, 1e-30) * jnp.ones(N) y0 = jnp.concatenate([jnp.array([gamma_0]), mu_xy_0]) args = { "G_modes": G_modes, "kd_modes": kd_modes, "eta_s": jnp.maximum(eta_s, 1e-10), # Small regularization "G_e": G_e, "sigma_0": sigma_0, } # Wrap with checkpoint to reduce VJP memory during NUTS reverse-mode AD term = diffrax.ODETerm(jax.checkpoint(ode_fn)) solver = diffrax.Tsit5() controller = diffrax.PIDController(rtol=1e-6, atol=1e-8) t0 = t[0] t1 = t[-1] dt0 = (t1 - t0) / max(len(t), 1000) saveat = diffrax.SaveAt(ts=t) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, stepsize_controller=controller, max_steps=500_000, throw=False, ) gamma = sol.ys[:, 0] gamma = jnp.where( sol.result == diffrax.RESULTS.successful, gamma, jnp.nan * jnp.ones_like(gamma), ) return gamma # ========================================================================= # LAOS Internal (ODE-based) # ========================================================================= def _simulate_laos_internal( self, t: jnp.ndarray, G_modes: jnp.ndarray, kd_modes: jnp.ndarray, eta_s: float, gamma_0: float, omega: float, ) -> tuple[jnp.ndarray, jnp.ndarray]: """Internal LAOS simulation for multi-network. State vector: [mu_xy_0, ..., mu_xy_{N-1}, mu_xx_0, ..., mu_xx_{N-1}] (mu_yy stays at 1 for each mode since yy decouples) Returns (strain, stress) arrays. """ N = self._n_modes def ode_fn(ti, yi, args): mu_xy = yi[:N] mu_xx = yi[N:] gdot = args["gamma_0"] * args["omega"] * jnp.cos(args["omega"] * ti) # mu_yy = 1 always (decoupled, relaxes to 1 independently) dmu_xy = -args["kd_modes"] * mu_xy + gdot * jnp.ones(N) dmu_xx = args["kd_modes"] * (1.0 - mu_xx) + 2.0 * gdot * mu_xy return jnp.concatenate([dmu_xy, dmu_xx]) args = { "gamma_0": gamma_0, "omega": omega, "kd_modes": kd_modes, } # Initial state: equilibrium y0 = jnp.concatenate( [ jnp.zeros(N), # mu_xy = 0 jnp.ones(N), # mu_xx = 1 ] ) # Wrap with checkpoint to reduce VJP memory during NUTS reverse-mode AD term = diffrax.ODETerm(jax.checkpoint(ode_fn)) solver = diffrax.Tsit5() controller = diffrax.PIDController(rtol=1e-6, atol=1e-8) t0 = t[0] t1 = t[-1] dt0 = (t1 - t0) / max(len(t), 1000) saveat = diffrax.SaveAt(ts=t) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, args=args, saveat=saveat, stepsize_controller=controller, max_steps=500_000, throw=False, ) mu_xy_all = sol.ys[:, :N] # shape (T, N) gamma_dot_t = gamma_0 * omega * jnp.cos(omega * t) strain = gamma_0 * jnp.sin(omega * t) stress = jnp.sum(G_modes[None, :] * mu_xy_all, axis=1) + eta_s * gamma_dot_t stress = jnp.where( sol.result == diffrax.RESULTS.successful, stress, jnp.nan * jnp.ones_like(stress), ) return strain, stress # ========================================================================= # Public Methods # =========================================================================
[docs] def predict_saos( self, omega: np.ndarray, return_components: bool = True, ) -> tuple[np.ndarray, np.ndarray] | np.ndarray: """Predict multi-mode SAOS moduli. Parameters ---------- omega : np.ndarray Angular frequency (rad/s) return_components : bool, default True If True, return (G', G'') Returns ------- tuple or np.ndarray (G', G'') or |G*| """ w = jnp.asarray(omega, dtype=jnp.float64) G_np, kd_np = self._get_mode_arrays_numpy() G_modes = jnp.asarray(G_np, dtype=jnp.float64) kd_modes = jnp.asarray(kd_np, dtype=jnp.float64) G_p, G_pp = vlb_multi_saos_vec(w, G_modes, kd_modes, self.G_e, self.eta_s) if return_components: return np.asarray(G_p), np.asarray(G_pp) return np.asarray(jnp.sqrt(G_p**2 + G_pp**2 + 1e-30))
[docs] def get_relaxation_spectrum(self) -> list[tuple[float, float]]: """Get relaxation spectrum as list of (G, tau) pairs. Returns ------- list[tuple[float, float]] [(G_i, 1/k_d_i)] for each transient mode """ G_np, kd_np = self._get_mode_arrays_numpy() return [(float(G_np[i]), 1.0 / float(kd_np[i])) for i in range(self._n_modes)]
[docs] def __repr__(self) -> str: """Return string representation.""" perm_str = "+permanent" if self._include_permanent else "" return ( f"VLBMultiNetwork(n_modes={self._n_modes}{perm_str}, " f"eta_0={self.eta_0:.2e})" )