Utilities (rheojax.utils)

The utils module provides numerical utilities for rheological analysis, including special functions, optimization tools, fit quality metrics, device detection, and modulus conversion for DMTA support.

Mittag-Leffler Functions

JAX-compatible Mittag-Leffler function implementations.

This module provides efficient, JAX-compatible implementations of the Mittag-Leffler function using a hybrid strategy:

  1. Taylor series for small arguments (|z| < 8)

  2. Asymptotic expansions for large arguments (|z| > 8)

    • Exponential expansion for positive z (Creep mode growth)

    • Inverse power law expansion for negative z (Relaxation mode decay)

This approach avoids the numerical instability of Padé approximations near alpha=beta and correctly models the exponential growth for positive arguments.

References

  • R. Garrappa, Numerical evaluation of two and three parameter Mittag-Leffler functions, SIAM Journal of Numerical Analysis, 2015, 53(3), 1350-1369

  • Haubold, H. J., Mathai, A. M., & Saxena, R. K. (2011). Mittag-Leffler functions and their applications. Journal of applied mathematics, 2011.

rheojax.utils.mittag_leffler.mittag_leffler_e(z, alpha)[source]

One-parameter Mittag-Leffler function E_α(z).

E_α(z) = E_{α,1}(z)

Parameters:
  • z (float | Array) – Argument(s) of the Mittag-Leffler function.

  • alpha (float) – Order parameter, must be real and positive (0 < alpha <= 2).

Returns:

Value(s) of E_α(z).

Return type:

float | Array

rheojax.utils.mittag_leffler.mittag_leffler_e2(z, alpha, beta)[source]

Two-parameter Mittag-Leffler function E_{α,β}(z).

Uses a hybrid evaluation strategy:

  • |z| <= 8: Taylor Series (Kahan summation)

  • z > 8: Positive Asymptotic Expansion (Exponential growth)

  • z < -8: Negative Asymptotic Expansion (Algebraic decay)

  • Smooth blending at boundaries for gradient stability.

Parameters:
  • z (float | Array) – Argument(s) of the Mittag-Leffler function.

  • alpha (float) – First parameter (0 < alpha <= 2).

  • beta (float) – Second parameter.

Returns:

Value(s) of E_{α,β}(z).

Return type:

float | Array

rheojax.utils.mittag_leffler.ml_e(z, alpha)

One-parameter Mittag-Leffler function E_α(z).

E_α(z) = E_{α,1}(z)

Parameters:
  • z (float | Array) – Argument(s) of the Mittag-Leffler function.

  • alpha (float) – Order parameter, must be real and positive (0 < alpha <= 2).

Returns:

Value(s) of E_α(z).

Return type:

float | Array

rheojax.utils.mittag_leffler.ml_e2(z, alpha, beta)

Two-parameter Mittag-Leffler function E_{α,β}(z).

Uses a hybrid evaluation strategy:

  • |z| <= 8: Taylor Series (Kahan summation)

  • z > 8: Positive Asymptotic Expansion (Exponential growth)

  • z < -8: Negative Asymptotic Expansion (Algebraic decay)

  • Smooth blending at boundaries for gradient stability.

Parameters:
  • z (float | Array) – Argument(s) of the Mittag-Leffler function.

  • alpha (float) – First parameter (0 < alpha <= 2).

  • beta (float) – Second parameter.

Returns:

Value(s) of E_{α,β}(z).

Return type:

float | Array

The Mittag-Leffler function is essential for fractional calculus in rheology. This module provides JAX-compatible implementations with high accuracy.

Functions

rheojax.utils.mittag_leffler.mittag_leffler_e(z, alpha)[source]

One-parameter Mittag-Leffler function E_α(z).

E_α(z) = E_{α,1}(z)

Parameters:
  • z (float | Array) – Argument(s) of the Mittag-Leffler function.

  • alpha (float) – Order parameter, must be real and positive (0 < alpha <= 2).

Returns:

Value(s) of E_α(z).

Return type:

float | Array

rheojax.utils.mittag_leffler.mittag_leffler_e2(z, alpha, beta)[source]

Two-parameter Mittag-Leffler function E_{α,β}(z).

Uses a hybrid evaluation strategy:

  • |z| <= 8: Taylor Series (Kahan summation)

  • z > 8: Positive Asymptotic Expansion (Exponential growth)

  • z < -8: Negative Asymptotic Expansion (Algebraic decay)

  • Smooth blending at boundaries for gradient stability.

Parameters:
  • z (float | Array) – Argument(s) of the Mittag-Leffler function.

  • alpha (float) – First parameter (0 < alpha <= 2).

  • beta (float) – Second parameter.

Returns:

Value(s) of E_{α,β}(z).

Return type:

float | Array

Aliases

rheojax.utils.mittag_leffler.ml_e

Alias for mittag_leffler_e()

rheojax.utils.mittag_leffler.ml_e2

Alias for mittag_leffler_e2()

Mathematical Background

One-Parameter Function

The one-parameter Mittag-Leffler function is defined as:

\[E_\alpha(z) = \sum_{k=0}^{\infty} \frac{z^k}{\Gamma(\alpha k + 1)}\]

where \(\Gamma\) is the gamma function and \(0 < \alpha \leq 2\).

Special cases:

  • \(\alpha = 1\): \(E_1(z) = e^z\) (exponential function)

  • \(\alpha = 2\): \(E_2(z^2) = \cosh(z)\) (hyperbolic cosine)

Two-Parameter Function

The two-parameter generalization:

\[E_{\alpha,\beta}(z) = \sum_{k=0}^{\infty} \frac{z^k}{\Gamma(\alpha k + \beta)}\]

Special cases:

  • \(\beta = 1\): \(E_{\alpha,1}(z) = E_\alpha(z)\) (one-parameter form)

  • \(\alpha = \beta = 1\): \(E_{1,1}(z) = e^z\) (exponential)

Implementation Details

The implementation uses Pade approximations for optimal performance:

  • Method: Pade(6,3) approximation

  • Accuracy: < 1e-6 relative error for \(|z| < 10\)

  • Performance: JIT-compiled with JAX for speed

  • Range: Optimized for rheological applications (\(|z| < 10\))

Examples

Basic Usage

import jax.numpy as jnp
from rheojax.utils.mittag_leffler import mittag_leffler_e, mittag_leffler_e2

# Single value
result = mittag_leffler_e(0.5, alpha=0.5)
print(result)  # ~1.6487...

# Array of values
z = jnp.linspace(0, 2, 10)
results = mittag_leffler_e(z, alpha=0.8)

# Two-parameter form
result2 = mittag_leffler_e2(0.5, alpha=0.5, beta=1.0)

Fractional Relaxation Modulus

import jax.numpy as jnp
from rheojax.utils.mittag_leffler import mittag_leffler_e

def fractional_maxwell_relaxation(t, E, tau, alpha):
    """Relaxation modulus for fractional Maxwell model.

    Parameters
    ----------
    t : array
        Time values
    E : float
        Elastic modulus (Pa)
    tau : float
        Relaxation time (s)
    alpha : float
        Fractional order (0 < alpha < 1)

    Returns
    -------
    G : array
        Relaxation modulus G(t)
    """
    return E * mittag_leffler_e(-(t / tau)**alpha, alpha)

# Compute relaxation modulus
time = jnp.logspace(-2, 2, 100)
G = fractional_maxwell_relaxation(time, E=1000, tau=1.0, alpha=0.5)

JIT Compilation

import jax
import jax.numpy as jnp
from rheojax.utils.mittag_leffler import mittag_leffler_e

# JIT compile function (alpha must be static)
@jax.jit
def compute_ml(z):
    return mittag_leffler_e(z, alpha=0.5)

# Use compiled function
z = jnp.linspace(0, 5, 1000)
result = compute_ml(z)  # Fast computation

Prony Series Functions

Prony series utilities for Generalized Maxwell Model parameter identification.

This module provides utilities for working with Prony series representations of viscoelastic relaxation moduli:

E(t) = E_∞ + Σᵢ₌₁ᴺ Eᵢ exp(-t/τᵢ)

Key capabilities: - Parameter validation and bounds checking - Dynamic ParameterSet creation for N modes - Log-space transforms for wide time-scale ranges - Element minimization (optimal N selection) - R² goodness-of-fit metric computation - Softmax penalty for constrained optimization

References

  • Park, S. W., & Schapery, R. A. (1999). Methods of interconversion between linear viscoelastic material functions. Part I—A numerical method based on Prony series. International Journal of Solids and Structures, 36(11), 1653-1675.

  • pyvisco: https://github.com/saintsfan342000/pyvisco

type rheojax.utils.prony.ArrayLike = ndarray
rheojax.utils.prony.validate_prony_parameters(E_inf, E_i, tau_i)[source]

Validate Prony series parameters for physical consistency.

Checks: - E_inf ≥ 0 (equilibrium modulus non-negative) - All Eᵢ > 0 (positive mode strengths) - All τᵢ > 0 (positive relaxation times) - Same number of Eᵢ and τᵢ elements

Parameters:
  • E_inf (float) – Equilibrium modulus (Pa)

  • E_i (numpy.typing.ArrayLike) – Array of mode strengths (Pa)

  • tau_i (numpy.typing.ArrayLike) – Array of relaxation times (s)

Returns:

Tuple of validation status and error message

Return type:

tuple[bool, str]

Example

>>> E_inf = 1e3
>>> E_i = np.array([1e5, 1e4, 1e3])
>>> tau_i = np.array([1e-2, 1e-1, 1.0])
>>> valid, msg = validate_prony_parameters(E_inf, E_i, tau_i)
>>> print(valid)
True
rheojax.utils.prony.create_prony_parameter_set(n_modes, modulus_type='shear')[source]

Create ParameterSet for N-mode Prony series.

Dynamically generates parameters: - E_inf (or G_inf for shear): Equilibrium modulus - E_1…E_N (or G_1…G_N): Mode strengths - tau_1…tau_N: Relaxation times

Parameters:
  • n_modes (int) – Number of Maxwell modes (N ≥ 1)

  • modulus_type (str) – ‘shear’ for G(t) or ‘tensile’ for E(t)

Return type:

ParameterSet

Returns:

ParameterSet with 2N+1 parameters configured for Prony series

Raises:

ValueError – If n_modes < 1 or modulus_type invalid

Example

>>> params = create_prony_parameter_set(n_modes=3, modulus_type='shear')
>>> list(params.keys())
['G_inf', 'G_1', 'G_2', 'G_3', 'tau_1', 'tau_2', 'tau_3']
rheojax.utils.prony.tau_to_log_tau(tau_i)[source]

Transform relaxation times to log-space.

Useful for optimization over wide time-scale ranges (e.g., 1e-6 to 1e6 s). Log-space optimization provides more uniform parameter sensitivity.

Parameters:

tau_i (numpy.typing.ArrayLike) – Array of relaxation times (s)

Returns:

Log-transformed relaxation times

Return type:

numpy.typing.ArrayLike

Example

>>> tau = np.array([1e-3, 1e-1, 1e1, 1e3])
>>> log_tau = tau_to_log_tau(tau)
>>> print(log_tau)
[-3. -1.  1.  3.]
rheojax.utils.prony.log_tau_to_tau(log_tau_i)[source]

Transform log-space relaxation times back to linear space.

Inverse of tau_to_log_tau().

Parameters:

log_tau_i (numpy.typing.ArrayLike) – Array of log10(tau) values

Returns:

Relaxation times (s)

Return type:

numpy.typing.ArrayLike

Example

>>> log_tau = np.array([-3., -1., 1., 3.])
>>> tau = log_tau_to_tau(log_tau)
>>> print(tau)
[1.e-03 1.e-01 1.e+01 1.e+03]
rheojax.utils.prony.compute_r_squared(y_true, y_pred)[source]

Compute R² coefficient of determination.

R² = 1 - SS_res / SS_tot where SS_res = Σ(y_true - y_pred)², SS_tot = Σ(y_true - mean(y_true))²

R² ∈ (-∞, 1], with R²=1 being perfect fit.

Parameters:
  • y_true (numpy.typing.ArrayLike) – True values

  • y_pred (numpy.typing.ArrayLike) – Predicted values

Return type:

float

Returns:

R² coefficient (1.0 = perfect fit, 0.0 = mean baseline, <0 = worse than mean)

Example

>>> y_true = np.array([1., 2., 3., 4., 5.])
>>> y_pred = np.array([1.1, 2.0, 2.9, 4.1, 5.0])
>>> r2 = compute_r_squared(y_true, y_pred)
>>> print(f"{r2:.4f}")
0.9960
rheojax.utils.prony.iterative_n_reduction(fit_results_dict)[source]

Track R² vs N for element minimization visualization.

Parameters:

fit_results_dict (dict[int, float]) – Dictionary mapping n_modes → R² value Example: {10: 0.998, 9: 0.997, 8: 0.995, …}

Returns:

  • ‘n_modes’: Array of N values (sorted ascending)

  • ’r2’: Array of R² values corresponding to each N

  • ’r2_min’: Minimum R² across all fits

  • ’r2_max’: Maximum R² across all fits

Return type:

dict[str, numpy.typing.ArrayLike]

Example

>>> results = {10: 0.998, 8: 0.995, 6: 0.990, 4: 0.980, 2: 0.950}
>>> diagnostics = iterative_n_reduction(results)
>>> print(diagnostics['n_modes'])
[ 2  4  6  8 10]
>>> print(diagnostics['r2'])
[0.95  0.98  0.99  0.995 0.998]
rheojax.utils.prony.select_optimal_n(r2_values, optimization_factor=1.5)[source]

Select optimal number of modes using R² threshold criterion.

Algorithm: 1. Find maximum R² across all N: R²_max (best achievable fit) 2. Compute R² degradation tolerance: ΔR² = (1 - R²_max) × (optimization_factor - 1.0) 3. Set threshold: R²_threshold = R²_max - ΔR² 4. Select smallest N where R²_N ≥ R²_threshold

Interpretation: - optimization_factor = 1.0: Require R² ≥ R²_max (maximum parsimony, only accept best) - optimization_factor = 1.5: Allow 50% of max degradation (balance quality/parsimony) - optimization_factor = 2.0: Allow 100% of max degradation (maximum parsimony)

For optimization_factor > 1, this allows some degradation from the best fit in exchange for fewer parameters. Higher factor = more tolerant of degradation = simpler model.

Parameters:
  • r2_values (dict[int, float]) – Dictionary mapping n_modes → R² value

  • optimization_factor (float) – Parsimony factor (≥ 1.0) - 1.0: No degradation allowed (require best R²) - 1.5 (default): Allow 50% of max possible degradation - 2.0: Allow 100% degradation (maximum simplicity)

Return type:

int

Returns:

Optimal number of modes (N_opt)

Raises:

ValueError – If optimization_factor < 1.0 or r2_values empty

Example

>>> r2 = {5: 0.998, 3: 0.995, 2: 0.980, 1: 0.900}
>>> # R²_max = 0.998, degradation room = 1 - 0.998 = 0.002
>>> # factor=1.5: ΔR² = 0.002 × 0.5 = 0.001, threshold = 0.997
>>> # Smallest N with R² ≥ 0.997: N=3
>>> n_opt = select_optimal_n(r2, optimization_factor=1.5)
>>> print(n_opt)
3
>>> # factor=1.0: ΔR² = 0, threshold = 0.998, need N=5
>>> n_opt = select_optimal_n(r2, optimization_factor=1.0)
>>> print(n_opt)
5
rheojax.utils.prony.softmax_penalty(E_i, scale=1.0)[source]

Compute softmax penalty for negative moduli in Step 1 fitting.

This differentiable penalty encourages positive Eᵢ values during unconstrained optimization. It approaches zero when all Eᵢ >> 0, and increases smoothly for negative values.

Penalty = scale × Σᵢ log(1 + exp(-Eᵢ/scale))

Parameters:
  • E_i (numpy.typing.ArrayLike) – Array of mode strengths (Pa)

  • scale (float) – Smoothness parameter (default 1.0). Larger values give smoother penalty but weaker enforcement.

