Source code for rheojax.utils.spp_kernels

"""JAX-compatible SPP (Sequence of Physical Processes) kernel functions.

This module provides efficient, JAX-compatible implementations of SPP analysis
kernel functions for LAOS (Large Amplitude Oscillatory Shear) rheology. SPP
analysis enables cycle-by-cycle decomposition of nonlinear stress responses
into elastic and viscous contributions, extracting physically meaningful
yield parameters.

The SPP framework was developed by Rogers (2012, 2017) and provides:
- Time-resolved apparent cage modulus G'_cage(t)
- Static and dynamic yield stress extraction
- Phase reconstruction from harmonic decomposition
- Lissajous-Bowditch plot metrics
- Frenet-Serret frame analysis (T, N, B vectors)
- Moduli rate calculations (Ġ', Ġ'', G_speed)

Key Functions
-------------
- apparent_cage_modulus: Time-resolved elastic modulus from stress/strain
- static_yield_stress: Yield stress at strain reversal (strain = ±gamma0)
- dynamic_yield_stress: Yield stress at rate reversal (strain rate = 0)
- harmonic_reconstruction: Stress reconstruction from Fourier components
- harmonic_reconstruction_full: Full Fourier with phase alignment (MATLAB-compatible)
- spp_fourier_analysis: Complete Fourier-based SPP analysis with analytical derivatives
- lissajous_metrics: Bowditch diagram derived quantities (S, T ratios)
- zero_crossing_detection: Robust strain/rate zero-crossing finder
- frenet_serret_frame: Compute T, N, B trajectory vectors
- moduli_rates: Compute Ġ'(t), Ġ''(t), G_speed, δ̇(t)
- yield_from_displacement_stress: SPP-based yield stress extraction

Physical Interpretation
-----------------------
SPP analysis provides a phenomenological interpretation of LAOS behavior:
- G'_cage(t): Instantaneous elastic modulus reflecting cage structure
- static_yield: Static yield stress (stress at max strain, cage breakage threshold)
- dynamic_yield: Dynamic yield stress (stress at zero rate, flow cessation threshold)
- Power-law regime: Post-yield flow characterized by sigma ~ strain_rate^n
- Frenet-Serret frame: Geometric analysis of (γ, γ̇/ω, σ) trajectory

References
----------
- S.A. Rogers et al., "A sequence of physical processes determined and
  quantified in large-amplitude oscillatory shear (LAOS): Application to
  theoretical nonlinear models", J. Rheol. 56(1), 2012
- S.A. Rogers, "In search of physical meaning: defining transient parameters
  for nonlinear viscoelasticity", Rheol. Acta 56, 2017
- G.J. Donley et al., "Time-resolved dynamics of the yielding transition
  in soft materials", J. Non-Newton. Fluid Mech. 264, 2019
"""

from functools import partial
from typing import TYPE_CHECKING

import numpy as np

from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger

logger = get_logger(__name__)

# Safe JAX import (enforces float64)
# Float64 precision is critical for accurate numerical differentiation
jax, jnp = safe_import_jax()

if TYPE_CHECKING:
    from jax import Array


# ============================================================================
# Apparent Cage Modulus
# ============================================================================


