"""FFT-based frequency analysis transform for rheological data.
This module provides FFT analysis to convert time-domain rheological data
(relaxation, creep) to frequency domain for spectral analysis and feature extraction.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import numpy as np
from rheojax.core.base import BaseTransform
from rheojax.core.inventory import TransformType
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import TransformRegistry
from rheojax.logging import get_logger, log_transform
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
# Module logger
logger = get_logger(__name__)
if TYPE_CHECKING:
import jax.numpy as jnp_typing
from rheojax.core.data import RheoData
else: # pragma: no cover - typing fallback when JAX missing at runtime
jnp_typing = np
type JaxArray = jnp_typing.ndarray
WindowType = Literal["hann", "hamming", "blackman", "bartlett", "none"]
[docs]
@TransformRegistry.register("fft_analysis", type=TransformType.SPECTRAL)
class FFTAnalysis(BaseTransform):
"""Transform time-domain rheological data to frequency domain using FFT.
This transform applies Fast Fourier Transform to convert time-domain signals
to frequency domain, enabling analysis of characteristic frequencies, relaxation
time distributions, and spectral features.
Features:
- Multiple window functions (Hann, Hamming, Blackman, Bartlett)
- Optional detrending to remove DC offset
- Power spectral density (PSD) calculation
- Peak detection for characteristic frequencies
- JAX-accelerated computation
Parameters
----------
window : WindowType, default='hann'
Window function to apply before FFT. Options: 'hann', 'hamming',
'blackman', 'bartlett', 'none'
detrend : bool, default=True
Whether to remove linear trend before FFT
return_psd : bool, default=False
If True, return power spectral density instead of FFT magnitude
normalize : bool, default=True
Whether to normalize the FFT result
Examples
--------
>>> from rheojax.core.data import RheoData
>>> from rheojax.transforms.fft_analysis import FFTAnalysis
>>>
>>> # Create time-domain relaxation data
>>> t = jnp.linspace(0, 10, 1000)
>>> G_t = jnp.exp(-t/2.0) # Exponential relaxation
>>> data = RheoData(x=t, y=G_t, domain='time')
>>>
>>> # Apply FFT analysis
>>> fft = FFTAnalysis(window='hann', detrend=True)
>>> freq_data = fft.transform(data)
>>>
>>> # freq_data.x contains frequencies, freq_data.y contains spectrum
"""
[docs]
def __init__(
self,
window: WindowType = "hann",
detrend: bool = True,
return_psd: bool = False,
normalize: bool = True,
):
"""Initialize FFT Analysis transform.
Parameters
----------
window : WindowType
Window function to apply
detrend : bool
Whether to detrend data before FFT
return_psd : bool
Return power spectral density instead of magnitude
normalize : bool
Normalize FFT output
"""
super().__init__()
self.window = window
self.detrend = detrend
self.return_psd = return_psd
self.normalize = normalize
def _get_window(self, n: int) -> JaxArray:
"""Get window function of length n.
Parameters
----------
n : int
Length of window
Returns
-------
jnp.ndarray
Window coefficients
"""
if self.window == "hann":
return jnp.hanning(n)
elif self.window == "hamming":
return jnp.hamming(n)
elif self.window == "blackman":
return jnp.blackman(n)
elif self.window == "bartlett":
return jnp.bartlett(n)
elif self.window == "none":
return jnp.ones(n)
else:
raise ValueError(f"Unknown window type: {self.window}")
def _detrend_data(self, y: JaxArray) -> JaxArray:
"""Remove linear trend from data.
Parameters
----------
y : jnp.ndarray
Input data (must have at least 2 points; single-point arrays are
returned unchanged since no trend can be estimated)
Returns
-------
jnp.ndarray
Detrended data
"""
n = len(y)
# Guard: need at least 2 points to fit a linear trend
if n < 2:
return y
# Fit linear trend: y = a*x + b
x = jnp.arange(n, dtype=jnp.float64)
# Linear regression
x_mean = jnp.mean(x)
y_mean = jnp.mean(y)
denom = jnp.sum((x - x_mean) ** 2)
# Guard against degenerate case (all x identical, which cannot happen
# for arange but is defended against for robustness)
slope = jnp.where(
denom > 1e-30,
jnp.sum((x - x_mean) * (y - y_mean)) / denom,
0.0,
)
intercept = y_mean - slope * x_mean
# Remove trend
trend = slope * x + intercept
return y - trend
def _transform(self, data: RheoData) -> RheoData:
"""Apply FFT transform to time-domain data.
Parameters
----------
data : RheoData
Input time-domain data
Returns
-------
RheoData
Frequency-domain data with FFT spectrum
Raises
------
ValueError
If data is already in frequency domain
"""
from rheojax.core.data import RheoData
input_shape = (len(data.x),) if hasattr(data.x, "__len__") else (1,) # type: ignore[arg-type]
with log_transform(
logger,
"fft_analysis",
input_shape=input_shape,
window=self.window,
detrend=self.detrend,
return_psd=self.return_psd,
) as ctx:
# Validate domain
if data.domain == "frequency":
logger.error(
"FFT analysis requires time-domain data",
current_domain=data.domain,
)
raise ValueError("FFT analysis requires time-domain data")
# Get time and signal data
t = data.x
y = data.y
# Convert to JAX arrays for computation
if not isinstance(t, jnp.ndarray):
t = jnp.array(t)
if not isinstance(y, jnp.ndarray):
y = jnp.array(y)
logger.debug("Processing FFT input", n_points=len(t), dtype=str(y.dtype))
# Guard: minimum data size (must run BEFORE detrend to avoid zero-division)
n = len(t)
if n < 2:
raise ValueError("FFT requires at least 2 data points")
# Handle complex data by taking real part
if jnp.iscomplexobj(y):
logger.debug("Taking real part of complex signal")
y = jnp.real(y)
# Detrend if requested
if self.detrend:
logger.debug("Applying detrending")
y = self._detrend_data(y)
# Compute frequencies and check spacing before windowing/FFT
# R8-FFT-001: use median diff for robust dt estimation (handles log-spaced time)
dt_values = np.diff(np.asarray(t))
dt = float(np.median(dt_values))
if dt <= 0:
raise ValueError(
"Time array must be monotonically increasing for FFT "
f"(got dt={dt:.3e})"
)
if np.std(dt_values) / dt > 0.1: # >10% variation in spacing
logger.warning(
"Non-uniform time spacing detected (std/mean=%.2f). "
"FFT requires uniform spacing — interpolating to uniform grid.",
np.std(dt_values) / dt,
)
t_np = np.asarray(t)
y_np = np.asarray(y)
t_uniform = np.linspace(float(t_np[0]), float(t_np[-1]), n)
# R9-FFT-001: TODO — Replace np.interp with interpax.interp1d for JIT compatibility.
# Currently safe since this runs before jnp.fft.rfft (outside JIT).
y = jnp.array(np.interp(t_uniform, t_np, y_np))
t = jnp.array(t_uniform)
dt = float(t_uniform[1] - t_uniform[0])
# R9-FFT-002: recompute n after resampling so rfftfreq and window
# length stay consistent with the (possibly resampled) y array.
n = len(y)
freqs = jnp.fft.rfftfreq(n, d=dt)
# Apply window function
logger.debug("Applying window function", window=self.window)
window = self._get_window(len(y))
y_windowed = y * window
# Compute FFT
# Use rfft for real signals (more efficient)
logger.debug("Computing FFT")
fft_result = jnp.fft.rfft(y_windowed)
# Compute magnitude or PSD
if self.return_psd:
# Power spectral density
logger.debug("Computing power spectral density")
spectrum = jnp.abs(fft_result) ** 2 / (n * dt)
# One-sided PSD: double non-DC, non-Nyquist bins
spectrum = spectrum.at[1:-1].set(spectrum[1:-1] * 2.0)
else:
# Magnitude spectrum
logger.debug("Computing magnitude spectrum")
spectrum = jnp.abs(fft_result)
# Normalize if requested
if self.normalize and not self.return_psd:
# R8-FFT-002: normalize=True scales spectrum to [0,1], discarding physical
# amplitude units. For quantitative harmonic analysis, use normalize=False.
# Consider changing default to normalize=False in a future version.
logger.debug("Normalizing spectrum")
max_val = jnp.max(spectrum)
spectrum = jnp.where(max_val > 1e-12, spectrum / max_val, spectrum)
# Create metadata
new_metadata = (data.metadata or {}).copy()
new_metadata.update(
{
"transform": "fft",
"window": self.window,
"detrended": self.detrend,
"psd": self.return_psd,
"original_domain": "time",
"n_points": len(t),
"dt": float(dt),
# Store complex coefficients as serializable list (T-010)
"fft_complex": fft_result.tolist(),
# T-011: inverse FFT reconstructs the windowed signal, not
# the original. For lossless round-trip, use window='none'.
"_windowed": self.window != "none",
}
)
ctx["output_shape"] = (len(freqs),)
ctx["frequency_range"] = (float(freqs[0]), float(freqs[-1]))
# Create new RheoData in frequency domain
return RheoData(
x=freqs,
y=spectrum,
x_units="Hz" if data.x_units else None,
y_units="PSD" if self.return_psd else "magnitude",
domain="frequency",
metadata=new_metadata,
validate=False,
)
def _inverse_transform(self, data: RheoData) -> RheoData:
"""Apply inverse FFT to return to time domain.
Parameters
----------
data : RheoData
Frequency-domain data
Returns
-------
RheoData
Time-domain data
Raises
------
ValueError
If data is not in frequency domain or missing required metadata
"""
from rheojax.core.data import RheoData
logger.debug("Starting inverse FFT transform")
if data.domain != "frequency":
logger.error(
"Inverse FFT requires frequency-domain data",
current_domain=data.domain,
)
raise ValueError("Inverse FFT requires frequency-domain data")
_fft_meta = data.metadata or {}
if "transform" not in _fft_meta or _fft_meta["transform"] != "fft":
logger.error(
"Data was not created by FFT transform",
# FFT-INV-001: use _fft_meta (not data.metadata) — data.metadata
# may be None when the caller passes a plain RheoData without
# metadata, and None.get() raises AttributeError.
metadata_transform=_fft_meta.get("transform"),
)
raise ValueError("Data was not created by FFT transform")
# Get original parameters
# _fft_meta is guaranteed non-None here (passed the transform check above)
n_points = _fft_meta.get("n_points")
dt = _fft_meta.get("dt")
fft_complex = _fft_meta.get("fft_complex")
if n_points is None or dt is None:
logger.error(
"Missing metadata for inverse FFT",
has_n_points=n_points is not None,
has_dt=dt is not None,
)
raise ValueError("Missing metadata for inverse FFT (n_points, dt)")
if fft_complex is None:
logger.error("Missing complex FFT coefficients for inverse transform")
raise ValueError("Missing complex FFT coefficients for inverse transform")
logger.debug("Performing inverse FFT", n_points=n_points, dt=dt)
# Use the stored complex coefficients for accurate reconstruction
# Convert from serializable list back to JAX array (T-010)
fft_complex = jnp.array(fft_complex)
y_reconstructed = jnp.fft.irfft(fft_complex, n=n_points)
# Reconstruct time array
t = jnp.arange(n_points) * dt
# Create metadata
new_metadata = (data.metadata or {}).copy()
new_metadata.update({"transform": "ifft", "original_domain": "frequency"})
logger.debug("Inverse FFT completed", output_points=len(y_reconstructed))
return RheoData(
x=t,
y=y_reconstructed,
x_units="s" if data.x_units else None,
y_units="reconstructed",
domain="time",
metadata=new_metadata,
validate=False,
)
[docs]
def find_peaks(
self, freq_data: RheoData, prominence: float = 0.1, n_peaks: int = 5
) -> tuple[JaxArray, JaxArray]:
"""Find characteristic frequency peaks in FFT spectrum.
Parameters
----------
freq_data : RheoData
Frequency-domain data from FFT
prominence : float, default=0.1
Minimum prominence for peak detection (relative to max)
n_peaks : int, default=5
Maximum number of peaks to return
Returns
-------
peak_freqs : JaxArray
Frequencies of detected peaks
peak_heights : JaxArray
Heights of detected peaks
"""
logger.debug(
"Finding peaks in FFT spectrum",
prominence=prominence,
n_peaks=n_peaks,
)
freqs = np.asarray(freq_data.x)
spectrum = np.asarray(freq_data.y)
if len(spectrum) == 0:
logger.debug("Empty spectrum — no peaks to detect")
return jnp.array([], dtype=jnp.float64), jnp.array([], dtype=jnp.float64)
# Simple peak detection: find local maxima
from scipy.signal import find_peaks as scipy_find_peaks
# Normalize spectrum for prominence calculation
all_nan = np.all(np.isnan(spectrum))
if all_nan:
logger.debug("All-NaN spectrum — no peaks to detect")
return jnp.array([], dtype=jnp.float64), jnp.array([], dtype=jnp.float64)
max_val = np.nanmax(spectrum)
spectrum_norm = spectrum / max_val if max_val > 1e-12 else spectrum
# Find peaks
peak_indices, properties = scipy_find_peaks(
spectrum_norm, prominence=prominence
)
logger.debug("Initial peaks found", n_peaks_found=len(peak_indices))
# Sort by prominence and take top n_peaks
if len(peak_indices) > n_peaks:
prominences = properties["prominences"]
top_indices = np.argsort(prominences)[-n_peaks:]
peak_indices = peak_indices[top_indices]
peak_freqs = freqs[peak_indices]
peak_heights = spectrum[peak_indices]
logger.debug(
"Peak detection completed",
n_peaks_returned=len(peak_freqs),
peak_frequencies=peak_freqs.tolist() if len(peak_freqs) > 0 else [],
)
# Convert back to JAX
return jnp.array(peak_freqs), jnp.array(peak_heights)
[docs]
def get_characteristic_time(self, freq_data: RheoData) -> float:
"""Extract characteristic time from FFT peak frequency.
Parameters
----------
freq_data : RheoData
Frequency-domain data from FFT
Returns
-------
float
Characteristic time (1 / peak_frequency)
"""
peak_freqs, peak_heights = self.find_peaks(freq_data, n_peaks=1)
if len(peak_freqs) == 0:
# No peak found, return NaN
return float("nan")
# Characteristic time is inverse of dominant frequency
return 1.0 / float(peak_freqs[0])
__all__ = ["FFTAnalysis"]