Returns:

Penalty value (≥ 0, differentiable, JAX array or scalar)

Note

Returns JAX array for gradient compatibility. Do not convert to Python float() when used in JAX-traced functions.

Example

>>> E_i = np.array([1e5, 1e4, -1e3])  # One negative mode
>>> penalty = softmax_penalty(E_i, scale=1e3)
>>> print(f"{penalty:.2f}")
693.15  # Penalty for negative value
>>> E_i_pos = np.array([1e5, 1e4, 1e3])  # All positive
>>> penalty_pos = softmax_penalty(E_i_pos, scale=1e3)
>>> print(f"{penalty_pos:.2e}")
3.13e+02  # Small penalty for finite positive values
rheojax.utils.prony.warm_start_from_n_modes(params_n, n_target, modulus_type='shear')[source]

Extract warm-start parameters for reduced-mode fit from N-mode solution.

Used in element minimization to initialize N-1 mode fit from N mode solution. Provides intelligent parameter extraction for faster convergence in successive NLSQ fits during element search optimization.

Algorithm: 1. Extract E_inf, E_i, tau_i from N-mode params 2. If n_target < N: Truncate to first n_target modes (keep strongest modes) 3. If n_target > N: Pad with zeros/default values (edge case, typically not used) 4. If n_target == N: Return params unchanged

Parameter Layout: - params_n format: [E_inf, E_1, E_2, …, E_N, tau_1, tau_2, …, tau_N] - Total length: 2*N + 1

Parameters:
  • params_n (numpy.typing.ArrayLike) – Fitted parameters from N-mode optimization Shape: (2*N + 1,) where N is current number of modes

  • n_target (int) – Target number of modes for next fit (typically N-1)

  • modulus_type (str) – ‘shear’ (G) or ‘tensile’ (E) - currently not used, but kept for API consistency

Return type:

numpy.typing.ArrayLike

Returns:

Initial parameters for n_target-mode fit Shape: (2*n_target + 1,)

Raises:

ValueError – If n_target < 1 or params_n has invalid length

Example

>>> # 5-mode fit result
>>> params_5 = np.array([1e3, 1e6, 5e5, 2e5, 8e4, 3e4,  # E_inf, E_1..E_5
...                      1e-2, 1e-1, 1.0, 1e1, 1e2])    # tau_1..tau_5
>>> # Warm-start for 4-mode fit (truncate weakest mode E_5)
>>> params_4 = warm_start_from_n_modes(params_5, n_target=4)
>>> print(params_4.shape)
(9,)  # 2*4 + 1 = 9 parameters
>>> # E_inf, E_1..E_4, tau_1..tau_4
>>> print(params_4)
[1.e+03 1.e+06 5.e+05 2.e+05 8.e+04 1.e-02 1.e-01 1.e+00 1.e+01]

Notes

  • Truncation assumes modes are ordered by importance (strongest first)

  • For GMM fitting, this ordering is typically achieved by sorting by E_i

  • Warm-start can provide 2-5x speedup in element minimization

  • Compilation reuse provides additional speedup when combined with this

The prony module provides utilities for Prony series representation of multi-mode viscoelastic behavior, supporting the Generalized Maxwell Model (GMM).

Functions

rheojax.utils.prony.create_prony_parameter_set(n_modes, modulus_type='shear')[source]

Create ParameterSet for N-mode Prony series.

Dynamically generates parameters: - E_inf (or G_inf for shear): Equilibrium modulus - E_1…E_N (or G_1…G_N): Mode strengths - tau_1…tau_N: Relaxation times

Parameters:
  • n_modes (int) – Number of Maxwell modes (N ≥ 1)

  • modulus_type (str) – ‘shear’ for G(t) or ‘tensile’ for E(t)

Return type:

ParameterSet

Returns:

ParameterSet with 2N+1 parameters configured for Prony series

Raises:

ValueError – If n_modes < 1 or modulus_type invalid

Example

>>> params = create_prony_parameter_set(n_modes=3, modulus_type='shear')
>>> list(params.keys())
['G_inf', 'G_1', 'G_2', 'G_3', 'tau_1', 'tau_2', 'tau_3']

Creates ParameterSet for N-mode Prony series with dynamic parameter generation.

rheojax.utils.prony.select_optimal_n(r2_values, optimization_factor=1.5)[source]

Select optimal number of modes using R² threshold criterion.

Algorithm: 1. Find maximum R² across all N: R²_max (best achievable fit) 2. Compute R² degradation tolerance: ΔR² = (1 - R²_max) × (optimization_factor - 1.0) 3. Set threshold: R²_threshold = R²_max - ΔR² 4. Select smallest N where R²_N ≥ R²_threshold

Interpretation: - optimization_factor = 1.0: Require R² ≥ R²_max (maximum parsimony, only accept best) - optimization_factor = 1.5: Allow 50% of max degradation (balance quality/parsimony) - optimization_factor = 2.0: Allow 100% of max degradation (maximum parsimony)

For optimization_factor > 1, this allows some degradation from the best fit in exchange for fewer parameters. Higher factor = more tolerant of degradation = simpler model.

Parameters:
  • r2_values (dict[int, float]) – Dictionary mapping n_modes → R² value

  • optimization_factor (float) – Parsimony factor (≥ 1.0) - 1.0: No degradation allowed (require best R²) - 1.5 (default): Allow 50% of max possible degradation - 2.0: Allow 100% degradation (maximum simplicity)

Return type:

int

Returns:

Optimal number of modes (N_opt)

Raises:

ValueError – If optimization_factor < 1.0 or r2_values empty

Example

>>> r2 = {5: 0.998, 3: 0.995, 2: 0.980, 1: 0.900}
>>> # R²_max = 0.998, degradation room = 1 - 0.998 = 0.002
>>> # factor=1.5: ΔR² = 0.002 × 0.5 = 0.001, threshold = 0.997
>>> # Smallest N with R² ≥ 0.997: N=3
>>> n_opt = select_optimal_n(r2, optimization_factor=1.5)
>>> print(n_opt)
3
>>> # factor=1.0: ΔR² = 0, threshold = 0.998, need N=5
>>> n_opt = select_optimal_n(r2, optimization_factor=1.0)
>>> print(n_opt)
5

Element minimization algorithm with warm-start optimization (v0.4.0+). Achieves 2-5x speedup through successive fits and compilation reuse.

rheojax.utils.prony.compute_r_squared(y_true, y_pred)[source]

Compute R² coefficient of determination.

R² = 1 - SS_res / SS_tot where SS_res = Σ(y_true - y_pred)², SS_tot = Σ(y_true - mean(y_true))²

R² ∈ (-∞, 1], with R²=1 being perfect fit.

Parameters:
Return type:

float

Returns:

R² coefficient (1.0 = perfect fit, 0.0 = mean baseline, <0 = worse than mean)

Example

>>> y_true = np.array([1., 2., 3., 4., 5.])
>>> y_pred = np.array([1.1, 2.0, 2.9, 4.1, 5.0])
>>> r2 = compute_r_squared(y_true, y_pred)
>>> print(f"{r2:.4f}")
0.9960

Computes R² coefficient of determination for model goodness-of-fit.

rheojax.utils.prony.softmax_penalty(E_i, scale=1.0)[source]

Compute softmax penalty for negative moduli in Step 1 fitting.

This differentiable penalty encourages positive Eᵢ values during unconstrained optimization. It approaches zero when all Eᵢ >> 0, and increases smoothly for negative values.

Penalty = scale × Σᵢ log(1 + exp(-Eᵢ/scale))

Parameters:
  • E_i (ArrayLike) – Array of mode strengths (Pa)

  • scale (float) – Smoothness parameter (default 1.0). Larger values give smoother penalty but weaker enforcement.

Returns:

Penalty value (≥ 0, differentiable, JAX array or scalar)

Note

Returns JAX array for gradient compatibility. Do not convert to Python float() when used in JAX-traced functions.

Example

>>> E_i = np.array([1e5, 1e4, -1e3])  # One negative mode
>>> penalty = softmax_penalty(E_i, scale=1e3)
>>> print(f"{penalty:.2f}")
693.15  # Penalty for negative value
>>> E_i_pos = np.array([1e5, 1e4, 1e3])  # All positive
>>> penalty_pos = softmax_penalty(E_i_pos, scale=1e3)
>>> print(f"{penalty_pos:.2e}")
3.13e+02  # Small penalty for finite positive values

Softmax penalty for physical constraints in NLSQ optimization.

Mathematical Background

The Prony series represents multi-mode relaxation:

\[E(t) = E_\infty + \sum_{i=1}^{N} E_i \exp(-t/\tau_i)\]

Parameters (2N+1 total):

  • \(E_\infty\) (or \(G_\infty\)): Equilibrium modulus

  • \(E_i\) (or \(G_i\)): Mode i strength

  • \(\tau_i\): Mode i relaxation time

Element Minimization Algorithm (v0.4.0):

  1. Fit N-mode model with NLSQ optimization

  2. Compute R² for current N

  3. Initialize (N-1)-mode fit from optimal N-mode parameters (warm-start)

  4. Continue until R² degrades below threshold × R²_max

  5. Return optimal N with minimal elements

Performance Optimization:

  • Warm-Start: Each N initialized from N+1 parameters (avoids cold-start overhead)

  • Compilation Reuse: Cached residual functions across iterations

  • Early Termination: Stops when R² < threshold (default: 1 - 1.5×(1 - R²_max))

  • Speedup: 2-5x measured improvement (20-50s → 4-25s for N=10 search)

Examples

Create Prony Parameter Set

from rheojax.utils.prony import create_prony_parameter_set

# Create 3-mode Prony series (shear modulus)
params = create_prony_parameter_set(n_modes=3, modulus_type='shear')
print(list(params.keys()))
# ['G_inf', 'G_1', 'G_2', 'G_3', 'tau_1', 'tau_2', 'tau_3']

# Create 5-mode Prony series (tensile modulus)
params_tensile = create_prony_parameter_set(n_modes=5, modulus_type='tensile')
print(list(params_tensile.keys()))
# ['E_inf', 'E_1', ..., 'E_5', 'tau_1', ..., 'tau_5']

Element Minimization with Warm-Start

from rheojax.models.generalized_maxwell import GeneralizedMaxwell
import numpy as np

# Create model with maximum modes
model = GeneralizedMaxwell(n_modes=10, modulus_type='shear')

# Generate relaxation data
t = np.logspace(-3, 2, 100)
G_data = ...  # Experimental relaxation modulus

# Fit with automatic element minimization (v0.4.0+ warm-start)
model.fit(
    t, G_data,
    test_mode='relaxation',
    optimization_factor=1.5  # R² threshold multiplier
)

# Check optimal number of modes (auto-reduced from 10)
print(f"Optimal modes: {model._n_modes}")  # e.g., 3

# Element minimization uses warm-start for 2-5x speedup:
# - Fit N=10: 5.2s
# - Fit N=9: 0.8s (warm-started from N=10)
# - Fit N=8: 0.7s (warm-started from N=9)
# - ...
# Total: ~10s vs ~50s cold-start

See Also

Optimization

Optimization utilities for parameter fitting using NLSQ.

This module provides GPU-accelerated optimization using the NLSQ package (https://github.com/imewei/NLSQ). NLSQ provides 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.

Critical: This module imports NLSQ, which must be imported before JAX to enable float64 precision mode. The rheojax package handles this automatically in __init__.py.

Example

>>> from rheojax.core.parameters import ParameterSet
>>> from rheojax.utils.optimization import nlsq_optimize
>>>
>>> # Set up parameters
>>> params = ParameterSet()
>>> params.add("x", value=1.0, bounds=(0, 10))
>>>
>>> # Define objective function
>>> def objective(values):
...     x = values[0]
...     return (x - 5.0) ** 2
>>>
>>> # Optimize
>>> result = nlsq_optimize(objective, params, use_jax=True)
>>> print(f"Optimal x: {result.x[0]:.4f}")
class rheojax.utils.optimization.ResidualFunction(fn, normalization_weights=None, y_data=None, use_log_residuals=False)[source]

Bases: object

Callable wrapper for residual functions that carries normalization metadata.

This replaces the fragile pattern of attaching _normalization_weights as a function attribute (which breaks if the function is wrapped by decorators, functools.wraps, jax.jit, etc.). The class is fully transparent to callers — it behaves like a plain function but safely exposes the weights.

The _y_data slot also carries the original dependent-variable array so downstream code (nlsq_optimize, scipy/DE fallback paths) can attach it to OptimizationResult and recover correct R²/RMSE/AIC/BIC. Without this, those paths leave y_data=None and the r_squared property silently returns None, masking successful fits as failures.

__init__(fn, normalization_weights=None, y_data=None, use_log_residuals=False)[source]
rheojax.utils.optimization.make_fd_differentiable(fn, eps=1e-07)[source]

Wrap a function with a finite-difference custom JVP.

This enables jax.jacfwd (forward-mode AD) for functions that cannot be traced by JAX’s autodiff — e.g. diffrax ODE solvers which use custom_vjp and are therefore incompatible with jacfwd.

The wrapper computes JVPs via central differences: (f(x+εv) - f(x-εv)) / . When combined with jax.jacfwd, this effectively computes the full Jacobian via vmap’d perturbations in a single batched XLA call — much faster than scipy’s sequential finite differences.

Parameters:
  • fn (Callable) – Function (x_data, params) -> predictions. Only the params argument (index 1) is differentiated; x_data passes through.

  • eps (float) – Perturbation size for central differences.

Return type:

Callable

Returns:

A function with identical signature but a custom JVP rule for params.

Example:

# Before: NLSQ fails with TypeError on ODE models
objective = create_least_squares_objective(model_fn, x, y)

# After: finite-difference JVP makes it NLSQ-compatible
objective = create_least_squares_objective(
    make_fd_differentiable(model_fn), x, y
)
rheojax.utils.optimization.nlsq_optimize(objective, parameters, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, workflow='auto', auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, compute_covariance=True, **kwargs)[source]

Optimize objective function using NLSQ (GPU-accelerated).

This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.

The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares).

Parameters:
  • objective (Callable[[ndarray], float | ndarray]) – Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation.

  • parameters (ParameterSet) – ParameterSet with initial values and bounds

  • method (str) –

    Optimization method. Options:

    • ”auto”: Automatically select based on bounds (default)

    • ”trf”: Trust Region Reflective (supports bounds)

    • ”lm”: Levenberg-Marquardt (no bounds)

    • ”scipy”: Use SciPy’s least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ’s forward-mode autodiff.

    NLSQ internally selects the best algorithm regardless of this parameter.

  • use_jax (bool) – Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision.

  • max_iter (int) – Maximum number of iterations (default: 1000)

  • ftol (float) – Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • xtol (float) – Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • gtol (float) – Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • workflow (str) –

    NLSQ 0.6.6 workflow selection (default: “auto”):

    • ”auto”: Memory-aware local optimization (default)

    • ”auto_global”: Global optimization with bounds exploration

    • ”hpc”: HPC mode with checkpointing support

  • auto_bounds (bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.

  • stability (str | bool) –

    Numerical stability checks (default: False):

    • ’auto’: Check and automatically fix stability issues

    • ’check’: Check and warn but don’t fix

    • False: Skip stability checks

  • fallback (bool) – Enable NLSQ’s native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this.

  • compute_diagnostics (bool) – Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information.

  • compute_covariance (bool) – Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit.

  • **kwargs – Additional arguments passed to nlsq.LeastSquares.least_squares

Return type:

OptimizationResult

Returns:

OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True).

Raises:

ValueError – If objective is not callable or parameters is not ParameterSet

Example