[docs] @jax.jit def apparent_cage_modulus( stress: "Array", strain: "Array", strain_amplitude: float, ) -> "Array": """ Compute time-resolved apparent cage modulus. Apparent cage modulus is the instantaneous elastic response, normalized by strain amplitude: G_cage(t) = stress(t) / gamma0 * sign(strain(t)). Reference: Rogers et al. (2012) J. Rheol. 56(1). Eq. (1): G'_cage(t) = sigma(t) / gamma_0 * sign(gamma(t)) Parameters ---------- stress : Array Time-resolved stress signal (Pa) strain : Array Time-resolved strain signal (dimensionless) strain_amplitude : float Maximum strain amplitude gamma0 (dimensionless) Returns ------- Array Apparent cage modulus (Pa) Examples -------- >>> import jax.numpy as jnp >>> from rheojax.utils.spp_kernels import apparent_cage_modulus >>> >>> # Sinusoidal LAOS data >>> t = jnp.linspace(0, 2*jnp.pi, 1000) >>> gamma_0 = 1.0 >>> gamma = gamma_0 * jnp.sin(t) >>> sigma = 100.0 * jnp.sin(t) + 10.0 * jnp.sin(3*t) # With 3rd harmonic >>> >>> G_cage = apparent_cage_modulus(sigma, gamma, gamma_0) Notes ----- - G_cage is constant for a purely linear sinusoidal material - Deviations from constant indicate nonlinearity - Sign(γ) ensures correct sign during negative strain half-cycle """ stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) gamma_0 = jnp.float64(strain_amplitude) # Avoid division by zero gamma_0 = jnp.where(gamma_0 > 1e-10, gamma_0, 1e-10) # Compute sign of strain (avoid division by zero at crossings) strain_sign = jnp.sign(strain_arr) # At zero crossing, use sign from neighboring points (forward difference) strain_sign = jnp.where( strain_arr == 0, jnp.sign(jnp.roll(strain_arr, -1)), strain_sign, ) # Apparent cage modulus: stress / gamma0 * sign(strain) G_cage = stress_arr / gamma_0 * strain_sign return G_cage
# ============================================================================ # Yield Stress Extraction # ============================================================================
[docs] @jax.jit def static_yield_stress( stress: "Array", strain: "Array", strain_amplitude: float, tolerance: float = 0.02, ) -> float: """Approximate static yield stress near strain reversal. Samples near the strain extrema (abs(strain) close to strain_amplitude) and returns the average absolute stress. """ stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) gamma_0 = jnp.float64(strain_amplitude) # Find points where |γ| ≈ γ_0 (strain reversal) threshold = gamma_0 * (1.0 - tolerance) at_reversal = jnp.abs(strain_arr) >= threshold # Average stress magnitude at reversal points stress_at_reversal = jnp.where(at_reversal, jnp.abs(stress_arr), 0.0) count = jnp.sum(at_reversal) # Avoid division by zero sigma_sy = jnp.where( count > 0, jnp.sum(stress_at_reversal) / count, jnp.abs(stress_arr).max(), ) return sigma_sy
[docs] @jax.jit def dynamic_yield_stress( stress: "Array", strain_rate: "Array", rate_amplitude: float, tolerance: float = 0.02, ) -> float: """Approximate dynamic yield stress near zero strain rate. Selects samples where abs(strain_rate) is small, averages abs(stress), and returns that average as the dynamic yield estimate. """ stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) strain_rate_arr = jnp.atleast_1d(jnp.asarray(strain_rate, dtype=jnp.float64)) gamma_dot_0 = jnp.float64(rate_amplitude) # Find points where |γ̇| ≈ 0 (rate reversal) threshold = gamma_dot_0 * tolerance at_zero_rate = jnp.abs(strain_rate_arr) <= threshold # Average stress magnitude at zero-rate points stress_at_zero = jnp.where(at_zero_rate, jnp.abs(stress_arr), 0.0) count = jnp.sum(at_zero_rate) # Avoid division by zero sigma_dy = jnp.where( count > 0, jnp.sum(stress_at_zero) / count, jnp.abs(stress_arr).min(), ) return sigma_dy
# ============================================================================ # Phase Reconstruction # ============================================================================
[docs] @partial(jax.jit, static_argnums=(2,)) def harmonic_reconstruction( stress: "Array", omega: float, n_harmonics: int = 39, dt: float | None = None, ) -> tuple["Array", "Array", "Array"]: """ Reconstruct stress signal from harmonic components (Fourier decomposition). Extracts odd harmonic amplitudes and phases from LAOS stress signal, enabling reconstruction and harmonic ratio analysis. Parameters ---------- stress : Array Time-resolved stress signal σ(t) (Pa) omega : float Fundamental angular frequency ω (rad/s) n_harmonics : int, optional Number of odd harmonics to extract (default: 5, gives 1ω, 3ω, 5ω, 7ω, 9ω) dt : float, optional Time step. If None, assumes stress spans exactly one period. Returns ------- amplitudes : Array Harmonic amplitudes [A_1, A_3, A_5, ...] (Pa) phases : Array Harmonic phases [φ_1, φ_3, φ_5, ...] (radians) reconstructed : Array Reconstructed stress from harmonics (Pa) Examples -------- >>> import jax.numpy as jnp >>> from rheojax.utils.spp_kernels import harmonic_reconstruction >>> >>> omega = 1.0 >>> t = jnp.linspace(0, 2*jnp.pi, 1000) >>> sigma = 100.0 * jnp.sin(t) + 20.0 * jnp.sin(3*t + 0.1) >>> >>> amps, phases, reconstructed = harmonic_reconstruction(sigma, omega) >>> # amps[0] ≈ 100.0, amps[1] ≈ 20.0 >>> # phases[0] ≈ 0, phases[1] ≈ 0.1 Notes ----- - Only odd harmonics (1, 3, 5, ...) are physically relevant in LAOS - Even harmonics indicate asymmetric response (wall slip, etc.) - I_n/I_1 ratio quantifies nonlinearity strength """ stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) n_points = len(stress_arr) # Determine time array if dt is None: period = 2.0 * jnp.pi / omega dt = period / n_points t = jnp.arange(n_points) * dt # Extract odd harmonics via discrete Fourier projection amplitudes = jnp.zeros(n_harmonics, dtype=jnp.float64) phases = jnp.zeros(n_harmonics, dtype=jnp.float64) def extract_harmonic(carry, harmonic_idx): n = 2 * harmonic_idx + 1 # Odd harmonics: 1, 3, 5, ... omega_n = n * omega # Fourier projection cos_component = jnp.sum(stress_arr * jnp.cos(omega_n * t)) * 2.0 / n_points sin_component = jnp.sum(stress_arr * jnp.sin(omega_n * t)) * 2.0 / n_points # Amplitude and phase amplitude = jnp.sqrt(cos_component**2 + sin_component**2 + 1e-30) phase = jnp.arctan2(-cos_component, sin_component) return carry, (amplitude, phase) _, (amplitudes, phases) = jax.lax.scan( extract_harmonic, None, jnp.arange(n_harmonics), ) # Reconstruct signal — vectorized over harmonics and time harmonics = 2 * jnp.arange(n_harmonics) + 1 # [1, 3, 5, ...] # phase matrix: (n_harmonics, len(t)) phase_matrix = harmonics[:, None] * omega * t[None, :] + phases[:, None] reconstructed = jnp.sum(amplitudes[:, None] * jnp.sin(phase_matrix), axis=0) return amplitudes, phases, reconstructed
# ============================================================================ # Phase-Aligned Harmonic Reconstruction (MATLAB-Compatible) # ============================================================================
[docs] @partial(jax.jit, static_argnums=(3,)) def compute_phase_offset( strain: "Array", omega: float, dt: float, n_cycles: int = 1, ) -> float: """ Compute phase offset Delta for aligning strain to start at zero crossing. This matches MATLAB SPPplus_fourier_v2.m phase offset calculation: Delta = atan(An1_n(p+1)/Bn1_n(p+1)) if Bn1_n(p+1) < 0: Delta = Delta + pi Parameters ---------- strain : Array Strain signal γ(t) omega : float Angular frequency ω (rad/s) dt : float Time step (s) n_cycles : int Number of complete cycles in data (default: 1) Returns ------- float Phase offset Delta (radians) to align strain reference """ strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) L = len(strain_arr) p = n_cycles # Compute FFT of strain fft_strain = jnp.fft.fft(strain_arr) # Get Fourier coefficients (matching MATLAB convention) # An1 = 2*Re(FFT)/L, Bn1 = -2*Im(FFT)/L An1_n = 2 * jnp.real(fft_strain) / L Bn1_n = -2 * jnp.imag(fft_strain) / L # Get fundamental harmonic coefficient (at index p+1 in MATLAB, p in Python 0-indexed) # For p cycles, the fundamental is at index p An_fund = An1_n[p] Bn_fund = Bn1_n[p] # Compute Delta Delta = jnp.arctan2(An_fund, Bn_fund) # Adjust if Bn_fund < 0 Delta = jnp.where(Bn_fund < 0, Delta + jnp.pi, Delta) return Delta
[docs] @partial(jax.jit, static_argnums=(4, 5, 6)) def harmonic_reconstruction_full( strain: "Array", strain_rate: "Array", stress: "Array", omega: float, n_harmonics: int = 39, n_cycles: int = 1, W_int: int | None = None, ) -> dict: """ Full Fourier-based harmonic reconstruction with phase alignment (MATLAB-compatible). Implements the complete workflow from SPPplus_fourier_v2.m: 1. FFT all three waveforms (strain, rate, stress) 2. Compute phase offset Delta from strain fundamental 3. Rotate all Fourier coefficients to align with phase reference 4. Reconstruct aligned waveforms Parameters ---------- strain : Array Strain signal γ(t) (dimensionless) strain_rate : Array Strain rate signal γ̇(t) (1/s) - will be normalized by omega stress : Array Stress signal σ(t) (Pa) omega : float Angular frequency ω (rad/s) n_harmonics : int Number of odd harmonics for stress reconstruction (default: 15) n_cycles : int Number of complete cycles in data (default: 1) Returns ------- dict Dictionary containing: - Delta: Phase offset (radians) - An_strain, Bn_strain: Aligned strain Fourier coefficients - An_rate, Bn_rate: Aligned rate Fourier coefficients - An_stress, Bn_stress: Aligned stress Fourier coefficients - strain_recon: Reconstructed strain - rate_recon: Reconstructed rate/omega - stress_recon: Reconstructed stress - time_new: Phase-aligned time array Notes ----- This function matches MATLAB SPPplus_fourier_v2.m coefficient rotation: An_n[nn+1] = An1_n[nn+1]*cos(Delta/p*nn) - Bn1_n[nn+1]*sin(Delta/p*nn) Bn_n[nn+1] = Bn1_n[nn+1]*cos(Delta/p*nn) + An1_n[nn+1]*sin(Delta/p*nn) """ logger.debug( "Starting harmonic reconstruction (full)", omega=omega, n_harmonics=n_harmonics, n_cycles=n_cycles, ) strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) rate_arr = jnp.atleast_1d(jnp.asarray(strain_rate, dtype=jnp.float64)) stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) L = len(strain_arr) p = int(n_cycles) W = W_int if W_int is not None else int(round(L / (2 * p))) logger.debug( "Harmonic reconstruction parameters", signal_length=L, n_cycles_parsed=p, window_size=W, ) # Normalize rate by omega (MATLAB convention) rate_normalized = rate_arr / omega # Compute FFT of all signals fft_strain = jnp.fft.fft(strain_arr) fft_rate = jnp.fft.fft(rate_normalized) fft_stress = jnp.fft.fft(stress_arr) # Convert to MATLAB-style coefficients # An = 2*Re(FFT)/L, Bn = -2*Im(FFT)/L An1_strain = 2 * jnp.real(fft_strain) / L Bn1_strain = -2 * jnp.imag(fft_strain) / L An1_rate = 2 * jnp.real(fft_rate) / L Bn1_rate = -2 * jnp.imag(fft_rate) / L An1_stress = 2 * jnp.real(fft_stress) / L Bn1_stress = -2 * jnp.imag(fft_stress) / L # Zero the DC component for both An (cosine) and Bn (sine) coefficients An1_strain = An1_strain.at[0].set(0.0) Bn1_strain = Bn1_strain.at[0].set(0.0) An1_rate = An1_rate.at[0].set(0.0) Bn1_rate = Bn1_rate.at[0].set(0.0) An1_stress = An1_stress.at[0].set(0.0) Bn1_stress = Bn1_stress.at[0].set(0.0) # Compute phase offset Delta from strain fundamental An_fund = An1_strain[p] Bn_fund = Bn1_strain[p] Delta = jnp.arctan2(An_fund, Bn_fund) Delta = jnp.where(Bn_fund < 0, Delta + jnp.pi, Delta) # Rotate coefficients to align with phase reference # An_new = An1*cos(Delta/p*n) - Bn1*sin(Delta/p*n) # Bn_new = Bn1*cos(Delta/p*n) + An1*sin(Delta/p*n) n_indices = jnp.arange(L // 2) rotation_angle = Delta / p * n_indices cos_rot = jnp.cos(rotation_angle) sin_rot = jnp.sin(rotation_angle) # Apply rotation to strain An_strain = An1_strain[: L // 2] * cos_rot - Bn1_strain[: L // 2] * sin_rot Bn_strain = Bn1_strain[: L // 2] * cos_rot + An1_strain[: L // 2] * sin_rot # Apply rotation to rate An_rate = An1_rate[: L // 2] * cos_rot - Bn1_rate[: L // 2] * sin_rot Bn_rate = Bn1_rate[: L // 2] * cos_rot + An1_rate[: L // 2] * sin_rot # Apply rotation to stress An_stress = An1_stress[: L // 2] * cos_rot - Bn1_stress[: L // 2] * sin_rot Bn_stress = Bn1_stress[: L // 2] * cos_rot + An1_stress[: L // 2] * sin_rot # Create new time array (shifted by Delta/omega) dt = 2 * jnp.pi / omega / L time_new = dt * jnp.arange(L) # Reconstruct waveforms from aligned coefficients # Only use fundamental for strain/rate (n=1), odd harmonics up to n_harmonics for stress def reconstruct_signal(An, Bn, max_harmonic, fundamental_only=False): """Reconstruct signal from Fourier coefficients (vectorized).""" if fundamental_only: idx = p cos_term = jnp.cos(omega * time_new) sin_term = jnp.sin(omega * time_new) an = jnp.where(idx < len(An), An[jnp.minimum(idx, len(An) - 1)], 0.0) bn = jnp.where(idx < len(Bn), Bn[jnp.minimum(idx, len(Bn) - 1)], 0.0) return an * cos_term + bn * sin_term # Odd harmonics: vectorized over harmonic indices harmonics = jnp.arange(1, max_harmonic + 1, 2) indices = p * harmonics an = jnp.where(indices < len(An), An[jnp.minimum(indices, len(An) - 1)], 0.0) bn = jnp.where(indices < len(Bn), Bn[jnp.minimum(indices, len(Bn) - 1)], 0.0) phase = harmonics[:, None] * omega * time_new[None, :] # (H, L) return jnp.sum( an[:, None] * jnp.cos(phase) + bn[:, None] * jnp.sin(phase), axis=0 ) strain_recon = reconstruct_signal(An_strain, Bn_strain, 1, fundamental_only=True) rate_recon = reconstruct_signal(An_rate, Bn_rate, 1, fundamental_only=True) stress_recon = reconstruct_signal(An_stress, Bn_stress, n_harmonics) # Fourier amplitude spectrum for stress (MATLAB ft_out) W_idx = int(W) k_indices = jnp.arange(W_idx + 1) * p stress_fft_scaled = fft_stress / L ft_amp = 2 * jnp.abs(stress_fft_scaled[k_indices]) ft_amp = ft_amp / jnp.maximum(ft_amp[1], 1e-20) f_domain = jnp.arange(W_idx + 1, dtype=jnp.float64) * (omega / (2 * jnp.pi)) ft_out = jnp.stack([f_domain, ft_amp], axis=1) return { "Delta": Delta, "An_strain": An_strain, "Bn_strain": Bn_strain, "An_rate": An_rate, "Bn_rate": Bn_rate, "An_stress": An_stress, "Bn_stress": Bn_stress, "strain_recon": strain_recon, "rate_recon": rate_recon, "stress_recon": stress_recon, "time_new": time_new, "ft_out": ft_out, }
[docs] @partial(jax.jit, static_argnums=(4, 5)) def spp_fourier_analysis( strain: "Array", stress: "Array", omega: float, dt: float, n_harmonics: int = 39, n_cycles: int = 1, ) -> dict: """ Complete SPP analysis using Fourier-based analytical derivatives (MATLAB-compatible). Implements the full workflow from SPPplus_fourier_v2.m: 1. FFT strain and stress signals 2. Compute phase offset and rotate coefficients 3. Compute derivatives ANALYTICALLY from Fourier coefficients 4. Calculate G'(t), G''(t) via cross-product formula 5. Extract all SPP metrics including moduli rates and Frenet-Serret frame This is more accurate than numerical differentiation for noisy data. Parameters ---------- strain : Array Strain signal γ(t) (dimensionless) stress : Array Stress signal σ(t) (Pa) omega : float Angular frequency ω (rad/s) dt : float Time step (s) n_harmonics : int Number of odd harmonics for reconstruction (default: 15) n_cycles : int Number of complete cycles in data (default: 1) Returns ------- dict Dictionary containing all SPP metrics: - Gp_t: Instantaneous G'(t) (Pa) - Gpp_t: Instantaneous G''(t) (Pa) - G_star_t: Complex modulus ``|G*(t)|`` (Pa) - tan_delta_t: Loss tangent tan(δ)(t) - delta_t: Phase angle δ(t) (radians) - disp_stress: Displacement stress (Pa) - eq_strain_est: Equivalent strain estimate - Gp_t_dot: Time derivative of G'(t) (Pa/s) - Gpp_t_dot: Time derivative of G''(t) (Pa/s) - G_speed: Moduli rate magnitude (Pa/s) - delta_t_dot: Phase angle rate (rad/s) - T_vec, N_vec, B_vec: Frenet-Serret frame vectors - strain_recon, stress_recon: Reconstructed waveforms - Delta: Phase offset used Notes ----- ANALYTICAL derivatives from Fourier series: f(t) = Σ [An*cos(nωt) + Bn*sin(nωt)] f'(t) = Σ [-nω*An*sin(nωt) + nω*Bn*cos(nωt)] f''(t) = Σ [-n²ω²*An*cos(nωt) - n²ω²*Bn*sin(nωt)] f'''(t) = Σ [n³ω³*An*sin(nωt) - n³ω³*Bn*cos(nωt)] """ logger.info( "Starting SPP Fourier analysis", omega=omega, dt=dt, n_harmonics=n_harmonics, n_cycles=n_cycles, ) strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) L = len(strain_arr) p = int(n_cycles) W_int = int(round(L / (2 * p))) logger.debug( "SPP Fourier analysis input data", signal_length=L, n_cycles_parsed=p, window_size=W_int, ) # Compute strain rate from strain (wrapped 8-point stencil) logger.debug("Computing strain rate from strain (8-point stencil)") strain_rate = differentiate_rate_from_strain( strain_arr, dt, step_size=8, looped=True ) # Get phase-aligned reconstruction with concrete W logger.debug("Performing phase-aligned Fourier reconstruction") fourier_result = harmonic_reconstruction_full( strain_arr, strain_rate, stress_arr, omega, n_harmonics, p, W_int ) Delta = fourier_result["Delta"] An_strain = fourier_result["An_strain"] Bn_strain = fourier_result["Bn_strain"] An_rate = fourier_result["An_rate"] Bn_rate = fourier_result["Bn_rate"] An_stress = fourier_result["An_stress"] Bn_stress = fourier_result["Bn_stress"] time_new = fourier_result["time_new"] # Compute ANALYTICAL derivatives from Fourier coefficients # For each waveform, compute f, f', f'', f''' def compute_derivatives_from_fourier(An, Bn, max_harmonic): """Compute signal and its 1st, 2nd, 3rd derivatives from Fourier coefficients. Vectorized over harmonics using broadcasting: harmonic indices (H,) broadcast against time points (L,) to produce (H, L) intermediates, then summed along the harmonic axis. """ # Build array of odd harmonic indices: [1, 3, 5, ..., max_harmonic] harmonics = jnp.arange(1, max_harmonic + 1, 2) indices = p * harmonics # Coefficient indices # Gather coefficients, zero-padding for out-of-range indices an = jnp.where(indices < len(An), An[jnp.minimum(indices, len(An) - 1)], 0.0) bn = jnp.where(indices < len(Bn), Bn[jnp.minimum(indices, len(Bn) - 1)], 0.0) # n*omega for each harmonic: shape (H,) n_omega = harmonics * omega # cos(n*omega*t) and sin(n*omega*t): shape (H, L) # n_omega[:, None] * time_new[None, :] broadcasts to (H, L) phase = n_omega[:, None] * time_new[None, :] cos_terms = jnp.cos(phase) sin_terms = jnp.sin(phase) # f(t) = sum_n [An*cos + Bn*sin] f = jnp.sum(an[:, None] * cos_terms + bn[:, None] * sin_terms, axis=0) # f'(t) = sum_n [-nω*An*sin + nω*Bn*cos] fd = jnp.sum( n_omega[:, None] * (-an[:, None] * sin_terms + bn[:, None] * cos_terms), axis=0, ) # f''(t) = sum_n [-n²ω²*An*cos - n²ω²*Bn*sin] n_omega2 = (n_omega**2)[:, None] fdd = jnp.sum( -n_omega2 * (an[:, None] * cos_terms + bn[:, None] * sin_terms), axis=0 ) # f'''(t) = sum_n [n³ω³*An*sin - n³ω³*Bn*cos] n_omega3 = (n_omega**3)[:, None] fddd = jnp.sum( n_omega3 * (an[:, None] * sin_terms - bn[:, None] * cos_terms), axis=0 ) return f, fd, fdd, fddd # Strain (fundamental only for n=1) strain_recon, strain_d, strain_dd, strain_ddd = compute_derivatives_from_fourier( An_strain, Bn_strain, 1 ) # Rate (fundamental only) - but we need rate/omega for the response wave rate_recon, rate_d, rate_dd, rate_ddd = compute_derivatives_from_fourier( An_rate, Bn_rate, 1 ) # Stress (odd harmonics up to n_harmonics) stress_recon, stress_d, stress_dd, stress_ddd = compute_derivatives_from_fourier( An_stress, Bn_stress, n_harmonics ) # Build response wave derivatives [γ, γ̇/ω, σ] # Note: rate_recon is already γ̇/ω from the reconstruction rd = jnp.stack([strain_d, rate_d, stress_d], axis=1) rdd = jnp.stack([strain_dd, rate_dd, stress_dd], axis=1) rddd = jnp.stack([strain_ddd, rate_ddd, stress_ddd], axis=1) # Cross product: rd × rdd rd_x_rdd = jnp.stack( [ rd[:, 1] * rdd[:, 2] - rd[:, 2] * rdd[:, 1], rd[:, 2] * rdd[:, 0] - rd[:, 0] * rdd[:, 2], rd[:, 0] * rdd[:, 1] - rd[:, 1] * rdd[:, 0], ], axis=1, ) # Second cross product: rd × (rd × rdd) rd_x_rd_x_rdd = jnp.stack( [ rd[:, 1] * rd_x_rdd[:, 2] - rd[:, 2] * rd_x_rdd[:, 1], rd[:, 2] * rd_x_rdd[:, 0] - rd[:, 0] * rd_x_rdd[:, 2], rd[:, 0] * rd_x_rdd[:, 1] - rd[:, 1] * rd_x_rdd[:, 0], ], axis=1, ) # Magnitudes (+ 1e-30 guards sqrt(0) infinite gradient) eps = 1e-20 mag_rd = jnp.sqrt(jnp.sum(rd**2, axis=1) + 1e-30) mag_rd_x_rdd = jnp.sqrt(jnp.sum(rd_x_rdd**2, axis=1) + 1e-30) # R11-SPP-KRN-001: Avoid sign(0)=0 at Frenet degeneracy denom = rd_x_rdd[:, 2] Gp_t = jnp.where( jnp.abs(denom) > eps, -rd_x_rdd[:, 0] / denom, jnp.nan, # Explicit NaN at degeneracy for callers to handle ) Gpp_t = jnp.where( jnp.abs(denom) > eps, -rd_x_rdd[:, 1] / denom, jnp.nan, # Explicit NaN at degeneracy for callers to handle ) # Moduli rates (MATLAB formula) # Gp_t_dot = -rd[:,1] * (rddd · rd_x_rdd) / rd_x_rdd[:,2]² # Gpp_t_dot = rd[:,0] * (rddd · rd_x_rdd) / rd_x_rdd[:,2]² rddd_dot_rd_x_rdd = jnp.sum(rddd * rd_x_rdd, axis=1) denom_sq = rd_x_rdd[:, 2] ** 2 denom_valid = jnp.abs(rd_x_rdd[:, 2]) > eps Gp_t_dot = jnp.where( denom_valid, -rd[:, 1] * rddd_dot_rd_x_rdd / jnp.maximum(denom_sq, eps), jnp.nan, ) Gpp_t_dot = jnp.where( denom_valid, rd[:, 0] * rddd_dot_rd_x_rdd / jnp.maximum(denom_sq, eps), jnp.nan, ) G_speed = jnp.sqrt(Gp_t_dot**2 + Gpp_t_dot**2 + 1e-30) # Complex modulus and phase angle G_star_t = jnp.sqrt(Gp_t**2 + Gpp_t**2 + 1e-30) tan_delta_t = Gpp_t / jnp.maximum(jnp.abs(Gp_t), eps) * jnp.sign(Gp_t) # Phase angle via arctan2 — handles all four quadrants correctly # Ref: Rogers 2012, SPP framework, Eq. (5) delta_t = jnp.arctan2(Gpp_t, Gp_t) # Phase angle rate — Rogers 2012, Eq. (8) # delta_t_dot = (sigma' * sigma''' - sigma''^2) / (sigma'^2 + sigma''^2) # where prime = d/dt normalized by omega rd_tn = rd / omega rdd_tn = rdd / omega**2 rddd_tn = rddd / omega**3 sigma_prime = rd_tn[:, 2] # dσ/dt normalized sigma_dprime = rdd_tn[:, 2] # d²σ/dt² normalized sigma_tprime = rddd_tn[:, 2] # d³σ/dt³ normalized denom = jnp.maximum(sigma_prime**2 + sigma_dprime**2, eps) delta_t_dot = (sigma_prime * sigma_tprime - sigma_dprime**2) / denom # Displacement stress: sigma_d = sigma - G'_t * gamma - G''_t/omega * gamma_dot disp_stress = stress_recon - (Gp_t * strain_recon + Gpp_t * rate_recon) # Equilibrium strain: gamma_eq = gamma - sigma_d / G'_t (Rogers 2017). # We use abs(G'_t) to prevent sign inversion near yielding where G'_t -> 0; # the equilibrium strain estimate is physically ill-defined in that regime. eq_strain_est = strain_recon - disp_stress / jnp.maximum(jnp.abs(Gp_t), eps) # Frenet-Serret frame T_vec = rd / jnp.maximum(mag_rd[:, None], eps) N_vec = -rd_x_rd_x_rdd / jnp.maximum((mag_rd * mag_rd_x_rdd)[:, None], eps) B_vec = rd_x_rdd / jnp.maximum(mag_rd_x_rdd[:, None], eps) return { # Core SPP metrics "Gp_t": Gp_t, "Gpp_t": Gpp_t, "G_star_t": G_star_t, "tan_delta_t": tan_delta_t, "delta_t": delta_t, "disp_stress": disp_stress, "eq_strain_est": eq_strain_est, # Moduli rates (NEW - Gap 4) "Gp_t_dot": Gp_t_dot, "Gpp_t_dot": Gpp_t_dot, "G_speed": G_speed, "delta_t_dot": delta_t_dot, # Frenet-Serret frame (NEW - Gap 5) "T_vec": T_vec, "N_vec": N_vec, "B_vec": B_vec, # Reconstructed waveforms "strain_recon": strain_recon, "rate_recon": rate_recon, "stress_recon": stress_recon, "time_new": time_new, # Phase alignment "Delta": Delta, # FSF and spectrum "fsf_data_out": jnp.concatenate([T_vec, N_vec, B_vec], axis=1), "ft_out": fourier_result["ft_out"], }
# ============================================================================ # Power-Law Fitting # ============================================================================
[docs] @jax.jit def power_law_fit( stress: "Array", strain_rate: "Array", threshold_fraction: float = 0.1, ) -> tuple[float, float, float]: """Log-log fit of sigma = K * abs(strain_rate) ** n over the flowing region. Returns (K, n, r_squared). """ stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) strain_rate_arr = jnp.atleast_1d(jnp.asarray(strain_rate, dtype=jnp.float64)) # Use only flowing region (above threshold) rate_max = jnp.max(jnp.abs(strain_rate_arr)) threshold = threshold_fraction * rate_max # Select first quadrant (positive stress and rate) mask = (strain_rate_arr > threshold) & (stress_arr > 0) # Extract valid points valid_rates = jnp.where(mask, strain_rate_arr, jnp.nan) valid_stress = jnp.where(mask, stress_arr, jnp.nan) # Log-log linear regression: log(stress) = log(K) + n * log(strain_rate) valid = ~jnp.isnan(valid_rates) & ~jnp.isnan(valid_stress) n_valid = jnp.sum(valid) # Compute log only for valid points; use 0.0 for invalid (excluded by valid mask) log_rate_valid = jnp.where(valid, jnp.log(jnp.maximum(strain_rate_arr, 1e-30)), 0.0) log_stress_valid = jnp.where(valid, jnp.log(jnp.maximum(stress_arr, 1e-30)), 0.0) sum_x = jnp.sum(log_rate_valid) sum_y = jnp.sum(log_stress_valid) sum_xx = jnp.sum(log_rate_valid**2) sum_xy = jnp.sum(log_rate_valid * log_stress_valid) # Linear regression solution denom = n_valid * sum_xx - sum_x**2 n_exponent = jnp.where( denom > 1e-10, (n_valid * sum_xy - sum_x * sum_y) / denom, 1.0, ) log_K = jnp.where( n_valid > 0, (sum_y - n_exponent * sum_x) / n_valid, 0.0, ) K = jnp.exp(log_K) # Compute R² for fit quality y_mean = sum_y / jnp.maximum(n_valid, 1.0) ss_tot = jnp.sum(jnp.where(valid, (log_stress_valid - y_mean) ** 2, 0.0)) y_pred = log_K + n_exponent * log_rate_valid ss_res = jnp.sum(jnp.where(valid, (log_stress_valid - y_pred) ** 2, 0.0)) r_squared = jnp.where(ss_tot > 1e-10, 1.0 - ss_res / ss_tot, 0.0) return K, n_exponent, r_squared
# ============================================================================ # Lissajous-Bowditch Metrics # ============================================================================
[docs] @jax.jit def lissajous_metrics( stress: "Array", strain: "Array", strain_rate: "Array", strain_amplitude: float, rate_amplitude: float, ) -> dict: """ Compute Lissajous-Bowditch diagram derived metrics. Extracts nonlinearity measures from Lissajous plots including S-factor (stiffening ratio) and T-factor (thickening ratio). Parameters ---------- stress : Array Time-resolved stress signal σ(t) (Pa) strain : Array Time-resolved strain signal γ(t) (dimensionless) strain_rate : Array Time-resolved strain rate signal γ̇(t) (1/s) strain_amplitude : float Maximum strain amplitude γ_0 (dimensionless) rate_amplitude : float Maximum strain rate amplitude γ̇_0 = ω * γ_0 (1/s) Returns ------- dict Dictionary containing: - G_L: Large-strain modulus (tangent at γ = γ_0) - G_M: Minimum-strain modulus (tangent at γ = 0) - eta_L: Large-rate viscosity (tangent at γ̇ = γ̇_0) - eta_M: Minimum-rate viscosity (tangent at γ̇ = 0) - S_factor: Stiffening ratio (G_L - G_M) / G_L - T_factor: Thickening ratio (η_L - η_M) / η_L Examples -------- >>> import jax.numpy as jnp >>> from rheojax.utils.spp_kernels import lissajous_metrics >>> >>> omega = 1.0 >>> t = jnp.linspace(0, 2*jnp.pi, 1000) >>> gamma_0 = 1.0 >>> gamma = gamma_0 * jnp.sin(omega * t) >>> gamma_dot = gamma_0 * omega * jnp.cos(omega * t) >>> sigma = 100.0 * gamma + 10.0 * gamma_dot # Linear viscoelastic >>> >>> metrics = lissajous_metrics(sigma, gamma, gamma_dot, gamma_0, gamma_0 * omega) >>> # S_factor ≈ 0 (linear), T_factor ≈ 0 (linear) Notes ----- - S > 0: strain stiffening, S < 0: strain softening - T > 0: shear thickening, T < 0: shear thinning - For linear viscoelastic: S = T = 0 """ stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) rate_arr = jnp.atleast_1d(jnp.asarray(strain_rate, dtype=jnp.float64)) gamma_0 = jnp.float64(strain_amplitude) rate_0 = jnp.float64(rate_amplitude) # G_L: Large-strain modulus (σ at γ = γ_0) # Find points where |γ| ≈ γ_0 at_max_strain = jnp.abs(strain_arr) >= 0.98 * gamma_0 sigma_at_max_strain = jnp.where(at_max_strain, jnp.abs(stress_arr), 0.0) G_L = jnp.sum(sigma_at_max_strain) / jnp.maximum(jnp.sum(at_max_strain), 1.0) G_L = G_L / gamma_0 # G_M: Minimum-strain modulus (dσ/dγ at γ = 0) # Find points where |γ| ≈ 0 at_zero_strain = jnp.abs(strain_arr) <= 0.02 * gamma_0 # Approximate derivative using central difference d_sigma = jnp.roll(stress_arr, -1) - jnp.roll(stress_arr, 1) d_gamma = jnp.roll(strain_arr, -1) - jnp.roll(strain_arr, 1) local_modulus = jnp.where( jnp.abs(d_gamma) > 1e-10, d_sigma / d_gamma, 0.0, ) modulus_at_zero = jnp.where(at_zero_strain, local_modulus, 0.0) G_M = jnp.sum(modulus_at_zero) / jnp.maximum(jnp.sum(at_zero_strain), 1.0) # η_L: Large-rate viscosity (σ at γ̇ = γ̇_0) at_max_rate = jnp.abs(rate_arr) >= 0.98 * rate_0 sigma_at_max_rate = jnp.where(at_max_rate, jnp.abs(stress_arr), 0.0) eta_L = jnp.sum(sigma_at_max_rate) / jnp.maximum(jnp.sum(at_max_rate), 1.0) eta_L = eta_L / rate_0 # η_M: Minimum-rate viscosity (dσ/dγ̇ at γ̇ = 0) at_zero_rate = jnp.abs(rate_arr) <= 0.02 * rate_0 d_rate = jnp.roll(rate_arr, -1) - jnp.roll(rate_arr, 1) local_viscosity = jnp.where( jnp.abs(d_rate) > 1e-10, d_sigma / d_rate, 0.0, ) viscosity_at_zero = jnp.where(at_zero_rate, local_viscosity, 0.0) eta_M = jnp.sum(viscosity_at_zero) / jnp.maximum(jnp.sum(at_zero_rate), 1.0) # S and T factors (stiffening and thickening ratios) S_factor = jnp.where( jnp.abs(G_L) > 1e-10, (G_L - G_M) / G_L, 0.0, ) T_factor = jnp.where( jnp.abs(eta_L) > 1e-10, (eta_L - eta_M) / eta_L, 0.0, ) return { "G_L": G_L, "G_M": G_M, "eta_L": eta_L, "eta_M": eta_M, "S_factor": S_factor, "T_factor": T_factor, }
# ============================================================================ # Zero-Crossing Detection (Robust) # ============================================================================
[docs] @jax.jit def zero_crossing_indices(signal: "Array") -> "Array": """ Find indices of zero-crossings in a signal (robust implementation). Uses linear interpolation to find precise crossing locations, handling noise-induced multiple crossings via hysteresis filtering. Parameters ---------- signal : Array Input signal to analyze for zero-crossings Returns ------- Array Boolean mask of zero-crossing locations (True at crossings) Examples -------- >>> import jax.numpy as jnp >>> from rheojax.utils.spp_kernels import zero_crossing_indices >>> >>> signal = jnp.sin(jnp.linspace(0, 4*jnp.pi, 100)) >>> crossings = zero_crossing_indices(signal) >>> # crossings is True at indices where sin crosses zero Notes ----- - Returns mask of same length as signal - Crossing detected when sign(s[i]) != sign(s[i+1]) - Edge cases (exact zeros) handled properly """ signal_arr = jnp.atleast_1d(jnp.asarray(signal, dtype=jnp.float64)) # Compute sign changes signs = jnp.sign(signal_arr) sign_changes = jnp.abs(jnp.diff(signs)) > 1.5 # Sign change: ±2 difference # Pad to match original length crossings = jnp.concatenate([sign_changes, jnp.array([False])]) return crossings
[docs] @jax.jit def harmonic_truncation_robustness( amplitudes: "Array", n_harmonics_original: int, n_harmonics_truncated: int, ) -> float: """ Compute truncation error metric for harmonic decomposition. Quantifies how much signal energy is lost when truncating to fewer harmonics, enabling assessment of reconstruction quality. Parameters ---------- amplitudes : Array Full set of harmonic amplitudes [A_1, A_3, A_5, ...] n_harmonics_original : int Original number of harmonics n_harmonics_truncated : int Number of harmonics to keep after truncation Returns ------- float Fraction of total energy retained after truncation (0 to 1) Examples -------- >>> import jax.numpy as jnp >>> from rheojax.utils.spp_kernels import harmonic_truncation_robustness >>> >>> # Fundamental dominant with small 3rd harmonic >>> amps = jnp.array([100.0, 10.0, 2.0, 0.5, 0.1]) >>> robustness = harmonic_truncation_robustness(amps, 5, 2) >>> # robustness ≈ 0.99 (most energy in first 2 harmonics) Notes ----- - Value near 1.0 indicates safe truncation - Value < 0.95 suggests significant information loss - Useful for adaptive harmonic selection """ amplitudes_arr = jnp.atleast_1d(jnp.asarray(amplitudes, dtype=jnp.float64)) # Total energy (sum of squared amplitudes) total_energy = jnp.sum(amplitudes_arr**2) # Energy in retained harmonics retained = amplitudes_arr[:n_harmonics_truncated] retained_energy = jnp.sum(retained**2) # Fraction retained robustness = jnp.where( total_energy > 1e-20, retained_energy / total_energy, 1.0, ) return robustness
# ============================================================================ # SPP Stress Decomposition # ============================================================================
[docs] @jax.jit def spp_stress_decomposition( stress: "Array", strain: "Array", strain_rate: "Array", strain_amplitude: float, rate_amplitude: float, ) -> tuple["Array", "Array"]: """ Decompose total stress into elastic and viscous contributions via linear projection. This is a Cho-style orthogonal projection (Cho et al. 2005), NOT the full SPP Frenet-Serret decomposition. It projects stress onto normalized strain and strain-rate directions and distributes the residual symmetrically. For the full SPP decomposition (which includes the displacement stress sigma_d via the osculating plane), use :func:`spp_fourier_analysis` instead. Separates σ(t) = σ'(t) + σ''(t) where: - σ'(t): Elastic (in-phase with strain) component - σ''(t): Viscous (in-phase with strain rate) component Parameters ---------- stress : Array Time-resolved stress signal σ(t) (Pa) strain : Array Time-resolved strain signal γ(t) (dimensionless) strain_rate : Array Time-resolved strain rate signal γ̇(t) (1/s) strain_amplitude : float Maximum strain amplitude γ_0 (dimensionless) rate_amplitude : float Maximum strain rate amplitude γ̇_0 = ω * γ_0 (1/s) Returns ------- sigma_elastic : Array Elastic stress contribution σ'(t) (Pa) sigma_viscous : Array Viscous stress contribution σ''(t) (Pa) Examples -------- >>> import jax.numpy as jnp >>> from rheojax.utils.spp_kernels import spp_stress_decomposition >>> >>> omega = 1.0 >>> t = jnp.linspace(0, 2*jnp.pi, 1000) >>> gamma_0 = 1.0 >>> gamma = gamma_0 * jnp.sin(omega * t) >>> gamma_dot = gamma_0 * omega * jnp.cos(omega * t) >>> # G' = 100 Pa, G'' = 50 Pa >>> sigma = 100.0 * gamma + 50.0 * gamma_dot / omega >>> >>> sigma_e, sigma_v = spp_stress_decomposition( ... sigma, gamma, gamma_dot, gamma_0, gamma_0 * omega ... ) >>> # sigma_e ≈ 100 * gamma (elastic) >>> # sigma_v ≈ 50 * gamma_dot / omega (viscous) Notes ----- - This is a linear projection decomposition, exact for sinusoidal signals. For nonlinear responses, the residual (displacement stress) is split equally between elastic and viscous components. - For the full SPP decomposition that separately tracks the displacement stress, use :func:`spp_fourier_analysis`. - For linear viscoelastic: σ_e = G' * γ, σ_v = G'' * γ / ω - Decomposition satisfies σ = σ_e + σ_v at all times """ stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) rate_arr = jnp.atleast_1d(jnp.asarray(strain_rate, dtype=jnp.float64)) gamma_0 = jnp.float64(strain_amplitude) rate_0 = jnp.float64(rate_amplitude) # Normalize strain and rate gamma_norm = strain_arr / gamma_0 rate_norm = rate_arr / rate_0 # Project stress onto strain and rate directions # Use orthogonality of sin and cos basis # Elastic component: projection onto strain direction # σ' = <σ, γ/γ_0> / <γ/γ_0, γ/γ_0> * γ/γ_0 * scale proj_elastic = jnp.sum(stress_arr * gamma_norm) / jnp.maximum( jnp.sum(gamma_norm**2), 1e-10 ) sigma_elastic = proj_elastic * gamma_norm # Viscous component: projection onto rate direction proj_viscous = jnp.sum(stress_arr * rate_norm) / jnp.maximum( jnp.sum(rate_norm**2), 1e-10 ) sigma_viscous = proj_viscous * rate_norm # Ensure decomposition is exact by distributing residual residual = stress_arr - sigma_elastic - sigma_viscous # Add half of residual to each (symmetric distribution) sigma_elastic = sigma_elastic + 0.5 * residual sigma_viscous = sigma_viscous + 0.5 * residual return sigma_elastic, sigma_viscous
# ============================================================================ # Numerical Differentiation (MATLAB-Compatible) # ============================================================================
[docs] @partial(jax.jit, static_argnums=(2, 3)) def numerical_derivative_4th_order( signal: "Array", dt: float, order: int = 1, step_size: int = 1, ) -> "Array": """ Compute numerical derivatives using 4th-order finite differences (MATLAB SPPplus compatible). Implements the EXACT finite-difference schemes from SPPplus_numerical_v2.m: - 4th-order centered differences in the interior - Forward differences at the beginning boundary - Backward differences at the ending boundary This matches MATLAB's "standard" differentiation mode (num_mode=1). Parameters ---------- signal : Array Input signal to differentiate (1D array) dt : float Time step between samples (s) order : int, optional Derivative order: 1, 2, or 3 (default: 1) step_size : int, optional Step size k for stencil (default: 1, larger = more smoothing) Returns ------- Array Numerical derivative of same length as input (4th-order accurate in interior) Notes ----- MATLAB SPPplus_numerical_v2.m stencils (mode 1): First derivative (interior, 4th order): rd = (-f[p+2k] + 8*f[p+k] - 8*f[p-k] + f[p-2k]) / (12*k*dt) Second derivative (interior, 4th order): rdd = (-f[p+2k] + 16*f[p+k] - 30*f[p] + 16*f[p-k] - f[p-2k]) / (12*(k*dt)^2) Third derivative (interior, 4th order): rddd = (-f[p+3k] + 8*f[p+2k] - 13*f[p+k] + 13*f[p-k] - 8*f[p-2k] + f[p-3k]) / (8*(k*dt)^3) Forward/backward stencils at boundaries use 2nd-order accurate formulas. """ signal_arr = jnp.atleast_1d(jnp.asarray(signal, dtype=jnp.float64)) L = len(signal_arr) k = step_size h = dt * k # Pad signal for boundary handling using reflect mode pad_size = 4 * k signal_padded = jnp.pad(signal_arr, pad_size, mode="edge") if order == 1: # 4th-order centered first derivative # (-f[i+2] + 8*f[i+1] - 8*f[i-1] + f[i-2]) / (12*h) result_padded = ( -jnp.roll(signal_padded, -2 * k) + 8 * jnp.roll(signal_padded, -k) - 8 * jnp.roll(signal_padded, k) + jnp.roll(signal_padded, 2 * k) ) / (12 * h) result = result_padded[pad_size : pad_size + L] # Fix boundaries with forward/backward differences (2nd order) # Vectorized scatter: compute corrections for all boundary points at once boundary_k = min(3 * k, L - 1) # Forward at start: (-f[p+2k] + 4*f[p+k] - 3*f[p]) / (2*k*dt) fwd_idx = jnp.arange(boundary_k) fwd_valid = (fwd_idx + 2 * k < L) & (fwd_idx + k < L) fwd_vals = ( -signal_arr[jnp.minimum(fwd_idx + 2 * k, L - 1)] + 4 * signal_arr[jnp.minimum(fwd_idx + k, L - 1)] - 3 * signal_arr[fwd_idx] ) / (2 * h) result = result.at[fwd_idx].set(jnp.where(fwd_valid, fwd_vals, result[fwd_idx])) # Backward at end: (f[p-2k] - 4*f[p-k] + 3*f[p]) / (2*k*dt) bwd_idx = jnp.arange(L - boundary_k, L) bwd_valid = (bwd_idx - 2 * k >= 0) & (bwd_idx - k >= 0) bwd_vals = ( signal_arr[jnp.maximum(bwd_idx - 2 * k, 0)] - 4 * signal_arr[jnp.maximum(bwd_idx - k, 0)] + 3 * signal_arr[bwd_idx] ) / (2 * h) result = result.at[bwd_idx].set(jnp.where(bwd_valid, bwd_vals, result[bwd_idx])) elif order == 2: # 4th-order centered second derivative # (-f[i+2] + 16*f[i+1] - 30*f[i] + 16*f[i-1] - f[i-2]) / (12*h^2) result_padded = ( -jnp.roll(signal_padded, -2 * k) + 16 * jnp.roll(signal_padded, -k) - 30 * signal_padded + 16 * jnp.roll(signal_padded, k) - jnp.roll(signal_padded, 2 * k) ) / (12 * h**2) result = result_padded[pad_size : pad_size + L] # Boundary correction — vectorized scatter boundary_k = min(3 * k, L - 1) fwd_idx = jnp.arange(boundary_k) fwd_valid = fwd_idx + 3 * k < L fwd_vals = ( -signal_arr[jnp.minimum(fwd_idx + 3 * k, L - 1)] + 4 * signal_arr[jnp.minimum(fwd_idx + 2 * k, L - 1)] - 5 * signal_arr[jnp.minimum(fwd_idx + k, L - 1)] + 2 * signal_arr[fwd_idx] ) / (h**2) result = result.at[fwd_idx].set(jnp.where(fwd_valid, fwd_vals, result[fwd_idx])) bwd_idx = jnp.arange(L - boundary_k, L) bwd_valid = bwd_idx - 3 * k >= 0 bwd_vals = ( -signal_arr[jnp.maximum(bwd_idx - 3 * k, 0)] + 4 * signal_arr[jnp.maximum(bwd_idx - 2 * k, 0)] - 5 * signal_arr[jnp.maximum(bwd_idx - k, 0)] + 2 * signal_arr[bwd_idx] ) / (h**2) result = result.at[bwd_idx].set(jnp.where(bwd_valid, bwd_vals, result[bwd_idx])) elif order == 3: # 4th-order centered third derivative # (-f[i+3] + 8*f[i+2] - 13*f[i+1] + 13*f[i-1] - 8*f[i-2] + f[i-3]) / (8*h^3) result_padded = ( -jnp.roll(signal_padded, -3 * k) + 8 * jnp.roll(signal_padded, -2 * k) - 13 * jnp.roll(signal_padded, -k) + 13 * jnp.roll(signal_padded, k) - 8 * jnp.roll(signal_padded, 2 * k) + jnp.roll(signal_padded, 3 * k) ) / (8 * h**3) result = result_padded[pad_size : pad_size + L] # Boundary correction — vectorized scatter boundary_k = min(4 * k, L - 1) fwd_idx = jnp.arange(boundary_k) fwd_valid = fwd_idx + 4 * k < L fwd_vals = ( -3 * signal_arr[jnp.minimum(fwd_idx + 4 * k, L - 1)] + 14 * signal_arr[jnp.minimum(fwd_idx + 3 * k, L - 1)] - 24 * signal_arr[jnp.minimum(fwd_idx + 2 * k, L - 1)] + 18 * signal_arr[jnp.minimum(fwd_idx + k, L - 1)] - 5 * signal_arr[fwd_idx] ) / (2 * h**3) result = result.at[fwd_idx].set(jnp.where(fwd_valid, fwd_vals, result[fwd_idx])) bwd_idx = jnp.arange(L - boundary_k, L) bwd_valid = bwd_idx - 4 * k >= 0 bwd_vals = ( 3 * signal_arr[jnp.maximum(bwd_idx - 4 * k, 0)] - 14 * signal_arr[jnp.maximum(bwd_idx - 3 * k, 0)] + 24 * signal_arr[jnp.maximum(bwd_idx - 2 * k, 0)] - 18 * signal_arr[jnp.maximum(bwd_idx - k, 0)] + 5 * signal_arr[bwd_idx] ) / (2 * h**3) result = result.at[bwd_idx].set(jnp.where(bwd_valid, bwd_vals, result[bwd_idx])) else: result = signal_arr return result
#: Alias for :func:`numerical_derivative_4th_order`. #: Kept for backwards compatibility and MATLAB SPPplus naming parity. numerical_derivative = numerical_derivative_4th_order
[docs] @partial(jax.jit, static_argnums=(2,)) def numerical_derivative_periodic( signal: "Array", dt: float, step_size: int = 1, ) -> tuple["Array", "Array", "Array"]: """ Compute 1st, 2nd, and 3rd derivatives assuming periodic signal (MATLAB "looped" mode). For LAOS data where the signal is periodic (steady-state oscillation), this uses centered differences everywhere by wrapping around at boundaries. Matches MATLAB SPPplus_numerical_v2.m "looped" differentiation mode. Parameters ---------- signal : Array Periodic input signal (e.g., one or more complete LAOS cycles) dt : float Time step between samples (s) step_size : int, optional Step size k for stencil (default: 1) Returns ------- d1 : Array First derivative d2 : Array Second derivative d3 : Array Third derivative Examples -------- >>> import jax.numpy as jnp >>> from rheojax.utils.spp_kernels import numerical_derivative_periodic >>> >>> # One complete period of sine wave >>> omega = 1.0 >>> t = jnp.linspace(0, 2*jnp.pi/omega, 1000, endpoint=False) >>> dt = t[1] - t[0] >>> signal = jnp.sin(omega * t) >>> >>> d1, d2, d3 = numerical_derivative_periodic(signal, dt) >>> # d1 ≈ omega * cos(omega*t) >>> # d2 ≈ -omega^2 * sin(omega*t) >>> # d3 ≈ -omega^3 * cos(omega*t) Notes ----- - Assumes signal represents complete periods (periodic boundary) - Uses higher-order centered differences for accuracy - More accurate than standard differentiation for periodic LAOS data """ signal_arr = jnp.atleast_1d(jnp.asarray(signal, dtype=jnp.float64)) k = step_size # step_size is static, so this is safe h = dt * k # Use JAX multiplication, not Python float() # Use jnp.roll for periodic boundary conditions (JIT-compatible) # First derivative: (-f[i+2k] + 8*f[i+k] - 8*f[i-k] + f[i-2k]) / (12*h) d1 = ( -jnp.roll(signal_arr, -2 * k) + 8 * jnp.roll(signal_arr, -k) - 8 * jnp.roll(signal_arr, k) + jnp.roll(signal_arr, 2 * k) ) / (12 * h) # Second derivative: (-f[i+2k] + 16*f[i+k] - 30*f[i] + 16*f[i-k] - f[i-2k]) / (12*h^2) d2 = ( -jnp.roll(signal_arr, -2 * k) + 16 * jnp.roll(signal_arr, -k) - 30 * signal_arr + 16 * jnp.roll(signal_arr, k) - jnp.roll(signal_arr, 2 * k) ) / (12 * h**2) # Third derivative: (-f[i+3k] + 8*f[i+2k] - 13*f[i+k] + 13*f[i-k] - 8*f[i-2k] + f[i-3k]) / (8*h^3) d3 = ( -jnp.roll(signal_arr, -3 * k) + 8 * jnp.roll(signal_arr, -2 * k) - 13 * jnp.roll(signal_arr, -k) + 13 * jnp.roll(signal_arr, k) - 8 * jnp.roll(signal_arr, 2 * k) + jnp.roll(signal_arr, 3 * k) ) / (8 * h**3) return d1, d2, d3
[docs] @partial(jax.jit, static_argnums=(4, 5)) def spp_numerical_analysis( strain: "Array", stress: "Array", omega: "float | Array", dt: float, step_size: int = 8, num_mode: int = 2, ) -> dict: """ Perform full SPP analysis using numerical differentiation (MATLAB-compatible). Implements the numerical SPP workflow from SPPplus_numerical_v2.m: 1. Compute strain rate from strain (or use provided) 2. Compute derivatives of [strain, rate, stress] trajectory 3. Calculate instantaneous ``G'_t`` and ``G''_t`` via cross-product formula 4. Extract ``tan(δ)_t``, phase angle, and displacement stress Parameters ---------- strain : Array Strain signal γ(t) (dimensionless) stress : Array Stress signal σ(t) (Pa) omega : float | Array Angular frequency ω (rad/s). Can be scalar or per-sample array. dt : float Time step between samples (s) step_size : int, optional Finite difference step size k (default: 8 for Rogers parity) num_mode : int, optional 1 = edge-aware (forward/backward + centered); 2 = periodic/looped (default). Returns ------- dict Dictionary containing: - Gp_t: Instantaneous storage modulus G'(t) (Pa) - Gpp_t: Instantaneous loss modulus G''(t) (Pa) - G_star_t: Instantaneous complex modulus ``|G*(t)|`` (Pa) - tan_delta_t: Instantaneous tan(δ)(t) - delta_t: Instantaneous phase angle δ(t) (radians) - disp_stress: Displacement stress (Pa) - eq_strain_est: Equivalent strain estimate Examples -------- >>> import jax.numpy as jnp >>> from rheojax.utils.spp_kernels import spp_numerical_analysis >>> >>> omega = 1.0 >>> t = jnp.linspace(0, 2*jnp.pi, 1000) >>> dt = t[1] - t[0] >>> gamma_0 = 1.0 >>> strain = gamma_0 * jnp.sin(omega * t) >>> # Linear viscoelastic response >>> stress = 100.0 * strain + 50.0 * gamma_0 * omega * jnp.cos(omega * t) >>> >>> result = spp_numerical_analysis(strain, stress, omega, dt) >>> # result['Gp_t'] ≈ 100.0 (constant for linear material) Notes ----- - Matches MATLAB SPPplus cross-product formulation - ``G'_t = -rd_x_rdd[:,0] / rd_x_rdd[:,2]`` - ``G''_t = -rd_x_rdd[:,1] / rd_x_rdd[:,2]`` - Works directly with raw experimental data (no Fourier decomposition) """ strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) # Handle scalar vs per-sample omega omega_arr = jnp.asarray(omega, dtype=jnp.float64) if omega_arr.ndim == 0: omega_arr = jnp.full_like(strain_arr, omega_arr) else: if omega_arr.shape[0] != strain_arr.shape[0]: raise ValueError( "omega array length must match strain length for numerical SPP" ) omega_scalar = jnp.mean(omega_arr) # Compute strain rate (normalize by omega as in MATLAB) strain_rate = differentiate_rate_from_strain( strain_arr, dt, step_size=step_size, looped=(num_mode == 2) ) strain_rate_normalized = strain_rate / omega_arr # Build response wave: [strain, rate/omega, stress] resp_wave = jnp.stack([strain_arr, strain_rate_normalized, stress_arr], axis=1) # Compute derivatives using periodic assumption (LAOS is periodic) # rd = first derivative, rdd = second derivative, rddd = third derivative if num_mode == 2: deriv_func = numerical_derivative_periodic rd_strain, rdd_strain, rddd_strain = deriv_func(resp_wave[:, 0], dt, step_size) rd_rate, rdd_rate, rddd_rate = deriv_func(resp_wave[:, 1], dt, step_size) rd_stress, rdd_stress, rddd_stress = deriv_func(resp_wave[:, 2], dt, step_size) else: # Edge-aware mode: compute derivatives separately using finite differences rd_strain = numerical_derivative( resp_wave[:, 0], dt, order=1, step_size=step_size ) rdd_strain = numerical_derivative( resp_wave[:, 0], dt, order=2, step_size=step_size ) rddd_strain = numerical_derivative( resp_wave[:, 0], dt, order=3, step_size=step_size ) rd_rate = numerical_derivative( resp_wave[:, 1], dt, order=1, step_size=step_size ) rdd_rate = numerical_derivative( resp_wave[:, 1], dt, order=2, step_size=step_size ) rddd_rate = numerical_derivative( resp_wave[:, 1], dt, order=3, step_size=step_size ) rd_stress = numerical_derivative( resp_wave[:, 2], dt, order=1, step_size=step_size ) rdd_stress = numerical_derivative( resp_wave[:, 2], dt, order=2, step_size=step_size ) rddd_stress = numerical_derivative( resp_wave[:, 2], dt, order=3, step_size=step_size ) # Stack into 3-column arrays rd = jnp.stack([rd_strain, rd_rate, rd_stress], axis=1) rdd = jnp.stack([rdd_strain, rdd_rate, rdd_stress], axis=1) rddd = jnp.stack([rddd_strain, rddd_rate, rddd_stress], axis=1) # Cross product: rd × rdd (MATLAB formula) rd_x_rdd = jnp.stack( [ rd[:, 1] * rdd[:, 2] - rd[:, 2] * rdd[:, 1], # x component rd[:, 2] * rdd[:, 0] - rd[:, 0] * rdd[:, 2], # y component rd[:, 0] * rdd[:, 1] - rd[:, 1] * rdd[:, 0], # z component ], axis=1, ) # Second cross product: rd × (rd × rdd) for Frenet-Serret frame rd_x_rd_x_rdd = jnp.stack( [ rd[:, 1] * rd_x_rdd[:, 2] - rd[:, 2] * rd_x_rdd[:, 1], rd[:, 2] * rd_x_rdd[:, 0] - rd[:, 0] * rd_x_rdd[:, 2], rd[:, 0] * rd_x_rdd[:, 1] - rd[:, 1] * rd_x_rdd[:, 0], ], axis=1, ) # Magnitudes for Frenet-Serret frame (+ 1e-30 guards sqrt(0) infinite gradient) eps = 1e-20 # Avoid division by zero mag_rd = jnp.sqrt(jnp.sum(rd**2, axis=1) + 1e-30) mag_rd_x_rdd = jnp.sqrt(jnp.sum(rd_x_rdd**2, axis=1) + 1e-30) # R11-SPP-KRN-001: Avoid sign(0)=0 at Frenet degeneracy # G'_t = -rd_x_rdd[:,0] / rd_x_rdd[:,2] # G''_t = -rd_x_rdd[:,1] / rd_x_rdd[:,2] denom = rd_x_rdd[:, 2] Gp_t = jnp.where( jnp.abs(denom) > eps, -rd_x_rdd[:, 0] / denom, jnp.nan, # Explicit NaN at degeneracy for callers to handle ) Gpp_t = jnp.where( jnp.abs(denom) > eps, -rd_x_rdd[:, 1] / denom, jnp.nan, # Explicit NaN at degeneracy for callers to handle ) # Moduli rates (MATLAB formula - Gap 4) # Gp_t_dot = -rd[:,1] * (rddd · rd_x_rdd) / rd_x_rdd[:,2]² # Gpp_t_dot = rd[:,0] * (rddd · rd_x_rdd) / rd_x_rdd[:,2]² rddd_dot_rd_x_rdd = jnp.sum(rddd * rd_x_rdd, axis=1) denom_sq = rd_x_rdd[:, 2] ** 2 denom_valid = jnp.abs(rd_x_rdd[:, 2]) > eps Gp_t_dot = jnp.where( denom_valid, -rd[:, 1] * rddd_dot_rd_x_rdd / jnp.maximum(denom_sq, eps), jnp.nan, ) Gpp_t_dot = jnp.where( denom_valid, rd[:, 0] * rddd_dot_rd_x_rdd / jnp.maximum(denom_sq, eps), jnp.nan, ) G_speed = jnp.sqrt(Gp_t_dot**2 + Gpp_t_dot**2 + 1e-30) # Complex modulus magnitude G_star_t = jnp.sqrt(Gp_t**2 + Gpp_t**2 + 1e-30) # Loss tangent and phase angle tan_delta_t = Gpp_t / jnp.maximum(jnp.abs(Gp_t), eps) * jnp.sign(Gp_t) # Phase angle via arctan2 — handles all four quadrants correctly # Ref: Rogers 2012, SPP framework, Eq. (5) delta_t = jnp.arctan2(Gpp_t, Gp_t) # Phase angle rate — Rogers 2012, Eq. (8) # delta_t_dot = (sigma' * sigma''' - sigma''^2) / (sigma'^2 + sigma''^2) # where prime = d/dt normalized by omega rd_tn = rd / omega_scalar rdd_tn = rdd / omega_scalar**2 rddd_tn = rddd / omega_scalar**3 sigma_prime = rd_tn[:, 2] # dσ/dt normalized sigma_dprime = rdd_tn[:, 2] # d²σ/dt² normalized sigma_tprime = rddd_tn[:, 2] # d³σ/dt³ normalized denom = jnp.maximum(sigma_prime**2 + sigma_dprime**2, eps) delta_t_dot = (sigma_prime * sigma_tprime - sigma_dprime**2) / denom # Displacement stress (MATLAB formula) disp_stress = stress_arr - (Gp_t * strain_arr + Gpp_t * strain_rate_normalized) # Equivalent strain estimate eq_strain_est = strain_arr - disp_stress / jnp.maximum(jnp.abs(Gp_t), eps) # Frenet-Serret frame (Gap 5) # T = tangent vector (normalized rd) # N = principal normal vector # B = binormal vector T_vec = rd / jnp.maximum(mag_rd[:, None], eps) N_vec = -rd_x_rd_x_rdd / jnp.maximum((mag_rd * mag_rd_x_rdd)[:, None], eps) B_vec = rd_x_rdd / jnp.maximum(mag_rd_x_rdd[:, None], eps) return { # Core SPP metrics "Gp_t": Gp_t, "Gpp_t": Gpp_t, "G_star_t": G_star_t, "tan_delta_t": tan_delta_t, "delta_t": delta_t, "disp_stress": disp_stress, "eq_strain_est": eq_strain_est, # Moduli rates (NEW - Gap 4) "Gp_t_dot": Gp_t_dot, "Gpp_t_dot": Gpp_t_dot, "G_speed": G_speed, "delta_t_dot": delta_t_dot, # Frenet-Serret frame (NEW - Gap 5) "T_vec": T_vec, "N_vec": N_vec, "B_vec": B_vec, # Intermediate values for debugging "strain_rate_normalized": strain_rate_normalized, "fsf_data_out": jnp.concatenate([T_vec, N_vec, B_vec], axis=1), # Reconstructions / time grid (numerical keeps original) "strain_recon": strain_arr, "rate_recon": strain_rate_normalized, "stress_recon": stress_arr, "time_new": jnp.arange(len(strain_arr)) * dt, }
# ============================================================================ # Displacement-Stress Yield Extraction (Gap 6) # ============================================================================
[docs] @jax.jit def yield_from_displacement_stress( disp_stress: "Array", strain: "Array", strain_rate: "Array", Gp_t: "Array", delta_t: "Array", strain_amplitude: float, rate_amplitude: float, ) -> dict: """ Extract yield stresses from displacement stress curve (SPP methodology). This implements the Donley et al. (2019) framework for yield stress extraction: - σ_sy (static yield): From displacement stress at G'(t) → 0 transition - σ_dy (dynamic yield): From displacement stress at δ(t) → π/2 transition This is more physically meaningful than simple geometric extraction. Parameters ---------- disp_stress : Array Displacement stress σ_disp = σ - (G'·γ + G''·γ̇/ω) (Pa) strain : Array Strain signal γ(t) (dimensionless) strain_rate : Array Strain rate signal γ̇(t)/ω (normalized, dimensionless) Gp_t : Array Instantaneous storage modulus G'(t) (Pa) delta_t : Array Instantaneous phase angle δ(t) (radians) strain_amplitude : float Maximum strain amplitude γ_0 (dimensionless) rate_amplitude : float Maximum strain rate amplitude γ̇_0 = ω * γ_0 (1/s) Returns ------- dict Dictionary containing: - sigma_sy: Static yield stress (Pa) - from G'(t) minima - sigma_dy: Dynamic yield stress (Pa) - from δ(t) → π/2 - yield_strain_sy: Strain at static yield - yield_strain_dy: Strain at dynamic yield - yield_indices_sy: Indices of static yield points - yield_indices_dy: Indices of dynamic yield points - sigma_sy_disp: Static yield from displacement stress peak - sigma_dy_disp: Dynamic yield from displacement stress at zero rate Notes ----- The SPP framework defines yield stresses based on the displacement stress: - Static yield occurs when the cage structure breaks (G'(t) → 0) - Dynamic yield occurs when flow ceases (δ(t) → π/2) This differs from simple geometric extraction (stress at extrema) and provides a more physically meaningful interpretation of the yielding transition. References ---------- G.J. Donley et al., "Time-resolved dynamics of the yielding transition in soft materials", J. Non-Newton. Fluid Mech. 264, 2019 """ disp_stress_arr = jnp.atleast_1d(jnp.asarray(disp_stress, dtype=jnp.float64)) strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) rate_arr = jnp.atleast_1d(jnp.asarray(strain_rate, dtype=jnp.float64)) Gp_t_arr = jnp.atleast_1d(jnp.asarray(Gp_t, dtype=jnp.float64)) delta_t_arr = jnp.atleast_1d(jnp.asarray(delta_t, dtype=jnp.float64)) gamma_0 = jnp.float64(strain_amplitude) # rate_amplitude is received but not used in current yield extraction methods # Reserved for future rate-dependent yield criteria _ = rate_amplitude # Explicitly acknowledge unused parameter eps = 1e-10 # ========================================================================= # Method 1: Static yield from G'(t) minima (cage breakage) # ========================================================================= # Find points where G'(t) is near its minimum (cage breaking) Gp_min = jnp.min(Gp_t_arr) Gp_max = jnp.max(Gp_t_arr) Gp_range = jnp.maximum(Gp_max - Gp_min, eps) # Threshold: within 10% of minimum near_Gp_min = Gp_t_arr < (Gp_min + 0.1 * Gp_range) # Static yield: stress magnitude at G'(t) minima stress_at_Gp_min = jnp.where(near_Gp_min, jnp.abs(disp_stress_arr), 0.0) count_sy = jnp.sum(near_Gp_min) sigma_sy = jnp.where( count_sy > 0, jnp.sum(stress_at_Gp_min) / count_sy, jnp.max(jnp.abs(disp_stress_arr)), ) # Find strain at static yield yield_strain_sy = jnp.where( count_sy > 0, jnp.sum(jnp.where(near_Gp_min, jnp.abs(strain_arr), 0.0)) / count_sy, gamma_0, ) # ========================================================================= # Method 2: Dynamic yield from δ(t) → π/2 (flow cessation) # ========================================================================= # Find points where δ(t) is near π/2 (viscous dominated) delta_threshold = jnp.pi / 2 - 0.1 # within ~6° of π/2 near_pi_half = delta_t_arr > delta_threshold # Dynamic yield: stress magnitude at δ → π/2 stress_at_delta_pi2 = jnp.where(near_pi_half, jnp.abs(disp_stress_arr), 0.0) count_dy = jnp.sum(near_pi_half) sigma_dy_from_delta = jnp.where( count_dy > 0, jnp.sum(stress_at_delta_pi2) / count_dy, 0.0 ) # ========================================================================= # Method 3: From displacement stress at strain/rate extrema (traditional) # ========================================================================= # Static: displacement stress at |γ| ≈ γ_0 near_max_strain = jnp.abs(strain_arr) >= 0.95 * gamma_0 disp_at_max_strain = jnp.where(near_max_strain, jnp.abs(disp_stress_arr), 0.0) count_max_strain = jnp.sum(near_max_strain) sigma_sy_disp = jnp.where( count_max_strain > 0, jnp.sum(disp_at_max_strain) / count_max_strain, jnp.max(jnp.abs(disp_stress_arr)), ) # Dynamic: displacement stress at |γ̇| ≈ 0 near_zero_rate = jnp.abs(rate_arr) <= 0.05 * jnp.max(jnp.abs(rate_arr)) disp_at_zero_rate = jnp.where(near_zero_rate, jnp.abs(disp_stress_arr), 0.0) count_zero_rate = jnp.sum(near_zero_rate) sigma_dy_disp = jnp.where( count_zero_rate > 0, jnp.sum(disp_at_zero_rate) / count_zero_rate, jnp.min(jnp.abs(disp_stress_arr)), ) # Dynamic yield: use the maximum of the two methods sigma_dy = jnp.maximum(sigma_dy_from_delta, sigma_dy_disp) # Find strain at dynamic yield (near zero rate) yield_strain_dy = jnp.where( count_zero_rate > 0, jnp.sum(jnp.where(near_zero_rate, jnp.abs(strain_arr), 0.0)) / count_zero_rate, 0.0, ) return { "sigma_sy": sigma_sy, "sigma_dy": sigma_dy, "yield_strain_sy": yield_strain_sy, "yield_strain_dy": yield_strain_dy, "yield_indices_sy": near_Gp_min, "yield_indices_dy": near_zero_rate, "sigma_sy_disp": sigma_sy_disp, "sigma_dy_disp": sigma_dy_disp, }
# ============================================================================ # Energy Integration (Gap 9) # ============================================================================
[docs] @jax.jit def calculate_loop_energy( stress: "Array", strain: "Array", ) -> dict: """ Calculate energy metrics from stress-strain loops via integration. Computes the dissipated energy (area of the hysteresis loop) and elastic energy metrics, ensuring parity with MATLAB SPP analysis. Parameters ---------- stress : Array Stress signal σ(t) (Pa) strain : Array Strain signal γ(t) (dimensionless) Returns ------- dict Dictionary containing: - dissipated_energy: Energy dissipated per cycle (Pa) [Loop Area] - dissipated_energy_density: Dissipated energy / π (Pa) - elastic_energy: Maximum stored elastic energy (Pa) [Estimated] Notes ----- - Dissipated Energy (E_d): Defined as the area enclosed by the Lissajous loop of stress vs strain: E_d = ∮ σ dγ. This represents the viscous energy loss per unit volume per cycle. - Elastic Energy (E_s): Estimated as the energy stored at maximum strain, approximated by 0.5 * σ(γ_max) * γ_max. This corresponds to the potential energy stored in the elastic structure at peak deformation. - Dissipated Energy Density: E_d / π. This is a normalized metric often used to compare dissipation across different amplitudes. The integration uses the trapezoidal rule on the time-ordered data points, which correctly computes the signed area of the closed loop. """ stress_arr = jnp.atleast_1d(jnp.asarray(stress, dtype=jnp.float64)) strain_arr = jnp.atleast_1d(jnp.asarray(strain, dtype=jnp.float64)) # Dissipated Energy (Loop Area) # E_d = ∮ σ dγ = ∫ σ(t) * (dγ/dt) dt # Use trapezoidal rule manually since jnp.trapz is deprecated/removed # sum(0.5 * (y[i] + y[i+1]) * (x[i+1] - x[i])) dy = strain_arr[1:] - strain_arr[:-1] y_avg = 0.5 * (stress_arr[1:] + stress_arr[:-1]) area = jnp.sum(y_avg * dy) # Take absolute value to ensure positive energy regardless of loop direction dissipated_energy = jnp.abs(area) # Dissipated Energy Density (Normalized by Pi) dissipated_energy_density = dissipated_energy / jnp.pi # Elastic Energy # Estimated as the energy stored at maximum strain: # E_s ≈ 0.5 * stress_at_max_strain * max_strain # This assumes linear-like storage behavior up to the peak. idx_max = jnp.argmax(jnp.abs(strain_arr)) stress_at_max = jnp.abs(stress_arr[idx_max]) strain_max = jnp.abs(strain_arr[idx_max]) elastic_energy = 0.5 * stress_at_max * strain_max return { "dissipated_energy": dissipated_energy, "dissipated_energy_density": dissipated_energy_density, "elastic_energy": elastic_energy, }
[docs] @jax.jit def frenet_serret_frame( rd: "Array", rdd: "Array", ) -> tuple["Array", "Array", "Array", "Array", "Array"]: """ Compute the Frenet-Serret frame (T, N, B) for a 3D trajectory. The Frenet-Serret frame provides a local coordinate system along the (γ, γ̇/ω, σ) trajectory, useful for understanding the geometry of the nonlinear response. Parameters ---------- rd : Array First derivative of response wave [d(γ)/dt, d(γ̇/ω)/dt, d(σ)/dt] Shape: (n_points, 3) rdd : Array Second derivative of response wave Shape: (n_points, 3) Returns ------- T_vec : Array Tangent vector (unit vector in direction of motion) N_vec : Array Principal normal vector (direction of curvature) B_vec : Array Binormal vector (``T × N``) curvature : Array Local curvature ``κ = |rd × rdd| / |rd|³`` torsion : Array Local torsion ``τ`` (requires third derivative, returns zeros) Notes ----- Formulas (matching MATLAB SPPplus):: T = rd / |rd| N = -(rd × (rd × rdd)) / (|rd| × |rd × rdd|) B = (rd × rdd) / |rd × rdd| κ = |rd × rdd| / |rd|³ """ rd_arr = jnp.asarray(rd, dtype=jnp.float64) rdd_arr = jnp.asarray(rdd, dtype=jnp.float64) eps = 1e-20 # Cross product: rd × rdd rd_x_rdd = jnp.stack( [ rd_arr[:, 1] * rdd_arr[:, 2] - rd_arr[:, 2] * rdd_arr[:, 1], rd_arr[:, 2] * rdd_arr[:, 0] - rd_arr[:, 0] * rdd_arr[:, 2], rd_arr[:, 0] * rdd_arr[:, 1] - rd_arr[:, 1] * rdd_arr[:, 0], ], axis=1, ) # Second cross product: rd × (rd × rdd) rd_x_rd_x_rdd = jnp.stack( [ rd_arr[:, 1] * rd_x_rdd[:, 2] - rd_arr[:, 2] * rd_x_rdd[:, 1], rd_arr[:, 2] * rd_x_rdd[:, 0] - rd_arr[:, 0] * rd_x_rdd[:, 2], rd_arr[:, 0] * rd_x_rdd[:, 1] - rd_arr[:, 1] * rd_x_rdd[:, 0], ], axis=1, ) # Magnitudes (+ 1e-30 guards sqrt(0) infinite gradient) mag_rd = jnp.sqrt(jnp.sum(rd_arr**2, axis=1) + 1e-30) mag_rd_x_rdd = jnp.sqrt(jnp.sum(rd_x_rdd**2, axis=1) + 1e-30) # Tangent vector: T = rd / |rd| T_vec = rd_arr / jnp.maximum(mag_rd[:, None], eps) # Principal normal: N = -(rd × (rd × rdd)) / (|rd| × |rd × rdd|) N_vec = -rd_x_rd_x_rdd / jnp.maximum((mag_rd * mag_rd_x_rdd)[:, None], eps) # Binormal: B = (rd × rdd) / |rd × rdd| B_vec = rd_x_rdd / jnp.maximum(mag_rd_x_rdd[:, None], eps) # Curvature: κ = |rd × rdd| / |rd|³ curvature = mag_rd_x_rdd / jnp.maximum(mag_rd**3, eps) # Torsion would require third derivative, return zeros for now torsion = jnp.zeros_like(curvature) return T_vec, N_vec, B_vec, curvature, torsion
# ============================================================================ # Export helpers (MATLAB-compatible schema) # ============================================================================ def build_spp_exports( time: np.ndarray, strain: np.ndarray, rate_over_omega: np.ndarray, stress: np.ndarray, metrics: dict, fsf_data_out: np.ndarray | None, spp_params: np.ndarray, ) -> dict: """Assemble MATLAB-compatible spp_data_out / fsf_data_out tables. Returns ------- dict with keys: spp_data_out, fsf_data_out, spp_params """ logger.debug( "Building SPP export tables", n_points=len(time), has_fsf_data=fsf_data_out is not None, n_metrics=len(metrics), ) spp_data_out = np.column_stack( [ time, strain, rate_over_omega, stress, metrics["Gp_t"], metrics["Gpp_t"], metrics["G_star_t"], metrics["tan_delta_t"], metrics["delta_t"], metrics["disp_stress"], metrics["eq_strain_est"], metrics.get("Gp_t_dot", np.full_like(time, np.nan)), metrics.get("Gpp_t_dot", np.full_like(time, np.nan)), metrics.get("G_speed", np.full_like(time, np.nan)), metrics.get("delta_t_dot", np.full_like(time, np.nan)), ] ) fsf_out = fsf_data_out if fsf_data_out is not None else None logger.info( "SPP export tables built", spp_data_shape=spp_data_out.shape, has_fsf_data=fsf_out is not None, ) return { "spp_data_out": spp_data_out, "fsf_data_out": fsf_out, "spp_params": spp_params, } # ============================================================================ # Data Preprocessing (Gap 7, 8) # ============================================================================
[docs] @partial(jax.jit, static_argnames=("step_size", "looped")) def differentiate_rate_from_strain( strain: "Array", dt: float, step_size: int = 8, looped: bool = True, ) -> "Array": """ Compute strain rate from strain via numerical differentiation. Provides a wrapped (periodic) 8-point stencil path to mirror the MATLAB/Rogers SPPplus implementation, while keeping the prior finite difference fallback for non-periodic data. Parameters ---------- strain : Array Strain signal γ(t) (dimensionless) dt : float Time step (s) step_size : int Finite difference step size ``k`` (default: 8, Rogers parity) looped : bool If True, use periodic derivative (wrapped); otherwise edge-aware. Returns ------- Array Strain rate γ̇(t) (1/s) Notes ----- - looped=True + step_size=8 matches SPPplus v2.1 wrapped 8-point rate inference when the rate column is absent. - looped=False preserves the previous 4th-order finite-difference path. """ if looped: d1, _, _ = numerical_derivative_periodic(strain, dt, step_size=step_size) return d1 return numerical_derivative_4th_order(strain, dt, order=1, step_size=step_size)
[docs] def convert_units( data: "Array", from_unit: str, to_unit: str, ) -> "Array": """ Convert data between common rheological units. Parameters ---------- data : Array Input data array from_unit : str Source unit (e.g., 'percent', 'mPa', 'rad', 'deg') to_unit : str Target unit (e.g., 'fraction', 'Pa', 'rad', 'deg') Returns ------- Array Converted data Examples -------- >>> strain_fraction = convert_units(strain_percent, 'percent', 'fraction') >>> stress_Pa = convert_units(stress_mPa, 'mPa', 'Pa') """ logger.debug( "Converting units", from_unit=from_unit, to_unit=to_unit, data_shape=getattr(data, "shape", "scalar"), ) data_arr = jnp.asarray(data, dtype=jnp.float64) # Define conversion factors (all lowercase for case-insensitive matching) # Note: "mpa" means milliPascal (mPa), not megaPascal (MPa) conversions = { # Strain conversions ("percent", "fraction"): 0.01, ("fraction", "percent"): 100.0, # Stress conversions (mPa = milliPascal, kPa = kiloPascal) ("mpa", "pa"): 0.001, # milliPascal to Pascal ("pa", "mpa"): 1000.0, # Pascal to milliPascal ("kpa", "pa"): 1000.0, # kiloPascal to Pascal ("pa", "kpa"): 0.001, # Pascal to kiloPascal # Angle conversions ("deg", "rad"): jnp.pi / 180.0, ("rad", "deg"): 180.0 / jnp.pi, # Time conversions ("ms", "s"): 0.001, ("s", "ms"): 1000.0, # Identity ("pa", "pa"): 1.0, ("fraction", "fraction"): 1.0, ("rad", "rad"): 1.0, ("s", "s"): 1.0, } key = (from_unit.lower(), to_unit.lower()) if key in conversions: logger.debug( "Unit conversion applied", conversion_key=key, factor=float(conversions[key]), ) return data_arr * conversions[key] else: # Return unchanged if conversion not found logger.debug( "No conversion found, returning unchanged data", from_unit=from_unit, to_unit=to_unit, ) return data_arr
# ============================================================================ # Convenience Exports # ============================================================================ __all__ = [ # Core SPP functions "apparent_cage_modulus", "static_yield_stress", "dynamic_yield_stress", "harmonic_reconstruction", "power_law_fit", "lissajous_metrics", "zero_crossing_indices", "harmonic_truncation_robustness", "spp_stress_decomposition", # Numerical differentiation "numerical_derivative", "numerical_derivative_4th_order", "numerical_derivative_periodic", # SPP analysis functions "spp_numerical_analysis", "spp_fourier_analysis", # Phase-aligned Fourier (NEW - Gap 2, 3) "compute_phase_offset", "harmonic_reconstruction_full", # Frenet-Serret frame (NEW - Gap 5) "frenet_serret_frame", # Displacement-stress yield extraction (NEW - Gap 6) "yield_from_displacement_stress", # Energy Integration (NEW - Gap 9) "calculate_loop_energy", # Data preprocessing (NEW - Gap 7, 8) "differentiate_rate_from_strain", "convert_units", ]