"""Optimally Windowed Chirp (OWChirp) transform for LAOS analysis.
This module implements the OWChirp transform for analyzing Large Amplitude
Oscillatory Shear (LAOS) data, providing time-frequency analysis and nonlinear
viscoelastic parameter extraction.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from rheojax.core.base import BaseTransform
from rheojax.core.inventory import TransformType
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import TransformRegistry
from rheojax.logging import get_logger
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
# Runtime Array class for isinstance checks
_JaxArray = jax.Array
if TYPE_CHECKING:
from jax import Array
from rheojax.core.data import RheoData
else:
Array = _JaxArray
# Module logger
logger = get_logger(__name__)
[docs]
@TransformRegistry.register("owchirp", type=TransformType.ANALYSIS)
class OWChirp(BaseTransform):
"""Optimally Windowed Chirp transform for LAOS data analysis.
The OWChirp transform uses chirp wavelets to perform time-frequency
analysis of Large Amplitude Oscillatory Shear (LAOS) data, extracting
nonlinear viscoelastic parameters and higher harmonics.
This is particularly useful for:
- Analyzing frequency-dependent nonlinear response
- Extracting time-varying moduli during LAOS
- Identifying structural changes during oscillatory deformation
- Higher harmonic analysis (3rd, 5th, 7th harmonics)
The transform uses a Morlet-like chirp wavelet that is optimally windowed
to balance time and frequency resolution.
Parameters
----------
n_frequencies : int, default=100
Number of frequency points for analysis
frequency_range : tuple, default=(1e-2, 1e2)
Frequency range (f_min, f_max) in Hz
wavelet_width : float, default=5.0
Width parameter for wavelet (controls time-frequency resolution)
extract_harmonics : bool, default=True
Whether to extract higher harmonics (3ω, 5ω, etc.)
max_harmonic : int, default=7
Maximum harmonic to extract (odd harmonics only)
Examples
--------
Basic usage:
>>> from rheojax.core.data import RheoData
>>> from rheojax.transforms.owchirp import OWChirp
>>>
>>> # LAOS stress response data
>>> t = jnp.linspace(0, 100, 10000)
>>> omega = 1.0 # rad/s
>>> # Nonlinear stress: includes 3rd harmonic
>>> stress = jnp.sin(omega * t) + 0.2 * jnp.sin(3 * omega * t)
>>> data = RheoData(x=t, y=stress, domain='time',
... metadata={'test_mode': 'oscillation'})
>>>
>>> # Apply OWChirp transform
>>> owchirp = OWChirp(n_frequencies=50, extract_harmonics=True)
>>> spectrum = owchirp.transform(data)
>>>
>>> # Extract nonlinear parameters
>>> harmonics = owchirp.get_harmonics(data)
"""
[docs]
def __init__(
self,
n_frequencies: int = 100,
frequency_range: tuple[float, float] = (1e-2, 1e2),
wavelet_width: float = 5.0,
extract_harmonics: bool = True,
max_harmonic: int = 7,
):
"""Initialize OWChirp transform.
Parameters
----------
n_frequencies : int
Number of frequency points
frequency_range : tuple
(f_min, f_max) in Hz
wavelet_width : float
Wavelet width parameter
extract_harmonics : bool
Extract higher harmonics
max_harmonic : int
Maximum harmonic order
"""
super().__init__()
self.n_frequencies = n_frequencies
self.frequency_range = frequency_range
self.wavelet_width = wavelet_width
self.extract_harmonics = extract_harmonics
self.max_harmonic = max_harmonic
def _chirp_wavelet(
self, t: Array, t_center: float, frequency: float | Array, width: float
) -> Array:
"""Generate chirp wavelet at given frequency.
The chirp wavelet is a Morlet-like wavelet with a Gaussian envelope:
ψ(t) = exp(-((t-t_c)/σ)²) * exp(2πi*f*t)
Parameters
----------
t : Array
Time array
t_center : float
Center time of wavelet
frequency : float
Frequency in Hz
width : float
Width parameter (controls localization)
Returns
-------
Array
Complex wavelet coefficients
"""
# R7-OWC-001: Guard against frequency=0 to prevent division by zero
# in sigma computation. Use a small epsilon floor.
freq_safe = jnp.maximum(frequency, 1e-30)
# Gaussian envelope width
sigma = width / (2.0 * jnp.pi * freq_safe)
# Gaussian envelope
envelope = jnp.exp(-0.5 * (((t - t_center) / sigma) ** 2))
# Complex exponential (chirp)
omega = 2.0 * jnp.pi * frequency
chirp = jnp.exp(1j * omega * t)
return envelope * chirp
def _wavelet_transform(self, t: Array, signal: Array, frequencies: Array) -> Array:
"""Compute wavelet transform of signal.
Uses vectorized JAX operations (vmap) instead of nested Python loops
for O(n_freqs * n_times) computation without Python-level overhead.
Parameters
----------
t : Array
Time array
signal : Array
Input signal
frequencies : Array
Frequency array
Returns
-------
Array
Wavelet coefficients (n_frequencies, n_times)
"""
# TRANS-001: Compute dt once (invariant to freq, t_center)
# R11-OWC-003: Use median dt for robustness to non-uniform sampling.
# Warn if spacing is non-uniform (>5% variation) — reduces CWT accuracy.
if len(t) > 1:
dt_arr = jnp.diff(t)
_dt_med = float(jnp.median(dt_arr))
_dt_std = float(jnp.std(dt_arr))
if _dt_med > 0 and (_dt_std / _dt_med) > 0.05:
import warnings as _warnings
_warnings.warn(
f"OWChirp (vmap path): non-uniform time spacing detected "
f"(std/median = {_dt_std / _dt_med:.2f}). "
"Results may be inaccurate.",
UserWarning,
stacklevel=2,
)
dt = jnp.where(len(t) > 1, jnp.median(jnp.diff(t)), 1.0)
# Vectorize over (freq, t_center) pairs using vmap
def compute_coeff(freq, t_center):
wavelet = self._chirp_wavelet(t, t_center, freq, self.wavelet_width)
return jnp.sum(signal * jnp.conj(wavelet)) * dt
# vmap over t_center (inner), then over freq (outer)
compute_row = jax.vmap(compute_coeff, in_axes=(None, 0)) # over t_centers
compute_all = jax.vmap(compute_row, in_axes=(0, None)) # over freqs
coefficients = compute_all(jnp.asarray(frequencies), t)
return coefficients
def _optimized_wavelet_transform(
self, t: Array, signal: Array, frequencies: Array
) -> Array:
"""Optimized wavelet transform using FFT convolution.
This is much faster than the direct method for long signals.
Parameters
----------
t : Array
Time array
signal : Array
Input signal
frequencies : Array
Frequency array
Returns
-------
Array
Wavelet coefficients
"""
if len(t) < 2:
raise ValueError("Wavelet transform requires at least 2 time points")
dt_arr = jnp.diff(t)
dt = float(jnp.median(dt_arr))
# R11-OWC-001: Use median dt for robustness to non-uniform sampling.
# Warn when spacing varies more than 5% — non-uniform dt reduces CWT accuracy.
dt_std = float(jnp.std(dt_arr))
if dt > 0 and (dt_std / dt) > 0.05:
import warnings as _warnings
_warnings.warn(
f"OWChirp: non-uniform time spacing detected "
f"(std/median = {dt_std / dt:.2f}). "
"FFT-based CWT assumes uniform dt — results may be inaccurate. "
"Interpolate to a uniform grid before transforming.",
UserWarning,
stacklevel=2,
)
# R10-OWC-002: Zero-pad to 2× length for linear (non-circular) correlation.
# The FFT-based cross-correlation is circular by default; zero-padding to at
# least 2N prevents wrap-around aliasing in the time domain.
n_orig = len(t)
n_pad = 2 * n_orig
# TR-01: Vectorized batched FFT — replaces the Python for-loop over
# frequencies (which issued 200 sequential FFT calls) with 3 batched
# operations: one fft on the wavelet matrix, one pointwise multiply, one
# ifft. Shape legend: F = n_frequencies, N = n_orig, P = n_pad.
# Build wavelet matrix (F, P) — all wavelets zero-padded in one shot.
# vmap over frequencies; each call produces a length-n_orig complex array
# that is then zero-padded to n_pad.
def _make_wavelet_row(freq: Array) -> Array:
"""Return zero-padded wavelet for a single frequency (shape: (n_pad,))."""
# Center at t=0 per R10-OWC-002 convention.
wavelet = self._chirp_wavelet(t, 0.0, freq, self.wavelet_width)
return jnp.pad(wavelet, (0, n_pad - n_orig))
# wavelet_matrix: (F, P)
wavelet_matrix = jax.vmap(_make_wavelet_row)(jnp.asarray(frequencies))
# Single batched FFT of all wavelets: (F, P)
wavelet_fft_matrix = jnp.fft.fft(wavelet_matrix, axis=-1)
# Signal: pad once and FFT once → (P,)
signal_padded = jnp.pad(signal, (0, n_pad - n_orig))
signal_fft = jnp.fft.fft(signal_padded) # (P,)
# Cross-correlation in frequency domain; broadcast signal_fft over F axis.
# signal_fft[None, :] is (1, P); result is (F, P).
conv_fft = signal_fft[None, :] * jnp.conj(wavelet_fft_matrix)
# Single batched IFFT: (F, P) → trim to (F, N)
conv_full = jnp.fft.ifft(conv_fft, axis=-1)
conv_trimmed = conv_full[:, :n_orig] # (F, N)
# Apply 1/√f scale normalization (standard L² CWT normalization).
# TR-02: Use jnp.maximum instead of Python max() — avoids device→host
# transfer when frequencies is a JAX array.
# R7-OWC-002: Guard against freq=0 (logspace guarantees positive values
# but defend against edge cases in direct calls).
freq_safe = jnp.maximum(jnp.asarray(frequencies), 1e-30) # (F,)
scale = jnp.sqrt(freq_safe)[:, None] # (F, 1) for broadcasting over N
coefficients = conv_trimmed / scale # (F, N)
return coefficients * dt
def _transform(self, data: RheoData) -> RheoData:
"""Apply OWChirp transform to LAOS data.
Parameters
----------
data : RheoData
Time-domain LAOS data (stress or strain)
Returns
-------
RheoData
Time-frequency spectrum
Raises
------
ValueError
If data is not time-domain
"""
from rheojax.core.data import RheoData
logger.info(
"Starting OWChirp transform",
n_frequencies=self.n_frequencies,
frequency_range=self.frequency_range,
wavelet_width=self.wavelet_width,
extract_harmonics=self.extract_harmonics,
)
# Validate domain
if data.domain != "time":
logger.error(
"Invalid domain for OWChirp",
expected="time",
got=data.domain,
)
raise ValueError("OWChirp requires time-domain data")
# Get time and signal
t = data.x
signal = data.y
logger.debug(
"Input data extracted",
data_points=len(t), # type: ignore[arg-type]
domain=data.domain,
)
# Convert to JAX arrays
if not isinstance(t, Array):
t = jnp.array(t)
if not isinstance(signal, Array):
signal = jnp.array(signal)
# Handle complex data
if jnp.iscomplexobj(signal):
logger.debug("Converting complex signal to real part")
signal = jnp.real(signal)
# R11-OWC-002: Remove DC offset to prevent spurious low-frequency peak
signal = signal - jnp.mean(signal)
# Generate frequency array (log-spaced)
logger.debug(
"Generating frequency array",
f_min=self.frequency_range[0],
f_max=self.frequency_range[1],
n_frequencies=self.n_frequencies,
)
frequencies = jnp.logspace(
jnp.log10(self.frequency_range[0]),
jnp.log10(self.frequency_range[1]),
self.n_frequencies,
)
# Compute wavelet transform (use optimized FFT method)
logger.debug("Computing optimized wavelet transform using FFT convolution")
coefficients = self._optimized_wavelet_transform(t, signal, frequencies)
# Compute magnitude spectrum (average over time)
# R8-OWC-001: averaging over time axis discards time-localization information;
# for time-resolved analysis, use the full 2D coefficients array directly.
logger.debug("Computing magnitude spectrum")
logger.info(
"Wavelet coefficients averaged over time axis. "
"For time-resolved analysis, use the raw coefficients from "
"_optimized_wavelet_transform() before averaging."
)
spectrum = jnp.mean(jnp.abs(coefficients), axis=1)
# Create metadata
new_metadata = (data.metadata or {}).copy()
new_metadata.update(
{
"transform": "owchirp",
"wavelet_width": self.wavelet_width,
"n_frequencies": self.n_frequencies,
"frequency_range": self.frequency_range,
# R9-OWC-001: Only the time-averaged spectrum is returned by
# _transform(). The full 2D map requires get_time_frequency_map().
"time_frequency_map": False,
}
)
logger.info(
"OWChirp transform completed",
output_frequencies=len(frequencies),
spectrum_shape=spectrum.shape,
)
# Return frequency-domain data (averaged)
return RheoData(
x=frequencies,
y=spectrum,
x_units="Hz",
y_units="magnitude",
domain="frequency",
metadata=new_metadata,
validate=False,
)
[docs]
def get_time_frequency_map(self, data: RheoData) -> tuple[Array, Array, Array]:
"""Get full time-frequency map (spectrogram).
Parameters
----------
data : RheoData
Time-domain LAOS data
Returns
-------
times : Array
Time array
frequencies : Array
Frequency array
coefficients : Array
Complex wavelet coefficients (n_frequencies, n_times)
"""
# Get time and signal
t = data.x
signal = data.y
# Convert to JAX arrays
if not isinstance(t, Array):
t = jnp.array(t)
if not isinstance(signal, Array):
signal = jnp.array(signal)
# Handle complex
if jnp.iscomplexobj(signal):
signal = jnp.real(signal)
# R11-OWC-002: Remove DC offset to prevent spurious low-frequency peak
signal = signal - jnp.mean(signal)
# Generate frequencies
frequencies = jnp.logspace(
jnp.log10(self.frequency_range[0]),
jnp.log10(self.frequency_range[1]),
self.n_frequencies,
)
# Compute wavelet transform
coefficients = self._optimized_wavelet_transform(t, signal, frequencies)
return t, frequencies, coefficients
[docs]
def get_harmonics(
self, data: RheoData, fundamental_freq: float | None = None
) -> dict:
"""Extract harmonic content from LAOS data.
Parameters
----------
data : RheoData
Time-domain LAOS data
fundamental_freq : float, optional
Fundamental frequency in Hz. If None, auto-detect from FFT peak.
Returns
-------
dict
Dictionary with harmonic amplitudes::
{'fundamental': (freq, amplitude),
'third': (3*freq, amplitude),
'fifth': (5*freq, amplitude),
...}
"""
logger.info(
"Extracting harmonics",
fundamental_freq=fundamental_freq,
max_harmonic=self.max_harmonic,
)
# Get frequency spectrum
freq_data = self.transform(data)
freqs = freq_data.x
spectrum = freq_data.y
# Convert to numpy for peak detection
if isinstance(freqs, Array):
freqs = np.array(freqs)
if isinstance(spectrum, Array):
spectrum = np.array(spectrum)
# Find fundamental frequency if not provided
if fundamental_freq is None:
logger.debug("Auto-detecting fundamental frequency from FFT peak")
# Find peak in spectrum
from scipy.signal import find_peaks
# Use lower prominence threshold (1% of max) to detect peaks more reliably
peaks, properties = find_peaks(spectrum, prominence=0.01 * np.max(spectrum))
if len(peaks) == 0:
# Fallback: use the frequency with maximum amplitude
logger.debug(
"No peaks detected with prominence threshold, using max amplitude"
)
max_idx = np.argmax(spectrum)
fundamental_freq = float(freqs[max_idx])
logger.debug(
"Fundamental frequency from max amplitude",
fundamental_freq=fundamental_freq,
)
else:
# Fundamental is typically the strongest peak
strongest_peak = peaks[np.argmax(spectrum[peaks])]
fundamental_freq = float(freqs[strongest_peak])
logger.debug(
"Fundamental frequency detected",
fundamental_freq=fundamental_freq,
n_peaks_found=len(peaks),
)
# Extract harmonics
harmonics = {}
harmonics["fundamental"] = (
fundamental_freq,
self._get_amplitude_at_freq(freqs, spectrum, fundamental_freq),
)
if self.extract_harmonics:
logger.debug(
"Extracting odd harmonics",
max_harmonic=self.max_harmonic,
)
# Extract odd harmonics up to max_harmonic
for n in range(3, self.max_harmonic + 1, 2):
harmonic_freq = n * fundamental_freq
amplitude = self._get_amplitude_at_freq(freqs, spectrum, harmonic_freq)
harmonic_name = {3: "third", 5: "fifth", 7: "seventh", 9: "ninth"}
if n in harmonic_name:
harmonics[harmonic_name[n]] = (harmonic_freq, amplitude)
logger.info(
"Harmonic extraction completed",
n_harmonics=len(harmonics),
fundamental_freq=fundamental_freq,
)
return harmonics
def _get_amplitude_at_freq(
self,
freqs: np.ndarray,
spectrum: np.ndarray,
target_freq: float,
window: float = 0.1,
) -> float:
"""Get amplitude at specific frequency (with averaging window).
Parameters
----------
freqs : np.ndarray
Frequency array
spectrum : np.ndarray
Spectrum values
target_freq : float
Target frequency
window : float
Fractional window for averaging (e.g., 0.1 = ±10%)
Returns
-------
float
Amplitude at target frequency
"""
# Find frequencies within window
f_min = target_freq * (1 - window)
f_max = target_freq * (1 + window)
mask = (freqs >= f_min) & (freqs <= f_max)
if np.sum(mask) == 0:
return 0.0
# Return maximum in window
return float(np.max(spectrum[mask]))
__all__ = ["OWChirp"]