>>> from rheojax.core.parameters import ParameterSet
>>> params = ParameterSet()
>>> params.add("a", value=1.0, bounds=(0, 10))
>>> params.add("b", value=1.0, bounds=(0, 10))
>>>
>>> def objective(values):
...     a, b = values
...     return (a - 5.0) ** 2 + (b - 3.0) ** 2
>>>
>>> result = nlsq_optimize(objective, params)
>>> print(result.x)  # Should be close to [5.0, 3.0]
>>>
>>> # With NLSQ 0.6.6 features
>>> result = nlsq_optimize(
...     objective, params,
...     workflow="auto_global",  # Global optimization
...     stability="auto",        # Auto-fix stability issues
...     compute_diagnostics=True # Get diagnostics
... )
>>> print(result.diagnostics)

Notes

  • This function automatically handles float64 precision through NLSQ

  • JAX JIT compilation provides 5-270x speedup over scipy

  • Automatic differentiation eliminates need for manual Jacobian

  • Bounds are automatically extracted from ParameterSet

  • Parameters are updated in-place with optimal values

rheojax.utils.optimization.nlsq_multistart_optimize(objective, parameters, n_starts=5, perturb_factor=0.3, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, verbose=False, parallel=True, n_workers=None, y_data=None, **kwargs)[source]

Multi-start optimization to escape local minima.

For complex objective functions (e.g., mastercurves with 10+ decades), single optimization runs may converge to poor local minima even from good initial guesses. This function performs multiple optimization runs from different starting points and returns the best result.

Strategy:
  1. First attempt: Use current parameter values (from smart initialization)

  2. Additional attempts: Random perturbations around initial values (parallel)

  3. Return result with lowest final cost (best fit)

Performance: With parallel=True (default), achieves 2-4x speedup for 5-10 starts by running optimizations concurrently. JAX releases the GIL during computation, enabling effective thread-based parallelism.

Parameters:
  • objective (Callable[[ndarray], float | ndarray]) – Objective function to minimize

  • parameters (ParameterSet) – ParameterSet with initial values and bounds

  • n_starts (int) – Number of random starts (default: 5)

  • perturb_factor (float) – Perturbation factor for random starts (default: 0.3) Parameters are perturbed by ± perturb_factor * (value or range)

  • method (str) – Optimization method (default: “auto”)

  • use_jax (bool) – Whether to use JAX (default: True)

  • max_iter (int) – Max iterations per start (default: 1000)

  • ftol (float) – Function tolerance (default: 1e-6)

  • xtol (float) – Parameter tolerance (default: 1e-6)

  • gtol (float) – Gradient tolerance (default: 1e-6)

  • verbose (bool) – Print progress messages (default: False)

  • parallel (bool) – Run additional starts in parallel (default: True)

  • n_workers (int | None) – Number of parallel workers (default: min(n_starts-1, 4))

  • **kwargs – Additional arguments for nlsq_optimize

Return type:

OptimizationResult

Returns:

OptimizationResult with best parameters from all starts

Example

>>> # For mastercurve data (12+ decades)
>>> result = nlsq_multistart_optimize(
...     objective, parameters, n_starts=5, verbose=True
... )
>>> print(f"Best cost: {result.fun:.3e}")
rheojax.utils.optimization.nlsq_curve_fit(model_fn, x_data, y_data, parameters, auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, multistart=False, n_starts=10, workflow='auto', **kwargs)[source]

Curve fitting using NLSQ 0.6.6 curve_fit() API with advanced features.

This function provides access to NLSQ 0.6.6’s enhanced curve_fit() features including auto-bounds, stability checks, fallback strategies, model diagnostics, and workflow selection. It returns an OptimizationResult with CurveFitResult-compatible statistical properties (r_squared, rmse, aic, bic, prediction_interval, etc.).

Parameters:
  • model_fn (Callable[[ndarray, ndarray], ndarray]) – Model function f(x, params_array) -> y_pred. Takes x_data and parameter array, returns predictions.

  • x_data (ndarray) – Independent variable data

  • y_data (ndarray) – Dependent variable data (observations)

  • parameters (ParameterSet) – ParameterSet with initial values and bounds

  • auto_bounds (bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.

  • stability (str | bool) – Numerical stability checks (default: False). - ‘auto’: Check and automatically fix stability issues - ‘check’: Check and warn but don’t fix - False: Skip stability checks

  • fallback (bool) – Enable automatic fallback strategies (default: False). When True, tries alternative approaches if optimization fails.

  • compute_diagnostics (bool) – Compute model health diagnostics (default: False). When True, result includes identifiability analysis, gradient health, etc.

  • multistart (bool) – Enable multi-start optimization (default: False). When True, explores multiple starting points to find global optimum.

  • n_starts (int) – Number of starting points for multi-start (default: 10).

  • workflow (str) –

    NLSQ 0.6.6 workflow selection (default: “auto”):

    • ”auto”: Memory-aware local optimization (default)

    • ”auto_global”: Global optimization with bounds exploration

    • ”hpc”: HPC mode with checkpointing support

  • **kwargs – Additional arguments passed to nlsq.curve_fit()

Returns:

  • r_squared, adj_r_squared, rmse, mae, aic, bic

  • confidence_intervals(alpha) method

  • prediction_interval(x_new, alpha) method (NLSQ 0.6.6 native)

  • get_parameter_uncertainties() method

Return type:

OptimizationResult

Example

>>> from rheojax.core.parameters import ParameterSet
>>> from rheojax.utils.optimization import nlsq_curve_fit
>>>
>>> def model(x, params):
...     a, b = params
...     return a * np.exp(-b * x)
>>>
>>> params = ParameterSet()
>>> params.add("a", value=1.0, bounds=(0, 10))
>>> params.add("b", value=0.5, bounds=(0, 5))
>>>
>>> result = nlsq_curve_fit(
...     model, x_data, y_data, params,
...     auto_bounds=True,
...     stability='auto',
...     fallback=True,
...     compute_diagnostics=True,
... )
>>> print(f"R² = {result.r_squared:.4f}")
>>> print(f"RMSE = {result.rmse:.4f}")
>>> ci = result.confidence_intervals(alpha=0.95)
>>>
>>> # Prediction intervals (NLSQ 0.6.6)
>>> pi = result.prediction_interval(x_new, alpha=0.95)
>>> print(f"95% PI: [{pi[0, 0]:.3f}, {pi[0, 1]:.3f}]")

Notes

  • This function uses nlsq.curve_fit() directly (not LeastSquares.least_squares())

  • The model function signature is f(x, params_array) not f(x, *params)

  • Results delegate to native CurveFitResult for prediction_interval() calls

  • Results include all CurveFitResult properties for model comparison

rheojax.utils.optimization.optimize_with_bounds(objective, x0, bounds, use_jax=True, **kwargs)[source]

Optimize objective function with parameter bounds.

Lower-level optimization function that works with arrays instead of ParameterSet. Useful for custom optimization workflows.

Parameters:
  • objective (Callable[[ndarray], float | ndarray]) – Objective function to minimize

  • x0 (ndarray) – Initial parameter values

  • bounds (list[tuple[float | None, float | None]]) – List of (min, max) tuples for each parameter

  • use_jax (bool) – Whether to use JAX for gradients (default: True)

  • **kwargs – Additional arguments passed to nlsq_optimize

Return type:

OptimizationResult

Returns:

OptimizationResult with optimal parameters

Example

>>> def objective(x):
...     return x[0]**2 + x[1]**2
>>> result = optimize_with_bounds(
...     objective,
...     x0=np.array([1.0, 1.0]),
...     bounds=[(0, 5), (0, 5)]
... )
rheojax.utils.optimization.residual_sum_of_squares(y_true, y_pred, normalize=True)[source]

Compute residual sum of squares (RSS).

Handles both real and complex data correctly. For complex data (e.g., oscillatory shear with G’ + iG”), computes RSS for both real and imaginary parts separately and returns the sum.

Parameters:
  • y_true (numpy.typing.ArrayLike) – True values (real or complex)

  • y_pred (numpy.typing.ArrayLike) – Predicted values (real or complex)

  • normalize (bool) – Whether to normalize by y_true (relative error)

Return type:

float

Returns:

RSS value (scalar, maintains float64 precision)

Example

>>> y_true = np.array([1.0, 2.0, 3.0])
>>> y_pred = np.array([1.1, 2.1, 2.9])
>>> rss = residual_sum_of_squares(y_true, y_pred)
rheojax.utils.optimization.create_least_squares_objective(model_fn, x_data, y_data, normalize=True, use_log_residuals=False)[source]

Create residual function for NLSQ least-squares fitting.

IMPORTANT: This now returns a RESIDUAL FUNCTION (vector output), not a scalar objective. NLSQ minimizes sum(residuals²), so this provides per-point residuals to the optimizer, which enables proper gradient computation and weighting.

For complex data (e.g., G* = G’ + iG”), returns stacked real and imaginary residuals: [real₁, …, real_n, imag₁, …, imag_n] with shape (2N,).

For real data, returns residuals with shape (N,).

Log-space residuals (NEW): For rheological data spanning many decades (e.g., mastercurves with 8+ decades), use use_log_residuals=True to compute residuals in log10 space. This gives equal weight to all frequency ranges and prevents optimizer bias toward high-modulus regions.

Parameters:
  • model_fn (Callable[[ndarray, ndarray], ndarray]) – Model function that takes (x_data, parameters) and returns predictions

  • x_data (ndarray) – Independent variable data

  • y_data (ndarray) – Dependent variable data (observations, may be complex)

  • normalize (bool) – Whether to use relative error (default: True)

  • use_log_residuals (bool) – Whether to compute residuals in log10 space (default: False). Recommended for data spanning >8 decades. Formula: residual = log10(abs(y_pred)) - log10(abs(y_data))

Return type:

ResidualFunction

Returns:

Residual function that takes parameters and returns residual vector

Example

>>> def linear_model(x, params):
...     a, b = params
...     return a * x + b
>>> x = np.array([1, 2, 3, 4, 5])
>>> y = np.array([2.1, 4.0, 5.9, 8.1, 10.0])
>>> residual_fn = create_least_squares_objective(linear_model, x, y)
>>> # Now use with nlsq_optimize - it receives proper residual vector
>>>
>>> # For mastercurve data (wide frequency range):
>>> residual_fn_log = create_least_squares_objective(
...     model_fn, omega, G_star, use_log_residuals=True
... )
rheojax.utils.optimization.optimize(objective, parameters, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, workflow='auto', auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, compute_covariance=True, **kwargs)

Optimize objective function using NLSQ (GPU-accelerated).

This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.

The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares).

Parameters:
  • objective (Callable[[ndarray], float | ndarray]) – Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation.

  • parameters (ParameterSet) – ParameterSet with initial values and bounds

  • method (str) –

    Optimization method. Options:

    • ”auto”: Automatically select based on bounds (default)

    • ”trf”: Trust Region Reflective (supports bounds)

    • ”lm”: Levenberg-Marquardt (no bounds)

    • ”scipy”: Use SciPy’s least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ’s forward-mode autodiff.

    NLSQ internally selects the best algorithm regardless of this parameter.

  • use_jax (bool) – Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision.

  • max_iter (int) – Maximum number of iterations (default: 1000)

  • ftol (float) – Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • xtol (float) – Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • gtol (float) – Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • workflow (str) –

    NLSQ 0.6.6 workflow selection (default: “auto”):

    • ”auto”: Memory-aware local optimization (default)

    • ”auto_global”: Global optimization with bounds exploration

    • ”hpc”: HPC mode with checkpointing support

  • auto_bounds (bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.

  • stability (str | bool) –

    Numerical stability checks (default: False):

    • ’auto’: Check and automatically fix stability issues

    • ’check’: Check and warn but don’t fix

    • False: Skip stability checks

  • fallback (bool) – Enable NLSQ’s native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this.

  • compute_diagnostics (bool) – Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information.

  • compute_covariance (bool) – Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit.

  • **kwargs – Additional arguments passed to nlsq.LeastSquares.least_squares

Return type:

OptimizationResult

Returns:

OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True).

Raises:

ValueError – If objective is not callable or parameters is not ParameterSet

Example

>>> from rheojax.core.parameters import ParameterSet
>>> params = ParameterSet()
>>> params.add("a", value=1.0, bounds=(0, 10))
>>> params.add("b", value=1.0, bounds=(0, 10))
>>>
>>> def objective(values):
...     a, b = values
...     return (a - 5.0) ** 2 + (b - 3.0) ** 2
>>>
>>> result = nlsq_optimize(objective, params)
>>> print(result.x)  # Should be close to [5.0, 3.0]
>>>
>>> # With NLSQ 0.6.6 features
>>> result = nlsq_optimize(
...     objective, params,
...     workflow="auto_global",  # Global optimization
...     stability="auto",        # Auto-fix stability issues
...     compute_diagnostics=True # Get diagnostics
... )
>>> print(result.diagnostics)

Notes

  • This function automatically handles float64 precision through NLSQ

  • JAX JIT compilation provides 5-270x speedup over scipy

  • Automatic differentiation eliminates need for manual Jacobian

  • Bounds are automatically extracted from ParameterSet

  • Parameters are updated in-place with optimal values

rheojax.utils.optimization.fit_parameters(objective, parameters, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, workflow='auto', auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, compute_covariance=True, **kwargs)

Optimize objective function using NLSQ (GPU-accelerated).

This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.

The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares).

Parameters:
  • objective (Callable[[ndarray], float | ndarray]) – Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation.

  • parameters (ParameterSet) – ParameterSet with initial values and bounds

  • method (str) –

    Optimization method. Options:

    • ”auto”: Automatically select based on bounds (default)

    • ”trf”: Trust Region Reflective (supports bounds)

    • ”lm”: Levenberg-Marquardt (no bounds)

    • ”scipy”: Use SciPy’s least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ’s forward-mode autodiff.

    NLSQ internally selects the best algorithm regardless of this parameter.

  • use_jax (bool) – Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision.

  • max_iter (int) – Maximum number of iterations (default: 1000)

  • ftol (float) – Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • xtol (float) – Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • gtol (float) – Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • workflow (str) –

    NLSQ 0.6.6 workflow selection (default: “auto”):

    • ”auto”: Memory-aware local optimization (default)

    • ”auto_global”: Global optimization with bounds exploration

    • ”hpc”: HPC mode with checkpointing support

  • auto_bounds (bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.

  • stability (str | bool) –

    Numerical stability checks (default: False):

    • ’auto’: Check and automatically fix stability issues

    • ’check’: Check and warn but don’t fix

    • False: Skip stability checks

  • fallback (bool) – Enable NLSQ’s native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this.

  • compute_diagnostics (bool) – Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information.

  • compute_covariance (bool) – Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit.

  • **kwargs – Additional arguments passed to nlsq.LeastSquares.least_squares

Return type:

OptimizationResult

Returns:

OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True).

Raises:

ValueError – If objective is not callable or parameters is not ParameterSet

Example

>>> from rheojax.core.parameters import ParameterSet
>>> params = ParameterSet()
>>> params.add("a", value=1.0, bounds=(0, 10))
>>> params.add("b", value=1.0, bounds=(0, 10))
>>>
>>> def objective(values):
...     a, b = values
...     return (a - 5.0) ** 2 + (b - 3.0) ** 2
>>>
>>> result = nlsq_optimize(objective, params)
>>> print(result.x)  # Should be close to [5.0, 3.0]
>>>
>>> # With NLSQ 0.6.6 features
>>> result = nlsq_optimize(
...     objective, params,
...     workflow="auto_global",  # Global optimization
...     stability="auto",        # Auto-fix stability issues
...     compute_diagnostics=True # Get diagnostics
... )
>>> print(result.diagnostics)

Notes

  • This function automatically handles float64 precision through NLSQ

  • JAX JIT compilation provides 5-270x speedup over scipy

  • Automatic differentiation eliminates need for manual Jacobian

  • Bounds are automatically extracted from ParameterSet

  • Parameters are updated in-place with optimal values

Functions and Classes

class rheojax.utils.optimization.OptimizationResult(x, fun, jac=None, pcov=None, success=False, message='', nit=0, nfev=0, njev=0, optimality=None, active_mask=None, cost=None, grad=None, nlsq_result=None, residuals=None, y_data=None, n_data=None, diagnostics=None, _curve_fit_result=None, _model_fn=None, _x_data=None, _is_complex_split=False, _normalization_weights=None, _use_log_residuals=False)[source]

