Numerical Implementation¶
This page covers computational strategies for DMTA data fitting in RheoJAX, including JAX/JIT considerations, parameter bounds, convergence criteria, and memory management.
JIT Compilation Strategy¶
All RheoJAX models use JAX’s @jit decorator for GPU-accelerated fitting.
DMTA-specific considerations:
First fit is slow (~5–30 s): JIT compiles the model function, residual computation, and Jacobian. Subsequent fits with the same model and data shape reuse the compiled code (< 1 s).
Shape changes trigger recompilation: Fitting 206-point and 481-point datasets back-to-back compiles twice. Group datasets by size when possible.
Float64 is mandatory: Use
safe_import_jax()(notimport jax) to ensure 64-bit precision. 32-bit arithmetic causes convergence failures with the wide dynamic range of DMTA data (0.1–10 000 MPa).
from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()
Parameter Bounds for Tensile Data¶
DMTA data is in tensile modulus space (Pa), which is typically 2–3 \(\times\) larger than shear modulus space. Real polymer DMTA data spans:
Material State |
\(E'\) Range |
\(G'\) Range |
Scale Factor |
|---|---|---|---|
Glassy plateau |
1–10 GPa |
0.4–3.7 GPa |
\(2(1+\nu) \approx 2.7\) |
Rubbery plateau |
0.1–10 MPa |
0.03–3.3 MPa |
\(2(1+\nu) = 3.0\) |
Glass transition |
1 MPa – 5 GPa |
0.3 MPa – 1.9 GPa |
varies with \(\nu(\omega)\) |
Default bounds handling:
GeneralizedMaxwell(modulus_type='tensile')automatically uses wider bounds: \(E_i \in [10^{-3}, 10^{12}]\) Pa (vs \(G_i \in [10^{-3}, 10^{9}]\) for shear).Other models (Zener, FZSS, etc.) with
deformation_mode='tension'convert E* → G* at thefit()boundary, so their internal bounds (in G-space) are sufficient.
When bounds errors occur:
If a model raises ValueError: Value ... violates constraints, the fitted
value exceeds the parameter bounds. Fix by widening:
# Widen bounds for a specific parameter
param = model.parameters["G0"]
param.bounds = (param.bounds[0], 1e12)
for c in param.constraints:
if c.type == "bounds":
c.min_value, c.max_value = param.bounds
Element Minimisation and Mode Reduction¶
The GeneralizedMaxwell model supports automatic mode reduction via
optimization_factor. This creates internal sub-models with default
bounds for each candidate mode count.
Warning
When fitting real DMTA data with modulus_type='tensile', set
optimization_factor=None to avoid element minimisation (which
uses default bounds internally):
gmm = GeneralizedMaxwell(n_modes=10, modulus_type='tensile')
gmm.fit(omega, E_star, test_mode='oscillation',
optimization_factor=None)
Alternatively, use modulus_type='shear' with
deformation_mode='tension' to fit in G-space where default bounds
and element minimisation work correctly.
Convergence Criteria¶
NLSQ convergence for DMTA data:
Parameter |
Default |
DMTA Recommendation |
|---|---|---|
|
200 |
500–1000 (broad master curves need more iterations) |
|
\(10^{-8}\) |
\(10^{-8}\) (sufficient) |
|
\(10^{-8}\) |
\(10^{-8}\) (sufficient) |
|
3 |
10–30 (match decades of data) |
Rule of thumb: Use approximately 1 Prony mode per 3 decades of
frequency data. For a master curve spanning 20 decades, n_modes=7 is
a minimum; n_modes=15--20 gives excellent fits.
Bayesian Inference (NUTS) Settings¶
NUTS sampling for DMTA data follows the standard NLSQ → NUTS pipeline (see DMTA Workflows Workflow 3 for a complete example):
Setting |
FAST_MODE |
Production |
Notes |
|---|---|---|---|
|
50 |
200–1000 |
More warmup for multi-modal posteriors |
|
100 |
500–2000 |
Check ESS > 400 |
|
1 |
4 |
Multi-chain for R-hat diagnostics |
|
0.8 |
0.8–0.95 |
Increase if divergences > 0 |
Memory Management¶
Sequential DMTA model fits can exhaust memory (especially on 16 GB machines). Follow this pattern between fits:
import gc
import jax
# Fit model 1
model1.fit(omega, E_star, test_mode='oscillation',
deformation_mode='tension')
E_pred1 = model1.predict(omega, test_mode='oscillation')
# Clean up before next fit
del model1
gc.collect()
jax.clear_caches()
# Fit model 2
model2.fit(omega, E_star, ...)
For notebooks, also use plt.close('all') instead of plt.show() to
prevent figure accumulation in headless (CI) environments.
FAST_MODE Guidelines¶
All DMTA example notebooks support FAST_MODE (default True in CI):
import os
FAST_MODE = os.environ.get('FAST_MODE', '1') == '1'
Feature |
FAST_MODE |
Full Mode |
|---|---|---|
GMM modes |
|
|
NUTS samples |
50 warmup + 100 samples |
200–1000 warmup + 500–2000 samples |
FZSS/extra models |
Skip |
Include |
Cross-domain validation |
Skip or reduced |
Full (requires \(n \geq 15\)) |
Data subsampling |
200 points max |
Full dataset |
Set FAST_MODE=0 for publication-quality results.
See also
DMTA Workflows — complete examples using the settings above
DMTA Model Selection & Applicability — model selection guide (complexity vs. expressiveness)
DMTA Theory & Conversion — E* ↔ G* conversion and bounds rationale