Utilities (rheojax.utils)¶
The utils module provides numerical utilities for rheological analysis, including special functions, optimization tools, fit quality metrics, device detection, and modulus conversion for DMTA support.
Mittag-Leffler Functions¶
JAX-compatible Mittag-Leffler function implementations.
This module provides efficient, JAX-compatible implementations of the Mittag-Leffler function using a hybrid strategy:
Taylor series for small arguments (|z| < 8)
Asymptotic expansions for large arguments (|z| > 8)
Exponential expansion for positive z (Creep mode growth)
Inverse power law expansion for negative z (Relaxation mode decay)
This approach avoids the numerical instability of Padé approximations near alpha=beta and correctly models the exponential growth for positive arguments.
References
R. Garrappa, Numerical evaluation of two and three parameter Mittag-Leffler functions, SIAM Journal of Numerical Analysis, 2015, 53(3), 1350-1369
Haubold, H. J., Mathai, A. M., & Saxena, R. K. (2011). Mittag-Leffler functions and their applications. Journal of applied mathematics, 2011.
- rheojax.utils.mittag_leffler.mittag_leffler_e(z, alpha)[source]¶
One-parameter Mittag-Leffler function E_α(z).
E_α(z) = E_{α,1}(z)
- rheojax.utils.mittag_leffler.mittag_leffler_e2(z, alpha, beta)[source]¶
Two-parameter Mittag-Leffler function E_{α,β}(z).
Uses a hybrid evaluation strategy:
|z| <= 8: Taylor Series (Kahan summation)
z > 8: Positive Asymptotic Expansion (Exponential growth)
z < -8: Negative Asymptotic Expansion (Algebraic decay)
Smooth blending at boundaries for gradient stability.
- rheojax.utils.mittag_leffler.ml_e(z, alpha)¶
One-parameter Mittag-Leffler function E_α(z).
E_α(z) = E_{α,1}(z)
- rheojax.utils.mittag_leffler.ml_e2(z, alpha, beta)¶
Two-parameter Mittag-Leffler function E_{α,β}(z).
Uses a hybrid evaluation strategy:
|z| <= 8: Taylor Series (Kahan summation)
z > 8: Positive Asymptotic Expansion (Exponential growth)
z < -8: Negative Asymptotic Expansion (Algebraic decay)
Smooth blending at boundaries for gradient stability.
The Mittag-Leffler function is essential for fractional calculus in rheology. This module provides JAX-compatible implementations with high accuracy.
Functions¶
- rheojax.utils.mittag_leffler.mittag_leffler_e(z, alpha)[source]
One-parameter Mittag-Leffler function E_α(z).
E_α(z) = E_{α,1}(z)
- rheojax.utils.mittag_leffler.mittag_leffler_e2(z, alpha, beta)[source]
Two-parameter Mittag-Leffler function E_{α,β}(z).
Uses a hybrid evaluation strategy:
|z| <= 8: Taylor Series (Kahan summation)
z > 8: Positive Asymptotic Expansion (Exponential growth)
z < -8: Negative Asymptotic Expansion (Algebraic decay)
Smooth blending at boundaries for gradient stability.
Aliases¶
- rheojax.utils.mittag_leffler.ml_e¶
Alias for
mittag_leffler_e()
- rheojax.utils.mittag_leffler.ml_e2¶
Alias for
mittag_leffler_e2()
Mathematical Background¶
One-Parameter Function¶
The one-parameter Mittag-Leffler function is defined as:
where \(\Gamma\) is the gamma function and \(0 < \alpha \leq 2\).
Special cases:
\(\alpha = 1\): \(E_1(z) = e^z\) (exponential function)
\(\alpha = 2\): \(E_2(z^2) = \cosh(z)\) (hyperbolic cosine)
Two-Parameter Function¶
The two-parameter generalization:
Special cases:
\(\beta = 1\): \(E_{\alpha,1}(z) = E_\alpha(z)\) (one-parameter form)
\(\alpha = \beta = 1\): \(E_{1,1}(z) = e^z\) (exponential)
Implementation Details¶
The implementation uses Pade approximations for optimal performance:
Method: Pade(6,3) approximation
Accuracy: < 1e-6 relative error for \(|z| < 10\)
Performance: JIT-compiled with JAX for speed
Range: Optimized for rheological applications (\(|z| < 10\))
Examples¶
Basic Usage¶
import jax.numpy as jnp
from rheojax.utils.mittag_leffler import mittag_leffler_e, mittag_leffler_e2
# Single value
result = mittag_leffler_e(0.5, alpha=0.5)
print(result) # ~1.6487...
# Array of values
z = jnp.linspace(0, 2, 10)
results = mittag_leffler_e(z, alpha=0.8)
# Two-parameter form
result2 = mittag_leffler_e2(0.5, alpha=0.5, beta=1.0)
Fractional Relaxation Modulus¶
import jax.numpy as jnp
from rheojax.utils.mittag_leffler import mittag_leffler_e
def fractional_maxwell_relaxation(t, E, tau, alpha):
"""Relaxation modulus for fractional Maxwell model.
Parameters
----------
t : array
Time values
E : float
Elastic modulus (Pa)
tau : float
Relaxation time (s)
alpha : float
Fractional order (0 < alpha < 1)
Returns
-------
G : array
Relaxation modulus G(t)
"""
return E * mittag_leffler_e(-(t / tau)**alpha, alpha)
# Compute relaxation modulus
time = jnp.logspace(-2, 2, 100)
G = fractional_maxwell_relaxation(time, E=1000, tau=1.0, alpha=0.5)
JIT Compilation¶
import jax
import jax.numpy as jnp
from rheojax.utils.mittag_leffler import mittag_leffler_e
# JIT compile function (alpha must be static)
@jax.jit
def compute_ml(z):
return mittag_leffler_e(z, alpha=0.5)
# Use compiled function
z = jnp.linspace(0, 5, 1000)
result = compute_ml(z) # Fast computation
Prony Series Functions¶
Prony series utilities for Generalized Maxwell Model parameter identification.
This module provides utilities for working with Prony series representations of viscoelastic relaxation moduli:
E(t) = E_∞ + Σᵢ₌₁ᴺ Eᵢ exp(-t/τᵢ)
Key capabilities: - Parameter validation and bounds checking - Dynamic ParameterSet creation for N modes - Log-space transforms for wide time-scale ranges - Element minimization (optimal N selection) - R² goodness-of-fit metric computation - Softmax penalty for constrained optimization
References
Park, S. W., & Schapery, R. A. (1999). Methods of interconversion between linear viscoelastic material functions. Part I—A numerical method based on Prony series. International Journal of Solids and Structures, 36(11), 1653-1675.
- rheojax.utils.prony.validate_prony_parameters(E_inf, E_i, tau_i)[source]¶
Validate Prony series parameters for physical consistency.
Checks: - E_inf ≥ 0 (equilibrium modulus non-negative) - All Eᵢ > 0 (positive mode strengths) - All τᵢ > 0 (positive relaxation times) - Same number of Eᵢ and τᵢ elements
- Parameters:
E_inf (
float) – Equilibrium modulus (Pa)E_i (numpy.typing.ArrayLike) – Array of mode strengths (Pa)
tau_i (numpy.typing.ArrayLike) – Array of relaxation times (s)
- Returns:
Tuple of validation status and error message
- Return type:
Example
>>> E_inf = 1e3 >>> E_i = np.array([1e5, 1e4, 1e3]) >>> tau_i = np.array([1e-2, 1e-1, 1.0]) >>> valid, msg = validate_prony_parameters(E_inf, E_i, tau_i) >>> print(valid) True
- rheojax.utils.prony.create_prony_parameter_set(n_modes, modulus_type='shear')[source]¶
Create ParameterSet for N-mode Prony series.
Dynamically generates parameters: - E_inf (or G_inf for shear): Equilibrium modulus - E_1…E_N (or G_1…G_N): Mode strengths - tau_1…tau_N: Relaxation times
- Parameters:
- Return type:
ParameterSet- Returns:
ParameterSet with 2N+1 parameters configured for Prony series
- Raises:
ValueError – If n_modes < 1 or modulus_type invalid
Example
>>> params = create_prony_parameter_set(n_modes=3, modulus_type='shear') >>> list(params.keys()) ['G_inf', 'G_1', 'G_2', 'G_3', 'tau_1', 'tau_2', 'tau_3']
- rheojax.utils.prony.tau_to_log_tau(tau_i)[source]¶
Transform relaxation times to log-space.
Useful for optimization over wide time-scale ranges (e.g., 1e-6 to 1e6 s). Log-space optimization provides more uniform parameter sensitivity.
- Parameters:
tau_i (numpy.typing.ArrayLike) – Array of relaxation times (s)
- Returns:
Log-transformed relaxation times
- Return type:
numpy.typing.ArrayLike
Example
>>> tau = np.array([1e-3, 1e-1, 1e1, 1e3]) >>> log_tau = tau_to_log_tau(tau) >>> print(log_tau) [-3. -1. 1. 3.]
- rheojax.utils.prony.log_tau_to_tau(log_tau_i)[source]¶
Transform log-space relaxation times back to linear space.
Inverse of tau_to_log_tau().
- Parameters:
log_tau_i (numpy.typing.ArrayLike) – Array of log10(tau) values
- Returns:
Relaxation times (s)
- Return type:
numpy.typing.ArrayLike
Example
>>> log_tau = np.array([-3., -1., 1., 3.]) >>> tau = log_tau_to_tau(log_tau) >>> print(tau) [1.e-03 1.e-01 1.e+01 1.e+03]
- rheojax.utils.prony.compute_r_squared(y_true, y_pred)[source]¶
Compute R² coefficient of determination.
R² = 1 - SS_res / SS_tot where SS_res = Σ(y_true - y_pred)², SS_tot = Σ(y_true - mean(y_true))²
R² ∈ (-∞, 1], with R²=1 being perfect fit.
- Parameters:
y_true (numpy.typing.ArrayLike) – True values
y_pred (numpy.typing.ArrayLike) – Predicted values
- Return type:
- Returns:
R² coefficient (1.0 = perfect fit, 0.0 = mean baseline, <0 = worse than mean)
Example
>>> y_true = np.array([1., 2., 3., 4., 5.]) >>> y_pred = np.array([1.1, 2.0, 2.9, 4.1, 5.0]) >>> r2 = compute_r_squared(y_true, y_pred) >>> print(f"{r2:.4f}") 0.9960
- rheojax.utils.prony.iterative_n_reduction(fit_results_dict)[source]¶
Track R² vs N for element minimization visualization.
- Parameters:
fit_results_dict (
dict[int,float]) – Dictionary mapping n_modes → R² value Example: {10: 0.998, 9: 0.997, 8: 0.995, …}- Returns:
‘n_modes’: Array of N values (sorted ascending)
’r2’: Array of R² values corresponding to each N
’r2_min’: Minimum R² across all fits
’r2_max’: Maximum R² across all fits
- Return type:
Example
>>> results = {10: 0.998, 8: 0.995, 6: 0.990, 4: 0.980, 2: 0.950} >>> diagnostics = iterative_n_reduction(results) >>> print(diagnostics['n_modes']) [ 2 4 6 8 10] >>> print(diagnostics['r2']) [0.95 0.98 0.99 0.995 0.998]
- rheojax.utils.prony.select_optimal_n(r2_values, optimization_factor=1.5)[source]¶
Select optimal number of modes using R² threshold criterion.
Algorithm: 1. Find maximum R² across all N: R²_max (best achievable fit) 2. Compute R² degradation tolerance: ΔR² = (1 - R²_max) × (optimization_factor - 1.0) 3. Set threshold: R²_threshold = R²_max - ΔR² 4. Select smallest N where R²_N ≥ R²_threshold
Interpretation: - optimization_factor = 1.0: Require R² ≥ R²_max (maximum parsimony, only accept best) - optimization_factor = 1.5: Allow 50% of max degradation (balance quality/parsimony) - optimization_factor = 2.0: Allow 100% of max degradation (maximum parsimony)
For optimization_factor > 1, this allows some degradation from the best fit in exchange for fewer parameters. Higher factor = more tolerant of degradation = simpler model.
- Parameters:
- Return type:
- Returns:
Optimal number of modes (N_opt)
- Raises:
ValueError – If optimization_factor < 1.0 or r2_values empty
Example
>>> r2 = {5: 0.998, 3: 0.995, 2: 0.980, 1: 0.900} >>> # R²_max = 0.998, degradation room = 1 - 0.998 = 0.002 >>> # factor=1.5: ΔR² = 0.002 × 0.5 = 0.001, threshold = 0.997 >>> # Smallest N with R² ≥ 0.997: N=3 >>> n_opt = select_optimal_n(r2, optimization_factor=1.5) >>> print(n_opt) 3 >>> # factor=1.0: ΔR² = 0, threshold = 0.998, need N=5 >>> n_opt = select_optimal_n(r2, optimization_factor=1.0) >>> print(n_opt) 5
- rheojax.utils.prony.softmax_penalty(E_i, scale=1.0)[source]¶
Compute softmax penalty for negative moduli in Step 1 fitting.
This differentiable penalty encourages positive Eᵢ values during unconstrained optimization. It approaches zero when all Eᵢ >> 0, and increases smoothly for negative values.
Penalty = scale × Σᵢ log(1 + exp(-Eᵢ/scale))
- Parameters:
E_i (numpy.typing.ArrayLike) – Array of mode strengths (Pa)
scale (
float) – Smoothness parameter (default 1.0). Larger values give smoother penalty but weaker enforcement.
- Returns:
Penalty value (≥ 0, differentiable, JAX array or scalar)
Note
Returns JAX array for gradient compatibility. Do not convert to Python float() when used in JAX-traced functions.
Example
>>> E_i = np.array([1e5, 1e4, -1e3]) # One negative mode >>> penalty = softmax_penalty(E_i, scale=1e3) >>> print(f"{penalty:.2f}") 693.15 # Penalty for negative value >>> E_i_pos = np.array([1e5, 1e4, 1e3]) # All positive >>> penalty_pos = softmax_penalty(E_i_pos, scale=1e3) >>> print(f"{penalty_pos:.2e}") 3.13e+02 # Small penalty for finite positive values
- rheojax.utils.prony.warm_start_from_n_modes(params_n, n_target, modulus_type='shear')[source]¶
Extract warm-start parameters for reduced-mode fit from N-mode solution.
Used in element minimization to initialize N-1 mode fit from N mode solution. Provides intelligent parameter extraction for faster convergence in successive NLSQ fits during element search optimization.
Algorithm: 1. Extract E_inf, E_i, tau_i from N-mode params 2. If n_target < N: Truncate to first n_target modes (keep strongest modes) 3. If n_target > N: Pad with zeros/default values (edge case, typically not used) 4. If n_target == N: Return params unchanged
Parameter Layout: - params_n format: [E_inf, E_1, E_2, …, E_N, tau_1, tau_2, …, tau_N] - Total length: 2*N + 1
- Parameters:
params_n (numpy.typing.ArrayLike) – Fitted parameters from N-mode optimization Shape: (2*N + 1,) where N is current number of modes
n_target (
int) – Target number of modes for next fit (typically N-1)modulus_type (
str) – ‘shear’ (G) or ‘tensile’ (E) - currently not used, but kept for API consistency
- Return type:
numpy.typing.ArrayLike
- Returns:
Initial parameters for n_target-mode fit Shape: (2*n_target + 1,)
- Raises:
ValueError – If n_target < 1 or params_n has invalid length
Example
>>> # 5-mode fit result >>> params_5 = np.array([1e3, 1e6, 5e5, 2e5, 8e4, 3e4, # E_inf, E_1..E_5 ... 1e-2, 1e-1, 1.0, 1e1, 1e2]) # tau_1..tau_5 >>> # Warm-start for 4-mode fit (truncate weakest mode E_5) >>> params_4 = warm_start_from_n_modes(params_5, n_target=4) >>> print(params_4.shape) (9,) # 2*4 + 1 = 9 parameters >>> # E_inf, E_1..E_4, tau_1..tau_4 >>> print(params_4) [1.e+03 1.e+06 5.e+05 2.e+05 8.e+04 1.e-02 1.e-01 1.e+00 1.e+01]
Notes
Truncation assumes modes are ordered by importance (strongest first)
For GMM fitting, this ordering is typically achieved by sorting by E_i
Warm-start can provide 2-5x speedup in element minimization
Compilation reuse provides additional speedup when combined with this
The prony module provides utilities for Prony series representation of multi-mode viscoelastic behavior, supporting the Generalized Maxwell Model (GMM).
Functions¶
- rheojax.utils.prony.create_prony_parameter_set(n_modes, modulus_type='shear')[source]
Create ParameterSet for N-mode Prony series.
Dynamically generates parameters: - E_inf (or G_inf for shear): Equilibrium modulus - E_1…E_N (or G_1…G_N): Mode strengths - tau_1…tau_N: Relaxation times
- Parameters:
- Return type:
ParameterSet- Returns:
ParameterSet with 2N+1 parameters configured for Prony series
- Raises:
ValueError – If n_modes < 1 or modulus_type invalid
Example
>>> params = create_prony_parameter_set(n_modes=3, modulus_type='shear') >>> list(params.keys()) ['G_inf', 'G_1', 'G_2', 'G_3', 'tau_1', 'tau_2', 'tau_3']
Creates ParameterSet for N-mode Prony series with dynamic parameter generation.
- rheojax.utils.prony.select_optimal_n(r2_values, optimization_factor=1.5)[source]
Select optimal number of modes using R² threshold criterion.
Algorithm: 1. Find maximum R² across all N: R²_max (best achievable fit) 2. Compute R² degradation tolerance: ΔR² = (1 - R²_max) × (optimization_factor - 1.0) 3. Set threshold: R²_threshold = R²_max - ΔR² 4. Select smallest N where R²_N ≥ R²_threshold
Interpretation: - optimization_factor = 1.0: Require R² ≥ R²_max (maximum parsimony, only accept best) - optimization_factor = 1.5: Allow 50% of max degradation (balance quality/parsimony) - optimization_factor = 2.0: Allow 100% of max degradation (maximum parsimony)
For optimization_factor > 1, this allows some degradation from the best fit in exchange for fewer parameters. Higher factor = more tolerant of degradation = simpler model.
- Parameters:
- Return type:
- Returns:
Optimal number of modes (N_opt)
- Raises:
ValueError – If optimization_factor < 1.0 or r2_values empty
Example
>>> r2 = {5: 0.998, 3: 0.995, 2: 0.980, 1: 0.900} >>> # R²_max = 0.998, degradation room = 1 - 0.998 = 0.002 >>> # factor=1.5: ΔR² = 0.002 × 0.5 = 0.001, threshold = 0.997 >>> # Smallest N with R² ≥ 0.997: N=3 >>> n_opt = select_optimal_n(r2, optimization_factor=1.5) >>> print(n_opt) 3 >>> # factor=1.0: ΔR² = 0, threshold = 0.998, need N=5 >>> n_opt = select_optimal_n(r2, optimization_factor=1.0) >>> print(n_opt) 5
Element minimization algorithm with warm-start optimization (v0.4.0+). Achieves 2-5x speedup through successive fits and compilation reuse.
- rheojax.utils.prony.compute_r_squared(y_true, y_pred)[source]
Compute R² coefficient of determination.
R² = 1 - SS_res / SS_tot where SS_res = Σ(y_true - y_pred)², SS_tot = Σ(y_true - mean(y_true))²
R² ∈ (-∞, 1], with R²=1 being perfect fit.
- Parameters:
- Return type:
- Returns:
R² coefficient (1.0 = perfect fit, 0.0 = mean baseline, <0 = worse than mean)
Example
>>> y_true = np.array([1., 2., 3., 4., 5.]) >>> y_pred = np.array([1.1, 2.0, 2.9, 4.1, 5.0]) >>> r2 = compute_r_squared(y_true, y_pred) >>> print(f"{r2:.4f}") 0.9960
Computes R² coefficient of determination for model goodness-of-fit.
- rheojax.utils.prony.softmax_penalty(E_i, scale=1.0)[source]
Compute softmax penalty for negative moduli in Step 1 fitting.
This differentiable penalty encourages positive Eᵢ values during unconstrained optimization. It approaches zero when all Eᵢ >> 0, and increases smoothly for negative values.
Penalty = scale × Σᵢ log(1 + exp(-Eᵢ/scale))
- Parameters:
- Returns:
Penalty value (≥ 0, differentiable, JAX array or scalar)
Note
Returns JAX array for gradient compatibility. Do not convert to Python float() when used in JAX-traced functions.
Example
>>> E_i = np.array([1e5, 1e4, -1e3]) # One negative mode >>> penalty = softmax_penalty(E_i, scale=1e3) >>> print(f"{penalty:.2f}") 693.15 # Penalty for negative value >>> E_i_pos = np.array([1e5, 1e4, 1e3]) # All positive >>> penalty_pos = softmax_penalty(E_i_pos, scale=1e3) >>> print(f"{penalty_pos:.2e}") 3.13e+02 # Small penalty for finite positive values
Softmax penalty for physical constraints in NLSQ optimization.
Mathematical Background¶
The Prony series represents multi-mode relaxation:
Parameters (2N+1 total):
\(E_\infty\) (or \(G_\infty\)): Equilibrium modulus
\(E_i\) (or \(G_i\)): Mode i strength
\(\tau_i\): Mode i relaxation time
Element Minimization Algorithm (v0.4.0):
Fit N-mode model with NLSQ optimization
Compute R² for current N
Initialize (N-1)-mode fit from optimal N-mode parameters (warm-start)
Continue until R² degrades below threshold × R²_max
Return optimal N with minimal elements
Performance Optimization:
Warm-Start: Each N initialized from N+1 parameters (avoids cold-start overhead)
Compilation Reuse: Cached residual functions across iterations
Early Termination: Stops when R² < threshold (default: 1 - 1.5×(1 - R²_max))
Speedup: 2-5x measured improvement (20-50s → 4-25s for N=10 search)
Examples¶
Create Prony Parameter Set¶
from rheojax.utils.prony import create_prony_parameter_set
# Create 3-mode Prony series (shear modulus)
params = create_prony_parameter_set(n_modes=3, modulus_type='shear')
print(list(params.keys()))
# ['G_inf', 'G_1', 'G_2', 'G_3', 'tau_1', 'tau_2', 'tau_3']
# Create 5-mode Prony series (tensile modulus)
params_tensile = create_prony_parameter_set(n_modes=5, modulus_type='tensile')
print(list(params_tensile.keys()))
# ['E_inf', 'E_1', ..., 'E_5', 'tau_1', ..., 'tau_5']
Element Minimization with Warm-Start¶
from rheojax.models.generalized_maxwell import GeneralizedMaxwell
import numpy as np
# Create model with maximum modes
model = GeneralizedMaxwell(n_modes=10, modulus_type='shear')
# Generate relaxation data
t = np.logspace(-3, 2, 100)
G_data = ... # Experimental relaxation modulus
# Fit with automatic element minimization (v0.4.0+ warm-start)
model.fit(
t, G_data,
test_mode='relaxation',
optimization_factor=1.5 # R² threshold multiplier
)
# Check optimal number of modes (auto-reduced from 10)
print(f"Optimal modes: {model._n_modes}") # e.g., 3
# Element minimization uses warm-start for 2-5x speedup:
# - Fit N=10: 5.2s
# - Fit N=9: 0.8s (warm-started from N=10)
# - Fit N=8: 0.7s (warm-started from N=9)
# - ...
# Total: ~10s vs ~50s cold-start
See Also¶
Models API - GeneralizedMaxwell model using Prony series
Core Module (rheojax.core) - ParameterSet for parameter management
Generalized Maxwell Model (Multi-Mode) - GMM handbook with examples
Optimization¶
Optimization utilities for parameter fitting using NLSQ.
This module provides GPU-accelerated optimization using the NLSQ package (https://github.com/imewei/NLSQ). NLSQ provides 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.
Critical: This module imports NLSQ, which must be imported before JAX to enable float64 precision mode. The rheojax package handles this automatically in __init__.py.
Example
>>> from rheojax.core.parameters import ParameterSet
>>> from rheojax.utils.optimization import nlsq_optimize
>>>
>>> # Set up parameters
>>> params = ParameterSet()
>>> params.add("x", value=1.0, bounds=(0, 10))
>>>
>>> # Define objective function
>>> def objective(values):
... x = values[0]
... return (x - 5.0) ** 2
>>>
>>> # Optimize
>>> result = nlsq_optimize(objective, params, use_jax=True)
>>> print(f"Optimal x: {result.x[0]:.4f}")
- class rheojax.utils.optimization.ResidualFunction(fn, normalization_weights=None, y_data=None, use_log_residuals=False)[source]¶
Bases:
objectCallable wrapper for residual functions that carries normalization metadata.
This replaces the fragile pattern of attaching
_normalization_weightsas a function attribute (which breaks if the function is wrapped by decorators,functools.wraps,jax.jit, etc.). The class is fully transparent to callers — it behaves like a plain function but safely exposes the weights.The
_y_dataslot also carries the original dependent-variable array so downstream code (nlsq_optimize, scipy/DE fallback paths) can attach it toOptimizationResultand recover correct R²/RMSE/AIC/BIC. Without this, those paths leavey_data=Noneand ther_squaredproperty silently returnsNone, masking successful fits as failures.
- rheojax.utils.optimization.make_fd_differentiable(fn, eps=1e-07)[source]¶
Wrap a function with a finite-difference custom JVP.
This enables
jax.jacfwd(forward-mode AD) for functions that cannot be traced by JAX’s autodiff — e.g. diffrax ODE solvers which usecustom_vjpand are therefore incompatible withjacfwd.The wrapper computes JVPs via central differences:
(f(x+εv) - f(x-εv)) / 2ε. When combined withjax.jacfwd, this effectively computes the full Jacobian viavmap’d perturbations in a single batched XLA call — much faster than scipy’s sequential finite differences.- Parameters:
- Return type:
- Returns:
A function with identical signature but a custom JVP rule for
params.
Example:
# Before: NLSQ fails with TypeError on ODE models objective = create_least_squares_objective(model_fn, x, y) # After: finite-difference JVP makes it NLSQ-compatible objective = create_least_squares_objective( make_fd_differentiable(model_fn), x, y )
- rheojax.utils.optimization.nlsq_optimize(objective, parameters, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, workflow='auto', auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, compute_covariance=True, **kwargs)[source]¶
Optimize objective function using NLSQ (GPU-accelerated).
This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.
The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares).
- Parameters:
objective (
Callable[[ndarray],float|ndarray]) – Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation.parameters (
ParameterSet) – ParameterSet with initial values and boundsmethod (
str) –Optimization method. Options:
”auto”: Automatically select based on bounds (default)
”trf”: Trust Region Reflective (supports bounds)
”lm”: Levenberg-Marquardt (no bounds)
”scipy”: Use SciPy’s least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ’s forward-mode autodiff.
NLSQ internally selects the best algorithm regardless of this parameter.
use_jax (
bool) – Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision.max_iter (
int) – Maximum number of iterations (default: 1000)ftol (
float) – Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.xtol (
float) – Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.gtol (
float) – Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.workflow (
str) –NLSQ 0.6.6 workflow selection (default: “auto”):
”auto”: Memory-aware local optimization (default)
”auto_global”: Global optimization with bounds exploration
”hpc”: HPC mode with checkpointing support
auto_bounds (
bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.Numerical stability checks (default: False):
’auto’: Check and automatically fix stability issues
’check’: Check and warn but don’t fix
False: Skip stability checks
fallback (
bool) – Enable NLSQ’s native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this.compute_diagnostics (
bool) – Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information.compute_covariance (
bool) – Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit.**kwargs – Additional arguments passed to nlsq.LeastSquares.least_squares
- Return type:
OptimizationResult- Returns:
OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True).
- Raises:
ValueError – If objective is not callable or parameters is not ParameterSet
Example
>>> from rheojax.core.parameters import ParameterSet >>> params = ParameterSet() >>> params.add("a", value=1.0, bounds=(0, 10)) >>> params.add("b", value=1.0, bounds=(0, 10)) >>> >>> def objective(values): ... a, b = values ... return (a - 5.0) ** 2 + (b - 3.0) ** 2 >>> >>> result = nlsq_optimize(objective, params) >>> print(result.x) # Should be close to [5.0, 3.0] >>> >>> # With NLSQ 0.6.6 features >>> result = nlsq_optimize( ... objective, params, ... workflow="auto_global", # Global optimization ... stability="auto", # Auto-fix stability issues ... compute_diagnostics=True # Get diagnostics ... ) >>> print(result.diagnostics)
Notes
This function automatically handles float64 precision through NLSQ
JAX JIT compilation provides 5-270x speedup over scipy
Automatic differentiation eliminates need for manual Jacobian
Bounds are automatically extracted from ParameterSet
Parameters are updated in-place with optimal values
- rheojax.utils.optimization.nlsq_multistart_optimize(objective, parameters, n_starts=5, perturb_factor=0.3, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, verbose=False, parallel=True, n_workers=None, y_data=None, **kwargs)[source]¶
Multi-start optimization to escape local minima.
For complex objective functions (e.g., mastercurves with 10+ decades), single optimization runs may converge to poor local minima even from good initial guesses. This function performs multiple optimization runs from different starting points and returns the best result.
- Strategy:
First attempt: Use current parameter values (from smart initialization)
Additional attempts: Random perturbations around initial values (parallel)
Return result with lowest final cost (best fit)
Performance: With parallel=True (default), achieves 2-4x speedup for 5-10 starts by running optimizations concurrently. JAX releases the GIL during computation, enabling effective thread-based parallelism.
- Parameters:
objective (
Callable[[ndarray],float|ndarray]) – Objective function to minimizeparameters (
ParameterSet) – ParameterSet with initial values and boundsn_starts (
int) – Number of random starts (default: 5)perturb_factor (
float) – Perturbation factor for random starts (default: 0.3) Parameters are perturbed by ± perturb_factor * (value or range)method (
str) – Optimization method (default: “auto”)use_jax (
bool) – Whether to use JAX (default: True)max_iter (
int) – Max iterations per start (default: 1000)ftol (
float) – Function tolerance (default: 1e-6)xtol (
float) – Parameter tolerance (default: 1e-6)gtol (
float) – Gradient tolerance (default: 1e-6)verbose (
bool) – Print progress messages (default: False)parallel (
bool) – Run additional starts in parallel (default: True)n_workers (
int|None) – Number of parallel workers (default: min(n_starts-1, 4))**kwargs – Additional arguments for nlsq_optimize
- Return type:
OptimizationResult- Returns:
OptimizationResult with best parameters from all starts
Example
>>> # For mastercurve data (12+ decades) >>> result = nlsq_multistart_optimize( ... objective, parameters, n_starts=5, verbose=True ... ) >>> print(f"Best cost: {result.fun:.3e}")
- rheojax.utils.optimization.nlsq_curve_fit(model_fn, x_data, y_data, parameters, auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, multistart=False, n_starts=10, workflow='auto', **kwargs)[source]¶
Curve fitting using NLSQ 0.6.6 curve_fit() API with advanced features.
This function provides access to NLSQ 0.6.6’s enhanced curve_fit() features including auto-bounds, stability checks, fallback strategies, model diagnostics, and workflow selection. It returns an OptimizationResult with CurveFitResult-compatible statistical properties (r_squared, rmse, aic, bic, prediction_interval, etc.).
- Parameters:
model_fn (
Callable[[ndarray,ndarray],ndarray]) – Model function f(x, params_array) -> y_pred. Takes x_data and parameter array, returns predictions.x_data (
ndarray) – Independent variable datay_data (
ndarray) – Dependent variable data (observations)parameters (
ParameterSet) – ParameterSet with initial values and boundsauto_bounds (
bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.stability (
str|bool) – Numerical stability checks (default: False). - ‘auto’: Check and automatically fix stability issues - ‘check’: Check and warn but don’t fix - False: Skip stability checksfallback (
bool) – Enable automatic fallback strategies (default: False). When True, tries alternative approaches if optimization fails.compute_diagnostics (
bool) – Compute model health diagnostics (default: False). When True, result includes identifiability analysis, gradient health, etc.multistart (
bool) – Enable multi-start optimization (default: False). When True, explores multiple starting points to find global optimum.n_starts (
int) – Number of starting points for multi-start (default: 10).workflow (
str) –NLSQ 0.6.6 workflow selection (default: “auto”):
”auto”: Memory-aware local optimization (default)
”auto_global”: Global optimization with bounds exploration
”hpc”: HPC mode with checkpointing support
**kwargs – Additional arguments passed to nlsq.curve_fit()
- Returns:
r_squared, adj_r_squared, rmse, mae, aic, bic
confidence_intervals(alpha) method
prediction_interval(x_new, alpha) method (NLSQ 0.6.6 native)
get_parameter_uncertainties() method
- Return type:
OptimizationResult
Example
>>> from rheojax.core.parameters import ParameterSet >>> from rheojax.utils.optimization import nlsq_curve_fit >>> >>> def model(x, params): ... a, b = params ... return a * np.exp(-b * x) >>> >>> params = ParameterSet() >>> params.add("a", value=1.0, bounds=(0, 10)) >>> params.add("b", value=0.5, bounds=(0, 5)) >>> >>> result = nlsq_curve_fit( ... model, x_data, y_data, params, ... auto_bounds=True, ... stability='auto', ... fallback=True, ... compute_diagnostics=True, ... ) >>> print(f"R² = {result.r_squared:.4f}") >>> print(f"RMSE = {result.rmse:.4f}") >>> ci = result.confidence_intervals(alpha=0.95) >>> >>> # Prediction intervals (NLSQ 0.6.6) >>> pi = result.prediction_interval(x_new, alpha=0.95) >>> print(f"95% PI: [{pi[0, 0]:.3f}, {pi[0, 1]:.3f}]")
Notes
This function uses nlsq.curve_fit() directly (not LeastSquares.least_squares())
The model function signature is
f(x, params_array)notf(x, *params)Results delegate to native CurveFitResult for prediction_interval() calls
Results include all CurveFitResult properties for model comparison
- rheojax.utils.optimization.optimize_with_bounds(objective, x0, bounds, use_jax=True, **kwargs)[source]¶
Optimize objective function with parameter bounds.
Lower-level optimization function that works with arrays instead of ParameterSet. Useful for custom optimization workflows.
- Parameters:
objective (
Callable[[ndarray],float|ndarray]) – Objective function to minimizex0 (
ndarray) – Initial parameter valuesbounds (
list[tuple[float|None,float|None]]) – List of (min, max) tuples for each parameteruse_jax (
bool) – Whether to use JAX for gradients (default: True)**kwargs – Additional arguments passed to nlsq_optimize
- Return type:
OptimizationResult- Returns:
OptimizationResult with optimal parameters
Example
>>> def objective(x): ... return x[0]**2 + x[1]**2 >>> result = optimize_with_bounds( ... objective, ... x0=np.array([1.0, 1.0]), ... bounds=[(0, 5), (0, 5)] ... )
- rheojax.utils.optimization.residual_sum_of_squares(y_true, y_pred, normalize=True)[source]¶
Compute residual sum of squares (RSS).
Handles both real and complex data correctly. For complex data (e.g., oscillatory shear with G’ + iG”), computes RSS for both real and imaginary parts separately and returns the sum.
- Parameters:
y_true (numpy.typing.ArrayLike) – True values (real or complex)
y_pred (numpy.typing.ArrayLike) – Predicted values (real or complex)
normalize (
bool) – Whether to normalize by y_true (relative error)
- Return type:
- Returns:
RSS value (scalar, maintains float64 precision)
Example
>>> y_true = np.array([1.0, 2.0, 3.0]) >>> y_pred = np.array([1.1, 2.1, 2.9]) >>> rss = residual_sum_of_squares(y_true, y_pred)
- rheojax.utils.optimization.create_least_squares_objective(model_fn, x_data, y_data, normalize=True, use_log_residuals=False)[source]¶
Create residual function for NLSQ least-squares fitting.
IMPORTANT: This now returns a RESIDUAL FUNCTION (vector output), not a scalar objective. NLSQ minimizes sum(residuals²), so this provides per-point residuals to the optimizer, which enables proper gradient computation and weighting.
For complex data (e.g., G* = G’ + iG”), returns stacked real and imaginary residuals: [real₁, …, real_n, imag₁, …, imag_n] with shape (2N,).
For real data, returns residuals with shape (N,).
Log-space residuals (NEW): For rheological data spanning many decades (e.g., mastercurves with 8+ decades), use use_log_residuals=True to compute residuals in log10 space. This gives equal weight to all frequency ranges and prevents optimizer bias toward high-modulus regions.
- Parameters:
model_fn (
Callable[[ndarray,ndarray],ndarray]) – Model function that takes (x_data, parameters) and returns predictionsx_data (
ndarray) – Independent variable datay_data (
ndarray) – Dependent variable data (observations, may be complex)normalize (
bool) – Whether to use relative error (default: True)use_log_residuals (
bool) – Whether to compute residuals in log10 space (default: False). Recommended for data spanning >8 decades. Formula:residual = log10(abs(y_pred)) - log10(abs(y_data))
- Return type:
- Returns:
Residual function that takes parameters and returns residual vector
Example
>>> def linear_model(x, params): ... a, b = params ... return a * x + b >>> x = np.array([1, 2, 3, 4, 5]) >>> y = np.array([2.1, 4.0, 5.9, 8.1, 10.0]) >>> residual_fn = create_least_squares_objective(linear_model, x, y) >>> # Now use with nlsq_optimize - it receives proper residual vector >>> >>> # For mastercurve data (wide frequency range): >>> residual_fn_log = create_least_squares_objective( ... model_fn, omega, G_star, use_log_residuals=True ... )
- rheojax.utils.optimization.optimize(objective, parameters, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, workflow='auto', auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, compute_covariance=True, **kwargs)¶
Optimize objective function using NLSQ (GPU-accelerated).
This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.
The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares).
- Parameters:
objective (
Callable[[ndarray],float|ndarray]) – Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation.parameters (
ParameterSet) – ParameterSet with initial values and boundsmethod (
str) –Optimization method. Options:
”auto”: Automatically select based on bounds (default)
”trf”: Trust Region Reflective (supports bounds)
”lm”: Levenberg-Marquardt (no bounds)
”scipy”: Use SciPy’s least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ’s forward-mode autodiff.
NLSQ internally selects the best algorithm regardless of this parameter.
use_jax (
bool) – Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision.max_iter (
int) – Maximum number of iterations (default: 1000)ftol (
float) – Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.xtol (
float) – Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.gtol (
float) – Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.workflow (
str) –NLSQ 0.6.6 workflow selection (default: “auto”):
”auto”: Memory-aware local optimization (default)
”auto_global”: Global optimization with bounds exploration
”hpc”: HPC mode with checkpointing support
auto_bounds (
bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.Numerical stability checks (default: False):
’auto’: Check and automatically fix stability issues
’check’: Check and warn but don’t fix
False: Skip stability checks
fallback (
bool) – Enable NLSQ’s native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this.compute_diagnostics (
bool) – Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information.compute_covariance (
bool) – Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit.**kwargs – Additional arguments passed to nlsq.LeastSquares.least_squares
- Return type:
OptimizationResult- Returns:
OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True).
- Raises:
ValueError – If objective is not callable or parameters is not ParameterSet
Example
>>> from rheojax.core.parameters import ParameterSet >>> params = ParameterSet() >>> params.add("a", value=1.0, bounds=(0, 10)) >>> params.add("b", value=1.0, bounds=(0, 10)) >>> >>> def objective(values): ... a, b = values ... return (a - 5.0) ** 2 + (b - 3.0) ** 2 >>> >>> result = nlsq_optimize(objective, params) >>> print(result.x) # Should be close to [5.0, 3.0] >>> >>> # With NLSQ 0.6.6 features >>> result = nlsq_optimize( ... objective, params, ... workflow="auto_global", # Global optimization ... stability="auto", # Auto-fix stability issues ... compute_diagnostics=True # Get diagnostics ... ) >>> print(result.diagnostics)
Notes
This function automatically handles float64 precision through NLSQ
JAX JIT compilation provides 5-270x speedup over scipy
Automatic differentiation eliminates need for manual Jacobian
Bounds are automatically extracted from ParameterSet
Parameters are updated in-place with optimal values
- rheojax.utils.optimization.fit_parameters(objective, parameters, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, workflow='auto', auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, compute_covariance=True, **kwargs)¶
Optimize objective function using NLSQ (GPU-accelerated).
This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.
The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares).
- Parameters:
objective (
Callable[[ndarray],float|ndarray]) – Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation.parameters (
ParameterSet) – ParameterSet with initial values and boundsmethod (
str) –Optimization method. Options:
”auto”: Automatically select based on bounds (default)
”trf”: Trust Region Reflective (supports bounds)
”lm”: Levenberg-Marquardt (no bounds)
”scipy”: Use SciPy’s least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ’s forward-mode autodiff.
NLSQ internally selects the best algorithm regardless of this parameter.
use_jax (
bool) – Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision.max_iter (
int) – Maximum number of iterations (default: 1000)ftol (
float) – Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.xtol (
float) – Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.gtol (
float) – Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.workflow (
str) –NLSQ 0.6.6 workflow selection (default: “auto”):
”auto”: Memory-aware local optimization (default)
”auto_global”: Global optimization with bounds exploration
”hpc”: HPC mode with checkpointing support
auto_bounds (
bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.Numerical stability checks (default: False):
’auto’: Check and automatically fix stability issues
’check’: Check and warn but don’t fix
False: Skip stability checks
fallback (
bool) – Enable NLSQ’s native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this.compute_diagnostics (
bool) – Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information.compute_covariance (
bool) – Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit.**kwargs – Additional arguments passed to nlsq.LeastSquares.least_squares
- Return type:
OptimizationResult- Returns:
OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True).
- Raises:
ValueError – If objective is not callable or parameters is not ParameterSet
Example
>>> from rheojax.core.parameters import ParameterSet >>> params = ParameterSet() >>> params.add("a", value=1.0, bounds=(0, 10)) >>> params.add("b", value=1.0, bounds=(0, 10)) >>> >>> def objective(values): ... a, b = values ... return (a - 5.0) ** 2 + (b - 3.0) ** 2 >>> >>> result = nlsq_optimize(objective, params) >>> print(result.x) # Should be close to [5.0, 3.0] >>> >>> # With NLSQ 0.6.6 features >>> result = nlsq_optimize( ... objective, params, ... workflow="auto_global", # Global optimization ... stability="auto", # Auto-fix stability issues ... compute_diagnostics=True # Get diagnostics ... ) >>> print(result.diagnostics)
Notes
This function automatically handles float64 precision through NLSQ
JAX JIT compilation provides 5-270x speedup over scipy
Automatic differentiation eliminates need for manual Jacobian
Bounds are automatically extracted from ParameterSet
Parameters are updated in-place with optimal values
Functions and Classes¶
- class rheojax.utils.optimization.OptimizationResult(x, fun, jac=None, pcov=None, success=False, message='', nit=0, nfev=0, njev=0, optimality=None, active_mask=None, cost=None, grad=None, nlsq_result=None, residuals=None, y_data=None, n_data=None, diagnostics=None, _curve_fit_result=None, _model_fn=None, _x_data=None, _is_complex_split=False, _normalization_weights=None, _use_log_residuals=False)[source]
Bases:
objectResult from optimization with NLSQ 0.6.6 CurveFitResult-compatible properties.
This dataclass stores the results of NLSQ optimization, including optimal parameter values, objective function value, convergence information, and statistical metrics compatible with NLSQ 0.6.6’s CurveFitResult.
- x
Optimal parameter values (float64 array)
- fun
Objective function value at optimum (RSS = sum of squared residuals)
- jac
Jacobian (gradient) at optimum
- pcov
Parameter covariance matrix (n_params x n_params)
- success
Whether optimization converged successfully
- message
Status message from optimizer
- nit
Number of iterations
- nfev
Number of function evaluations
- njev
Number of Jacobian evaluations
- optimality
Optimality metric (gradient norm)
- active_mask
Active bound constraints at solution
- cost
Final cost value
- grad
Final gradient
- nlsq_result
Full NLSQ result dictionary (for advanced diagnostics)
- residuals
Residual vector (y_data - y_pred) for statistical metrics
- y_data
Original dependent variable data (for R² computation)
- n_data
Number of data points (for AIC/BIC computation)
- diagnostics
Model health diagnostics (NLSQ 0.6.6, when compute_diagnostics=True)
- Statistical Properties (NLSQ 0.6.6 CurveFitResult compatible):
r_squared: Coefficient of determination (R²) adj_r_squared: Adjusted R² accounting for number of parameters rmse: Root mean squared error mae: Mean absolute error aic: Akaike Information Criterion bic: Bayesian Information Criterion
- confidence_intervals(alpha)[source]
Compute parameter confidence intervals
- prediction_interval(x_new, alpha)[source]
Compute prediction intervals (NLSQ 0.6.6)
- get_parameter_uncertainties()[source]
Get standard errors from covariance diagonal
Result container for optimization.
- property r_squared: float | None
Coefficient of determination (R²).
Measures goodness of fit. Range: (-∞, 1], where 1 is perfect fit.
R² = 1 - SS_res / SS_tot
where SS_res = sum((y - y_pred)²) and SS_tot = sum((y - y_mean)²)
- Returns:
R² value, or None if residuals/y_data not available
- property adj_r_squared: float | None
Adjusted R² accounting for number of parameters.
Adj R² = 1 - (1 - R²) * (n - 1) / (n - p - 1)
where n is number of data points and p is number of parameters.
- Returns:
Adjusted R² value, or None if cannot be computed
- property rmse: float | None
Root mean squared error.
RMSE = sqrt(mean(residuals²))
- Returns:
RMSE value, or None if residuals not available
- property mae: float | None
Mean absolute error.
MAE = mean(abs(residuals))
More robust to outliers than RMSE.
- Returns:
MAE value, or None if residuals not available
- property aic: float | None
Akaike Information Criterion.
AIC = 2k + n*ln(RSS/n)
where k is number of parameters, n is number of data points, and RSS is residual sum of squares.
Lower is better. Used for model selection.
- Returns:
AIC value, or None if cannot be computed
- property bic: float | None
Bayesian Information Criterion.
BIC = k*ln(n) + n*ln(RSS/n)
where k is number of parameters, n is number of data points, and RSS is residual sum of squares.
Lower is better. Penalizes model complexity more than AIC.
- Returns:
BIC value, or None if cannot be computed
- confidence_intervals(alpha=0.95)[source]
Compute parameter confidence intervals.
- Parameters:
alpha (
float) – Confidence level (default: 0.95 for 95% CI).- Returns:
intervals – Array of shape (n_params, 2) with [lower, upper] bounds for each parameter, or None if covariance not available.
- Return type:
Examples
>>> result = nlsq_optimize(objective, params) >>> ci = result.confidence_intervals(alpha=0.95) >>> if ci is not None: ... for i, (lower, upper) in enumerate(ci): ... print(f"Parameter {i}: [{lower:.3f}, {upper:.3f}]")
- get_parameter_uncertainties()[source]
Get standard errors for parameters from covariance diagonal.
- Returns:
uncertainties – Standard errors for each parameter, or None if covariance not available.
- Return type:
Examples
>>> result = nlsq_optimize(objective, params) >>> std_errs = result.get_parameter_uncertainties() >>> if std_errs is not None: ... for i, se in enumerate(std_errs): ... print(f"Parameter {i}: {result.x[i]:.4f} ± {se:.4f}")
- prediction_interval(x_new=None, alpha=0.95)[source]
Compute prediction intervals for new x values.
Prediction intervals account for both parameter uncertainty and observation noise, providing bounds where future observations are expected to fall with the specified probability.
- Parameters:
- Returns:
intervals – Array of shape (n_points, 2) with [lower, upper] bounds for each point, or None if prediction intervals cannot be computed.
- Return type:
Notes
When a native NLSQ CurveFitResult is available (from nlsq_curve_fit), this method delegates to NLSQ’s prediction_interval for accuracy. Otherwise, it falls back to a manual computation using covariance propagation.
Examples
>>> result = nlsq_curve_fit(model, x_data, y_data, params) >>> pi = result.prediction_interval(x_new, alpha=0.95) >>> if pi is not None: ... for i, (lower, upper) in enumerate(pi): ... print(f"x={x_new[i]:.2f}: [{lower:.3f}, {upper:.3f}]")
- classmethod from_curve_fit_result(curve_fit_result, y_data=None, model_fn=None, x_data=None)[source]
Create OptimizationResult from NLSQ 0.6.6 CurveFitResult.
This factory method preserves the native CurveFitResult for property delegation, enabling access to NLSQ 0.6.6’s statistical methods like prediction_interval() without reimplementation.
- Parameters:
curve_fit_result (
Any) – Result from nlsq.curve_fit() call.y_data (
ndarray|None) – Original dependent variable data (for complex data handling).model_fn (
Callable|None) – Model function f(x, params) for prediction intervals.x_data (
ndarray|None) – Original independent variable data for prediction intervals.
- Returns:
result – Result with native delegation to CurveFitResult properties.
- Return type:
OptimizationResult
Examples
>>> curve_fit_result = nlsq.curve_fit(model_fn, x, y, p0=p0) >>> result = OptimizationResult.from_curve_fit_result( ... curve_fit_result, y_data=y, model_fn=model_fn, x_data=x ... ) >>> print(result.r_squared) # Delegates to native >>> pi = result.prediction_interval(x_new) # Delegates to native
- classmethod from_nlsq(nlsq_result, residuals=None, y_data=None, compute_covariance=True)[source]
Create OptimizationResult from NLSQ result dictionary.
- Parameters:
nlsq_result (
dict[str,Any]) – Result dictionary from nlsq.LeastSquares.least_squaresresiduals (
ndarray|None) – Optional residual vector for covariance scaling and metricsy_data (
ndarray|None) – Optional original y data for R² computationcompute_covariance (
bool) – Whether to compute the covariance matrix via SVD (default: True). Set False to skip SVD when CIs are not needed.
- Return type:
OptimizationResult- Returns:
OptimizationResult instance with fields extracted from NLSQ result
- __init__(x, fun, jac=None, pcov=None, success=False, message='', nit=0, nfev=0, njev=0, optimality=None, active_mask=None, cost=None, grad=None, nlsq_result=None, residuals=None, y_data=None, n_data=None, diagnostics=None, _curve_fit_result=None, _model_fn=None, _x_data=None, _is_complex_split=False, _normalization_weights=None, _use_log_residuals=False)
- rheojax.utils.optimization.nlsq_optimize(objective, parameters, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, workflow='auto', auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, compute_covariance=True, **kwargs)[source]
Optimize objective function using NLSQ (GPU-accelerated).
This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.
The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares).
- Parameters:
objective (
Callable[[ndarray],float|ndarray]) – Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation.parameters (
ParameterSet) – ParameterSet with initial values and boundsmethod (
str) –Optimization method. Options:
”auto”: Automatically select based on bounds (default)
”trf”: Trust Region Reflective (supports bounds)
”lm”: Levenberg-Marquardt (no bounds)
”scipy”: Use SciPy’s least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ’s forward-mode autodiff.
NLSQ internally selects the best algorithm regardless of this parameter.
use_jax (
bool) – Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision.max_iter (
int) – Maximum number of iterations (default: 1000)ftol (
float) – Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.xtol (
float) – Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.gtol (
float) – Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.workflow (
str) –NLSQ 0.6.6 workflow selection (default: “auto”):
”auto”: Memory-aware local optimization (default)
”auto_global”: Global optimization with bounds exploration
”hpc”: HPC mode with checkpointing support
auto_bounds (
bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.Numerical stability checks (default: False):
’auto’: Check and automatically fix stability issues
’check’: Check and warn but don’t fix
False: Skip stability checks
fallback (
bool) – Enable NLSQ’s native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this.compute_diagnostics (
bool) – Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information.compute_covariance (
bool) – Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit.**kwargs – Additional arguments passed to nlsq.LeastSquares.least_squares
- Return type:
OptimizationResult- Returns:
OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True).
- Raises:
ValueError – If objective is not callable or parameters is not ParameterSet
Example
>>> from rheojax.core.parameters import ParameterSet >>> params = ParameterSet() >>> params.add("a", value=1.0, bounds=(0, 10)) >>> params.add("b", value=1.0, bounds=(0, 10)) >>> >>> def objective(values): ... a, b = values ... return (a - 5.0) ** 2 + (b - 3.0) ** 2 >>> >>> result = nlsq_optimize(objective, params) >>> print(result.x) # Should be close to [5.0, 3.0] >>> >>> # With NLSQ 0.6.6 features >>> result = nlsq_optimize( ... objective, params, ... workflow="auto_global", # Global optimization ... stability="auto", # Auto-fix stability issues ... compute_diagnostics=True # Get diagnostics ... ) >>> print(result.diagnostics)
Notes
This function automatically handles float64 precision through NLSQ
JAX JIT compilation provides 5-270x speedup over scipy
Automatic differentiation eliminates need for manual Jacobian
Bounds are automatically extracted from ParameterSet
Parameters are updated in-place with optimal values
- rheojax.utils.optimization.optimize_with_bounds(objective, x0, bounds, use_jax=True, **kwargs)[source]
Optimize objective function with parameter bounds.
Lower-level optimization function that works with arrays instead of ParameterSet. Useful for custom optimization workflows.
- Parameters:
objective (
Callable[[ndarray],float|ndarray]) – Objective function to minimizex0 (
ndarray) – Initial parameter valuesbounds (
list[tuple[float|None,float|None]]) – List of (min, max) tuples for each parameteruse_jax (
bool) – Whether to use JAX for gradients (default: True)**kwargs – Additional arguments passed to nlsq_optimize
- Return type:
OptimizationResult- Returns:
OptimizationResult with optimal parameters
Example
>>> def objective(x): ... return x[0]**2 + x[1]**2 >>> result = optimize_with_bounds( ... objective, ... x0=np.array([1.0, 1.0]), ... bounds=[(0, 5), (0, 5)] ... )
- rheojax.utils.optimization.residual_sum_of_squares(y_true, y_pred, normalize=True)[source]
Compute residual sum of squares (RSS).
Handles both real and complex data correctly. For complex data (e.g., oscillatory shear with G’ + iG”), computes RSS for both real and imaginary parts separately and returns the sum.
- Parameters:
y_true (numpy.typing.ArrayLike) – True values (real or complex)
y_pred (numpy.typing.ArrayLike) – Predicted values (real or complex)
normalize (
bool) – Whether to normalize by y_true (relative error)
- Return type:
- Returns:
RSS value (scalar, maintains float64 precision)
Example
>>> y_true = np.array([1.0, 2.0, 3.0]) >>> y_pred = np.array([1.1, 2.1, 2.9]) >>> rss = residual_sum_of_squares(y_true, y_pred)
- rheojax.utils.optimization.create_least_squares_objective(model_fn, x_data, y_data, normalize=True, use_log_residuals=False)[source]
Create residual function for NLSQ least-squares fitting.
IMPORTANT: This now returns a RESIDUAL FUNCTION (vector output), not a scalar objective. NLSQ minimizes sum(residuals²), so this provides per-point residuals to the optimizer, which enables proper gradient computation and weighting.
For complex data (e.g., G* = G’ + iG”), returns stacked real and imaginary residuals: [real₁, …, real_n, imag₁, …, imag_n] with shape (2N,).
For real data, returns residuals with shape (N,).
Log-space residuals (NEW): For rheological data spanning many decades (e.g., mastercurves with 8+ decades), use use_log_residuals=True to compute residuals in log10 space. This gives equal weight to all frequency ranges and prevents optimizer bias toward high-modulus regions.
- Parameters:
model_fn (
Callable[[ndarray,ndarray],ndarray]) – Model function that takes (x_data, parameters) and returns predictionsx_data (
ndarray) – Independent variable datay_data (
ndarray) – Dependent variable data (observations, may be complex)normalize (
bool) – Whether to use relative error (default: True)use_log_residuals (
bool) – Whether to compute residuals in log10 space (default: False). Recommended for data spanning >8 decades. Formula:residual = log10(abs(y_pred)) - log10(abs(y_data))
- Return type:
- Returns:
Residual function that takes parameters and returns residual vector
Example
>>> def linear_model(x, params): ... a, b = params ... return a * x + b >>> x = np.array([1, 2, 3, 4, 5]) >>> y = np.array([2.1, 4.0, 5.9, 8.1, 10.0]) >>> residual_fn = create_least_squares_objective(linear_model, x, y) >>> # Now use with nlsq_optimize - it receives proper residual vector >>> >>> # For mastercurve data (wide frequency range): >>> residual_fn_log = create_least_squares_objective( ... model_fn, omega, G_star, use_log_residuals=True ... )
Aliases¶
- rheojax.utils.optimization.optimize¶
Alias for
nlsq_optimize()
- rheojax.utils.optimization.fit_parameters¶
Alias for
nlsq_optimize()
Optimization Methods¶
The following scipy.optimize methods are supported:
“L-BFGS-B”: L-BFGS algorithm with bounds (default for bounded problems)
“TNC”: Truncated Newton with bounds
“SLSQP”: Sequential Least Squares Programming
“trust-constr”: Trust-region constrained optimization
“BFGS”: Broyden-Fletcher-Goldfarb-Shanno (default for unbounded)
JAX Gradient Computation¶
When use_jax=True, gradients are computed using JAX automatic differentiation:
This provides exact gradients (up to floating-point precision) compared to numerical finite-difference approximations, leading to faster and more robust optimization.
Examples¶
Basic Optimization¶
from rheojax.core.parameters import ParameterSet
from rheojax.utils.optimization import nlsq_optimize
import jax.numpy as jnp
# Define objective function
def objective(params):
x, y = params
return (x - 5.0)**2 + (y - 3.0)**2
# Set up parameters
params = ParameterSet()
params.add("x", value=0.0, bounds=(-10, 10))
params.add("y", value=0.0, bounds=(-10, 10))
# Optimize
result = nlsq_optimize(objective, params, use_jax=True)
print(f"Optimal: x={result.x[0]:.4f}, y={result.x[1]:.4f}")
print(f"Function value: {result.fun:.6f}")
print(f"Success: {result.success}")
Model Fitting¶
import jax.numpy as jnp
from rheojax.core.parameters import ParameterSet
from rheojax.utils.optimization import nlsq_optimize
# Experimental data
t_exp = jnp.array([0.1, 0.5, 1.0, 2.0, 5.0, 10.0])
stress_exp = jnp.array([1000, 800, 650, 500, 320, 200])
# Maxwell model
def maxwell_model(t, params):
E, tau = params
return E * jnp.exp(-t / tau)
# Objective: minimize residuals
def objective(params):
predictions = maxwell_model(t_exp, params)
residuals = predictions - stress_exp
return jnp.sum(residuals**2)
# Set up parameters
params = ParameterSet()
params.add("E", value=1000, bounds=(100, 5000))
params.add("tau", value=1.0, bounds=(0.1, 100))
# Fit model
result = nlsq_optimize(
objective,
params,
use_jax=True,
method="L-BFGS-B"
)
# Extract fitted parameters
E_fit = params.get_value("E")
tau_fit = params.get_value("tau")
print(f"Fitted: E={E_fit:.1f} Pa, tau={tau_fit:.2f} s")
Custom Objective with Least Squares¶
from rheojax.utils.optimization import create_least_squares_objective
# Define model function
def power_law(shear_rate, params):
K, n = params
return K * shear_rate**n
# Data
shear_rate = jnp.logspace(-2, 2, 50)
viscosity = 100 * shear_rate**(-0.7)
# Create objective
objective = create_least_squares_objective(
power_law,
shear_rate,
viscosity,
normalize=True # Use relative error
)
# Set up parameters
params = ParameterSet()
params.add("K", value=100, bounds=(1, 1000))
params.add("n", value=-0.5, bounds=(-2, 0))
# Optimize
result = nlsq_optimize(objective, params, use_jax=True)
Optimization with Constraints¶
from rheojax.core.parameters import ParameterConstraint
# Add relative constraint: tau1 < tau2
params = ParameterSet()
params.add("tau1", value=1.0, bounds=(0.1, 100))
tau2_constraint = ParameterConstraint(
type="relative",
relation="greater_than",
other_param="tau1"
)
params.add(
"tau2",
value=10.0,
bounds=(0.1, 100),
constraints=[tau2_constraint]
)
Monitoring Optimization¶
from rheojax.core.parameters import ParameterOptimizer
# Create optimizer with tracking
optimizer = ParameterOptimizer(
parameters=params,
use_jax=True,
track_history=True
)
optimizer.set_objective(objective)
# Define callback
def callback(iteration, values, obj_value):
if iteration % 10 == 0:
print(f"Iteration {iteration}: f={obj_value:.6f}")
optimizer.set_callback(callback)
# Run optimization (integrate with scipy)
# result = optimizer.optimize()
# Get history
history = optimizer.get_history()
for entry in history:
print(f"Iter {entry['iteration']}: {entry['objective']}")
Performance Tips¶
Use JAX gradients: Set
use_jax=Truefor faster optimizationChoose appropriate method: L-BFGS-B for bounds, BFGS for unbounded
Scale parameters: Normalize to similar magnitudes (0.1-10 range)
Provide good initial guess: Closer to optimum = faster convergence
Set tolerances: Adjust
ftol,xtol,gtolfor speed vs accuracy
# Good parameter scaling
params.add("E", value=1.0, bounds=(0.1, 10)) # Scaled from Pa
params.add("tau", value=1.0, bounds=(0.1, 10)) # Scaled from s
# In objective, unscale:
def objective(params_scaled):
E = params_scaled[0] * 1000 # Convert back to Pa
tau = params_scaled[1] # Already in seconds
# ... compute objective
See Also¶
Core Module (rheojax.core) - Parameter system
../user_guide/getting_started - Basic usage examples
SciPy optimize - Optimization algorithms
JAX autodiff - Automatic differentiation
Model-Data Compatibility¶
Model-data compatibility checking for rheological models.
This module provides functions to assess whether a given model is appropriate for a dataset based on the underlying physics and data characteristics.
The compatibility checker helps users understand when model failures are expected due to physics mismatch rather than optimization issues.
- class rheojax.utils.compatibility.DecayType(*values)[source]¶
Bases:
EnumTypes of relaxation decay behavior.
- EXPONENTIAL = 'exponential'¶
- POWER_LAW = 'power_law'¶
- STRETCHED = 'stretched'¶
- MITTAG_LEFFLER = 'mittag_leffler'¶
- MULTI_MODE = 'multi_mode'¶
- UNKNOWN = 'unknown'¶
- class rheojax.utils.compatibility.MaterialType(*values)[source]¶
Bases:
EnumTypes of material behavior.
- SOLID = 'solid'¶
- LIQUID = 'liquid'¶
- GEL = 'gel'¶
- VISCOELASTIC_SOLID = 'viscoelastic_solid'¶
- VISCOELASTIC_LIQUID = 'viscoelastic_liquid'¶
- UNKNOWN = 'unknown'¶
- rheojax.utils.compatibility.detect_decay_type(t, G_t)[source]¶
Detect the type of relaxation decay from time-domain data.
Analyzes the decay pattern to determine if it follows exponential, power-law, stretched exponential, or Mittag-Leffler behavior.
- rheojax.utils.compatibility.detect_material_type(t=None, G_t=None, omega=None, G_star=None)[source]¶
Detect the material type from relaxation or oscillation data.
- Parameters:
- Returns:
Detected material type
- Return type:
- rheojax.utils.compatibility.check_model_compatibility(model, t=None, G_t=None, omega=None, G_star=None, test_mode=None)[source]¶
Check if a model is compatible with the given data.
This function analyzes the data characteristics and compares them with the model’s underlying physics to assess compatibility.
- Parameters:
- Returns:
Dictionary with compatibility information: - ‘compatible’: bool, whether model is likely compatible - ‘confidence’: float, confidence level (0-1) - ‘decay_type’: DecayType, detected decay pattern - ‘material_type’: MaterialType, detected material behavior - ‘warnings’: list[str], compatibility warnings - ‘recommendations’: list[str], suggested alternative models
- Return type:
- rheojax.utils.compatibility.format_compatibility_message(compatibility)[source]¶
Format compatibility check results as a user-friendly message.
The compatibility module provides intelligent detection of when rheological models are inappropriate for experimental data based on underlying physics. This helps users understand when model failures are expected due to physics mismatch rather than optimization issues.
Enums¶
- class rheojax.utils.compatibility.DecayType(*values)[source]
Bases:
EnumTypes of relaxation decay behavior.
Types of relaxation decay behavior:
EXPONENTIAL: Simple Maxwell-like exp(-t/tau)
POWER_LAW: Power-law t^(-alpha) (gel-like)
STRETCHED: Stretched exponential exp(-(t/tau)^beta)
MITTAG_LEFFLER: Mittag-Leffler E_alpha(-(t/tau)^alpha) (fractional)
MULTI_MODE: Multiple relaxation modes
UNKNOWN: Cannot determine
- EXPONENTIAL = 'exponential'
- POWER_LAW = 'power_law'
- STRETCHED = 'stretched'
- MITTAG_LEFFLER = 'mittag_leffler'
- MULTI_MODE = 'multi_mode'
- UNKNOWN = 'unknown'
- class rheojax.utils.compatibility.MaterialType(*values)[source]
Bases:
EnumTypes of material behavior.
Types of material behavior:
SOLID: Solid-like (finite equilibrium modulus)
LIQUID: Liquid-like (zero equilibrium modulus, flows)
GEL: Gel-like (power-law relaxation)
VISCOELASTIC_SOLID: Viscoelastic solid
VISCOELASTIC_LIQUID: Viscoelastic liquid
UNKNOWN: Cannot determine
- SOLID = 'solid'
- LIQUID = 'liquid'
- GEL = 'gel'
- VISCOELASTIC_SOLID = 'viscoelastic_solid'
- VISCOELASTIC_LIQUID = 'viscoelastic_liquid'
- UNKNOWN = 'unknown'
Functions¶
- rheojax.utils.compatibility.detect_decay_type(t, G_t)[source]
Detect the type of relaxation decay from time-domain data.
Analyzes the decay pattern to determine if it follows exponential, power-law, stretched exponential, or Mittag-Leffler behavior.
- Parameters:
- Returns:
Detected decay type
- Return type:
Analyzes relaxation modulus data to determine the type of decay pattern. Uses linear regression on log-transformed data to identify exponential, power-law, stretched exponential, or Mittag-Leffler behavior.
- rheojax.utils.compatibility.detect_material_type(t=None, G_t=None, omega=None, G_star=None)[source]
Detect the material type from relaxation or oscillation data.
- Parameters:
- Returns:
Detected material type
- Return type:
Classifies material behavior from relaxation or oscillation data. Detects solid-like, liquid-like, gel-like, or viscoelastic behavior based on equilibrium modulus or low-frequency response.
- rheojax.utils.compatibility.check_model_compatibility(model, t=None, G_t=None, omega=None, G_star=None, test_mode=None)[source]
Check if a model is compatible with the given data.
This function analyzes the data characteristics and compares them with the model’s underlying physics to assess compatibility.
- Parameters:
- Returns:
Dictionary with compatibility information: - ‘compatible’: bool, whether model is likely compatible - ‘confidence’: float, confidence level (0-1) - ‘decay_type’: DecayType, detected decay pattern - ‘material_type’: MaterialType, detected material behavior - ‘warnings’: list[str], compatibility warnings - ‘recommendations’: list[str], suggested alternative models
- Return type:
Comprehensive compatibility check comparing model physics with data characteristics. Returns detailed compatibility information including warnings and model recommendations.
- rheojax.utils.compatibility.format_compatibility_message(compatibility)[source]
Format compatibility check results as a user-friendly message.
- Parameters:
compatibility (
dict) – Compatibility check results from check_model_compatibility()- Returns:
Formatted message
- Return type:
Formats compatibility check results as a human-readable message with warnings, detected characteristics, and alternative model recommendations.
Decay Detection Algorithm¶
The decay type detection uses statistical analysis on log-transformed data:
Exponential Decay Detection
Linear regression on \(\log(G)\) vs \(t\):
High \(R^2\) (> 0.90) indicates exponential decay (Maxwell-like behavior).
Power-Law Decay Detection
Linear regression on \(\log(G)\) vs \(\log(t)\):
High \(R^2\) (> 0.90) indicates power-law decay (gel-like behavior).
Stretched Exponential Detection
Linear regression on \(\log(-\log(G/G_0))\) vs \(\log(t)\):
High \(R^2\) (> 0.90) indicates stretched exponential behavior.
Material Type Classification¶
From Relaxation Data
Material type is determined by the decay ratio:
Solid-like: decay ratio > 0.5 (significant equilibrium modulus)
Liquid-like: decay ratio < 0.1 (nearly complete relaxation)
Power-law materials: Classified based on decay type regardless of ratio
From Oscillation Data
Material type is determined by low-frequency behavior:
Solid: \(G' > G''\) at lowest frequency (elastic dominant)
Liquid: \(G'' > G'\) at lowest frequency (viscous dominant)
Examples¶
Basic Compatibility Check¶
from rheojax.models.fractional_zener_ss import FractionalZenerSolidSolid
from rheojax.utils.compatibility import (
check_model_compatibility,
format_compatibility_message
)
import numpy as np
# Generate exponential decay data (Maxwell-like)
t = np.logspace(-2, 2, 50)
G_t = 1e5 * np.exp(-t / 1.0)
# Check if FZSS is appropriate
model = FractionalZenerSolidSolid()
compatibility = check_model_compatibility(
model, t=t, G_t=G_t, test_mode='relaxation'
)
# Print human-readable report
print(format_compatibility_message(compatibility))
# Output:
# WARNING: Model may not be appropriate for this data
# Confidence: 90%
# Detected decay: exponential
# Material type: viscoelastic_liquid
#
# Warnings:
# - FZSS model expects Mittag-Leffler (power-law) relaxation,
# but data shows exponential decay.
#
# Recommended alternative models:
# - Maxwell
# - Zener
Automatic Checking During Fit¶
from rheojax.models import Maxwell
import numpy as np
# Enable automatic compatibility checking
model = Maxwell()
model.fit(
t, G_data,
test_mode='relaxation',
check_compatibility=True # Warns if model-data mismatch
)
# If incompatible, warning is logged and enhanced error
# messages provide physics-based explanations
Detecting Decay Type¶
from rheojax.utils.compatibility import detect_decay_type, DecayType
import numpy as np
# Power-law decay (gel-like)
t = np.logspace(-2, 2, 100)
G_gel = 1e5 * t**(-0.5)
decay_type = detect_decay_type(t, G_gel)
print(decay_type) # DecayType.POWER_LAW
# Exponential decay (Maxwell-like)
G_maxwell = 1e5 * np.exp(-t / 1.0)
decay_type = detect_decay_type(t, G_maxwell)
print(decay_type) # DecayType.EXPONENTIAL
Classifying Material Type¶
from rheojax.utils.compatibility import detect_material_type
import numpy as np
# Solid-like material (finite equilibrium modulus)
t = np.logspace(-2, 2, 50)
G_solid = 5e4 + 5e4 * np.exp(-t / 1.0) # Ge + Gm*exp(-t/tau)
material_type = detect_material_type(t=t, G_t=G_solid)
print(material_type) # MaterialType.VISCOELASTIC_SOLID
# Liquid-like material (no equilibrium modulus)
G_liquid = 1e5 * np.exp(-t / 1.0)
material_type = detect_material_type(t=t, G_t=G_liquid)
print(material_type) # MaterialType.VISCOELASTIC_LIQUID
Oscillation Data Analysis¶
from rheojax.utils.compatibility import (
check_model_compatibility,
detect_material_type
)
from rheojax.models.fractional_maxwell_liquid import FractionalMaxwellLiquid
import numpy as np
# Oscillation data (G', G")
omega = np.logspace(-2, 2, 50)
G_prime = 1e5 * np.ones(50) # Constant storage modulus
G_double_prime = 1e3 * omega**0.5 # Frequency-dependent loss
G_star = np.column_stack([G_prime, G_double_prime])
# Detect material type
material_type = detect_material_type(omega=omega, G_star=G_star)
print(material_type) # MaterialType.SOLID (G' > G" at low freq)
# Check model compatibility
model = FractionalMaxwellLiquid()
compatibility = check_model_compatibility(
model,
omega=omega,
G_star=G_star,
test_mode='oscillation'
)
if not compatibility['compatible']:
print(f"Confidence: {compatibility['confidence']}")
print(f"Warnings: {compatibility['warnings']}")
print(f"Try instead: {compatibility['recommendations']}")
Enhanced Error Messages¶
import numpy as np
from rheojax.models.fractional_zener_ss import FractionalZenerSolidSolid
# Generate exponential data (incompatible with FZSS)
np.random.seed(42)
t = np.logspace(-2, 2, 50)
G_t = 1e5 * np.exp(-t / 1.0) + np.random.normal(0, 1000, size=len(t))
model = FractionalZenerSolidSolid()
try:
# Fit will fail with enhanced error message
model.fit(t, G_t, test_mode='relaxation', max_iter=100)
except RuntimeError as e:
print(e)
# Output includes:
# - Original optimization error
# - Detected decay type and material type
# - Physics-based explanation of mismatch
# - Recommended alternative models
# - Guidance that failures are normal in model comparison
Model Compatibility Rules¶
Fractional Zener Solid-Solid (FZSS)
Expects: Mittag-Leffler or power-law relaxation with finite equilibrium modulus
Incompatible with: Exponential decay (use Maxwell/Zener instead)
Incompatible with: Liquid-like behavior (use FractionalMaxwellLiquid)
Fractional Maxwell Liquid (FML)
Expects: Liquid-like behavior (no equilibrium modulus)
Incompatible with: Solid-like materials (use FZSS or FractionalKelvinVoigt)
Fractional Maxwell Gel (FMG)
Expects: Power-law relaxation (gel-like)
Incompatible with: Exponential decay (use Maxwell instead)
Maxwell Model
Expects: Exponential decay
Incompatible with: Power-law decay (use FMG or FZSS)
Zener Model
Expects: Exponential decay with equilibrium modulus
Incompatible with: Power-law decay (use FZSS)
Fractional Kelvin-Voigt
Expects: Solid-like behavior
Incompatible with: Liquid-like behavior (use FractionalMaxwellLiquid)
Performance Considerations¶
Fast detection: < 1 ms for typical datasets (50-100 points)
Minimal overhead: Can be enabled during fitting without performance impact
Robust to noise: Uses statistical regression with confidence thresholds
Automatic test mode detection: Works with relaxation, creep, and oscillation data
Use Cases¶
Model Selection: Identify appropriate models before fitting
Error Diagnosis: Understand why optimization failed
Automated Pipelines: Filter incompatible model-data combinations
Model Comparison: Expect some models to fail (this is normal!)
Educational: Learn about rheological model physics
See Also¶
Core Module (rheojax.core) - BaseModel integration with check_compatibility parameter
../user_guide/model_selection - Comprehensive model selection guide
../user_guide/getting_started - Basic usage examples
Data Quality Analysis¶
Data quality and range detection utilities.
This module provides utilities for detecting data characteristics that affect optimization quality, such as very wide frequency ranges (mastercurves).
- rheojax.utils.data_quality.detect_data_range_decades(x)[source]¶
Detect the range of data in decades (log10 scale).
- Parameters:
x (
ndarray) – Data array (e.g., frequency, time)- Return type:
- Returns:
Range in decades (log10(max/min))
Example
>>> freq = np.array([1e-8, 1e-6, 1e-4, 1e4]) >>> decades = detect_data_range_decades(freq) >>> print(f"{decades:.1f} decades") # 12.0 decades
- rheojax.utils.data_quality.check_wide_frequency_range(x, threshold_decades=8.0, warn=True, recommend_log_residuals=True)[source]¶
Check if data has a very wide frequency/time range (e.g., mastercurve).
Wide-range data (>8 decades) can cause optimization problems: - Optimizer bias toward high-value regions - Poor parameter recovery - Convergence to local minima
Recommended solutions: - Use log-space residuals (use_log_residuals=True) - Fit to subset of data for initialization - Use multi-start optimization
- Parameters:
- Returns:
‘is_wide_range’: True if range > threshold
’decades’: Actual range in decades
’recommendation’: Recommended action (or empty string)
- Return type:
Example
>>> omega = np.logspace(-8, 4, 100) # 12 decades (mastercurve) >>> result = check_wide_frequency_range(omega) >>> if result['is_wide_range']: ... print(f"Wide range detected: {result['decades']:.1f} decades") ... print(result['recommendation'])
- rheojax.utils.data_quality.suggest_optimization_strategy(x, y, test_mode=None)[source]¶
Suggest optimization strategy based on data characteristics.
Analyzes data range, complexity, and test mode to recommend: - Whether to use log-residuals - Whether to use multi-start optimization - Whether to use subset initialization
- Parameters:
- Returns:
‘use_log_residuals’: Recommended for wide ranges
’use_multi_start’: Recommended for complex landscapes
’use_subset_init’: Recommended for very wide ranges
’rationale’: Explanation of recommendations
- Return type:
Example
>>> omega = np.logspace(-8, 4, 100) >>> G_star = ... # Complex modulus data >>> strategy = suggest_optimization_strategy(omega, G_star, 'oscillation') >>> print(strategy['rationale'])
- rheojax.utils.data_quality.check_nan_inf(data, label='data')[source]¶
Check for NaN/Inf values and return a diagnostic dictionary.
- Parameters:
- Returns:
‘label’: The provided label string.
’n_nan’: Number of NaN values.
’n_inf’: Number of Inf values (±∞).
’has_issues’: True if any NaN or Inf is present.
’fraction_clean’: Fraction of finite values in [0, 1].
- Return type:
Example
>>> arr = np.array([1.0, np.nan, np.inf, 2.0]) >>> result = check_nan_inf(arr, label="G_star") >>> result['n_nan'] 1 >>> result['has_issues'] True
- rheojax.utils.data_quality.check_monotonicity(x, threshold=0.95)[source]¶
Check whether an array is approximately monotonic.
An array is considered monotonic if at least threshold fraction of consecutive differences share the same sign.
- Parameters:
- Returns:
‘is_monotonic’: True if the dominant direction exceeds threshold.
’direction’: ‘increasing’, ‘decreasing’, ‘constant’, or ‘mixed’.
’fraction’: Fraction of steps in the dominant direction.
- Return type:
Example
>>> x = np.array([1.0, 2.0, 3.0, 2.9, 4.0]) >>> result = check_monotonicity(x, threshold=0.95) >>> result['direction'] 'increasing'
The data quality module provides intelligent analysis of experimental data characteristics to optimize fitting strategies, especially for wide-range frequency or time-domain data spanning multiple decades.
Functions¶
- rheojax.utils.data_quality.detect_data_range_decades(x)[source]
Detect the range of data in decades (log10 scale).
- Parameters:
x (
ndarray) – Data array (e.g., frequency, time)- Return type:
- Returns:
Range in decades (log10(max/min))
Example
>>> freq = np.array([1e-8, 1e-6, 1e-4, 1e4]) >>> decades = detect_data_range_decades(freq) >>> print(f"{decades:.1f} decades") # 12.0 decades
Detects the number of decades spanned by the independent variable (frequency, time, or shear rate). This helps identify when multi-start or log-residuals optimization may be beneficial.
- rheojax.utils.data_quality.check_wide_frequency_range(x, threshold_decades=8.0, warn=True, recommend_log_residuals=True)[source]
Check if data has a very wide frequency/time range (e.g., mastercurve).
Wide-range data (>8 decades) can cause optimization problems: - Optimizer bias toward high-value regions - Poor parameter recovery - Convergence to local minima
Recommended solutions: - Use log-space residuals (use_log_residuals=True) - Fit to subset of data for initialization - Use multi-start optimization
- Parameters:
- Returns:
‘is_wide_range’: True if range > threshold
’decades’: Actual range in decades
’recommendation’: Recommended action (or empty string)
- Return type:
Example
>>> omega = np.logspace(-8, 4, 100) # 12 decades (mastercurve) >>> result = check_wide_frequency_range(omega) >>> if result['is_wide_range']: ... print(f"Wide range detected: {result['decades']:.1f} decades") ... print(result['recommendation'])
Comprehensive analysis of frequency-domain data to determine if it spans a wide range (> 4 decades) and whether special optimization techniques are needed.
- rheojax.utils.data_quality.suggest_optimization_strategy(x, y, test_mode=None)[source]
Suggest optimization strategy based on data characteristics.
Analyzes data range, complexity, and test mode to recommend: - Whether to use log-residuals - Whether to use multi-start optimization - Whether to use subset initialization
- Parameters:
- Returns:
‘use_log_residuals’: Recommended for wide ranges
’use_multi_start’: Recommended for complex landscapes
’use_subset_init’: Recommended for very wide ranges
’rationale’: Explanation of recommendations
- Return type:
Example
>>> omega = np.logspace(-8, 4, 100) >>> G_star = ... # Complex modulus data >>> strategy = suggest_optimization_strategy(omega, G_star, 'oscillation') >>> print(strategy['rationale'])
Provides intelligent recommendations for optimization strategy based on data characteristics (range, domain, test mode). Returns configuration for:
Use of log-residuals (for wide-range data)
Multi-start optimization (for complex landscapes)
Recommended number of random starts
Wide-Range Data Challenges¶
When experimental data spans many decades (e.g., frequency from 0.01 to 1000 Hz), standard least-squares fitting can encounter problems:
Problem: Linear residuals \(\sum (y_{\text{pred}} - y_{\text{exp}})^2\) are dominated by high-magnitude points, causing poor fits at low values.
Example: For \(G'\) spanning 100 Pa to 1e6 Pa:
High-frequency error (1e6 Pa): residual ~ 1e12
Low-frequency error (100 Pa): residual ~ 1e4
Optimizer focuses on high-frequency region, ignores low-frequency
Solutions:
Log-Residuals: Minimize \(\sum (\log y_{\text{pred}} - \log y_{\text{exp}})^2\)
Balances contributions across decades
Equivalent to minimizing relative error
Automatically enabled for data > 4 decades
Multi-Start Optimization: Run multiple optimizations from random initial points
Escapes local minima
Finds global optimum more reliably
Recommended for complex model landscapes
Detection and Recommendations¶
The module automatically detects when to use these strategies:
from rheojax.utils.data_quality import suggest_optimization_strategy
import numpy as np
# Wide-range oscillation data
omega = np.logspace(-2, 3, 100) # 5 decades
G_star = ... # Complex modulus data
strategy = suggest_optimization_strategy(
x=omega,
test_mode='oscillation'
)
print(f"Use log-residuals: {strategy['use_log_residuals']}") # True
print(f"Use multi-start: {strategy['use_multi_start']}") # True
print(f"Number of starts: {strategy['n_starts']}") # 10
# Apply recommendations to model fitting
model.fit(
omega, G_star,
test_mode='oscillation',
use_log_residuals=strategy['use_log_residuals'],
multi_start=strategy['use_multi_start'],
n_starts=strategy['n_starts']
)
Examples¶
Detect Data Range¶
from rheojax.utils.data_quality import detect_data_range_decades
import numpy as np
# Narrow range (2 decades)
freq_narrow = np.logspace(0, 2, 50)
decades = detect_data_range_decades(freq_narrow)
print(f"Range: {decades:.1f} decades") # 2.0 decades
# Wide range (5 decades)
freq_wide = np.logspace(-2, 3, 100)
decades = detect_data_range_decades(freq_wide)
print(f"Range: {decades:.1f} decades") # 5.0 decades
Check Frequency Range¶
from rheojax.utils.data_quality import check_wide_frequency_range
import numpy as np
omega = np.logspace(-1, 3, 80) # 4 decades
result = check_wide_frequency_range(omega)
print(f"Is wide range: {result['is_wide_range']}") # True
print(f"Decades: {result['decades']:.2f}") # 4.0
print(f"Recommend log: {result['use_log_residuals']}") # True
print(f"Recommend multi-start: {result['use_multi_start']}") # False
# Very wide range triggers multi-start
omega_very_wide = np.logspace(-2, 4, 100) # 6 decades
result = check_wide_frequency_range(omega_very_wide)
print(f"Multi-start: {result['use_multi_start']}") # True
print(f"Starts: {result['n_starts']}") # 15
Get Complete Strategy¶
from rheojax.utils.data_quality import suggest_optimization_strategy
import numpy as np
# Time-domain relaxation data
time = np.logspace(-3, 2, 100) # 5 decades
G_t = ... # Relaxation modulus
strategy = suggest_optimization_strategy(
x=time,
test_mode='relaxation'
)
# Use strategy with BaseModel
from rheojax.models.fractional_maxwell_liquid import FractionalMaxwellLiquid
model = FractionalMaxwellLiquid()
model.fit(
time, G_t,
test_mode='relaxation',
**strategy # Unpack all recommended settings
)
Integration with BaseModel¶
The BaseModel._fit() method automatically uses suggest_optimization_strategy() when no explicit optimization configuration is provided. This ensures optimal fitting for all data ranges without user intervention.
Automatic behavior:
Data < 3 decades: Standard least-squares
Data 3-5 decades: Log-residuals enabled
Data > 5 decades: Log-residuals + multi-start (10-20 starts)
Performance Considerations¶
Detection overhead: < 0.1 ms (negligible)
Log-residuals: Same computational cost as linear
Multi-start: N times slower (N = number of starts), but more robust
Memory: Minimal additional memory for multi-start
Use Cases¶
Wide-range oscillation data: Master curves spanning 8+ decades
Time-temperature superposition: Combined data across temperatures
Multi-technique fitting: Combining relaxation + oscillation data
Fractional models: Complex parameter landscapes benefit from multi-start
Automated pipelines: Robust fitting without manual tuning
See Also¶
Core Module (rheojax.core) - BaseModel integration with automatic strategy selection
../user_guide/getting_started - Basic fitting examples
rheojax.utils.optimization- Optimization functions using these strategies
Modulus Conversion (DMTA Support)¶
Modulus conversion utilities for DMTA/DMA data analysis.
This module provides functions to convert between shear modulus G* (measured by rotational rheometers) and Young’s modulus E* (measured by DMTA/DMA instruments in tension, bending, or compression).
The fundamental relationship from isotropic linear elasticity:
E*(w) = 2(1 + v) * G*(w)
where v is the Poisson’s ratio of the material.
Example
>>> from rheojax.utils.modulus_conversion import convert_modulus
>>> from rheojax.core.test_modes import DeformationMode
>>>
>>> # Convert E* (DMTA) to G* (shear) for rubber (v=0.5 -> factor=3)
>>> G_star = convert_modulus(E_star, DeformationMode.TENSION, DeformationMode.SHEAR, poisson_ratio=0.5)
>>>
>>> # Use preset materials
>>> from rheojax.utils.modulus_conversion import POISSON_PRESETS
>>> nu = POISSON_PRESETS["glassy_polymer"] # 0.35
- rheojax.utils.modulus_conversion.convert_modulus(data, from_mode, to_mode, poisson_ratio=0.5)[source]¶
Convert modulus data between shear (G*) and tensile (E*) representations.
Applies the isotropic elasticity relationship E* = 2(1+v) * G*. Works with both real and complex arrays, and both NumPy and JAX arrays.
- Parameters:
data (
ndarray|Any) – Modulus data array (real or complex). Can be NumPy or JAX array.from_mode (
DeformationMode|str) – Source deformation mode (e.g., “tension”, “shear”)to_mode (
DeformationMode|str) – Target deformation mode (e.g., “shear”, “tension”)poisson_ratio (
float) – Poisson’s ratio of the material (default: 0.5 for rubber)
- Return type:
- Returns:
Converted modulus data in the same array type as input
- Raises:
ValueError – If Poisson’s ratio is out of bounds or modes are invalid
Example
>>> E_star = np.array([1e9 + 1e8j, 2e9 + 2e8j]) # E* in Pa >>> G_star = convert_modulus(E_star, "tension", "shear", poisson_ratio=0.5) >>> # G_star ≈ E_star / 3 for rubber
- rheojax.utils.modulus_conversion.convert_rheodata(data, to_mode, poisson_ratio=0.5)[source]¶
Convert RheoData between shear and tensile modulus representations.
Creates a new RheoData with converted y-values and updated metadata. The original RheoData is not modified.
- Parameters:
data (
RheoData) – Source RheoData objectto_mode (
DeformationMode|str) – Target deformation modepoisson_ratio (
float) – Poisson’s ratio of the material
- Return type:
- Returns:
New RheoData with converted modulus values and updated metadata
Example
>>> from rheojax.core.data import RheoData >>> # DMTA data in tension >>> dmta_data = RheoData(x=omega, y=E_star, ... metadata={"deformation_mode": "tension"}) >>> # Convert to shear for model fitting >>> shear_data = convert_rheodata(dmta_data, "shear", poisson_ratio=0.5)
The modulus conversion module provides utilities for converting between shear modulus G* and Young’s modulus E* for DMTA/DMA data analysis.
The fundamental relationship from isotropic linear elasticity:
where \(\nu\) is the Poisson’s ratio of the material.
Functions¶
- rheojax.utils.modulus_conversion.convert_modulus(data, from_mode, to_mode, poisson_ratio=0.5)[source]
Convert modulus data between shear (G*) and tensile (E*) representations.
Applies the isotropic elasticity relationship E* = 2(1+v) * G*. Works with both real and complex arrays, and both NumPy and JAX arrays.
- Parameters:
data (
ndarray|Any) – Modulus data array (real or complex). Can be NumPy or JAX array.from_mode (
DeformationMode|str) – Source deformation mode (e.g., “tension”, “shear”)to_mode (
DeformationMode|str) – Target deformation mode (e.g., “shear”, “tension”)poisson_ratio (
float) – Poisson’s ratio of the material (default: 0.5 for rubber)
- Return type:
- Returns:
Converted modulus data in the same array type as input
- Raises:
ValueError – If Poisson’s ratio is out of bounds or modes are invalid
Example
>>> E_star = np.array([1e9 + 1e8j, 2e9 + 2e8j]) # E* in Pa >>> G_star = convert_modulus(E_star, "tension", "shear", poisson_ratio=0.5) >>> # G_star ≈ E_star / 3 for rubber
Array-level conversion between E* and G* using Poisson’s ratio.
Converts complex modulus arrays in either direction:
tension→shear: \(G^* = E^* / 2(1+\nu)\)shear→tension: \(E^* = 2(1+\nu) \, G^*\)
- rheojax.utils.modulus_conversion.convert_rheodata(data, to_mode, poisson_ratio=0.5)[source]
Convert RheoData between shear and tensile modulus representations.
Creates a new RheoData with converted y-values and updated metadata. The original RheoData is not modified.
- Parameters:
data (
RheoData) – Source RheoData objectto_mode (
DeformationMode|str) – Target deformation modepoisson_ratio (
float) – Poisson’s ratio of the material
- Return type:
- Returns:
New RheoData with converted modulus values and updated metadata
Example
>>> from rheojax.core.data import RheoData >>> # DMTA data in tension >>> dmta_data = RheoData(x=omega, y=E_star, ... metadata={"deformation_mode": "tension"}) >>> # Convert to shear for model fitting >>> shear_data = convert_rheodata(dmta_data, "shear", poisson_ratio=0.5)
RheoData-level conversion with automatic metadata update (units, deformation_mode).
Material Presets¶
- rheojax.utils.modulus_conversion.POISSON_PRESETS¶
Dictionary of common Poisson’s ratio values by material class:
Material
\(\nu\)
Notes
rubber/elastomer0.50
Incompressible (E = 3G)
hydrogel0.50
Water-swollen networks
semicrystalline0.40
PE, PP, PA, PET
thermoset0.38
Epoxies, polyesters
glassy_polymer0.35
PS, PMMA, PC below Tg
metal0.30
Steel, aluminum
foam0.30
Open-cell foams
Examples¶
Array-Level Conversion¶
import numpy as np
from rheojax.utils.modulus_conversion import convert_modulus, POISSON_PRESETS
# DMTA data: E* from tensile DMA
omega = np.logspace(-2, 2, 50)
E_prime = 3e9 * np.ones(50) # E' (Pa)
E_double_prime = 1e8 * omega**0.3 # E'' (Pa)
E_star = E_prime + 1j * E_double_prime
# Convert to G* for rubber (v=0.5, factor=3)
G_star = convert_modulus(E_star, "tension", "shear", poisson_ratio=0.5)
# G* = E* / 3 for rubber
# Use preset Poisson's ratio
nu = POISSON_PRESETS["glassy_polymer"] # 0.35
G_star = convert_modulus(E_star, "tension", "shear", poisson_ratio=nu)
Fitting DMTA Data Directly¶
from rheojax.models import Maxwell
# All models accept deformation_mode in fit()/predict()
model = Maxwell()
model.fit(
omega, E_star,
test_mode='oscillation',
deformation_mode='tension',
poisson_ratio=0.5,
)
# predict() returns E* when fitted with tensile deformation_mode
E_pred = model.predict(omega, test_mode='oscillation')
Device Utilities¶
GPU detection and warning utilities for RheoJAX (System CUDA version).
This module provides utilities to detect GPU availability and warn users when they have GPU hardware available but are using CPU-only JAX.
- rheojax.utils.device.get_recommended_package()[source]¶
Get recommended JAX package based on system CUDA.
- rheojax.utils.device.check_gpu_availability(warn=True)[source]¶
Check if GPU is available but not being used by JAX.
Prints a helpful warning if GPU hardware and system CUDA are detected but JAX is running in CPU-only mode.
- Parameters:
warn (
bool) – If True, print warning when GPU available but not used. Default is True.- Returns:
True if GPU is being used by JAX, False otherwise.
- Return type:
Examples
Call this at package initialization or in CLI entry points:
>>> from rheojax.utils.device import check_gpu_availability >>> check_gpu_availability() # Prints warning if GPU detected but not used
- rheojax.utils.device.get_device_info()[source]¶
Get comprehensive device information.
- Returns:
Dictionary with: - jax_version: JAX version string - jax_backend: Current backend (cpu, gpu) - devices: List of device strings - gpu_count: Number of GPU devices - using_gpu: Boolean - gpu_hardware: GPU name - gpu_sm_version: SM version (float) - system_cuda_version: System CUDA version string - system_cuda_major: System CUDA major version (int) - recommended_package: Recommended JAX package
- Return type:
- rheojax.utils.device.get_gpu_memory_info()[source]¶
Get GPU memory information using nvidia-smi.
- Return type:
- Returns:
dict – Dictionary with keys: - ‘total_mb’: Total GPU memory in MB - ‘used_mb’: Used GPU memory in MB - ‘free_mb’: Free GPU memory in MB - ‘utilization_percent’: GPU utilization percentage
Returns empty dict if nvidia-smi is not available.
Examples
>>> from rheojax.utils.device import get_gpu_memory_info >>> info = get_gpu_memory_info() >>> if info: ... print(f"GPU Memory: {info['used_mb']}/{info['total_mb']} MB")
- rheojax.utils.device.print_device_summary()[source]¶
Print a summary of available compute devices.
Displays: - JAX version - Available devices (CPU/GPU) - GPU memory info (if available) - Warning if GPU hardware is detected but not being used
- Return type:
Examples
>>> from rheojax.utils.device import print_device_summary >>> print_device_summary() JAX Device Summary ================== JAX version: 0.8.0 Devices: [CpuDevice(id=0)] Using: CPU-only
The device module provides GPU detection and diagnostic utilities for RheoJAX. These functions help users identify available compute resources and configure JAX for optimal performance.
Functions¶
- rheojax.utils.device.check_gpu_availability(warn=True)[source]
Check if GPU is available but not being used by JAX.
Prints a helpful warning if GPU hardware and system CUDA are detected but JAX is running in CPU-only mode.
- Parameters:
warn (
bool) – If True, print warning when GPU available but not used. Default is True.- Returns:
True if GPU is being used by JAX, False otherwise.
- Return type:
Examples
Call this at package initialization or in CLI entry points:
>>> from rheojax.utils.device import check_gpu_availability >>> check_gpu_availability() # Prints warning if GPU detected but not used
Checks if GPU hardware and CUDA are available but JAX is running CPU-only. Prints a helpful message with installation instructions if GPU is not being used. Also checks for plugin conflicts even when GPU is working.
- rheojax.utils.device.check_plugin_conflicts()[source]
Check for known JAX CUDA plugin conflicts.
Detects known JAX CUDA plugin issues: dual cuda12/cuda13 plugins installed simultaneously (causes PJRT registration conflicts) or plugin/jaxlib version mismatches (causes silent CPU fallback). Returns a list of issue descriptions.
- rheojax.utils.device.get_device_info()[source]
Get comprehensive device information.
- Returns:
Dictionary with: - jax_version: JAX version string - jax_backend: Current backend (cpu, gpu) - devices: List of device strings - gpu_count: Number of GPU devices - using_gpu: Boolean - gpu_hardware: GPU name - gpu_sm_version: SM version (float) - system_cuda_version: System CUDA version string - system_cuda_major: System CUDA major version (int) - recommended_package: Recommended JAX package
- Return type:
Returns comprehensive device information including JAX version, backend, device list, GPU hardware name, SM version, system CUDA version, recommended package, and any plugin issues.
- rheojax.utils.device.get_gpu_memory_info()[source]
Get GPU memory information using nvidia-smi.
- Return type:
- Returns:
dict – Dictionary with keys: - ‘total_mb’: Total GPU memory in MB - ‘used_mb’: Used GPU memory in MB - ‘free_mb’: Free GPU memory in MB - ‘utilization_percent’: GPU utilization percentage
Returns empty dict if nvidia-smi is not available.
Examples
>>> from rheojax.utils.device import get_gpu_memory_info >>> info = get_gpu_memory_info() >>> if info: ... print(f"GPU Memory: {info['used_mb']}/{info['total_mb']} MB")
Queries nvidia-smi for GPU memory utilization (total, used, free, utilization %). Returns empty dict on systems without NVIDIA GPUs.
- rheojax.utils.device.print_device_summary()[source]
Print a summary of available compute devices.
Displays: - JAX version - Available devices (CPU/GPU) - GPU memory info (if available) - Warning if GPU hardware is detected but not being used
- Return type:
Examples
>>> from rheojax.utils.device import print_device_summary >>> print_device_summary() JAX Device Summary ================== JAX version: 0.8.0 Devices: [CpuDevice(id=0)] Using: CPU-only
Prints a formatted summary of the compute environment (JAX version, devices, GPU memory). Useful at the start of scripts and notebooks.
Examples¶
Quick Environment Check¶
from rheojax.utils import print_device_summary, check_gpu_availability
# Print full device summary at script start
print_device_summary()
# JAX Device Summary
# ==================
# JAX version: 0.8.3
# Devices: [CpuDevice(id=0)]
# Using: CPU-only
# Programmatic check
using_gpu = check_gpu_availability(warn=False)
if not using_gpu:
print("Running on CPU — consider installing JAX with CUDA support")
Detailed Device Info¶
from rheojax.utils import get_device_info, get_gpu_memory_info
info = get_device_info()
print(f"JAX {info['jax_version']} on {info['jax_backend']}")
print(f"GPU hardware: {info['gpu_hardware'] or 'None'}")
print(f"System CUDA: {info['system_cuda_version'] or 'Not found'}")
# Check for plugin issues
if info['plugin_issues']:
for issue in info['plugin_issues']:
print(f"WARNING: {issue}")
# GPU memory (if available)
mem = get_gpu_memory_info()
if mem:
print(f"GPU Memory: {mem['used_mb']}/{mem['total_mb']} MB")
Plugin Conflict Detection¶
from rheojax.utils import check_plugin_conflicts
issues = check_plugin_conflicts()
if issues:
for issue in issues:
print(f"Issue: {issue}")
print("Fix: pip uninstall -y jax jaxlib "
"jax-cuda12-plugin jax-cuda12-pjrt "
"jax-cuda13-plugin jax-cuda13-pjrt")
print("Then: make install-jax-gpu")
else:
print("No plugin conflicts detected")
See Also¶
Development Status & Performance - Technology stack and GPU requirements
make install-jax-gpu- Automated GPU JAX installationmake gpu-diagnose- Diagnose common GPU issues (plugin conflicts, version mismatches)make gpu-check- Verify GPU backend, devices, and SVD computation
Fit Quality Metrics¶
Fit quality metrics for rheological model evaluation.
This module provides functions to compute standard statistical metrics for evaluating model fit quality.
- rheojax.utils.metrics.compute_fit_quality(y_true, y_pred)[source]¶
Compute R² and RMSE fit quality metrics.
- Parameters:
y_true (numpy.typing.ArrayLike) – Observed (ground truth) values.
y_pred (numpy.typing.ArrayLike) – Predicted values from the model.
- Returns:
Dictionary containing: - ‘R2’: Coefficient of determination (R²) - ‘RMSE’: Root mean squared error - ‘nrmse’: Normalized RMSE (RMSE / range of y_true)
- Return type:
Examples
>>> y_true = [1.0, 2.0, 3.0, 4.0] >>> y_pred = [1.1, 1.9, 3.1, 3.9] >>> metrics = compute_fit_quality(y_true, y_pred) >>> metrics['R2'] > 0.99 True
- rheojax.utils.metrics.r2_complex(y_true, y_pred)[source]¶
Compute R² for complex-valued data using magnitudes.
- Parameters:
y_true (numpy.typing.ArrayLike) – Observed complex values.
y_pred (numpy.typing.ArrayLike) – Predicted complex values.
- Returns:
Coefficient of determination computed on magnitudes |G*|.
- Return type:
Note
This metric evaluates magnitude fit only. Phase errors (e.g., correct |G*| but wrong tan(δ)) are not captured. For phase-sensitive evaluation, use
r2_complex_components()which averages R² over the real and imaginary components independently.
- rheojax.utils.metrics.r2_complex_components(y_true, y_pred)[source]¶
Compute R² for complex data using separate real and imaginary components.
Returns the arithmetic mean of R²(real) and R²(imag), capturing both magnitude and phase accuracy. A model that fits |G*| perfectly but has the wrong phase angle will score lower here than with
r2_complex().- Parameters:
y_true (numpy.typing.ArrayLike) – Observed complex values (e.g., G* = G’ + i·G’’).
y_pred (numpy.typing.ArrayLike) – Predicted complex values.
- Returns:
Average R² across real (G’) and imaginary (G’’) components.
- Return type:
Examples
>>> import numpy as np >>> omega = np.logspace(-2, 2, 50) >>> G_star = omega * 1j # Pure viscous >>> r2_complex_components(G_star, G_star) 1.0
The metrics module provides standard statistical measures for evaluating model fit quality after NLSQ optimization or Bayesian inference.
Functions¶
- rheojax.utils.metrics.compute_fit_quality(y_true, y_pred)[source]
Compute R² and RMSE fit quality metrics.
- Parameters:
y_true (numpy.typing.ArrayLike) – Observed (ground truth) values.
y_pred (numpy.typing.ArrayLike) – Predicted values from the model.
- Returns:
Dictionary containing: - ‘R2’: Coefficient of determination (R²) - ‘RMSE’: Root mean squared error - ‘nrmse’: Normalized RMSE (RMSE / range of y_true)
- Return type:
Examples
>>> y_true = [1.0, 2.0, 3.0, 4.0] >>> y_pred = [1.1, 1.9, 3.1, 3.9] >>> metrics = compute_fit_quality(y_true, y_pred) >>> metrics['R2'] > 0.99 True
Computes R², RMSE, and normalized RMSE for real-valued data. Handles multi-dimensional arrays (flattened automatically).
- rheojax.utils.metrics.r2_complex(y_true, y_pred)[source]
Compute R² for complex-valued data using magnitudes.
- Parameters:
y_true (numpy.typing.ArrayLike) – Observed complex values.
y_pred (numpy.typing.ArrayLike) – Predicted complex values.
- Returns:
Coefficient of determination computed on magnitudes |G*|.
- Return type:
Note
This metric evaluates magnitude fit only. Phase errors (e.g., correct |G*| but wrong tan(δ)) are not captured. For phase-sensitive evaluation, use
r2_complex_components()which averages R² over the real and imaginary components independently.Computes R² for complex-valued modulus data (G*, E*) using magnitude comparison. Useful for oscillation fits where predictions are complex.
Examples¶
Real-Valued Fit Assessment¶
import numpy as np
from rheojax.utils import compute_fit_quality
# After model fitting
y_data = np.array([1000, 800, 650, 500, 320, 200])
y_pred = np.array([1010, 790, 660, 495, 325, 195])
metrics = compute_fit_quality(y_data, y_pred)
print(f"R² = {metrics['R2']:.4f}") # 0.9997
print(f"RMSE = {metrics['RMSE']:.1f}") # 7.9 Pa
print(f"NRMSE = {metrics['nrmse']:.4f}") # 0.0099
Complex Modulus Fit Assessment¶
import numpy as np
from rheojax.utils import r2_complex
# Oscillation data (G' + iG'')
G_star_data = np.array([1e5 + 1e3j, 5e4 + 5e3j, 1e4 + 2e4j])
G_star_pred = np.array([9.9e4 + 1.1e3j, 5.1e4 + 4.9e3j, 1.05e4 + 2.1e4j])
r2 = r2_complex(G_star_data, G_star_pred)
print(f"R² (magnitude) = {r2:.4f}")
See Also¶
Core Module (rheojax.core) - BaseModel
fit()returns optimization results with fit statisticsFitting Strategies and Troubleshooting - Fitting strategy guide
EPM Kernels¶
Kernels for Elasto-Plastic Models (EPM).
This module implements the core physics kernels for scalar EPM simulations using JAX. It includes the FFT-based elastic propagator for stress redistribution, logic for plastic events (dual-mode: hard/smooth), and the full time-stepping kernel.
- rheojax.utils.epm_kernels.make_propagator_q(L_x, L_y, shear_modulus=1.0)[source]¶
Create the quadrupolar Eshelby propagator in Fourier space.
G(q) = -2 * mu * (qx * qy)^2 / |q|^4 for q != 0 G(0) = 0
Reference: Talamali et al. (2011) Phys. Rev. E 84, 016115. Eq. (6): G(q) = -2*mu*(qx*qy)^2/|q|^4 — quadrupolar Eshelby propagator.
- rheojax.utils.epm_kernels.solve_elastic_propagator(plastic_strain_rate, propagator_q)[source]¶
Solve for the elastic stress redistribution rate using FFT.
Calculates sigma_dot_el = G * epsilon_dot_pl.
- rheojax.utils.epm_kernels.compute_plastic_strain_rate(stress, yield_thresholds, fluidity=1.0, smooth=False, smoothing_width=0.1, n_fluid=1.0, sigma_c_mean=1.0, fluidity_form='overstress')[source]¶
Compute the local plastic strain rate.
Supports two activation modes and three constitutive laws selected by fluidity_form:
Hard activation: gamma_dot_p = f(sigma) * Theta(|sigma| - sigma_c)
Smooth activation: gamma_dot_p = f(sigma) * 0.5 * (1 + tanh((|sigma| - sigma_c)/w))
The constitutive law f(sigma) depends on fluidity_form:
- “linear” (classical Bingham):
f(sigma) = sigma / tau_pl
High-rate asymptote: stress ~ gamma_dot * tau_pl. No yield-stress baseline in the asymptote.
- “power” (power-law fluidity, soft-glassy rheology):
f(sigma) = sign(sigma) * |sigma / sigma_c_mean|^n_fluid * sigma_c_mean / tau_pl
High-rate asymptote: stress ~ sigma_c_mean * (gamma_dot * tau_pl)^(1/n_fluid). Shear-thinning but no additive yield-stress baseline.
- “overstress” (Herschel-Bulkley, DEFAULT):
f(sigma) = sign(sigma) * (|sigma| - sigma_c_mean)_+^n_fluid / (sigma_c_mean^(n_fluid-1) * tau_pl)
Only stress above the threshold contributes to plastic flow. High-rate asymptote: stress ~ sigma_c_mean + sigma_c_mean * (gamma_dot * tau_pl / sigma_c_mean)^(1/n_fluid). This is the full Herschel-Bulkley form sigma = sigma_y + K * gamma_dot^n_HB with sigma_y = sigma_c_mean and n_HB = 1/n_fluid. Recommended for HB-like flow curves (emulsions, gels, foams, yield-stress fluids in general).
At n_fluid = 1, “power” reduces to “linear”; “overstress” at n_fluid = 1 gives a Bingham fluid with explicit yield stress (sigma = sigma_c_mean + gamma_dot * tau_pl).
- Parameters:
stress (
Array) – Local stress field.yield_thresholds (
Array) – Local yield thresholds.fluidity (
float) – Inverse plastic timescale ($1/tau_{pl}$).smooth (
bool) – Whether to use the differentiable smooth approximation.smoothing_width (
float) – Width parameter $w$ for smoothing.n_fluid (
float) – Power-law / HB exponent. The implied HB flow exponent is n_HB = 1/n_fluid.sigma_c_mean (
float) – Mean yield threshold, used as the scale for the power-law forms.fluidity_form (
str) – One of “linear”, “power”, or “overstress”. Default “overstress”.
- Return type:
- Returns:
Plastic strain rate field.
- rheojax.utils.epm_kernels.update_yield_thresholds(key, active_mask, current_thresholds, mean=1.0, std=0.1)[source]¶
Renew yield thresholds for active sites.
- rheojax.utils.epm_kernels.epm_step(state, propagator_q, shear_rate, dt, params, smooth=False, fluidity_form='overstress')[source]¶
Perform one full EPM time step.
Dynamics: sigma_dot = mu * gamma_dot - mu * gamma_dot_p + G * gamma_dot_p
- Parameters:
state (
tuple[Array,Array,float,Array]) – Tuple (stress, yield_thresholds, accumulated_strain, key).propagator_q (
Array) – Precomputed propagator.shear_rate (
float) – Macroscopic imposed shear rate gamma_dot.dt (
float) – Time step size.params (
dict) – Dictionary of model parameters (mu, tau_pl, sigma_c_mean, etc.).smooth (
bool) – Use smooth yielding (for inference) vs hard yielding (for simulation).
- Return type:
- Returns:
Updated state tuple.
The EPM kernels module provides low-level JAX functions for Elasto-Plastic Model simulations, including the Eshelby propagator and plastic event logic.
Functions¶
- rheojax.utils.epm_kernels.make_propagator_q(L_x, L_y, shear_modulus=1.0)[source]
Create the quadrupolar Eshelby propagator in Fourier space.
G(q) = -2 * mu * (qx * qy)^2 / |q|^4 for q != 0 G(0) = 0
Reference: Talamali et al. (2011) Phys. Rev. E 84, 016115. Eq. (6): G(q) = -2*mu*(qx*qy)^2/|q|^4 — quadrupolar Eshelby propagator.
- Parameters:
- Return type:
- Returns:
2D array of the propagator in Fourier space (L_x, L_y // 2 + 1).
Creates the quadrupolar Eshelby propagator in Fourier space.
- rheojax.utils.epm_kernels.solve_elastic_propagator(plastic_strain_rate, propagator_q)[source]
Solve for the elastic stress redistribution rate using FFT.
Calculates sigma_dot_el = G * epsilon_dot_pl.
- Parameters:
- Return type:
- Returns:
2D array of elastic stress redistribution rate.
Solves for elastic stress redistribution using FFT.
- rheojax.utils.epm_kernels.compute_plastic_strain_rate(stress, yield_thresholds, fluidity=1.0, smooth=False, smoothing_width=0.1, n_fluid=1.0, sigma_c_mean=1.0, fluidity_form='overstress')[source]
Compute the local plastic strain rate.
Supports two activation modes and three constitutive laws selected by fluidity_form:
Hard activation: gamma_dot_p = f(sigma) * Theta(|sigma| - sigma_c)
Smooth activation: gamma_dot_p = f(sigma) * 0.5 * (1 + tanh((|sigma| - sigma_c)/w))
The constitutive law f(sigma) depends on fluidity_form:
- “linear” (classical Bingham):
f(sigma) = sigma / tau_pl
High-rate asymptote: stress ~ gamma_dot * tau_pl. No yield-stress baseline in the asymptote.
- “power” (power-law fluidity, soft-glassy rheology):
f(sigma) = sign(sigma) * |sigma / sigma_c_mean|^n_fluid * sigma_c_mean / tau_pl
High-rate asymptote: stress ~ sigma_c_mean * (gamma_dot * tau_pl)^(1/n_fluid). Shear-thinning but no additive yield-stress baseline.
- “overstress” (Herschel-Bulkley, DEFAULT):
f(sigma) = sign(sigma) * (|sigma| - sigma_c_mean)_+^n_fluid / (sigma_c_mean^(n_fluid-1) * tau_pl)
Only stress above the threshold contributes to plastic flow. High-rate asymptote: stress ~ sigma_c_mean + sigma_c_mean * (gamma_dot * tau_pl / sigma_c_mean)^(1/n_fluid). This is the full Herschel-Bulkley form sigma = sigma_y + K * gamma_dot^n_HB with sigma_y = sigma_c_mean and n_HB = 1/n_fluid. Recommended for HB-like flow curves (emulsions, gels, foams, yield-stress fluids in general).
At n_fluid = 1, “power” reduces to “linear”; “overstress” at n_fluid = 1 gives a Bingham fluid with explicit yield stress (sigma = sigma_c_mean + gamma_dot * tau_pl).
- Parameters:
stress (
Array) – Local stress field.yield_thresholds (
Array) – Local yield thresholds.fluidity (
float) – Inverse plastic timescale ($1/tau_{pl}$).smooth (
bool) – Whether to use the differentiable smooth approximation.smoothing_width (
float) – Width parameter $w$ for smoothing.n_fluid (
float) – Power-law / HB exponent. The implied HB flow exponent is n_HB = 1/n_fluid.sigma_c_mean (
float) – Mean yield threshold, used as the scale for the power-law forms.fluidity_form (
str) – One of “linear”, “power”, or “overstress”. Default “overstress”.
- Return type:
- Returns:
Plastic strain rate field.
Computes local plastic strain rate (supports Hard and Smooth modes).
- rheojax.utils.epm_kernels.epm_step(state, propagator_q, shear_rate, dt, params, smooth=False, fluidity_form='overstress')[source]
Perform one full EPM time step.
Dynamics: sigma_dot = mu * gamma_dot - mu * gamma_dot_p + G * gamma_dot_p
- Parameters:
state (
tuple[Array,Array,float,Array]) – Tuple (stress, yield_thresholds, accumulated_strain, key).propagator_q (
Array) – Precomputed propagator.shear_rate (
float) – Macroscopic imposed shear rate gamma_dot.dt (
float) – Time step size.params (
dict) – Dictionary of model parameters (mu, tau_pl, sigma_c_mean, etc.).smooth (
bool) – Use smooth yielding (for inference) vs hard yielding (for simulation).
- Return type:
- Returns:
Updated state tuple.
Performs one full time step of the EPM dynamics.