Bases: object

Result from optimization with NLSQ 0.6.6 CurveFitResult-compatible properties.

This dataclass stores the results of NLSQ optimization, including optimal parameter values, objective function value, convergence information, and statistical metrics compatible with NLSQ 0.6.6’s CurveFitResult.

x

Optimal parameter values (float64 array)

fun

Objective function value at optimum (RSS = sum of squared residuals)

jac

Jacobian (gradient) at optimum

pcov

Parameter covariance matrix (n_params x n_params)

success

Whether optimization converged successfully

message

Status message from optimizer

nit

Number of iterations

nfev

Number of function evaluations

njev

Number of Jacobian evaluations

optimality

Optimality metric (gradient norm)

active_mask

Active bound constraints at solution

cost

Final cost value

grad

Final gradient

nlsq_result

Full NLSQ result dictionary (for advanced diagnostics)

residuals

Residual vector (y_data - y_pred) for statistical metrics

y_data

Original dependent variable data (for R² computation)

n_data

Number of data points (for AIC/BIC computation)

diagnostics

Model health diagnostics (NLSQ 0.6.6, when compute_diagnostics=True)

Statistical Properties (NLSQ 0.6.6 CurveFitResult compatible):

r_squared: Coefficient of determination (R²) adj_r_squared: Adjusted R² accounting for number of parameters rmse: Root mean squared error mae: Mean absolute error aic: Akaike Information Criterion bic: Bayesian Information Criterion

confidence_intervals(alpha)[source]

Compute parameter confidence intervals

prediction_interval(x_new, alpha)[source]

Compute prediction intervals (NLSQ 0.6.6)

get_parameter_uncertainties()[source]

Get standard errors from covariance diagonal

Result container for optimization.

pcov: ndarray | None = None
residuals: ndarray | None = None
y_data: ndarray | None = None
n_data: int | None = None
diagnostics: dict[str, Any] | None = None
property r_squared: float | None

Coefficient of determination (R²).

Measures goodness of fit. Range: (-∞, 1], where 1 is perfect fit.

R² = 1 - SS_res / SS_tot

where SS_res = sum((y - y_pred)²) and SS_tot = sum((y - y_mean)²)

Returns:

R² value, or None if residuals/y_data not available

property adj_r_squared: float | None

Adjusted R² accounting for number of parameters.

Adj R² = 1 - (1 - R²) * (n - 1) / (n - p - 1)

where n is number of data points and p is number of parameters.

Returns:

Adjusted R² value, or None if cannot be computed

property rmse: float | None

Root mean squared error.

RMSE = sqrt(mean(residuals²))

Returns:

RMSE value, or None if residuals not available

property mae: float | None

Mean absolute error.

MAE = mean(abs(residuals))

More robust to outliers than RMSE.

Returns:

MAE value, or None if residuals not available

property aic: float | None

Akaike Information Criterion.

AIC = 2k + n*ln(RSS/n)

where k is number of parameters, n is number of data points, and RSS is residual sum of squares.

Lower is better. Used for model selection.

Returns:

AIC value, or None if cannot be computed

property bic: float | None

Bayesian Information Criterion.

BIC = k*ln(n) + n*ln(RSS/n)

where k is number of parameters, n is number of data points, and RSS is residual sum of squares.

Lower is better. Penalizes model complexity more than AIC.

Returns:

BIC value, or None if cannot be computed

confidence_intervals(alpha=0.95)[source]

Compute parameter confidence intervals.

Parameters:

alpha (float) – Confidence level (default: 0.95 for 95% CI).

Returns:

intervals – Array of shape (n_params, 2) with [lower, upper] bounds for each parameter, or None if covariance not available.

Return type:

ndarray | None

Examples

>>> result = nlsq_optimize(objective, params)
>>> ci = result.confidence_intervals(alpha=0.95)
>>> if ci is not None:
...     for i, (lower, upper) in enumerate(ci):
...         print(f"Parameter {i}: [{lower:.3f}, {upper:.3f}]")
get_parameter_uncertainties()[source]

Get standard errors for parameters from covariance diagonal.

Returns:

uncertainties – Standard errors for each parameter, or None if covariance not available.

Return type:

ndarray | None

Examples

>>> result = nlsq_optimize(objective, params)
>>> std_errs = result.get_parameter_uncertainties()
>>> if std_errs is not None:
...     for i, se in enumerate(std_errs):
...         print(f"Parameter {i}: {result.x[i]:.4f} ± {se:.4f}")
prediction_interval(x_new=None, alpha=0.95)[source]

Compute prediction intervals for new x values.

Prediction intervals account for both parameter uncertainty and observation noise, providing bounds where future observations are expected to fall with the specified probability.

Parameters:
  • x_new (ndarray | None) – New x values for prediction. If None, uses original x_data.

  • alpha (float) – Confidence level for intervals (default: 0.95 for 95% PI).

Returns:

intervals – Array of shape (n_points, 2) with [lower, upper] bounds for each point, or None if prediction intervals cannot be computed.

Return type:

ndarray | None

Notes

When a native NLSQ CurveFitResult is available (from nlsq_curve_fit), this method delegates to NLSQ’s prediction_interval for accuracy. Otherwise, it falls back to a manual computation using covariance propagation.

Examples

>>> result = nlsq_curve_fit(model, x_data, y_data, params)
>>> pi = result.prediction_interval(x_new, alpha=0.95)
>>> if pi is not None:
...     for i, (lower, upper) in enumerate(pi):
...         print(f"x={x_new[i]:.2f}: [{lower:.3f}, {upper:.3f}]")
classmethod from_curve_fit_result(curve_fit_result, y_data=None, model_fn=None, x_data=None)[source]

Create OptimizationResult from NLSQ 0.6.6 CurveFitResult.

This factory method preserves the native CurveFitResult for property delegation, enabling access to NLSQ 0.6.6’s statistical methods like prediction_interval() without reimplementation.

Parameters:
  • curve_fit_result (Any) – Result from nlsq.curve_fit() call.

  • y_data (ndarray | None) – Original dependent variable data (for complex data handling).

  • model_fn (Callable | None) – Model function f(x, params) for prediction intervals.

  • x_data (ndarray | None) – Original independent variable data for prediction intervals.

Returns:

result – Result with native delegation to CurveFitResult properties.

Return type:

OptimizationResult

Examples

>>> curve_fit_result = nlsq.curve_fit(model_fn, x, y, p0=p0)
>>> result = OptimizationResult.from_curve_fit_result(
...     curve_fit_result, y_data=y, model_fn=model_fn, x_data=x
... )
>>> print(result.r_squared)  # Delegates to native
>>> pi = result.prediction_interval(x_new)  # Delegates to native
classmethod from_nlsq(nlsq_result, residuals=None, y_data=None, compute_covariance=True)[source]

Create OptimizationResult from NLSQ result dictionary.

Parameters:
  • nlsq_result (dict[str, Any]) – Result dictionary from nlsq.LeastSquares.least_squares

  • residuals (ndarray | None) – Optional residual vector for covariance scaling and metrics

  • y_data (ndarray | None) – Optional original y data for R² computation

  • compute_covariance (bool) – Whether to compute the covariance matrix via SVD (default: True). Set False to skip SVD when CIs are not needed.

Return type:

OptimizationResult

Returns:

OptimizationResult instance with fields extracted from NLSQ result

__init__(x, fun, jac=None, pcov=None, success=False, message='', nit=0, nfev=0, njev=0, optimality=None, active_mask=None, cost=None, grad=None, nlsq_result=None, residuals=None, y_data=None, n_data=None, diagnostics=None, _curve_fit_result=None, _model_fn=None, _x_data=None, _is_complex_split=False, _normalization_weights=None, _use_log_residuals=False)
rheojax.utils.optimization.nlsq_optimize(objective, parameters, method='auto', use_jax=True, max_iter=1000, ftol=1e-06, xtol=1e-06, gtol=1e-06, workflow='auto', auto_bounds=False, stability=False, fallback=False, compute_diagnostics=False, compute_covariance=True, **kwargs)[source]

Optimize objective function using NLSQ (GPU-accelerated).

This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation.

The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares).

Parameters:
  • objective (Callable[[ndarray], float | ndarray]) – Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation.

  • parameters (ParameterSet) – ParameterSet with initial values and bounds

  • method (str) –

    Optimization method. Options:

    • ”auto”: Automatically select based on bounds (default)

    • ”trf”: Trust Region Reflective (supports bounds)

    • ”lm”: Levenberg-Marquardt (no bounds)

    • ”scipy”: Use SciPy’s least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ’s forward-mode autodiff.

    NLSQ internally selects the best algorithm regardless of this parameter.

  • use_jax (bool) – Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision.

  • max_iter (int) – Maximum number of iterations (default: 1000)

  • ftol (float) – Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • xtol (float) – Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • gtol (float) – Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ’s mixed precision management.

  • workflow (str) –

    NLSQ 0.6.6 workflow selection (default: “auto”):

    • ”auto”: Memory-aware local optimization (default)

    • ”auto_global”: Global optimization with bounds exploration

    • ”hpc”: HPC mode with checkpointing support

  • auto_bounds (bool) – Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics.

  • stability (str | bool) –

    Numerical stability checks (default: False):

    • ’auto’: Check and automatically fix stability issues

    • ’check’: Check and warn but don’t fix

    • False: Skip stability checks

  • fallback (bool) – Enable NLSQ’s native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this.

  • compute_diagnostics (bool) – Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information.

  • compute_covariance (bool) – Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit.

  • **kwargs – Additional arguments passed to nlsq.LeastSquares.least_squares

Return type:

OptimizationResult

Returns:

OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True).

Raises:

ValueError – If objective is not callable or parameters is not ParameterSet

Example

>>> from rheojax.core.parameters import ParameterSet
>>> params = ParameterSet()
>>> params.add("a", value=1.0, bounds=(0, 10))
>>> params.add("b", value=1.0, bounds=(0, 10))
>>>
>>> def objective(values):
...     a, b = values
...     return (a - 5.0) ** 2 + (b - 3.0) ** 2
>>>
>>> result = nlsq_optimize(objective, params)
>>> print(result.x)  # Should be close to [5.0, 3.0]
>>>
>>> # With NLSQ 0.6.6 features
>>> result = nlsq_optimize(
...     objective, params,
...     workflow="auto_global",  # Global optimization
...     stability="auto",        # Auto-fix stability issues
...     compute_diagnostics=True # Get diagnostics
... )
>>> print(result.diagnostics)

Notes

  • This function automatically handles float64 precision through NLSQ

  • JAX JIT compilation provides 5-270x speedup over scipy

  • Automatic differentiation eliminates need for manual Jacobian

  • Bounds are automatically extracted from ParameterSet

  • Parameters are updated in-place with optimal values

rheojax.utils.optimization.optimize_with_bounds(objective, x0, bounds, use_jax=True, **kwargs)[source]

Optimize objective function with parameter bounds.

Lower-level optimization function that works with arrays instead of ParameterSet. Useful for custom optimization workflows.

Parameters:
  • objective (Callable[[ndarray], float | ndarray]) – Objective function to minimize

  • x0 (ndarray) – Initial parameter values

  • bounds (list[tuple[float | None, float | None]]) – List of (min, max) tuples for each parameter

  • use_jax (bool) – Whether to use JAX for gradients (default: True)

  • **kwargs – Additional arguments passed to nlsq_optimize

Return type:

OptimizationResult

Returns:

OptimizationResult with optimal parameters

Example

>>> def objective(x):
...     return x[0]**2 + x[1]**2
>>> result = optimize_with_bounds(
...     objective,
...     x0=np.array([1.0, 1.0]),
...     bounds=[(0, 5), (0, 5)]
... )
rheojax.utils.optimization.residual_sum_of_squares(y_true, y_pred, normalize=True)[source]

Compute residual sum of squares (RSS).

Handles both real and complex data correctly. For complex data (e.g., oscillatory shear with G’ + iG”), computes RSS for both real and imaginary parts separately and returns the sum.

Parameters:
  • y_true (numpy.typing.ArrayLike) – True values (real or complex)

  • y_pred (numpy.typing.ArrayLike) – Predicted values (real or complex)

  • normalize (bool) – Whether to normalize by y_true (relative error)

Return type:

float

Returns:

RSS value (scalar, maintains float64 precision)

Example

>>> y_true = np.array([1.0, 2.0, 3.0])
>>> y_pred = np.array([1.1, 2.1, 2.9])
>>> rss = residual_sum_of_squares(y_true, y_pred)
rheojax.utils.optimization.create_least_squares_objective(model_fn, x_data, y_data, normalize=True, use_log_residuals=False)[source]

Create residual function for NLSQ least-squares fitting.

IMPORTANT: This now returns a RESIDUAL FUNCTION (vector output), not a scalar objective. NLSQ minimizes sum(residuals²), so this provides per-point residuals to the optimizer, which enables proper gradient computation and weighting.

For complex data (e.g., G* = G’ + iG”), returns stacked real and imaginary residuals: [real₁, …, real_n, imag₁, …, imag_n] with shape (2N,).

For real data, returns residuals with shape (N,).

Log-space residuals (NEW): For rheological data spanning many decades (e.g., mastercurves with 8+ decades), use use_log_residuals=True to compute residuals in log10 space. This gives equal weight to all frequency ranges and prevents optimizer bias toward high-modulus regions.

Parameters:
  • model_fn (Callable[[ndarray, ndarray], ndarray]) – Model function that takes (x_data, parameters) and returns predictions

  • x_data (ndarray) – Independent variable data

  • y_data (ndarray) – Dependent variable data (observations, may be complex)

  • normalize (bool) – Whether to use relative error (default: True)

  • use_log_residuals (bool) – Whether to compute residuals in log10 space (default: False). Recommended for data spanning >8 decades. Formula: residual = log10(abs(y_pred)) - log10(abs(y_data))

Return type:

ResidualFunction

Returns:

Residual function that takes parameters and returns residual vector

Example

>>> def linear_model(x, params):
...     a, b = params
...     return a * x + b
>>> x = np.array([1, 2, 3, 4, 5])
>>> y = np.array([2.1, 4.0, 5.9, 8.1, 10.0])
>>> residual_fn = create_least_squares_objective(linear_model, x, y)
>>> # Now use with nlsq_optimize - it receives proper residual vector
>>>
>>> # For mastercurve data (wide frequency range):
>>> residual_fn_log = create_least_squares_objective(
...     model_fn, omega, G_star, use_log_residuals=True
... )

Aliases

rheojax.utils.optimization.optimize

Alias for nlsq_optimize()

rheojax.utils.optimization.fit_parameters

Alias for nlsq_optimize()

Optimization Methods

The following scipy.optimize methods are supported:

  • “L-BFGS-B”: L-BFGS algorithm with bounds (default for bounded problems)

  • “TNC”: Truncated Newton with bounds

  • “SLSQP”: Sequential Least Squares Programming

  • “trust-constr”: Trust-region constrained optimization

  • “BFGS”: Broyden-Fletcher-Goldfarb-Shanno (default for unbounded)

JAX Gradient Computation

When use_jax=True, gradients are computed using JAX automatic differentiation:

\[\nabla f(x) = \left[\frac{\partial f}{\partial x_1}, \ldots, \frac{\partial f}{\partial x_n}\right]\]

This provides exact gradients (up to floating-point precision) compared to numerical finite-difference approximations, leading to faster and more robust optimization.

Examples

Basic Optimization

from rheojax.core.parameters import ParameterSet
from rheojax.utils.optimization import nlsq_optimize
import jax.numpy as jnp

# Define objective function
def objective(params):
    x, y = params
    return (x - 5.0)**2 + (y - 3.0)**2

# Set up parameters
params = ParameterSet()
params.add("x", value=0.0, bounds=(-10, 10))
params.add("y", value=0.0, bounds=(-10, 10))

# Optimize
result = nlsq_optimize(objective, params, use_jax=True)
print(f"Optimal: x={result.x[0]:.4f}, y={result.x[1]:.4f}")
print(f"Function value: {result.fun:.6f}")
print(f"Success: {result.success}")

