Source code for rheojax.utils.prony

"""Prony series utilities for Generalized Maxwell Model parameter identification.

This module provides utilities for working with Prony series representations of
viscoelastic relaxation moduli:

    E(t) = E_∞ + Σᵢ₌₁ᴺ Eᵢ exp(-t/τᵢ)

Key capabilities:
- Parameter validation and bounds checking
- Dynamic ParameterSet creation for N modes
- Log-space transforms for wide time-scale ranges
- Element minimization (optimal N selection)
- R² goodness-of-fit metric computation
- Softmax penalty for constrained optimization

References:
    - Park, S. W., & Schapery, R. A. (1999). Methods of interconversion between
      linear viscoelastic material functions. Part I—A numerical method based on
      Prony series. International Journal of Solids and Structures, 36(11), 1653-1675.
    - pyvisco: https://github.com/saintsfan342000/pyvisco
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np

from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import ParameterSet
from rheojax.logging import get_logger

logger = get_logger(__name__)

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()

if TYPE_CHECKING:  # pragma: no cover
    import jax.numpy as jnp_typing
else:
    jnp_typing = np


type ArrayLike = np.ndarray | jnp_typing.ndarray


[docs] def validate_prony_parameters( E_inf: float, E_i: ArrayLike, tau_i: ArrayLike ) -> tuple[bool, str]: """Validate Prony series parameters for physical consistency. Checks: - E_inf ≥ 0 (equilibrium modulus non-negative) - All Eᵢ > 0 (positive mode strengths) - All τᵢ > 0 (positive relaxation times) - Same number of Eᵢ and τᵢ elements Args: E_inf: Equilibrium modulus (Pa) E_i: Array of mode strengths (Pa) tau_i: Array of relaxation times (s) Returns: (valid, message): Tuple of validation status and error message Example: >>> E_inf = 1e3 >>> E_i = np.array([1e5, 1e4, 1e3]) >>> tau_i = np.array([1e-2, 1e-1, 1.0]) >>> valid, msg = validate_prony_parameters(E_inf, E_i, tau_i) >>> print(valid) True """ logger.debug( "Validating Prony series parameters", E_inf=E_inf, n_modes=len(E_i) if hasattr(E_i, "__len__") else 1, ) # Convert to numpy arrays for consistent handling E_i_arr = np.asarray(E_i) tau_i_arr = np.asarray(tau_i) # Check E_inf non-negative if E_inf < 0: logger.debug( "Prony validation failed: negative E_inf", E_inf=E_inf, ) return False, f"E_inf must be non-negative, got {E_inf}" # Check array lengths match if len(E_i_arr) != len(tau_i_arr): logger.debug( "Prony validation failed: length mismatch", len_E_i=len(E_i_arr), len_tau_i=len(tau_i_arr), ) return ( False, f"E_i and tau_i must have same length, got {len(E_i_arr)} and {len(tau_i_arr)}", ) # Check all Eᵢ > 0 if np.any(E_i_arr <= 0): neg_indices = np.where(E_i_arr <= 0)[0] logger.debug( "Prony validation failed: non-positive mode strengths", neg_indices=neg_indices.tolist(), ) return ( False, f"All E_i must be positive, found non-positive at indices {neg_indices.tolist()}", ) # Check all τᵢ > 0 if np.any(tau_i_arr <= 0): neg_indices = np.where(tau_i_arr <= 0)[0] logger.debug( "Prony validation failed: non-positive relaxation times", neg_indices=neg_indices.tolist(), ) return ( False, f"All tau_i must be positive, found non-positive at indices {neg_indices.tolist()}", ) logger.debug("Prony parameters validation passed", n_modes=len(E_i_arr)) return True, ""
[docs] def create_prony_parameter_set( n_modes: int, modulus_type: str = "shear" ) -> ParameterSet: """Create ParameterSet for N-mode Prony series. Dynamically generates parameters: - E_inf (or G_inf for shear): Equilibrium modulus - E_1...E_N (or G_1...G_N): Mode strengths - tau_1...tau_N: Relaxation times Args: n_modes: Number of Maxwell modes (N ≥ 1) modulus_type: 'shear' for G(t) or 'tensile' for E(t) Returns: ParameterSet with 2N+1 parameters configured for Prony series Raises: ValueError: If n_modes < 1 or modulus_type invalid Example: >>> params = create_prony_parameter_set(n_modes=3, modulus_type='shear') >>> list(params.keys()) ['G_inf', 'G_1', 'G_2', 'G_3', 'tau_1', 'tau_2', 'tau_3'] """ logger.debug( "Creating Prony ParameterSet", n_modes=n_modes, modulus_type=modulus_type, ) if n_modes < 1: logger.error( "Invalid n_modes for Prony ParameterSet", n_modes=n_modes, ) raise ValueError(f"n_modes must be ≥ 1, got {n_modes}") if modulus_type not in ["shear", "tensile"]: logger.error( "Invalid modulus_type for Prony ParameterSet", modulus_type=modulus_type, valid_types=["shear", "tensile"], ) raise ValueError( f"modulus_type must be 'shear' or 'tensile', got {modulus_type}" ) param_set = ParameterSet() # Choose symbol based on modulus type symbol = "G" if modulus_type == "shear" else "E" units = "Pa" # Tensile moduli (E) are ~3x shear moduli (G) and real DMTA polymer # data can reach ~10 GPa, so tensile bounds must be wider. modulus_upper = 1e12 if modulus_type == "tensile" else 1e9 # Add equilibrium modulus (can be zero for liquids) param_set.add( name=f"{symbol}_inf", value=1e3, bounds=(0.0, modulus_upper), units=units, description=f"Equilibrium {modulus_type} modulus", ) # Add mode strengths (must be positive). Lower bound is 0.0 so the # optimizer can push a superfluous mode arbitrarily close to zero during # element minimization; ParameterSet.set_value would otherwise reject # writeback of sub-milli-Pascal values converged by NLSQ. for i in range(1, n_modes + 1): param_set.add( name=f"{symbol}_{i}", value=1e5, bounds=(0.0, modulus_upper), units=units, description=f"Mode {i} strength", ) # Add relaxation times. Bounds widened to [1e-30, 1e30] so TTS master # curves (which can span 20+ decades in reduced time) can be fit without # the ParameterSet constraint check rejecting valid optimizer solutions. # The _fit_*_mode methods derive the actual initial-guess range from the # observed x-data and pad it, so the wide bound here is only a hard # backstop, not the active search space. for i in range(1, n_modes + 1): param_set.add( name=f"tau_{i}", value=10.0 ** (i - 1 - n_modes / 2), # Logarithmic spacing bounds=(1e-30, 1e30), units="s", description=f"Mode {i} relaxation time", ) logger.info( "Created Prony ParameterSet", n_modes=n_modes, modulus_type=modulus_type, n_parameters=len(param_set), ) return param_set
[docs] def tau_to_log_tau(tau_i: ArrayLike) -> ArrayLike: """Transform relaxation times to log-space. Useful for optimization over wide time-scale ranges (e.g., 1e-6 to 1e6 s). Log-space optimization provides more uniform parameter sensitivity. Args: tau_i: Array of relaxation times (s) Returns: log10(tau_i): Log-transformed relaxation times Example: >>> tau = np.array([1e-3, 1e-1, 1e1, 1e3]) >>> log_tau = tau_to_log_tau(tau) >>> print(log_tau) [-3. -1. 1. 3.] """ tau_arr = jnp.asarray(tau_i) return jnp.log10(tau_arr)
[docs] def log_tau_to_tau(log_tau_i: ArrayLike) -> ArrayLike: """Transform log-space relaxation times back to linear space. Inverse of tau_to_log_tau(). Args: log_tau_i: Array of log10(tau) values Returns: tau_i: Relaxation times (s) Example: >>> log_tau = np.array([-3., -1., 1., 3.]) >>> tau = log_tau_to_tau(log_tau) >>> print(tau) [1.e-03 1.e-01 1.e+01 1.e+03] """ log_tau_arr = jnp.asarray(log_tau_i) return jnp.power(10.0, log_tau_arr)
[docs] def compute_r_squared(y_true: ArrayLike, y_pred: ArrayLike) -> float: """Compute R² coefficient of determination. R² = 1 - SS_res / SS_tot where SS_res = Σ(y_true - y_pred)², SS_tot = Σ(y_true - mean(y_true))² R² ∈ (-∞, 1], with R²=1 being perfect fit. Args: y_true: True values y_pred: Predicted values Returns: R² coefficient (1.0 = perfect fit, 0.0 = mean baseline, <0 = worse than mean) Example: >>> y_true = np.array([1., 2., 3., 4., 5.]) >>> y_pred = np.array([1.1, 2.0, 2.9, 4.1, 5.0]) >>> r2 = compute_r_squared(y_true, y_pred) >>> print(f"{r2:.4f}") 0.9960 """ # Use pure numpy: this is a scalar metric with no benefit from JAX dispatch. # The original JAX path had 23x overhead due to host→device transfer and the # Python-level `if ss_tot == 0` branch forcing device→host synchronisation. y_true_arr = np.asarray(y_true) y_pred_arr = np.asarray(y_pred) # Residual sum of squares ss_res = np.sum((y_true_arr - y_pred_arr) ** 2) # Total sum of squares ss_tot = np.sum((y_true_arr - np.mean(y_true_arr)) ** 2) # Handle edge case where all y_true are identical (R² undefined) if ss_tot == 0.0: return 1.0 if ss_res == 0.0 else -1.0 r2 = 1.0 - ss_res / ss_tot return float(r2)
[docs] def iterative_n_reduction(fit_results_dict: dict[int, float]) -> dict[str, ArrayLike]: """Track R² vs N for element minimization visualization. Args: fit_results_dict: Dictionary mapping n_modes → R² value Example: {10: 0.998, 9: 0.997, 8: 0.995, ...} Returns: Dictionary with keys: - 'n_modes': Array of N values (sorted ascending) - 'r2': Array of R² values corresponding to each N - 'r2_min': Minimum R² across all fits - 'r2_max': Maximum R² across all fits Example: >>> results = {10: 0.998, 8: 0.995, 6: 0.990, 4: 0.980, 2: 0.950} >>> diagnostics = iterative_n_reduction(results) >>> print(diagnostics['n_modes']) [ 2 4 6 8 10] >>> print(diagnostics['r2']) [0.95 0.98 0.99 0.995 0.998] """ if not fit_results_dict: raise ValueError("fit_results_dict cannot be empty") # Sort by n_modes n_values = sorted(fit_results_dict.keys()) r2_values = [fit_results_dict[n] for n in n_values] return { "n_modes": np.array(n_values), "r2": np.array(r2_values), "r2_min": float(np.min(r2_values)), "r2_max": float(np.max(r2_values)), }
[docs] def select_optimal_n( r2_values: dict[int, float], optimization_factor: float = 1.5 ) -> int: """Select optimal number of modes using R² threshold criterion. Algorithm: 1. Find maximum R² across all N: R²_max (best achievable fit) 2. Compute R² degradation tolerance: ΔR² = (1 - R²_max) × (optimization_factor - 1.0) 3. Set threshold: R²_threshold = R²_max - ΔR² 4. Select smallest N where R²_N ≥ R²_threshold Interpretation: - optimization_factor = 1.0: Require R² ≥ R²_max (maximum parsimony, only accept best) - optimization_factor = 1.5: Allow 50% of max degradation (balance quality/parsimony) - optimization_factor = 2.0: Allow 100% of max degradation (maximum parsimony) For optimization_factor > 1, this allows some degradation from the best fit in exchange for fewer parameters. Higher factor = more tolerant of degradation = simpler model. Args: r2_values: Dictionary mapping n_modes → R² value optimization_factor: Parsimony factor (≥ 1.0) - 1.0: No degradation allowed (require best R²) - 1.5 (default): Allow 50% of max possible degradation - 2.0: Allow 100% degradation (maximum simplicity) Returns: Optimal number of modes (N_opt) Raises: ValueError: If optimization_factor < 1.0 or r2_values empty Example: >>> r2 = {5: 0.998, 3: 0.995, 2: 0.980, 1: 0.900} >>> # R²_max = 0.998, degradation room = 1 - 0.998 = 0.002 >>> # factor=1.5: ΔR² = 0.002 × 0.5 = 0.001, threshold = 0.997 >>> # Smallest N with R² ≥ 0.997: N=3 >>> n_opt = select_optimal_n(r2, optimization_factor=1.5) >>> print(n_opt) 3 >>> # factor=1.0: ΔR² = 0, threshold = 0.998, need N=5 >>> n_opt = select_optimal_n(r2, optimization_factor=1.0) >>> print(n_opt) 5 """ logger.debug( "Selecting optimal number of modes", n_candidates=len(r2_values), optimization_factor=optimization_factor, ) if optimization_factor < 1.0: logger.error( "Invalid optimization_factor", optimization_factor=optimization_factor, ) raise ValueError( f"optimization_factor must be ≥ 1.0, got {optimization_factor}" ) if not r2_values: logger.error("Empty r2_values dictionary") raise ValueError("r2_values cannot be empty") # Find maximum R² (best fit) r2_max = max(r2_values.values()) # Compute degradation tolerance # degradation_room = how much R² can degrade from perfect (1.0 - r2_max) # we allow (optimization_factor - 1.0) × degradation_room loss degradation_room = 1.0 - r2_max allowed_degradation = degradation_room * (optimization_factor - 1.0) # Set threshold r2_threshold = r2_max - allowed_degradation logger.debug( "Computed R2 threshold for mode selection", r2_max=r2_max, r2_threshold=r2_threshold, allowed_degradation=allowed_degradation, ) # Find smallest N satisfying threshold # Sort by N (ascending) to find minimum N first n_sorted = sorted(r2_values.keys()) for n in n_sorted: if r2_values[n] >= r2_threshold: logger.info( "Selected optimal number of modes", n_opt=n, r2_at_n_opt=r2_values[n], r2_max=r2_max, r2_threshold=r2_threshold, ) return n # If no N satisfies threshold (shouldn't happen), return smallest N n_opt = min(n_sorted) # R8-PRONY-002: warn when falling back to minimum N if max(r2_values.values()) > 0: logger.warning( "select_optimal_n: no N satisfies R\u00b2 threshold %.4f. " "Returning minimum N=%d (R\u00b2=%.4f). Consider adjusting optimization_factor.", r2_threshold, n_opt, r2_values.get(n_opt, float("nan")), ) else: # R10-PRONY-001: r2_max <= 0 means the best fit is worse than a flat line — # this is a strong signal of data quality issues or a fundamentally wrong model. # Elevate from info to warning so it is not silently overlooked. logger.warning( "select_optimal_n: best R\u00b2=%.4f is non-positive (fit worse than " "flat-line baseline). Returning minimum N=%d. Check data quality and " "model choice.", r2_max, n_opt, ) return n_opt
[docs] def softmax_penalty(E_i: ArrayLike, scale: float = 1.0): """Compute softmax penalty for negative moduli in Step 1 fitting. This differentiable penalty encourages positive Eᵢ values during unconstrained optimization. It approaches zero when all Eᵢ >> 0, and increases smoothly for negative values. Penalty = scale × Σᵢ log(1 + exp(-Eᵢ/scale)) Args: E_i: Array of mode strengths (Pa) scale: Smoothness parameter (default 1.0). Larger values give smoother penalty but weaker enforcement. Returns: Penalty value (≥ 0, differentiable, JAX array or scalar) Note: Returns JAX array for gradient compatibility. Do not convert to Python float() when used in JAX-traced functions. Example: >>> E_i = np.array([1e5, 1e4, -1e3]) # One negative mode >>> penalty = softmax_penalty(E_i, scale=1e3) >>> print(f"{penalty:.2f}") 693.15 # Penalty for negative value >>> E_i_pos = np.array([1e5, 1e4, 1e3]) # All positive >>> penalty_pos = softmax_penalty(E_i_pos, scale=1e3) >>> print(f"{penalty_pos:.2e}") 3.13e+02 # Small penalty for finite positive values """ E_arr = jnp.asarray(E_i) x = -E_arr / scale penalty = scale * jnp.sum( jnp.where(x > 0, x + jnp.log1p(jnp.exp(-x)), jnp.log1p(jnp.exp(x))) ) # Return JAX array (do not convert to Python float for gradient compatibility) return penalty
[docs] def warm_start_from_n_modes( params_n: ArrayLike, n_target: int, modulus_type: str = "shear" ) -> ArrayLike: """Extract warm-start parameters for reduced-mode fit from N-mode solution. Used in element minimization to initialize N-1 mode fit from N mode solution. Provides intelligent parameter extraction for faster convergence in successive NLSQ fits during element search optimization. Algorithm: 1. Extract E_inf, E_i, tau_i from N-mode params 2. If n_target < N: Truncate to first n_target modes (keep strongest modes) 3. If n_target > N: Pad with zeros/default values (edge case, typically not used) 4. If n_target == N: Return params unchanged Parameter Layout: - params_n format: [E_inf, E_1, E_2, ..., E_N, tau_1, tau_2, ..., tau_N] - Total length: 2*N + 1 Args: params_n: Fitted parameters from N-mode optimization Shape: (2*N + 1,) where N is current number of modes n_target: Target number of modes for next fit (typically N-1) modulus_type: 'shear' (G) or 'tensile' (E) - currently not used, but kept for API consistency Returns: Initial parameters for n_target-mode fit Shape: (2*n_target + 1,) Raises: ValueError: If n_target < 1 or params_n has invalid length Example: >>> # 5-mode fit result >>> params_5 = np.array([1e3, 1e6, 5e5, 2e5, 8e4, 3e4, # E_inf, E_1..E_5 ... 1e-2, 1e-1, 1.0, 1e1, 1e2]) # tau_1..tau_5 >>> # Warm-start for 4-mode fit (truncate weakest mode E_5) >>> params_4 = warm_start_from_n_modes(params_5, n_target=4) >>> print(params_4.shape) (9,) # 2*4 + 1 = 9 parameters >>> # E_inf, E_1..E_4, tau_1..tau_4 >>> print(params_4) [1.e+03 1.e+06 5.e+05 2.e+05 8.e+04 1.e-02 1.e-01 1.e+00 1.e+01] Notes: - Truncation assumes modes are ordered by importance (strongest first) - For GMM fitting, this ordering is typically achieved by sorting by E_i - Warm-start can provide 2-5x speedup in element minimization - Compilation reuse provides additional speedup when combined with this """ logger.debug( "Extracting warm-start parameters", n_target=n_target, params_length=len(params_n) if hasattr(params_n, "__len__") else 1, modulus_type=modulus_type, ) if n_target < 1: logger.error("Invalid n_target for warm-start", n_target=n_target) raise ValueError(f"n_target must be ≥ 1, got {n_target}") params_arr = np.asarray(params_n) # Infer current N from params length: 2*N + 1 if (len(params_arr) - 1) % 2 != 0: logger.error( "Invalid parameter array format", params_length=len(params_arr), expected_format="2*N+1", ) raise ValueError( f"Invalid params_n length {len(params_arr)}, expected 2*N+1 format" ) n_current = (len(params_arr) - 1) // 2 # Extract components E_inf = params_arr[0] E_i = params_arr[1 : 1 + n_current] tau_i = params_arr[1 + n_current :] # Case 1: Reduce modes (typical for element minimization) if n_target < n_current: # Truncate to first n_target modes E_i_target = E_i[:n_target] tau_i_target = tau_i[:n_target] logger.debug( "Warm-start: reducing modes", n_current=n_current, n_target=n_target, operation="truncate", ) # Case 2: Same number of modes (no-op) elif n_target == n_current: E_i_target = E_i tau_i_target = tau_i logger.debug( "Warm-start: same number of modes", n_current=n_current, operation="passthrough", ) # Case 3: Increase modes (edge case, pad relative to existing tau values) else: n_pad = n_target - n_current # Pad taus extending beyond existing range TAU_LB, TAU_UB = 1e-6, 1e6 if len(tau_i) == 0: tau_min_new = TAU_LB tau_max_new = TAU_UB else: tau_min_new = max(tau_i.min() / (10.0 ** max(n_pad, 1)), TAU_LB) tau_max_new = min(tau_i.max() * (10.0 ** max(n_pad, 1)), TAU_UB) new_taus_candidates = np.logspace( np.log10(tau_min_new), np.log10(tau_max_new), n_pad * 3 ) # Filter out taus too close to existing ones new_taus = new_taus_candidates[ ~np.any( np.abs( np.log10(new_taus_candidates[:, None]) - np.log10(tau_i[None, :]) ) < 0.3, axis=1, ) ][:n_pad] if len(new_taus) < n_pad: # Fallback: extend beyond existing range (clamp to valid logspace bounds) tau_upper = min(tau_i.max() * 2, TAU_UB) tau_lower = max(tau_i.min() / 2, TAU_LB) if tau_lower >= tau_upper: tau_lower = TAU_LB tau_upper = TAU_UB new_taus = np.logspace(np.log10(tau_lower), np.log10(tau_upper), n_pad) tau_i_combined = np.concatenate([tau_i, new_taus]) sort_idx = np.argsort(tau_i_combined) tau_i_target = tau_i_combined[sort_idx] # R8-PRONY-001: guard against empty E_i from corrupted params E_fill = E_i.mean() if len(E_i) > 0 else 1e4 E_i_combined = np.concatenate([E_i, np.full(n_pad, E_fill)]) E_i_target = E_i_combined[sort_idx] logger.debug( "Warm-start: increasing modes", n_current=n_current, n_target=n_target, n_pad=n_pad, operation="pad", ) # Reconstruct parameter array params_target = np.concatenate([[E_inf], E_i_target, tau_i_target]) logger.debug( "Warm-start parameters extracted", output_length=len(params_target), E_inf=E_inf, ) return params_target