"""Multi-mode Giesekus nonlinear viscoelastic model.
This module implements `GiesekusMultiMode`, an extension of the single-mode
Giesekus model with N parallel relaxation modes.
Multi-Mode Superposition
------------------------
The total stress is the sum of N polymer modes plus a Newtonian solvent::
σ_total = σ_s + Σᵢ σ_p,i
where each mode i has its own parameters (η_p,i, λ_i, α_i).
For SAOS (linear regime)::
G'(ω) = Σᵢ G_i·(ωλ_i)² / (1 + (ωλ_i)²)
G''(ω) = Σᵢ G_i·(ωλ_i) / (1 + (ωλ_i)²) + η_s·ω
where G_i = η_p,i / λ_i.
Example
-------
>>> from rheojax.models.giesekus import GiesekusMultiMode
>>> import numpy as np
>>>
>>> # Create 3-mode model
>>> model = GiesekusMultiMode(n_modes=3)
>>>
>>> # Set mode parameters
>>> model.set_mode_params(0, eta_p=100.0, lambda_1=10.0, alpha=0.3)
>>> model.set_mode_params(1, eta_p=50.0, lambda_1=1.0, alpha=0.2)
>>> model.set_mode_params(2, eta_p=20.0, lambda_1=0.1, alpha=0.1)
>>>
>>> # Predict SAOS
>>> omega = np.logspace(-2, 2, 50)
>>> G_prime, G_double_prime = model.predict_saos(omega)
References
----------
- Giesekus, H. (1982). J. Non-Newtonian Fluid Mech. 11, 69-109.
- Bird, R.B. et al. (1987). Dynamics of Polymeric Liquids, Vol. 1.
"""
from __future__ import annotations
import logging
import numpy as np
from rheojax.core.base import BaseModel
from rheojax.core.inventory import Protocol
from rheojax.core.jax_config import lazy_import, safe_import_jax
diffrax = lazy_import("diffrax")
from rheojax.core.parameters import ParameterSet
from rheojax.core.registry import ModelRegistry
from rheojax.core.test_modes import DeformationMode
from rheojax.models.giesekus._kernels import (
giesekus_multimode_ode_rhs,
giesekus_multimode_saos_moduli,
giesekus_steady_shear_stress_vec,
)
jax, jnp = safe_import_jax()
logger = logging.getLogger(__name__)
[docs]
@ModelRegistry.register(
"giesekus_multi",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
@ModelRegistry.register(
"giesekus_multimode",
protocols=[
Protocol.FLOW_CURVE,
Protocol.OSCILLATION,
Protocol.STARTUP,
],
deformation_modes=[
DeformationMode.SHEAR,
DeformationMode.TENSION,
DeformationMode.BENDING,
DeformationMode.COMPRESSION,
],
)
class GiesekusMultiMode(BaseModel):
"""Multi-mode Giesekus nonlinear viscoelastic model.
This model extends the single-mode Giesekus with N parallel Maxwell
modes, each with its own relaxation time, viscosity, and mobility
factor.
The constitutive equation for each mode is::
τᵢ + λᵢ∇̂τᵢ + (αᵢλᵢ/η_p,i)τᵢ·τᵢ = 2η_p,i D
Total stress: σ = η_s·γ̇ + Σᵢ τᵢ
Parameters
----------
n_modes : int
Number of relaxation modes (N ≥ 1). Default: 3
Attributes
----------
parameters : ParameterSet
Model parameters including per-mode values
fitted_ : bool
Whether the model has been fitted
Notes
-----
The multi-mode model is particularly useful for:
1. Fitting broad SAOS spectra that single-mode cannot capture
2. Representing polydisperse polymer systems
3. Capturing multiple relaxation processes
Each mode can have different α_i values, allowing different
molecular weight fractions to exhibit different anisotropy.
See Also
--------
GiesekusSingleMode : Single relaxation time variant
GeneralizedMaxwell : Linear multi-mode Maxwell model
"""
[docs]
def __init__(self, n_modes: int = 3):
"""Initialize multi-mode Giesekus model.
Parameters
----------
n_modes : int, default 3
Number of relaxation modes (must be ≥ 1)
Raises
------
ValueError
If n_modes < 1
"""
super().__init__()
if n_modes < 1:
raise ValueError(f"n_modes must be ≥ 1, got {n_modes}")
self._n_modes = n_modes
self._test_mode = None
self._setup_parameters()
# Protocol-specific inputs
self._gamma_dot_applied: float | None = None
self._sigma_applied: float | None = None
self._gamma_0: float | None = None
self._omega_laos: float | None = None
# Internal storage
self._trajectory: dict[str, np.ndarray] | None = None
def _setup_parameters(self):
"""Initialize ParameterSet with multi-mode parameters.
Creates parameters:
- eta_s: Shared solvent viscosity
- eta_p_i: Polymer viscosity for mode i
- lambda_i: Relaxation time for mode i
- alpha_i: Mobility factor for mode i
"""
self.parameters = ParameterSet()
# Shared solvent viscosity
self.parameters.add(
name="eta_s",
value=0.0,
bounds=(0.0, 1e4),
units="Pa·s",
description="Solvent viscosity (Newtonian contribution)",
)
# Per-mode parameters with logarithmically spaced defaults
for i in range(self._n_modes):
# Viscosity (decreasing with mode number)
eta_default = 100.0 / (i + 1)
self.parameters.add(
name=f"eta_p_{i}",
value=eta_default,
bounds=(1e-6, 1e6),
units="Pa·s",
description=f"Polymer viscosity for mode {i}",
)
# Relaxation time (logarithmically spaced)
lambda_default = 10.0 ** (1 - i) # 10, 1, 0.1, ...
self.parameters.add(
name=f"lambda_{i}",
value=lambda_default,
bounds=(1e-8, 1e4),
units="s",
description=f"Relaxation time for mode {i}",
)
# Mobility factor (same default for all modes)
self.parameters.add(
name=f"alpha_{i}",
value=0.3,
bounds=(0.0, 0.5),
units="dimensionless",
description=f"Mobility factor for mode {i}",
)
# =========================================================================
# Properties
# =========================================================================
@property
def n_modes(self) -> int:
"""Get number of modes."""
return self._n_modes
@property
def eta_s(self) -> float:
"""Get solvent viscosity η_s (Pa·s)."""
val = self.parameters.get_value("eta_s")
assert val is not None
return float(val)
@property
def eta_0(self) -> float:
"""Get zero-shear viscosity η₀ = η_s + Σ η_p,i (Pa·s)."""
eta_p_total = sum(
float(self.parameters.get_value(f"eta_p_{i}") or 0.0)
for i in range(self._n_modes)
)
return self.eta_s + eta_p_total
[docs]
def get_mode_params(self, mode_idx: int) -> dict[str, float]:
"""Get parameters for a specific mode.
Parameters
----------
mode_idx : int
Mode index (0 to n_modes-1)
Returns
-------
dict[str, float]
Dictionary with keys 'eta_p', 'lambda_1', 'alpha'
"""
if mode_idx < 0 or mode_idx >= self._n_modes:
raise IndexError(f"Mode index {mode_idx} out of range [0, {self._n_modes})")
return {
"eta_p": float(self.parameters.get_value(f"eta_p_{mode_idx}") or 0.0),
"lambda_1": float(self.parameters.get_value(f"lambda_{mode_idx}") or 0.0),
"alpha": float(self.parameters.get_value(f"alpha_{mode_idx}") or 0.0),
}
[docs]
def set_mode_params(
self,
mode_idx: int,
eta_p: float | None = None,
lambda_1: float | None = None,
alpha: float | None = None,
) -> None:
"""Set parameters for a specific mode.
Parameters
----------
mode_idx : int
Mode index (0 to n_modes-1)
eta_p : float, optional
Polymer viscosity (Pa·s)
lambda_1 : float, optional
Relaxation time (s)
alpha : float, optional
Mobility factor (0 ≤ α ≤ 0.5)
"""
if mode_idx < 0 or mode_idx >= self._n_modes:
raise IndexError(f"Mode index {mode_idx} out of range [0, {self._n_modes})")
if eta_p is not None:
self.parameters.set_value(f"eta_p_{mode_idx}", eta_p)
if lambda_1 is not None:
self.parameters.set_value(f"lambda_{mode_idx}", lambda_1)
if alpha is not None:
self.parameters.set_value(f"alpha_{mode_idx}", alpha)
[docs]
def get_mode_arrays(self) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Get all mode parameters as JAX arrays.
Uses vectorized extraction via get_values() + slicing for ~3x speedup
over N individual get_value() calls.
Returns
-------
tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
(eta_p_modes, lambda_modes, alpha_modes), each shape (n_modes,)
"""
# Get all parameter values at once (single dict lookup traversal)
all_values = self.parameters.get_values() # shape: (1 + 3*n_modes,)
# Parameter layout: [eta_s, eta_p_0, lambda_0, alpha_0, eta_p_1, ...]
# Extract mode arrays using NumPy slicing (faster than list comprehension)
# eta_p: indices 1, 4, 7, ... (stride 3 starting from 1)
# lambda: indices 2, 5, 8, ... (stride 3 starting from 2)
# alpha: indices 3, 6, 9, ... (stride 3 starting from 3)
eta_p = jnp.asarray(all_values[1::3][: self._n_modes], dtype=jnp.float64)
lambda_vals = jnp.asarray(all_values[2::3][: self._n_modes], dtype=jnp.float64)
alpha = jnp.asarray(all_values[3::3][: self._n_modes], dtype=jnp.float64)
return eta_p, lambda_vals, alpha
# =========================================================================
# Core Interface Methods
# =========================================================================
def _fit(self, x, y, **kwargs):
"""Fit model to data.
Parameters
----------
x : array-like
Independent variable
y : array-like
Dependent variable
**kwargs
Additional arguments including test_mode
Returns
-------
self
"""
from rheojax.utils.optimization import (
create_least_squares_objective,
make_fd_differentiable,
nlsq_optimize,
)
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "oscillation"
)
)
self._test_mode = test_mode
self._gamma_dot_applied = kwargs.get("gamma_dot")
self._sigma_applied = kwargs.get("sigma_applied")
x_jax = jnp.asarray(x, dtype=jnp.float64)
# Handle complex G* for oscillation: split into [G', G''] columns
# to match model_function's (N, 2) real output format
if test_mode == "oscillation" and np.iscomplexobj(y):
y_real = np.asarray(y)
y_jax = jnp.column_stack(
[
jnp.asarray(y_real.real, dtype=jnp.float64),
jnp.asarray(y_real.imag, dtype=jnp.float64),
]
)
else:
y_jax = jnp.asarray(y, dtype=jnp.float64)
# Define model function for fitting (follows ParameterSet ordering)
def model_fn(x_fit, params):
"""Stateless model function for optimization."""
return self.model_function(x_fit, params, test_mode=test_mode)
# ODE-based protocols use diffrax (custom_vjp), incompatible with
# NLSQ's jacfwd. On GPU, wrap with FD-JVP (parallel perturbations).
# On CPU, vmap doesn't parallelize — scipy sequential FD is faster.
_ode_protocols = {"relaxation", "startup", "creep", "laos"}
_is_ode = test_mode in _ode_protocols
_on_gpu = jax.default_backend() != "cpu"
_fit_fn = make_fd_differentiable(model_fn) if _is_ode and _on_gpu else model_fn
# Create objective and optimize using ParameterSet
objective = create_least_squares_objective(
_fit_fn,
x_jax,
y_jax,
use_log_residuals=kwargs.get(
"use_log_residuals", test_mode == "flow_curve"
),
)
# On CPU, ODE protocols use scipy (sequential FD is faster than vmap).
# On GPU, the FD-JVP wrapper makes NLSQ work → use 'auto'.
_method = kwargs.get("method", "auto")
if _is_ode and not _on_gpu and _method in ("auto", "nlsq", "trf", "lm"):
_method = "scipy"
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=_method,
max_iter=kwargs.get("max_iter", 2000),
)
self.fitted_ = True
self._nlsq_result = result
logger.info(f"Fitted {self._n_modes}-mode Giesekus: η₀={self.eta_0:.2e} Pa·s")
return self
def _predict(self, x, **kwargs):
"""Predict response.
Parameters
----------
x : array-like
Independent variable
**kwargs
Additional arguments including test_mode
Returns
-------
jnp.ndarray
Predicted response
"""
_kw_mode = kwargs.get("test_mode")
test_mode = (
_kw_mode
if _kw_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "oscillation"
)
)
x_jax = jnp.asarray(x, dtype=jnp.float64)
param_names = list(self.parameters.keys())
params = jnp.array(
[self.parameters.get_value(n) for n in param_names], dtype=jnp.float64
)
# Forward kwargs (gamma_dot, sigma_applied, etc.) to model_function
predict_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ("test_mode", "deformation_mode", "poisson_ratio")
}
result = self.model_function(
x_jax, params, test_mode=test_mode, **predict_kwargs
)
# model_function returns (N,2) real [G', G''] for oscillation;
# convert to complex G* to match the established convention
if test_mode == "oscillation" and result.ndim == 2 and result.shape[1] == 2:
result = result[:, 0] + 1j * result[:, 1]
return result
[docs]
def model_function(self, X, params, test_mode=None, **kwargs):
"""NumPyro/BayesianMixin model function.
Parameters
----------
X : array-like
Independent variable
params : array-like
All parameter values in order
test_mode : str, optional
Override stored test mode
**kwargs : dict
Protocol-specific arguments (gamma_dot, sigma_applied, etc.)
Returns
-------
jnp.ndarray
Predicted response
"""
mode = (
test_mode
if test_mode is not None
else (
getattr(self, "_test_mode", None)
if getattr(self, "_test_mode", None) is not None
else "oscillation"
)
)
X_jax = jnp.asarray(X, dtype=jnp.float64)
# Parse parameters - interleaved order:
# [eta_s, eta_p_0, lambda_0, alpha_0, eta_p_1, lambda_1, alpha_1, ...]
eta_s = params[0]
# Use stride-3 slicing matching _setup_parameters() order
eta_p_modes = params[1::3][: self._n_modes]
lambda_modes = params[2::3][: self._n_modes]
alpha_modes = params[3::3][: self._n_modes]
if mode == "oscillation":
G_prime, G_double_prime = self._predict_saos_internal(
X_jax, eta_p_modes, lambda_modes, eta_s
)
# Return components for fitting to [G', G''] data
return jnp.column_stack([G_prime, G_double_prime])
elif mode in ["flow_curve", "steady_shear", "rotation"]:
return self._predict_flow_curve_internal(
X_jax, eta_p_modes, lambda_modes, alpha_modes, eta_s
)
elif mode == "startup":
# Get gamma_dot from kwargs or instance attribute
# Use `if in kwargs` to avoid swallowing valid 0.0 values
gamma_dot = (
kwargs["gamma_dot"]
if "gamma_dot" in kwargs
else self._gamma_dot_applied
)
if gamma_dot is None:
raise ValueError("startup mode requires gamma_dot")
return self._simulate_startup_internal(
X_jax, eta_p_modes, lambda_modes, alpha_modes, gamma_dot
)
else:
logger.warning(f"Unknown test_mode '{mode}', using oscillation")
G_prime, G_double_prime = self._predict_saos_internal(
X_jax, eta_p_modes, lambda_modes, eta_s
)
return jnp.column_stack([G_prime, G_double_prime])
# =========================================================================
# Analytical Predictions
# =========================================================================
def _predict_saos_internal(
self,
omega: jnp.ndarray,
eta_p_modes: jnp.ndarray,
lambda_modes: jnp.ndarray,
eta_s: float,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Internal SAOS prediction."""
# Vectorize over frequency
def saos_at_omega(w):
return giesekus_multimode_saos_moduli(w, eta_p_modes, lambda_modes, eta_s)
G_prime, G_double_prime = jax.vmap(saos_at_omega)(omega)
return G_prime, G_double_prime
[docs]
def predict_saos(
self,
omega: np.ndarray,
return_components: bool = True,
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
"""Predict SAOS storage and loss moduli.
Parameters
----------
omega : np.ndarray
Angular frequency array (rad/s)
return_components : bool, default True
If True, return (G', G'')
Returns
-------
tuple or np.ndarray
(G', G'') if return_components=True, else |G*|
"""
omega_jax = jnp.asarray(omega, dtype=jnp.float64)
eta_p_modes, lambda_modes, _ = self.get_mode_arrays()
G_prime, G_double_prime = self._predict_saos_internal(
omega_jax, eta_p_modes, lambda_modes, self.eta_s
)
if return_components:
return np.asarray(G_prime), np.asarray(G_double_prime)
G_star_mag = jnp.sqrt(jnp.maximum(G_prime**2 + G_double_prime**2, 1e-30))
return np.asarray(G_star_mag)
def _predict_flow_curve_internal(
self,
gamma_dot: jnp.ndarray,
eta_p_modes: jnp.ndarray,
lambda_modes: jnp.ndarray,
alpha_modes: jnp.ndarray,
eta_s: float,
) -> jnp.ndarray:
"""Internal flow curve prediction (steady shear).
For multi-mode Giesekus, we sum the contributions from each mode.
This is approximate for nonlinear superposition.
"""
# Sum contributions from each mode
# Note: This is exact only in linear regime; nonlinear coupling is neglected
def mode_contribution(i):
eta_p = eta_p_modes[i]
lambda_1 = lambda_modes[i]
alpha = alpha_modes[i]
return giesekus_steady_shear_stress_vec(
gamma_dot, eta_p, lambda_1, alpha, 0.0
)
stress_contributions = jax.vmap(mode_contribution)(jnp.arange(len(eta_p_modes)))
polymer_stress = jnp.sum(stress_contributions, axis=0)
# Add solvent contribution
total_stress = polymer_stress + eta_s * gamma_dot
return total_stress
[docs]
def predict_flow_curve(
self,
gamma_dot: np.ndarray,
return_components: bool = False,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""Predict steady shear stress.
Parameters
----------
gamma_dot : np.ndarray
Shear rate array (1/s)
return_components : bool, default False
If True, return (sigma, eta)
Returns
-------
np.ndarray or tuple
Shear stress σ (Pa), or (σ, η) if return_components=True
"""
gd = jnp.asarray(gamma_dot, dtype=jnp.float64)
eta_p_modes, lambda_modes, alpha_modes = self.get_mode_arrays()
sigma = self._predict_flow_curve_internal(
gd, eta_p_modes, lambda_modes, alpha_modes, self.eta_s
)
if return_components:
eta = sigma / jnp.maximum(gd, 1e-20)
return np.asarray(sigma), np.asarray(eta)
return np.asarray(sigma)
# =========================================================================
# ODE-Based Simulations
# =========================================================================
def _simulate_startup_internal(
self,
t: jnp.ndarray,
eta_p_modes: jnp.ndarray,
lambda_modes: jnp.ndarray,
alpha_modes: jnp.ndarray,
gamma_dot: float,
) -> jnp.ndarray:
"""Internal startup simulation."""
# State: [τ_xx^0, τ_yy^0, τ_xy^0, τ_zz^0, ..., τ_xx^N-1, ...]
# Total size: 4 * n_modes
y0 = jnp.zeros(4 * self._n_modes, dtype=jnp.float64)
def ode_fn(ti, yi, args):
return giesekus_multimode_ode_rhs(
ti,
yi,
args["gamma_dot"],
args["eta_p"],
args["lambda"],
args["alpha"],
)
args = {
"gamma_dot": gamma_dot,
"eta_p": eta_p_modes,
"lambda": lambda_modes,
"alpha": alpha_modes,
}
term = diffrax.ODETerm(jax.checkpoint(ode_fn))
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t[0]
t1 = t[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=100_000,
throw=False,
)
# Sum τ_xy from all modes (index 2 in each mode's 4-element block)
# Use vectorized index selection instead of a Python loop
mode_xy_indices = jnp.arange(self._n_modes) * 4 + 2 # [2, 6, 10, ...]
tau_xy_total = jnp.sum(sol.ys[:, mode_xy_indices], axis=1)
# Add solvent contribution
total_stress = tau_xy_total + self.eta_s * gamma_dot
# Handle solver failures
total_stress = jnp.where(
sol.result == diffrax.RESULTS.successful,
total_stress,
jnp.nan * jnp.ones_like(total_stress),
)
return total_stress
[docs]
def simulate_startup(
self,
t: np.ndarray,
gamma_dot: float,
return_full: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
"""Simulate startup flow at constant shear rate.
Parameters
----------
t : np.ndarray
Time array (s)
gamma_dot : float
Applied shear rate (1/s)
return_full : bool, default False
If True, return per-mode stresses
Returns
-------
np.ndarray or dict
Total shear stress, or dict with per-mode stresses
"""
t_jax = jnp.asarray(t, dtype=jnp.float64)
eta_p_modes, lambda_modes, alpha_modes = self.get_mode_arrays()
# State: 4 * n_modes
y0 = jnp.zeros(4 * self._n_modes, dtype=jnp.float64)
def ode_fn(ti, yi, args):
return giesekus_multimode_ode_rhs(
ti,
yi,
args["gamma_dot"],
args["eta_p"],
args["lambda"],
args["alpha"],
)
args = {
"gamma_dot": gamma_dot,
"eta_p": eta_p_modes,
"lambda": lambda_modes,
"alpha": alpha_modes,
}
term = diffrax.ODETerm(jax.checkpoint(ode_fn))
solver = diffrax.Tsit5()
stepsize_controller = diffrax.PIDController(rtol=1e-6, atol=1e-8)
t0 = t_jax[0]
t1 = t_jax[-1]
dt0 = (t1 - t0) / max(len(t), 1000)
saveat = diffrax.SaveAt(ts=t_jax)
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
args=args,
saveat=saveat,
stepsize_controller=stepsize_controller,
max_steps=100_000,
throw=False,
)
if return_full:
result = {"t": np.asarray(t_jax)}
tau_xy_total = np.zeros(len(t))
for i in range(self._n_modes):
tau_xy_i = np.asarray(sol.ys[:, 4 * i + 2])
result[f"tau_xy_{i}"] = tau_xy_i
tau_xy_total += tau_xy_i
tau_xy_total_final = tau_xy_total + self.eta_s * gamma_dot
# Handle solver failures
tau_xy_total_final = np.where(
sol.result == diffrax.RESULTS.successful,
tau_xy_total_final,
np.nan * np.ones_like(tau_xy_total_final),
)
result["tau_xy_total"] = tau_xy_total_final
return result
# Sum τ_xy from all modes (vectorized over mode index)
mode_xy_indices = jnp.arange(self._n_modes) * 4 + 2 # [2, 6, 10, ...]
tau_xy_total = np.asarray(jnp.sum(sol.ys[:, mode_xy_indices], axis=1))
total_stress = tau_xy_total + self.eta_s * gamma_dot
# Handle solver failures
total_stress = np.where(
sol.result == diffrax.RESULTS.successful,
total_stress,
np.nan * np.ones_like(total_stress),
)
return total_stress
# =========================================================================
# Analysis Methods
# =========================================================================
[docs]
def get_relaxation_spectrum(self) -> tuple[np.ndarray, np.ndarray]:
"""Get discrete relaxation spectrum.
Returns
-------
tuple[np.ndarray, np.ndarray]
(lambda_i, G_i) where G_i = η_p,i / λ_i
"""
eta_p_modes, lambda_modes, _ = self.get_mode_arrays()
G_modes = eta_p_modes / lambda_modes
# Sort by relaxation time (descending)
sort_idx = jnp.argsort(lambda_modes)[::-1]
return np.asarray(lambda_modes[sort_idx]), np.asarray(G_modes[sort_idx])
[docs]
def get_continuous_spectrum(
self,
t: np.ndarray | None = None,
n_points: int = 200,
) -> tuple[np.ndarray, np.ndarray]:
"""Get continuous relaxation modulus G(t).
Parameters
----------
t : np.ndarray, optional
Time array
n_points : int, default 200
Number of points if t not provided
Returns
-------
tuple[np.ndarray, np.ndarray]
(t, G(t))
"""
eta_p_modes, lambda_modes, _ = self.get_mode_arrays()
if t is None:
lambda_min = float(jnp.min(lambda_modes))
lambda_max = float(jnp.max(lambda_modes))
t = np.logspace(
np.log10(0.01 * lambda_min),
np.log10(100 * lambda_max),
n_points,
)
t_jax = jnp.asarray(t, dtype=jnp.float64)
# G(t) = Σ G_i exp(-t/λ_i)
G_modes = eta_p_modes / lambda_modes
def G_at_t(t_val):
return jnp.sum(G_modes * jnp.exp(-t_val / lambda_modes))
G_t = jax.vmap(G_at_t)(t_jax)
return np.asarray(t), np.asarray(G_t)
# =========================================================================
# String Representation
# =========================================================================
[docs]
def __repr__(self) -> str:
"""Return string representation."""
return (
f"{self.__class__.__name__}("
f"n_modes={self._n_modes}, "
f"η₀={self.eta_0:.2e} Pa·s, "
f"η_s={self.eta_s:.2e} Pa·s)"
)