Model Fitting

import jax.numpy as jnp
from rheojax.core.parameters import ParameterSet
from rheojax.utils.optimization import nlsq_optimize

# Experimental data
t_exp = jnp.array([0.1, 0.5, 1.0, 2.0, 5.0, 10.0])
stress_exp = jnp.array([1000, 800, 650, 500, 320, 200])

# Maxwell model
def maxwell_model(t, params):
    E, tau = params
    return E * jnp.exp(-t / tau)

# Objective: minimize residuals
def objective(params):
    predictions = maxwell_model(t_exp, params)
    residuals = predictions - stress_exp
    return jnp.sum(residuals**2)

# Set up parameters
params = ParameterSet()
params.add("E", value=1000, bounds=(100, 5000))
params.add("tau", value=1.0, bounds=(0.1, 100))

# Fit model
result = nlsq_optimize(
    objective,
    params,
    use_jax=True,
    method="L-BFGS-B"
)

# Extract fitted parameters
E_fit = params.get_value("E")
tau_fit = params.get_value("tau")
print(f"Fitted: E={E_fit:.1f} Pa, tau={tau_fit:.2f} s")

Custom Objective with Least Squares

from rheojax.utils.optimization import create_least_squares_objective

# Define model function
def power_law(shear_rate, params):
    K, n = params
    return K * shear_rate**n

# Data
shear_rate = jnp.logspace(-2, 2, 50)
viscosity = 100 * shear_rate**(-0.7)

# Create objective
objective = create_least_squares_objective(
    power_law,
    shear_rate,
    viscosity,
    normalize=True  # Use relative error
)

# Set up parameters
params = ParameterSet()
params.add("K", value=100, bounds=(1, 1000))
params.add("n", value=-0.5, bounds=(-2, 0))

# Optimize
result = nlsq_optimize(objective, params, use_jax=True)

Optimization with Constraints

from rheojax.core.parameters import ParameterConstraint

# Add relative constraint: tau1 < tau2
params = ParameterSet()
params.add("tau1", value=1.0, bounds=(0.1, 100))

tau2_constraint = ParameterConstraint(
    type="relative",
    relation="greater_than",
    other_param="tau1"
)
params.add(
    "tau2",
    value=10.0,
    bounds=(0.1, 100),
    constraints=[tau2_constraint]
)

Monitoring Optimization

from rheojax.core.parameters import ParameterOptimizer

# Create optimizer with tracking
optimizer = ParameterOptimizer(
    parameters=params,
    use_jax=True,
    track_history=True
)

optimizer.set_objective(objective)

# Define callback
def callback(iteration, values, obj_value):
    if iteration % 10 == 0:
        print(f"Iteration {iteration}: f={obj_value:.6f}")

optimizer.set_callback(callback)

# Run optimization (integrate with scipy)
# result = optimizer.optimize()

# Get history
history = optimizer.get_history()
for entry in history:
    print(f"Iter {entry['iteration']}: {entry['objective']}")

Performance Tips

  1. Use JAX gradients: Set use_jax=True for faster optimization

  2. Choose appropriate method: L-BFGS-B for bounds, BFGS for unbounded

  3. Scale parameters: Normalize to similar magnitudes (0.1-10 range)

  4. Provide good initial guess: Closer to optimum = faster convergence

  5. Set tolerances: Adjust ftol, xtol, gtol for speed vs accuracy

# Good parameter scaling
params.add("E", value=1.0, bounds=(0.1, 10))  # Scaled from Pa
params.add("tau", value=1.0, bounds=(0.1, 10))  # Scaled from s

# In objective, unscale:
def objective(params_scaled):
    E = params_scaled[0] * 1000  # Convert back to Pa
    tau = params_scaled[1]        # Already in seconds
    # ... compute objective

See Also

Model-Data Compatibility

Model-data compatibility checking for rheological models.

This module provides functions to assess whether a given model is appropriate for a dataset based on the underlying physics and data characteristics.

The compatibility checker helps users understand when model failures are expected due to physics mismatch rather than optimization issues.

class rheojax.utils.compatibility.DecayType(*values)[source]

Bases: Enum

Types of relaxation decay behavior.

EXPONENTIAL = 'exponential'
POWER_LAW = 'power_law'
STRETCHED = 'stretched'
MITTAG_LEFFLER = 'mittag_leffler'
MULTI_MODE = 'multi_mode'
UNKNOWN = 'unknown'
class rheojax.utils.compatibility.MaterialType(*values)[source]

Bases: Enum

Types of material behavior.

SOLID = 'solid'
LIQUID = 'liquid'
GEL = 'gel'
VISCOELASTIC_SOLID = 'viscoelastic_solid'
VISCOELASTIC_LIQUID = 'viscoelastic_liquid'
UNKNOWN = 'unknown'
rheojax.utils.compatibility.detect_decay_type(t, G_t)[source]

Detect the type of relaxation decay from time-domain data.

Analyzes the decay pattern to determine if it follows exponential, power-law, stretched exponential, or Mittag-Leffler behavior.

Parameters:
  • t (ndarray) – Time array (s)

  • G_t (ndarray) – Relaxation modulus array (Pa)

Returns:

Detected decay type

Return type:

DecayType

rheojax.utils.compatibility.detect_material_type(t=None, G_t=None, omega=None, G_star=None)[source]

Detect the material type from relaxation or oscillation data.

Parameters:
  • t (ndarray | None) – Time array for relaxation data (s)

  • G_t (ndarray | None) – Relaxation modulus (Pa)

  • omega (ndarray | None) – Frequency array for oscillation data (rad/s)

  • G_star (ndarray | None) – Complex modulus array with shape (N, 2) where [:, 0] is G’ and [:, 1] is G”

Returns:

Detected material type

Return type:

MaterialType

rheojax.utils.compatibility.check_model_compatibility(model, t=None, G_t=None, omega=None, G_star=None, test_mode=None)[source]

Check if a model is compatible with the given data.

This function analyzes the data characteristics and compares them with the model’s underlying physics to assess compatibility.

Parameters:
  • model (BaseModel) – The rheological model to check

  • t (ndarray | None) – Time array for relaxation data (s)

  • G_t (ndarray | None) – Relaxation modulus (Pa)

  • omega (ndarray | None) – Frequency array for oscillation data (rad/s)

  • G_star (ndarray | None) – Complex modulus array with shape (N, 2)

  • test_mode (str | None) – Test mode (‘relaxation’, ‘creep’, ‘oscillation’)

Returns:

Dictionary with compatibility information: - ‘compatible’: bool, whether model is likely compatible - ‘confidence’: float, confidence level (0-1) - ‘decay_type’: DecayType, detected decay pattern - ‘material_type’: MaterialType, detected material behavior - ‘warnings’: list[str], compatibility warnings - ‘recommendations’: list[str], suggested alternative models

Return type:

dict[str, Any]

rheojax.utils.compatibility.format_compatibility_message(compatibility)[source]

Format compatibility check results as a user-friendly message.

Parameters:

compatibility (dict) – Compatibility check results from check_model_compatibility()

Returns:

Formatted message

Return type:

str

The compatibility module provides intelligent detection of when rheological models are inappropriate for experimental data based on underlying physics. This helps users understand when model failures are expected due to physics mismatch rather than optimization issues.

Enums

class rheojax.utils.compatibility.DecayType(*values)[source]

Bases: Enum

Types of relaxation decay behavior.

Types of relaxation decay behavior:

  • EXPONENTIAL: Simple Maxwell-like exp(-t/tau)

  • POWER_LAW: Power-law t^(-alpha) (gel-like)

  • STRETCHED: Stretched exponential exp(-(t/tau)^beta)

  • MITTAG_LEFFLER: Mittag-Leffler E_alpha(-(t/tau)^alpha) (fractional)

  • MULTI_MODE: Multiple relaxation modes

  • UNKNOWN: Cannot determine

EXPONENTIAL = 'exponential'
POWER_LAW = 'power_law'
STRETCHED = 'stretched'
MITTAG_LEFFLER = 'mittag_leffler'
MULTI_MODE = 'multi_mode'
UNKNOWN = 'unknown'
class rheojax.utils.compatibility.MaterialType(*values)[source]

Bases: Enum

Types of material behavior.

Types of material behavior:

  • SOLID: Solid-like (finite equilibrium modulus)

  • LIQUID: Liquid-like (zero equilibrium modulus, flows)

  • GEL: Gel-like (power-law relaxation)

  • VISCOELASTIC_SOLID: Viscoelastic solid

  • VISCOELASTIC_LIQUID: Viscoelastic liquid

  • UNKNOWN: Cannot determine

SOLID = 'solid'
LIQUID = 'liquid'
GEL = 'gel'
VISCOELASTIC_SOLID = 'viscoelastic_solid'
VISCOELASTIC_LIQUID = 'viscoelastic_liquid'
UNKNOWN = 'unknown'

Functions

rheojax.utils.compatibility.detect_decay_type(t, G_t)[source]

Detect the type of relaxation decay from time-domain data.

Analyzes the decay pattern to determine if it follows exponential, power-law, stretched exponential, or Mittag-Leffler behavior.

Parameters:
  • t (ndarray) – Time array (s)

  • G_t (ndarray) – Relaxation modulus array (Pa)

Returns:

Detected decay type

Return type:

DecayType

Analyzes relaxation modulus data to determine the type of decay pattern. Uses linear regression on log-transformed data to identify exponential, power-law, stretched exponential, or Mittag-Leffler behavior.

rheojax.utils.compatibility.detect_material_type(t=None, G_t=None, omega=None, G_star=None)[source]

Detect the material type from relaxation or oscillation data.

Parameters:
  • t (ndarray | None) – Time array for relaxation data (s)

  • G_t (ndarray | None) – Relaxation modulus (Pa)

  • omega (ndarray | None) – Frequency array for oscillation data (rad/s)

  • G_star (ndarray | None) – Complex modulus array with shape (N, 2) where [:, 0] is G’ and [:, 1] is G”

Returns:

Detected material type

Return type:

MaterialType

Classifies material behavior from relaxation or oscillation data. Detects solid-like, liquid-like, gel-like, or viscoelastic behavior based on equilibrium modulus or low-frequency response.

rheojax.utils.compatibility.check_model_compatibility(model, t=None, G_t=None, omega=None, G_star=None, test_mode=None)[source]

Check if a model is compatible with the given data.

This function analyzes the data characteristics and compares them with the model’s underlying physics to assess compatibility.

Parameters:
  • model (BaseModel) – The rheological model to check

  • t (ndarray | None) – Time array for relaxation data (s)

  • G_t (ndarray | None) – Relaxation modulus (Pa)

  • omega (ndarray | None) – Frequency array for oscillation data (rad/s)

  • G_star (ndarray | None) – Complex modulus array with shape (N, 2)

  • test_mode (str | None) – Test mode (‘relaxation’, ‘creep’, ‘oscillation’)

Returns:

Dictionary with compatibility information: - ‘compatible’: bool, whether model is likely compatible - ‘confidence’: float, confidence level (0-1) - ‘decay_type’: DecayType, detected decay pattern - ‘material_type’: MaterialType, detected material behavior - ‘warnings’: list[str], compatibility warnings - ‘recommendations’: list[str], suggested alternative models

Return type:

dict[str, Any]

Comprehensive compatibility check comparing model physics with data characteristics. Returns detailed compatibility information including warnings and model recommendations.

rheojax.utils.compatibility.format_compatibility_message(compatibility)[source]

Format compatibility check results as a user-friendly message.

Parameters:

compatibility (dict) – Compatibility check results from check_model_compatibility()

Returns:

Formatted message

Return type:

str

Formats compatibility check results as a human-readable message with warnings, detected characteristics, and alternative model recommendations.

Decay Detection Algorithm

The decay type detection uses statistical analysis on log-transformed data:

Exponential Decay Detection

Linear regression on \(\log(G)\) vs \(t\):

\[\log G(t) = \log G_0 - \frac{t}{\tau}\]

High \(R^2\) (> 0.90) indicates exponential decay (Maxwell-like behavior).

Power-Law Decay Detection

Linear regression on \(\log(G)\) vs \(\log(t)\):

\[\log G(t) = \log G_0 - \alpha \log t\]

High \(R^2\) (> 0.90) indicates power-law decay (gel-like behavior).

Stretched Exponential Detection

Linear regression on \(\log(-\log(G/G_0))\) vs \(\log(t)\):

\[\log\left(-\log\frac{G(t)}{G_0}\right) = \beta \log t + \text{const}\]

High \(R^2\) (> 0.90) indicates stretched exponential behavior.

Material Type Classification

From Relaxation Data

Material type is determined by the decay ratio:

\[\text{decay ratio} = \frac{\text{mean}(G_{\text{final}})}{\text{mean}(G_{\text{initial}})}\]
  • Solid-like: decay ratio > 0.5 (significant equilibrium modulus)

  • Liquid-like: decay ratio < 0.1 (nearly complete relaxation)

  • Power-law materials: Classified based on decay type regardless of ratio

From Oscillation Data

Material type is determined by low-frequency behavior:

  • Solid: \(G' > G''\) at lowest frequency (elastic dominant)

  • Liquid: \(G'' > G'\) at lowest frequency (viscous dominant)

Examples

Basic Compatibility Check

from rheojax.models.fractional_zener_ss import FractionalZenerSolidSolid
from rheojax.utils.compatibility import (
    check_model_compatibility,
    format_compatibility_message
)
import numpy as np

# Generate exponential decay data (Maxwell-like)
t = np.logspace(-2, 2, 50)
G_t = 1e5 * np.exp(-t / 1.0)

# Check if FZSS is appropriate
model = FractionalZenerSolidSolid()
compatibility = check_model_compatibility(
    model, t=t, G_t=G_t, test_mode='relaxation'
)

# Print human-readable report
print(format_compatibility_message(compatibility))
# Output:
# WARNING: Model may not be appropriate for this data
#   Confidence: 90%
#   Detected decay: exponential
#   Material type: viscoelastic_liquid
#
# Warnings:
#   - FZSS model expects Mittag-Leffler (power-law) relaxation,
#     but data shows exponential decay.
#
# Recommended alternative models:
#   - Maxwell
#   - Zener

Automatic Checking During Fit

from rheojax.models import Maxwell
import numpy as np

# Enable automatic compatibility checking
model = Maxwell()
model.fit(
    t, G_data,
    test_mode='relaxation',
    check_compatibility=True  # Warns if model-data mismatch
)

# If incompatible, warning is logged and enhanced error
# messages provide physics-based explanations

Detecting Decay Type

from rheojax.utils.compatibility import detect_decay_type, DecayType
import numpy as np

# Power-law decay (gel-like)
t = np.logspace(-2, 2, 100)
G_gel = 1e5 * t**(-0.5)

decay_type = detect_decay_type(t, G_gel)
print(decay_type)  # DecayType.POWER_LAW

# Exponential decay (Maxwell-like)
G_maxwell = 1e5 * np.exp(-t / 1.0)
decay_type = detect_decay_type(t, G_maxwell)
print(decay_type)  # DecayType.EXPONENTIAL

Classifying Material Type

from rheojax.utils.compatibility import detect_material_type
import numpy as np

# Solid-like material (finite equilibrium modulus)
t = np.logspace(-2, 2, 50)
G_solid = 5e4 + 5e4 * np.exp(-t / 1.0)  # Ge + Gm*exp(-t/tau)

material_type = detect_material_type(t=t, G_t=G_solid)
print(material_type)  # MaterialType.VISCOELASTIC_SOLID

# Liquid-like material (no equilibrium modulus)
G_liquid = 1e5 * np.exp(-t / 1.0)
material_type = detect_material_type(t=t, G_t=G_liquid)
print(material_type)  # MaterialType.VISCOELASTIC_LIQUID

Oscillation Data Analysis

from rheojax.utils.compatibility import (
    check_model_compatibility,
    detect_material_type
)
from rheojax.models.fractional_maxwell_liquid import FractionalMaxwellLiquid
import numpy as np

# Oscillation data (G', G")
omega = np.logspace(-2, 2, 50)
G_prime = 1e5 * np.ones(50)  # Constant storage modulus
G_double_prime = 1e3 * omega**0.5  # Frequency-dependent loss

G_star = np.column_stack([G_prime, G_double_prime])

# Detect material type
material_type = detect_material_type(omega=omega, G_star=G_star)
print(material_type)  # MaterialType.SOLID (G' > G" at low freq)

# Check model compatibility
model = FractionalMaxwellLiquid()
compatibility = check_model_compatibility(
    model,
    omega=omega,
    G_star=G_star,
    test_mode='oscillation'
)

if not compatibility['compatible']:
    print(f"Confidence: {compatibility['confidence']}")
    print(f"Warnings: {compatibility['warnings']}")
    print(f"Try instead: {compatibility['recommendations']}")

Enhanced Error Messages

import numpy as np
from rheojax.models.fractional_zener_ss import FractionalZenerSolidSolid

# Generate exponential data (incompatible with FZSS)
np.random.seed(42)
t = np.logspace(-2, 2, 50)
G_t = 1e5 * np.exp(-t / 1.0) + np.random.normal(0, 1000, size=len(t))

model = FractionalZenerSolidSolid()

try:
    # Fit will fail with enhanced error message
    model.fit(t, G_t, test_mode='relaxation', max_iter=100)
except RuntimeError as e:
    print(e)
    # Output includes:
    # - Original optimization error
    # - Detected decay type and material type
    # - Physics-based explanation of mismatch
    # - Recommended alternative models
    # - Guidance that failures are normal in model comparison

Model Compatibility Rules

Fractional Zener Solid-Solid (FZSS)

  • Expects: Mittag-Leffler or power-law relaxation with finite equilibrium modulus

  • Incompatible with: Exponential decay (use Maxwell/Zener instead)

  • Incompatible with: Liquid-like behavior (use FractionalMaxwellLiquid)

Fractional Maxwell Liquid (FML)

  • Expects: Liquid-like behavior (no equilibrium modulus)

  • Incompatible with: Solid-like materials (use FZSS or FractionalKelvinVoigt)

Fractional Maxwell Gel (FMG)

  • Expects: Power-law relaxation (gel-like)

  • Incompatible with: Exponential decay (use Maxwell instead)

Maxwell Model

  • Expects: Exponential decay

  • Incompatible with: Power-law decay (use FMG or FZSS)

Zener Model

  • Expects: Exponential decay with equilibrium modulus

  • Incompatible with: Power-law decay (use FZSS)

Fractional Kelvin-Voigt

  • Expects: Solid-like behavior

  • Incompatible with: Liquid-like behavior (use FractionalMaxwellLiquid)

Performance Considerations

  • Fast detection: < 1 ms for typical datasets (50-100 points)

  • Minimal overhead: Can be enabled during fitting without performance impact

  • Robust to noise: Uses statistical regression with confidence thresholds

  • Automatic test mode detection: Works with relaxation, creep, and oscillation data

Use Cases

  1. Model Selection: Identify appropriate models before fitting

  2. Error Diagnosis: Understand why optimization failed

  3. Automated Pipelines: Filter incompatible model-data combinations

  4. Model Comparison: Expect some models to fail (this is normal!)

  5. Educational: Learn about rheological model physics

See Also

  • Core Module (rheojax.core) - BaseModel integration with check_compatibility parameter

  • ../user_guide/model_selection - Comprehensive model selection guide

  • ../user_guide/getting_started - Basic usage examples

Data Quality Analysis

Data quality and range detection utilities.

This module provides utilities for detecting data characteristics that affect optimization quality, such as very wide frequency ranges (mastercurves).

rheojax.utils.data_quality.detect_data_range_decades(x)[source]

Detect the range of data in decades (log10 scale).

Parameters:

x (ndarray) – Data array (e.g., frequency, time)

Return type:

float

Returns:

Range in decades (log10(max/min))

Example

>>> freq = np.array([1e-8, 1e-6, 1e-4, 1e4])
>>> decades = detect_data_range_decades(freq)
>>> print(f"{decades:.1f} decades")  # 12.0 decades
rheojax.utils.data_quality.check_wide_frequency_range(x, threshold_decades=8.0, warn=True, recommend_log_residuals=True)[source]

Check if data has a very wide frequency/time range (e.g., mastercurve).

Wide-range data (>8 decades) can cause optimization problems: - Optimizer bias toward high-value regions - Poor parameter recovery - Convergence to local minima

Recommended solutions: - Use log-space residuals (use_log_residuals=True) - Fit to subset of data for initialization - Use multi-start optimization

Parameters:
  • x (ndarray) – Independent variable data (frequency, time, etc.)

  • threshold_decades (float) – Threshold for “wide range” warning (default: 8.0)

  • warn (bool) – Whether to emit a warning if range is wide (default: True)

  • recommend_log_residuals (bool) – Whether to recommend log-residuals in warning

Returns:

  • ‘is_wide_range’: True if range > threshold

  • ’decades’: Actual range in decades

  • ’recommendation’: Recommended action (or empty string)

Return type:

dict[str, bool | float | str]

Example

>>> omega = np.logspace(-8, 4, 100)  # 12 decades (mastercurve)
>>> result = check_wide_frequency_range(omega)
>>> if result['is_wide_range']:
...     print(f"Wide range detected: {result['decades']:.1f} decades")
...     print(result['recommendation'])
rheojax.utils.data_quality.suggest_optimization_strategy(x, y, test_mode=None)[source]

Suggest optimization strategy based on data characteristics.

Analyzes data range, complexity, and test mode to recommend: - Whether to use log-residuals - Whether to use multi-start optimization - Whether to use subset initialization

Parameters:
  • x (ndarray) – Independent variable (frequency, time, etc.)

  • y (ndarray) – Dependent variable (modulus, stress, etc.)

  • test_mode (str | None) – Test mode (‘oscillation’, ‘relaxation’, ‘creep’)

Returns:

  • ‘use_log_residuals’: Recommended for wide ranges

  • ’use_multi_start’: Recommended for complex landscapes

  • ’use_subset_init’: Recommended for very wide ranges

  • ’rationale’: Explanation of recommendations

Return type:

dict[str, bool | str | float]

Example

>>> omega = np.logspace(-8, 4, 100)
>>> G_star = ...  # Complex modulus data
>>> strategy = suggest_optimization_strategy(omega, G_star, 'oscillation')
>>> print(strategy['rationale'])
rheojax.utils.data_quality.check_nan_inf(data, label='data')[source]

Check for NaN/Inf values and return a diagnostic dictionary.

Parameters:
  • data (ndarray) – Array to inspect (any shape; will be flattened internally).

  • label (str) – Human-readable name for this array used in the returned dict.

Returns:

  • ‘label’: The provided label string.

  • ’n_nan’: Number of NaN values.

  • ’n_inf’: Number of Inf values (±∞).

  • ’has_issues’: True if any NaN or Inf is present.

  • ’fraction_clean’: Fraction of finite values in [0, 1].

Return type:

dict[str, object]

Example

>>> arr = np.array([1.0, np.nan, np.inf, 2.0])
>>> result = check_nan_inf(arr, label="G_star")
>>> result['n_nan']
1
>>> result['has_issues']
True
rheojax.utils.data_quality.check_monotonicity(x, threshold=0.95)[source]

Check whether an array is approximately monotonic.

An array is considered monotonic if at least threshold fraction of consecutive differences share the same sign.

Parameters:
  • x (ndarray) – 1-D array to check.

  • threshold (float) – Minimum fraction of steps that must be consistently increasing or decreasing to classify as monotonic (default: 0.95).

Returns:

  • ‘is_monotonic’: True if the dominant direction exceeds threshold.

  • ’direction’: ‘increasing’, ‘decreasing’, ‘constant’, or ‘mixed’.

  • ’fraction’: Fraction of steps in the dominant direction.

Return type:

dict[str, object]

Example

>>> x = np.array([1.0, 2.0, 3.0, 2.9, 4.0])
>>> result = check_monotonicity(x, threshold=0.95)
>>> result['direction']
'increasing'

The data quality module provides intelligent analysis of experimental data characteristics to optimize fitting strategies, especially for wide-range frequency or time-domain data spanning multiple decades.

Functions

rheojax.utils.data_quality.detect_data_range_decades(x)[source]

Detect the range of data in decades (log10 scale).

Parameters:

x (ndarray) – Data array (e.g., frequency, time)

Return type:

float

Returns:

Range in decades (log10(max/min))

Example

>>> freq = np.array([1e-8, 1e-6, 1e-4, 1e4])
>>> decades = detect_data_range_decades(freq)
>>> print(f"{decades:.1f} decades")  # 12.0 decades

Detects the number of decades spanned by the independent variable (frequency, time, or shear rate). This helps identify when multi-start or log-residuals optimization may be beneficial.

rheojax.utils.data_quality.check_wide_frequency_range(x, threshold_decades=8.0, warn=True, recommend_log_residuals=True)[source]

Check if data has a very wide frequency/time range (e.g., mastercurve).

Wide-range data (>8 decades) can cause optimization problems: - Optimizer bias toward high-value regions - Poor parameter recovery - Convergence to local minima

Recommended solutions: - Use log-space residuals (use_log_residuals=True) - Fit to subset of data for initialization - Use multi-start optimization

Parameters:
  • x (ndarray) – Independent variable data (frequency, time, etc.)

  • threshold_decades (float) – Threshold for “wide range” warning (default: 8.0)

  • warn (bool) – Whether to emit a warning if range is wide (default: True)

  • recommend_log_residuals (bool) – Whether to recommend log-residuals in warning

Returns:

  • ‘is_wide_range’: True if range > threshold

  • ’decades’: Actual range in decades

  • ’recommendation’: Recommended action (or empty string)

Return type:

dict[str, bool | float | str]

Example

>>> omega = np.logspace(-8, 4, 100)  # 12 decades (mastercurve)
>>> result = check_wide_frequency_range(omega)
>>> if result['is_wide_range']:
...     print(f"Wide range detected: {result['decades']:.1f} decades")
...     print(result['recommendation'])

Comprehensive analysis of frequency-domain data to determine if it spans a wide range (> 4 decades) and whether special optimization techniques are needed.

rheojax.utils.data_quality.suggest_optimization_strategy(x, y, test_mode=None)[source]

Suggest optimization strategy based on data characteristics.

Analyzes data range, complexity, and test mode to recommend: - Whether to use log-residuals - Whether to use multi-start optimization - Whether to use subset initialization

Parameters:
  • x (ndarray) – Independent variable (frequency, time, etc.)

  • y (ndarray) – Dependent variable (modulus, stress, etc.)

  • test_mode (str | None) – Test mode (‘oscillation’, ‘relaxation’, ‘creep’)

Returns:

  • ‘use_log_residuals’: Recommended for wide ranges

  • ’use_multi_start’: Recommended for complex landscapes

  • ’use_subset_init’: Recommended for very wide ranges

  • ’rationale’: Explanation of recommendations

Return type:

dict[str, bool | str | float]

Example

>>> omega = np.logspace(-8, 4, 100)
>>> G_star = ...  # Complex modulus data
>>> strategy = suggest_optimization_strategy(omega, G_star, 'oscillation')
>>> print(strategy['rationale'])

Provides intelligent recommendations for optimization strategy based on data characteristics (range, domain, test mode). Returns configuration for:

  • Use of log-residuals (for wide-range data)

  • Multi-start optimization (for complex landscapes)

  • Recommended number of random starts

Wide-Range Data Challenges

When experimental data spans many decades (e.g., frequency from 0.01 to 1000 Hz), standard least-squares fitting can encounter problems:

Problem: Linear residuals \(\sum (y_{\text{pred}} - y_{\text{exp}})^2\) are dominated by high-magnitude points, causing poor fits at low values.

Example: For \(G'\) spanning 100 Pa to 1e6 Pa:

  • High-frequency error (1e6 Pa): residual ~ 1e12

  • Low-frequency error (100 Pa): residual ~ 1e4

  • Optimizer focuses on high-frequency region, ignores low-frequency

Solutions:

  1. Log-Residuals: Minimize \(\sum (\log y_{\text{pred}} - \log y_{\text{exp}})^2\)

    • Balances contributions across decades

    • Equivalent to minimizing relative error

    • Automatically enabled for data > 4 decades

  2. Multi-Start Optimization: Run multiple optimizations from random initial points

    • Escapes local minima

    • Finds global optimum more reliably

    • Recommended for complex model landscapes

Detection and Recommendations

The module automatically detects when to use these strategies:

from rheojax.utils.data_quality import suggest_optimization_strategy
import numpy as np

# Wide-range oscillation data
omega = np.logspace(-2, 3, 100)  # 5 decades
G_star = ...  # Complex modulus data

strategy = suggest_optimization_strategy(
    x=omega,
    test_mode='oscillation'
)

print(f"Use log-residuals: {strategy['use_log_residuals']}")  # True
print(f"Use multi-start: {strategy['use_multi_start']}")      # True
print(f"Number of starts: {strategy['n_starts']}")            # 10

# Apply recommendations to model fitting
model.fit(
    omega, G_star,
    test_mode='oscillation',
    use_log_residuals=strategy['use_log_residuals'],
    multi_start=strategy['use_multi_start'],
    n_starts=strategy['n_starts']
)

Examples

Detect Data Range

from rheojax.utils.data_quality import detect_data_range_decades
import numpy as np

# Narrow range (2 decades)
freq_narrow = np.logspace(0, 2, 50)
decades = detect_data_range_decades(freq_narrow)
print(f"Range: {decades:.1f} decades")  # 2.0 decades

# Wide range (5 decades)
freq_wide = np.logspace(-2, 3, 100)
decades = detect_data_range_decades(freq_wide)
print(f"Range: {decades:.1f} decades")  # 5.0 decades

Check Frequency Range

from rheojax.utils.data_quality import check_wide_frequency_range
import numpy as np

omega = np.logspace(-1, 3, 80)  # 4 decades
result = check_wide_frequency_range(omega)

print(f"Is wide range: {result['is_wide_range']}")      # True
print(f"Decades: {result['decades']:.2f}")              # 4.0
print(f"Recommend log: {result['use_log_residuals']}")  # True
print(f"Recommend multi-start: {result['use_multi_start']}")  # False

# Very wide range triggers multi-start
omega_very_wide = np.logspace(-2, 4, 100)  # 6 decades
result = check_wide_frequency_range(omega_very_wide)
print(f"Multi-start: {result['use_multi_start']}")  # True
print(f"Starts: {result['n_starts']}")              # 15

Get Complete Strategy

from rheojax.utils.data_quality import suggest_optimization_strategy
import numpy as np

# Time-domain relaxation data
time = np.logspace(-3, 2, 100)  # 5 decades
G_t = ...  # Relaxation modulus

strategy = suggest_optimization_strategy(
    x=time,
    test_mode='relaxation'
)

# Use strategy with BaseModel
from rheojax.models.fractional_maxwell_liquid import FractionalMaxwellLiquid

model = FractionalMaxwellLiquid()
model.fit(
    time, G_t,
    test_mode='relaxation',
    **strategy  # Unpack all recommended settings
)

Integration with BaseModel

The BaseModel._fit() method automatically uses suggest_optimization_strategy() when no explicit optimization configuration is provided. This ensures optimal fitting for all data ranges without user intervention.

Automatic behavior:

  • Data < 3 decades: Standard least-squares

  • Data 3-5 decades: Log-residuals enabled

  • Data > 5 decades: Log-residuals + multi-start (10-20 starts)

Performance Considerations

  • Detection overhead: < 0.1 ms (negligible)

  • Log-residuals: Same computational cost as linear

  • Multi-start: N times slower (N = number of starts), but more robust

  • Memory: Minimal additional memory for multi-start

Use Cases

  1. Wide-range oscillation data: Master curves spanning 8+ decades

  2. Time-temperature superposition: Combined data across temperatures

  3. Multi-technique fitting: Combining relaxation + oscillation data

  4. Fractional models: Complex parameter landscapes benefit from multi-start

  5. Automated pipelines: Robust fitting without manual tuning

See Also

Modulus Conversion (DMTA Support)

Modulus conversion utilities for DMTA/DMA data analysis.

This module provides functions to convert between shear modulus G* (measured by rotational rheometers) and Young’s modulus E* (measured by DMTA/DMA instruments in tension, bending, or compression).

The fundamental relationship from isotropic linear elasticity:

E*(w) = 2(1 + v) * G*(w)

where v is the Poisson’s ratio of the material.

Example

>>> from rheojax.utils.modulus_conversion import convert_modulus
>>> from rheojax.core.test_modes import DeformationMode
>>>
>>> # Convert E* (DMTA) to G* (shear) for rubber (v=0.5 -> factor=3)
>>> G_star = convert_modulus(E_star, DeformationMode.TENSION, DeformationMode.SHEAR, poisson_ratio=0.5)
>>>
>>> # Use preset materials
>>> from rheojax.utils.modulus_conversion import POISSON_PRESETS
>>> nu = POISSON_PRESETS["glassy_polymer"]  # 0.35
rheojax.utils.modulus_conversion.convert_modulus(data, from_mode, to_mode, poisson_ratio=0.5)[source]

Convert modulus data between shear (G*) and tensile (E*) representations.

Applies the isotropic elasticity relationship E* = 2(1+v) * G*. Works with both real and complex arrays, and both NumPy and JAX arrays.

Parameters:
  • data (ndarray | Any) – Modulus data array (real or complex). Can be NumPy or JAX array.

  • from_mode (DeformationMode | str) – Source deformation mode (e.g., “tension”, “shear”)

  • to_mode (DeformationMode | str) – Target deformation mode (e.g., “shear”, “tension”)

  • poisson_ratio (float) – Poisson’s ratio of the material (default: 0.5 for rubber)

Return type:

ndarray | Any

Returns:

Converted modulus data in the same array type as input

Raises:

ValueError – If Poisson’s ratio is out of bounds or modes are invalid

Example

>>> E_star = np.array([1e9 + 1e8j, 2e9 + 2e8j])  # E* in Pa
>>> G_star = convert_modulus(E_star, "tension", "shear", poisson_ratio=0.5)
>>> # G_star ≈ E_star / 3 for rubber
rheojax.utils.modulus_conversion.convert_rheodata(data, to_mode, poisson_ratio=0.5)[source]

Convert RheoData between shear and tensile modulus representations.

Creates a new RheoData with converted y-values and updated metadata. The original RheoData is not modified.

Parameters:
  • data (RheoData) – Source RheoData object

  • to_mode (DeformationMode | str) – Target deformation mode

  • poisson_ratio (float) – Poisson’s ratio of the material

Return type:

RheoData

Returns:

New RheoData with converted modulus values and updated metadata

Example

>>> from rheojax.core.data import RheoData
>>> # DMTA data in tension
>>> dmta_data = RheoData(x=omega, y=E_star,
...     metadata={"deformation_mode": "tension"})
>>> # Convert to shear for model fitting
>>> shear_data = convert_rheodata(dmta_data, "shear", poisson_ratio=0.5)

The modulus conversion module provides utilities for converting between shear modulus G* and Young’s modulus E* for DMTA/DMA data analysis.

The fundamental relationship from isotropic linear elasticity:

\[E^*(\omega) = 2(1 + \nu) \, G^*(\omega)\]

where \(\nu\) is the Poisson’s ratio of the material.

Functions

rheojax.utils.modulus_conversion.convert_modulus(data, from_mode, to_mode, poisson_ratio=0.5)[source]

Convert modulus data between shear (G*) and tensile (E*) representations.

Applies the isotropic elasticity relationship E* = 2(1+v) * G*. Works with both real and complex arrays, and both NumPy and JAX arrays.

Parameters:
  • data (ndarray | Any) – Modulus data array (real or complex). Can be NumPy or JAX array.

  • from_mode (DeformationMode | str) – Source deformation mode (e.g., “tension”, “shear”)

  • to_mode (DeformationMode | str) – Target deformation mode (e.g., “shear”, “tension”)

  • poisson_ratio (float) – Poisson’s ratio of the material (default: 0.5 for rubber)

Return type:

ndarray | Any

Returns:

Converted modulus data in the same array type as input

Raises:

ValueError – If Poisson’s ratio is out of bounds or modes are invalid

Example

>>> E_star = np.array([1e9 + 1e8j, 2e9 + 2e8j])  # E* in Pa
>>> G_star = convert_modulus(E_star, "tension", "shear", poisson_ratio=0.5)
>>> # G_star ≈ E_star / 3 for rubber

Array-level conversion between E* and G* using Poisson’s ratio.

Converts complex modulus arrays in either direction:

  • tensionshear: \(G^* = E^* / 2(1+\nu)\)

  • sheartension: \(E^* = 2(1+\nu) \, G^*\)

rheojax.utils.modulus_conversion.convert_rheodata(data, to_mode, poisson_ratio=0.5)[source]

Convert RheoData between shear and tensile modulus representations.

Creates a new RheoData with converted y-values and updated metadata. The original RheoData is not modified.

Parameters:
  • data (RheoData) – Source RheoData object

  • to_mode (DeformationMode | str) – Target deformation mode

  • poisson_ratio (float) – Poisson’s ratio of the material

Return type:

RheoData

Returns:

New RheoData with converted modulus values and updated metadata

Example

>>> from rheojax.core.data import RheoData
>>> # DMTA data in tension
>>> dmta_data = RheoData(x=omega, y=E_star,
...     metadata={"deformation_mode": "tension"})
>>> # Convert to shear for model fitting
>>> shear_data = convert_rheodata(dmta_data, "shear", poisson_ratio=0.5)

RheoData-level conversion with automatic metadata update (units, deformation_mode).

Material Presets

rheojax.utils.modulus_conversion.POISSON_PRESETS

Dictionary of common Poisson’s ratio values by material class:

Material

\(\nu\)

Notes

rubber / elastomer

0.50

Incompressible (E = 3G)

hydrogel

0.50

Water-swollen networks

semicrystalline

0.40

PE, PP, PA, PET

thermoset

0.38

Epoxies, polyesters

glassy_polymer

0.35

PS, PMMA, PC below Tg

metal

0.30

Steel, aluminum

foam

0.30

Open-cell foams

Examples

Array-Level Conversion

import numpy as np
from rheojax.utils.modulus_conversion import convert_modulus, POISSON_PRESETS

# DMTA data: E* from tensile DMA
omega = np.logspace(-2, 2, 50)
E_prime = 3e9 * np.ones(50)       # E' (Pa)
E_double_prime = 1e8 * omega**0.3  # E'' (Pa)
E_star = E_prime + 1j * E_double_prime

# Convert to G* for rubber (v=0.5, factor=3)
G_star = convert_modulus(E_star, "tension", "shear", poisson_ratio=0.5)
# G* = E* / 3 for rubber

# Use preset Poisson's ratio
nu = POISSON_PRESETS["glassy_polymer"]  # 0.35
G_star = convert_modulus(E_star, "tension", "shear", poisson_ratio=nu)

Fitting DMTA Data Directly

from rheojax.models import Maxwell

# All models accept deformation_mode in fit()/predict()
model = Maxwell()
model.fit(
    omega, E_star,
    test_mode='oscillation',
    deformation_mode='tension',
    poisson_ratio=0.5,
)

# predict() returns E* when fitted with tensile deformation_mode
E_pred = model.predict(omega, test_mode='oscillation')

Device Utilities

GPU detection and warning utilities for RheoJAX (System CUDA version).

This module provides utilities to detect GPU availability and warn users when they have GPU hardware available but are using CPU-only JAX.

rheojax.utils.device.get_system_cuda_version()[source]

Detect system CUDA version from nvcc.

Returns:

Tuple of (full_version, major_version) or (None, None) if not found. Example: (“12.6”, 12) or (“13.0”, 13)

Return type:

tuple[str | None, int | None]

rheojax.utils.device.get_gpu_info()[source]

Detect GPU name and SM version.

Returns:

Tuple of (gpu_name, sm_version) or (None, None) if not found. Example: (“NVIDIA GeForce RTX 4090”, 8.9)

Return type:

tuple[str | None, float | None]

Get recommended JAX package based on system CUDA.

Returns:

Package name like “jax[cuda12-local]” or “jax[cuda13-local]”, or None if no compatible setup found.

Return type:

str | None

rheojax.utils.device.check_plugin_conflicts()[source]

Check for known JAX CUDA plugin conflicts.

Returns:

List of issue descriptions (empty = no issues).

Return type:

list[str]

rheojax.utils.device.check_gpu_availability(warn=True)[source]

Check if GPU is available but not being used by JAX.

Prints a helpful warning if GPU hardware and system CUDA are detected but JAX is running in CPU-only mode.

Parameters:

warn (bool) – If True, print warning when GPU available but not used. Default is True.

Returns:

True if GPU is being used by JAX, False otherwise.

Return type:

bool

Examples

Call this at package initialization or in CLI entry points:

>>> from rheojax.utils.device import check_gpu_availability
>>> check_gpu_availability()  # Prints warning if GPU detected but not used
rheojax.utils.device.get_device_info()[source]

Get comprehensive device information.

Returns:

Dictionary with: - jax_version: JAX version string - jax_backend: Current backend (cpu, gpu) - devices: List of device strings - gpu_count: Number of GPU devices - using_gpu: Boolean - gpu_hardware: GPU name - gpu_sm_version: SM version (float) - system_cuda_version: System CUDA version string - system_cuda_major: System CUDA major version (int) - recommended_package: Recommended JAX package

Return type:

dict

rheojax.utils.device.get_gpu_memory_info()[source]

Get GPU memory information using nvidia-smi.

Return type:

dict

Returns:

  • dict – Dictionary with keys: - ‘total_mb’: Total GPU memory in MB - ‘used_mb’: Used GPU memory in MB - ‘free_mb’: Free GPU memory in MB - ‘utilization_percent’: GPU utilization percentage

  • Returns empty dict if nvidia-smi is not available.

Examples

>>> from rheojax.utils.device import get_gpu_memory_info
>>> info = get_gpu_memory_info()
>>> if info:
...     print(f"GPU Memory: {info['used_mb']}/{info['total_mb']} MB")
rheojax.utils.device.print_device_summary()[source]

Print a summary of available compute devices.

Displays: - JAX version - Available devices (CPU/GPU) - GPU memory info (if available) - Warning if GPU hardware is detected but not being used

Return type:

None

Examples

>>> from rheojax.utils.device import print_device_summary
>>> print_device_summary()
JAX Device Summary
==================
JAX version: 0.8.0
Devices: [CpuDevice(id=0)]
Using: CPU-only

The device module provides GPU detection and diagnostic utilities for RheoJAX. These functions help users identify available compute resources and configure JAX for optimal performance.

Functions

rheojax.utils.device.check_gpu_availability(warn=True)[source]

Check if GPU is available but not being used by JAX.

Prints a helpful warning if GPU hardware and system CUDA are detected but JAX is running in CPU-only mode.

Parameters:

warn (bool) – If True, print warning when GPU available but not used. Default is True.

Returns:

True if GPU is being used by JAX, False otherwise.

Return type:

bool

Examples

Call this at package initialization or in CLI entry points:

>>> from rheojax.utils.device import check_gpu_availability
>>> check_gpu_availability()  # Prints warning if GPU detected but not used

Checks if GPU hardware and CUDA are available but JAX is running CPU-only. Prints a helpful message with installation instructions if GPU is not being used. Also checks for plugin conflicts even when GPU is working.

rheojax.utils.device.check_plugin_conflicts()[source]

Check for known JAX CUDA plugin conflicts.

Returns:

List of issue descriptions (empty = no issues).

Return type:

list[str]

Detects known JAX CUDA plugin issues: dual cuda12/cuda13 plugins installed simultaneously (causes PJRT registration conflicts) or plugin/jaxlib version mismatches (causes silent CPU fallback). Returns a list of issue descriptions.

rheojax.utils.device.get_device_info()[source]

Get comprehensive device information.

Returns:

Dictionary with: - jax_version: JAX version string - jax_backend: Current backend (cpu, gpu) - devices: List of device strings - gpu_count: Number of GPU devices - using_gpu: Boolean - gpu_hardware: GPU name - gpu_sm_version: SM version (float) - system_cuda_version: System CUDA version string - system_cuda_major: System CUDA major version (int) - recommended_package: Recommended JAX package

Return type:

dict

Returns comprehensive device information including JAX version, backend, device list, GPU hardware name, SM version, system CUDA version, recommended package, and any plugin issues.

rheojax.utils.device.get_gpu_memory_info()[source]

Get GPU memory information using nvidia-smi.

Return type:

dict

Returns:

  • dict – Dictionary with keys: - ‘total_mb’: Total GPU memory in MB - ‘used_mb’: Used GPU memory in MB - ‘free_mb’: Free GPU memory in MB - ‘utilization_percent’: GPU utilization percentage

  • Returns empty dict if nvidia-smi is not available.

Examples

>>> from rheojax.utils.device import get_gpu_memory_info
>>> info = get_gpu_memory_info()
>>> if info:
...     print(f"GPU Memory: {info['used_mb']}/{info['total_mb']} MB")

Queries nvidia-smi for GPU memory utilization (total, used, free, utilization %). Returns empty dict on systems without NVIDIA GPUs.

rheojax.utils.device.print_device_summary()[source]

Print a summary of available compute devices.

Displays: - JAX version - Available devices (CPU/GPU) - GPU memory info (if available) - Warning if GPU hardware is detected but not being used

Return type:

None

Examples

>>> from rheojax.utils.device import print_device_summary
>>> print_device_summary()
JAX Device Summary
==================
JAX version: 0.8.0
Devices: [CpuDevice(id=0)]
Using: CPU-only

Prints a formatted summary of the compute environment (JAX version, devices, GPU memory). Useful at the start of scripts and notebooks.

Examples

Quick Environment Check

from rheojax.utils import print_device_summary, check_gpu_availability

# Print full device summary at script start
print_device_summary()
# JAX Device Summary
# ==================
# JAX version: 0.8.3
# Devices: [CpuDevice(id=0)]
# Using: CPU-only

# Programmatic check
using_gpu = check_gpu_availability(warn=False)
if not using_gpu:
    print("Running on CPU — consider installing JAX with CUDA support")

Detailed Device Info

from rheojax.utils import get_device_info, get_gpu_memory_info

info = get_device_info()
print(f"JAX {info['jax_version']} on {info['jax_backend']}")
print(f"GPU hardware: {info['gpu_hardware'] or 'None'}")
print(f"System CUDA: {info['system_cuda_version'] or 'Not found'}")

# Check for plugin issues
if info['plugin_issues']:
    for issue in info['plugin_issues']:
        print(f"WARNING: {issue}")

# GPU memory (if available)
mem = get_gpu_memory_info()
if mem:
    print(f"GPU Memory: {mem['used_mb']}/{mem['total_mb']} MB")

Plugin Conflict Detection

from rheojax.utils import check_plugin_conflicts

issues = check_plugin_conflicts()
if issues:
    for issue in issues:
        print(f"Issue: {issue}")
    print("Fix: pip uninstall -y jax jaxlib "
          "jax-cuda12-plugin jax-cuda12-pjrt "
          "jax-cuda13-plugin jax-cuda13-pjrt")
    print("Then: make install-jax-gpu")
else:
    print("No plugin conflicts detected")

See Also

  • Development Status & Performance - Technology stack and GPU requirements

  • make install-jax-gpu - Automated GPU JAX installation

  • make gpu-diagnose - Diagnose common GPU issues (plugin conflicts, version mismatches)

  • make gpu-check - Verify GPU backend, devices, and SVD computation

Fit Quality Metrics

Fit quality metrics for rheological model evaluation.

This module provides functions to compute standard statistical metrics for evaluating model fit quality.

rheojax.utils.metrics.compute_fit_quality(y_true, y_pred)[source]

Compute R² and RMSE fit quality metrics.

Parameters:
  • y_true (numpy.typing.ArrayLike) – Observed (ground truth) values.

  • y_pred (numpy.typing.ArrayLike) – Predicted values from the model.

Returns:

Dictionary containing: - ‘R2’: Coefficient of determination (R²) - ‘RMSE’: Root mean squared error - ‘nrmse’: Normalized RMSE (RMSE / range of y_true)

Return type:

dict[str, float]

Examples

>>> y_true = [1.0, 2.0, 3.0, 4.0]
>>> y_pred = [1.1, 1.9, 3.1, 3.9]
>>> metrics = compute_fit_quality(y_true, y_pred)
>>> metrics['R2'] > 0.99
True
rheojax.utils.metrics.r2_complex(y_true, y_pred)[source]

Compute R² for complex-valued data using magnitudes.

Parameters:
  • y_true (numpy.typing.ArrayLike) – Observed complex values.

  • y_pred (numpy.typing.ArrayLike) – Predicted complex values.

Returns:

Coefficient of determination computed on magnitudes |G*|.

Return type:

float

Note

This metric evaluates magnitude fit only. Phase errors (e.g., correct |G*| but wrong tan(δ)) are not captured. For phase-sensitive evaluation, use r2_complex_components() which averages R² over the real and imaginary components independently.

rheojax.utils.metrics.r2_complex_components(y_true, y_pred)[source]

Compute R² for complex data using separate real and imaginary components.

Returns the arithmetic mean of R²(real) and R²(imag), capturing both magnitude and phase accuracy. A model that fits |G*| perfectly but has the wrong phase angle will score lower here than with r2_complex().

Parameters:
  • y_true (numpy.typing.ArrayLike) – Observed complex values (e.g., G* = G’ + i·G’’).

  • y_pred (numpy.typing.ArrayLike) – Predicted complex values.

Returns:

Average R² across real (G’) and imaginary (G’’) components.

Return type:

float

Examples

>>> import numpy as np
>>> omega = np.logspace(-2, 2, 50)
>>> G_star = omega * 1j  # Pure viscous
>>> r2_complex_components(G_star, G_star)
1.0

The metrics module provides standard statistical measures for evaluating model fit quality after NLSQ optimization or Bayesian inference.

Functions

rheojax.utils.metrics.compute_fit_quality(y_true, y_pred)[source]

Compute R² and RMSE fit quality metrics.

Parameters:
  • y_true (numpy.typing.ArrayLike) – Observed (ground truth) values.

  • y_pred (numpy.typing.ArrayLike) – Predicted values from the model.

Returns:

Dictionary containing: - ‘R2’: Coefficient of determination (R²) - ‘RMSE’: Root mean squared error - ‘nrmse’: Normalized RMSE (RMSE / range of y_true)

Return type:

dict[str, float]

Examples

>>> y_true = [1.0, 2.0, 3.0, 4.0]
>>> y_pred = [1.1, 1.9, 3.1, 3.9]
>>> metrics = compute_fit_quality(y_true, y_pred)
>>> metrics['R2'] > 0.99
True

Computes R², RMSE, and normalized RMSE for real-valued data. Handles multi-dimensional arrays (flattened automatically).

rheojax.utils.metrics.r2_complex(y_true, y_pred)[source]

Compute R² for complex-valued data using magnitudes.

Parameters:
  • y_true (numpy.typing.ArrayLike) – Observed complex values.

  • y_pred (numpy.typing.ArrayLike) – Predicted complex values.

Returns:

Coefficient of determination computed on magnitudes |G*|.

Return type:

float

Note

This metric evaluates magnitude fit only. Phase errors (e.g., correct |G*| but wrong tan(δ)) are not captured. For phase-sensitive evaluation, use r2_complex_components() which averages R² over the real and imaginary components independently.

Computes R² for complex-valued modulus data (G*, E*) using magnitude comparison. Useful for oscillation fits where predictions are complex.

Examples

Real-Valued Fit Assessment

import numpy as np
from rheojax.utils import compute_fit_quality

# After model fitting
y_data = np.array([1000, 800, 650, 500, 320, 200])
y_pred = np.array([1010, 790, 660, 495, 325, 195])

metrics = compute_fit_quality(y_data, y_pred)
print(f"R² = {metrics['R2']:.4f}")     # 0.9997
print(f"RMSE = {metrics['RMSE']:.1f}")  # 7.9 Pa
print(f"NRMSE = {metrics['nrmse']:.4f}")  # 0.0099

Complex Modulus Fit Assessment

import numpy as np
from rheojax.utils import r2_complex

# Oscillation data (G' + iG'')
G_star_data = np.array([1e5 + 1e3j, 5e4 + 5e3j, 1e4 + 2e4j])
G_star_pred = np.array([9.9e4 + 1.1e3j, 5.1e4 + 4.9e3j, 1.05e4 + 2.1e4j])

r2 = r2_complex(G_star_data, G_star_pred)
print(f"R² (magnitude) = {r2:.4f}")

See Also

EPM Kernels

Kernels for Elasto-Plastic Models (EPM).

This module implements the core physics kernels for scalar EPM simulations using JAX. It includes the FFT-based elastic propagator for stress redistribution, logic for plastic events (dual-mode: hard/smooth), and the full time-stepping kernel.

rheojax.utils.epm_kernels.make_propagator_q(L_x, L_y, shear_modulus=1.0)[source]

Create the quadrupolar Eshelby propagator in Fourier space.

G(q) = -2 * mu * (qx * qy)^2 / |q|^4 for q != 0 G(0) = 0

Reference: Talamali et al. (2011) Phys. Rev. E 84, 016115. Eq. (6): G(q) = -2*mu*(qx*qy)^2/|q|^4 — quadrupolar Eshelby propagator.

Parameters:
  • L_x (int) – Lattice size in x.

  • L_y (int) – Lattice size in y.

  • shear_modulus (float) – Shear modulus mu.

Return type:

Array

Returns:

2D array of the propagator in Fourier space (L_x, L_y // 2 + 1).

rheojax.utils.epm_kernels.solve_elastic_propagator(plastic_strain_rate, propagator_q)[source]

Solve for the elastic stress redistribution rate using FFT.

Calculates sigma_dot_el = G * epsilon_dot_pl.

Parameters:
  • plastic_strain_rate (Array) – 2D array of plastic strain rate field (L, L).

  • propagator_q (Array) – Precomputed propagator in Fourier space (L, L // 2 + 1).

Return type:

Array

Returns:

2D array of elastic stress redistribution rate.

rheojax.utils.epm_kernels.compute_plastic_strain_rate(stress, yield_thresholds, fluidity=1.0, smooth=False, smoothing_width=0.1, n_fluid=1.0, sigma_c_mean=1.0, fluidity_form='overstress')[source]

Compute the local plastic strain rate.

Supports two activation modes and three constitutive laws selected by fluidity_form:

  1. Hard activation: gamma_dot_p = f(sigma) * Theta(|sigma| - sigma_c)

  2. Smooth activation: gamma_dot_p = f(sigma) * 0.5 * (1 + tanh((|sigma| - sigma_c)/w))

The constitutive law f(sigma) depends on fluidity_form:

  • “linear” (classical Bingham):

    f(sigma) = sigma / tau_pl

    High-rate asymptote: stress ~ gamma_dot * tau_pl. No yield-stress baseline in the asymptote.

  • “power” (power-law fluidity, soft-glassy rheology):

    f(sigma) = sign(sigma) * |sigma / sigma_c_mean|^n_fluid * sigma_c_mean / tau_pl

    High-rate asymptote: stress ~ sigma_c_mean * (gamma_dot * tau_pl)^(1/n_fluid). Shear-thinning but no additive yield-stress baseline.

  • “overstress” (Herschel-Bulkley, DEFAULT):

    f(sigma) = sign(sigma) * (|sigma| - sigma_c_mean)_+^n_fluid / (sigma_c_mean^(n_fluid-1) * tau_pl)

    Only stress above the threshold contributes to plastic flow. High-rate asymptote: stress ~ sigma_c_mean + sigma_c_mean * (gamma_dot * tau_pl / sigma_c_mean)^(1/n_fluid). This is the full Herschel-Bulkley form sigma = sigma_y + K * gamma_dot^n_HB with sigma_y = sigma_c_mean and n_HB = 1/n_fluid. Recommended for HB-like flow curves (emulsions, gels, foams, yield-stress fluids in general).

At n_fluid = 1, “power” reduces to “linear”; “overstress” at n_fluid = 1 gives a Bingham fluid with explicit yield stress (sigma = sigma_c_mean + gamma_dot * tau_pl).

Parameters:
  • stress (Array) – Local stress field.

  • yield_thresholds (Array) – Local yield thresholds.

  • fluidity (float) – Inverse plastic timescale ($1/tau_{pl}$).

  • smooth (bool) – Whether to use the differentiable smooth approximation.

  • smoothing_width (float) – Width parameter $w$ for smoothing.

  • n_fluid (float) – Power-law / HB exponent. The implied HB flow exponent is n_HB = 1/n_fluid.

  • sigma_c_mean (float) – Mean yield threshold, used as the scale for the power-law forms.

  • fluidity_form (str) – One of “linear”, “power”, or “overstress”. Default “overstress”.

Return type:

Array

Returns:

Plastic strain rate field.

rheojax.utils.epm_kernels.update_yield_thresholds(key, active_mask, current_thresholds, mean=1.0, std=0.1)[source]

Renew yield thresholds for active sites.

Parameters:
  • key (Array) – PRNG Key.

  • active_mask (Array) – Boolean mask of sites that yielded.

  • current_thresholds (Array) – Current thresholds.

  • mean (float) – Mean of Gaussian distribution.

  • std (float) – Std dev of Gaussian distribution.

Return type:

Array

Returns:

Updated yield thresholds.

rheojax.utils.epm_kernels.epm_step(state, propagator_q, shear_rate, dt, params, smooth=False, fluidity_form='overstress')[source]

Perform one full EPM time step.

Dynamics: sigma_dot = mu * gamma_dot - mu * gamma_dot_p + G * gamma_dot_p

Parameters:
  • state (tuple[Array, Array, float, Array]) – Tuple (stress, yield_thresholds, accumulated_strain, key).

  • propagator_q (Array) – Precomputed propagator.

  • shear_rate (float) – Macroscopic imposed shear rate gamma_dot.

  • dt (float) – Time step size.

  • params (dict) – Dictionary of model parameters (mu, tau_pl, sigma_c_mean, etc.).

  • smooth (bool) – Use smooth yielding (for inference) vs hard yielding (for simulation).

Return type:

tuple[Array, Array, float, Array]

Returns:

Updated state tuple.

The EPM kernels module provides low-level JAX functions for Elasto-Plastic Model simulations, including the Eshelby propagator and plastic event logic.

Functions

rheojax.utils.epm_kernels.make_propagator_q(L_x, L_y, shear_modulus=1.0)[source]

Create the quadrupolar Eshelby propagator in Fourier space.

G(q) = -2 * mu * (qx * qy)^2 / |q|^4 for q != 0 G(0) = 0

Reference: Talamali et al. (2011) Phys. Rev. E 84, 016115. Eq. (6): G(q) = -2*mu*(qx*qy)^2/|q|^4 — quadrupolar Eshelby propagator.

Parameters:
  • L_x (int) – Lattice size in x.

  • L_y (int) – Lattice size in y.

  • shear_modulus (float) – Shear modulus mu.

Return type:

Array

Returns:

2D array of the propagator in Fourier space (L_x, L_y // 2 + 1).

Creates the quadrupolar Eshelby propagator in Fourier space.

rheojax.utils.epm_kernels.solve_elastic_propagator(plastic_strain_rate, propagator_q)[source]

Solve for the elastic stress redistribution rate using FFT.

Calculates sigma_dot_el = G * epsilon_dot_pl.

Parameters:
  • plastic_strain_rate (Array) – 2D array of plastic strain rate field (L, L).

  • propagator_q (Array) – Precomputed propagator in Fourier space (L, L // 2 + 1).

Return type:

Array

Returns:

2D array of elastic stress redistribution rate.

Solves for elastic stress redistribution using FFT.

rheojax.utils.epm_kernels.compute_plastic_strain_rate(stress, yield_thresholds, fluidity=1.0, smooth=False, smoothing_width=0.1, n_fluid=1.0, sigma_c_mean=1.0, fluidity_form='overstress')[source]

Compute the local plastic strain rate.

Supports two activation modes and three constitutive laws selected by fluidity_form:

  1. Hard activation: gamma_dot_p = f(sigma) * Theta(|sigma| - sigma_c)

  2. Smooth activation: gamma_dot_p = f(sigma) * 0.5 * (1 + tanh((|sigma| - sigma_c)/w))

The constitutive law f(sigma) depends on fluidity_form:

  • “linear” (classical Bingham):

    f(sigma) = sigma / tau_pl

    High-rate asymptote: stress ~ gamma_dot * tau_pl. No yield-stress baseline in the asymptote.

  • “power” (power-law fluidity, soft-glassy rheology):

    f(sigma) = sign(sigma) * |sigma / sigma_c_mean|^n_fluid * sigma_c_mean / tau_pl

    High-rate asymptote: stress ~ sigma_c_mean * (gamma_dot * tau_pl)^(1/n_fluid). Shear-thinning but no additive yield-stress baseline.

  • “overstress” (Herschel-Bulkley, DEFAULT):

    f(sigma) = sign(sigma) * (|sigma| - sigma_c_mean)_+^n_fluid / (sigma_c_mean^(n_fluid-1) * tau_pl)

    Only stress above the threshold contributes to plastic flow. High-rate asymptote: stress ~ sigma_c_mean + sigma_c_mean * (gamma_dot * tau_pl / sigma_c_mean)^(1/n_fluid). This is the full Herschel-Bulkley form sigma = sigma_y + K * gamma_dot^n_HB with sigma_y = sigma_c_mean and n_HB = 1/n_fluid. Recommended for HB-like flow curves (emulsions, gels, foams, yield-stress fluids in general).

At n_fluid = 1, “power” reduces to “linear”; “overstress” at n_fluid = 1 gives a Bingham fluid with explicit yield stress (sigma = sigma_c_mean + gamma_dot * tau_pl).

Parameters:
  • stress (Array) – Local stress field.

  • yield_thresholds (Array) – Local yield thresholds.

  • fluidity (float) – Inverse plastic timescale ($1/tau_{pl}$).

  • smooth (bool) – Whether to use the differentiable smooth approximation.

  • smoothing_width (float) – Width parameter $w$ for smoothing.

  • n_fluid (float) – Power-law / HB exponent. The implied HB flow exponent is n_HB = 1/n_fluid.

  • sigma_c_mean (float) – Mean yield threshold, used as the scale for the power-law forms.

  • fluidity_form (str) – One of “linear”, “power”, or “overstress”. Default “overstress”.

Return type:

Array

Returns:

Plastic strain rate field.

Computes local plastic strain rate (supports Hard and Smooth modes).

rheojax.utils.epm_kernels.epm_step(state, propagator_q, shear_rate, dt, params, smooth=False, fluidity_form='overstress')[source]

Perform one full EPM time step.

Dynamics: sigma_dot = mu * gamma_dot - mu * gamma_dot_p + G * gamma_dot_p

Parameters:
  • state (tuple[Array, Array, float, Array]) – Tuple (stress, yield_thresholds, accumulated_strain, key).

  • propagator_q (Array) – Precomputed propagator.

  • shear_rate (float) – Macroscopic imposed shear rate gamma_dot.

  • dt (float) – Time step size.

  • params (dict) – Dictionary of model parameters (mu, tau_pl, sigma_c_mean, etc.).

  • smooth (bool) – Use smooth yielding (for inference) vs hard yielding (for simulation).

Return type:

tuple[Array, Array, float, Array]

Returns:

Updated state tuple.

Performs one full time step of the EPM dynamics.