"""Smooth noise-robust numerical differentiation for rheological data.
This module provides noise-robust differentiation using Savitzky-Golay filtering
and other smoothing techniques, essential for converting between rheological
functions (e.g., creep compliance → relaxation modulus).
"""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Literal
import numpy as np
from scipy.signal import savgol_filter
from rheojax.core.base import BaseTransform
from rheojax.core.inventory import TransformType
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import TransformRegistry
from rheojax.logging import get_logger, log_transform
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
# Module logger
logger = get_logger(__name__)
if TYPE_CHECKING:
import jax.numpy as jnp_typing
from rheojax.core.data import RheoData
else: # pragma: no cover - typing fallback when JAX unavailable
jnp_typing = np
type JaxArray = jnp_typing.ndarray
DerivativeMethod = Literal["savgol", "finite_diff", "spline", "total_variation"]
[docs]
@TransformRegistry.register("smooth_derivative", type=TransformType.PROCESSING)
class SmoothDerivative(BaseTransform):
"""Smooth noise-robust numerical differentiation.
This transform computes derivatives of noisy rheological data using
regularization techniques to suppress noise amplification. Multiple
methods are available:
1. Savitzky-Golay: Fits local polynomials and computes analytical derivatives
2. Finite Difference: Simple finite differences with optional smoothing
3. Spline: Fits smoothing splines and computes derivatives
4. Total Variation: Regularized differentiation minimizing total variation
Savitzky-Golay is recommended for most applications as it preserves peak
positions better than simple smoothing while providing good noise suppression.
Common use cases:
- Creep compliance J(t) → relaxation modulus G(t) (via numerical inversion)
- Storage modulus G'(ω) → loss modulus G"(ω) via Kramers-Kronig
- Time-derivative of strain in controlled-strain experiments
Parameters
----------
method : DerivativeMethod, default='savgol'
Differentiation method
window_length : int, default=11
Window length for Savitzky-Golay or smoothing (must be odd)
polyorder : int, default=3
Polynomial order for Savitzky-Golay (must be < window_length)
deriv : int, default=1
Order of derivative (1, 2, 3, ...)
smooth_before : bool, default=False
Apply additional smoothing before differentiation
smooth_after : bool, default=False
Apply additional smoothing after differentiation
Examples
--------
>>> from rheojax.core.data import RheoData
>>> from rheojax.transforms.smooth_derivative import SmoothDerivative
>>>
>>> # Create noisy creep compliance data
>>> t = jnp.linspace(0.1, 10, 100)
>>> J_t = t + 0.1 * jnp.random.normal(size=len(t)) # Noisy linear creep
>>> data = RheoData(x=t, y=J_t, domain='time')
>>>
>>> # Compute smooth derivative
>>> deriv = SmoothDerivative(window_length=11, polyorder=3)
>>> dJ_dt = deriv.transform(data)
>>>
>>> # For higher-order derivatives
>>> deriv2 = SmoothDerivative(window_length=15, polyorder=4, deriv=2)
>>> d2J_dt2 = deriv2.transform(data)
"""
[docs]
def __init__(
self,
method: DerivativeMethod = "savgol",
window_length: int = 11,
polyorder: int = 3,
deriv: int = 1,
smooth_before: bool = False,
smooth_after: bool = False,
smooth_window: int = 5,
):
"""Initialize Smooth Derivative transform.
Parameters
----------
method : DerivativeMethod
Differentiation method
window_length : int
Window length (must be odd)
polyorder : int
Polynomial order for Savitzky-Golay
deriv : int
Derivative order
smooth_before : bool
Smooth before differentiation
smooth_after : bool
Smooth after differentiation
smooth_window : int
Smoothing window size
"""
super().__init__()
self.method = method
self.window_length = window_length
self.polyorder = polyorder
self.deriv = deriv
self.smooth_before = smooth_before
self.smooth_after = smooth_after
self.smooth_window = smooth_window
# Validate parameters
if self.window_length % 2 == 0:
raise ValueError("window_length must be odd")
if self.polyorder >= self.window_length:
raise ValueError("polyorder must be less than window_length")
if self.deriv < 1:
raise ValueError("deriv must be at least 1")
def _smooth_data(self, y: JaxArray, window: int) -> JaxArray:
"""Apply moving average smoothing.
Parameters
----------
y : jnp.ndarray
Data to smooth
window : int
Window size
Returns
-------
jnp.ndarray
Smoothed data
"""
if window % 2 == 0:
window += 1
kernel = jnp.ones(window) / window
smoothed = jnp.convolve(y, kernel, mode="same")
return smoothed
def _savgol_derivative(self, x: JaxArray, y: JaxArray) -> JaxArray:
"""Compute derivative using Savitzky-Golay filter.
Parameters
----------
x : jnp.ndarray
Independent variable
y : jnp.ndarray
Dependent variable
Returns
-------
jnp.ndarray
Derivative dy/dx
"""
# Convert to numpy for scipy
x_np = np.array(x) if isinstance(x, jnp.ndarray) else x
y_np = np.array(y) if isinstance(y, jnp.ndarray) else y
# Guard: need at least 2 points to compute dx[0], and window_length points
# for Savitzky-Golay. Fall back to finite differences for small arrays.
if len(x_np) < 2:
raise ValueError(
f"SmoothDerivative requires at least 2 data points, got {len(x_np)}"
)
if len(x_np) < self.window_length:
logger.warning(
"Data length < window_length; falling back to finite differences",
data_length=len(x_np),
window_length=self.window_length,
)
return self._finite_diff_derivative(x, y)
# Check if uniformly spaced
dx = np.diff(x_np)
is_uniform = np.allclose(dx, dx[0], rtol=1e-5)
if is_uniform:
# Use scipy's savgol_filter directly
delta = dx[0]
# R7-DERIV-002: Guard against zero spacing (duplicate x values).
# Cannot fall back to finite differences either since jnp.gradient
# also divides by dx — raise ValueError instead.
if abs(delta) < 1e-30:
raise ValueError(
"SmoothDerivative: uniform spacing is near-zero "
f"(delta={delta:.2e}). Data likely contains duplicate x "
"values. Remove duplicates before computing derivatives."
)
dy_dx = savgol_filter(
y_np,
window_length=self.window_length,
polyorder=self.polyorder,
deriv=self.deriv,
delta=delta,
)
else:
# Non-uniform spacing: use derivative of fitted polynomial
# This is more complex - use finite difference as fallback
dy_dx = self._finite_diff_derivative(x, y)
return jnp.array(dy_dx)
def _finite_diff_derivative(self, x: JaxArray, y: JaxArray) -> JaxArray:
"""Compute derivative using finite differences.
Parameters
----------
x : jnp.ndarray
Independent variable
y : jnp.ndarray
Dependent variable
Returns
-------
jnp.ndarray
Derivative dy/dx
"""
if self.deriv == 1:
# First derivative using central differences
dy_dx = jnp.gradient(y, x)
elif self.deriv == 2:
# Second derivative
dy_dx_1 = jnp.gradient(y, x)
dy_dx = jnp.gradient(dy_dx_1, x)
else:
# Higher-order derivatives (recursive)
dy_dx = y
for _ in range(self.deriv):
dy_dx = jnp.gradient(dy_dx, x)
return dy_dx
def _spline_derivative(self, x: JaxArray, y: JaxArray) -> JaxArray:
"""Compute derivative using JIT-safe cubic splines via interpax.
Parameters
----------
x : jnp.ndarray
Independent variable
y : jnp.ndarray
Dependent variable
Returns
-------
jnp.ndarray
Derivative dy/dx
"""
from interpax import CubicSpline
# Ensure JAX arrays
x_jax = jnp.asarray(x)
y_jax = jnp.asarray(y)
# Sort data if needed (interpax requires sorted x)
sort_idx = jnp.argsort(x_jax)
x_sorted = x_jax[sort_idx]
y_sorted = y_jax[sort_idx]
# Fit cubic spline (JIT-compatible)
spline = CubicSpline(x_sorted, y_sorted)
# Compute derivative at original x points
# interpax splines have .derivative() method
deriv_spline = spline.derivative(nu=self.deriv)
dy_dx = deriv_spline(x_sorted)
# Unsort if needed
unsort_idx = jnp.argsort(sort_idx)
dy_dx = dy_dx[unsort_idx]
return dy_dx
def _total_variation_derivative(
self, x: JaxArray, y: JaxArray, alpha: float = 0.1
) -> JaxArray:
"""Compute derivative with total variation regularization.
This minimizes:
||y - integral(u)||² + α * TV(u)
where u = dy/dx is the derivative.
.. warning::
Currently uses finite differences with smoothing as an
approximation, not true TV-regularised differentiation.
Parameters
----------
x : jnp.ndarray
Independent variable
y : jnp.ndarray
Dependent variable
alpha : float
Regularization parameter
Returns
-------
jnp.ndarray
Derivative dy/dx
"""
# TODO: Implement proper TV-regularised differentiation (Chambolle's algorithm)
warnings.warn(
"The 'total_variation' method currently uses finite differences with "
"smoothing as an approximation. For true TV-regularised differentiation, "
"use an external solver.",
stacklevel=2,
)
dy_dx = self._finite_diff_derivative(x, y)
# Apply TV denoising to the derivative
# (Requires convex optimization - use simple smoothing for now)
dy_dx = self._smooth_data(dy_dx, self.smooth_window)
return dy_dx
def _transform(self, data: RheoData) -> RheoData:
"""Compute smooth derivative of data.
Parameters
----------
data : RheoData
Input data
Returns
-------
RheoData
Derivative data
"""
from rheojax.core.data import RheoData
input_shape = (len(data.x),) if hasattr(data.x, "__len__") else (1,) # type: ignore[arg-type]
with log_transform(
logger,
"smooth_derivative",
input_shape=input_shape,
method=self.method,
derivative_order=self.deriv,
window_length=self.window_length,
) as ctx:
# Get data
x = data.x
y = data.y
# Convert to JAX arrays
if not isinstance(x, jnp.ndarray):
x = jnp.array(x)
if not isinstance(y, jnp.ndarray):
y = jnp.array(y)
logger.debug(
"Processing derivative input",
n_points=len(x),
dtype=str(y.dtype),
)
# Handle complex data
if jnp.iscomplexobj(y):
logger.debug("Taking real part of complex signal")
y = jnp.real(y)
# Pre-smoothing if requested
if self.smooth_before:
logger.debug("Applying pre-smoothing", window=self.smooth_window)
y = self._smooth_data(y, self.smooth_window)
# Compute derivative based on method
logger.debug("Computing derivative", method=self.method)
if self.method == "savgol":
dy_dx = self._savgol_derivative(x, y)
elif self.method == "finite_diff":
dy_dx = self._finite_diff_derivative(x, y)
elif self.method == "spline":
dy_dx = self._spline_derivative(x, y)
elif self.method == "total_variation":
dy_dx = self._total_variation_derivative(x, y)
else:
logger.error("Unknown differentiation method", method=self.method) # type: ignore[unreachable]
raise ValueError(f"Unknown method: {self.method}")
# Post-smoothing if requested
if self.smooth_after:
logger.debug("Applying post-smoothing", window=self.smooth_window)
dy_dx = self._smooth_data(dy_dx, self.smooth_window)
# Create new y_units
if data.y_units and data.x_units:
if self.deriv == 1:
new_y_units = f"d({data.y_units})/d({data.x_units})"
else:
new_y_units = (
f"d^{self.deriv}({data.y_units})/d({data.x_units})^{self.deriv}"
)
else:
new_y_units = f"derivative_order_{self.deriv}"
# Create metadata
new_metadata = (data.metadata or {}).copy()
new_metadata.update(
{
"transform": "derivative",
"method": self.method,
"derivative_order": self.deriv,
"window_length": self.window_length,
"polyorder": self.polyorder,
}
)
ctx["output_shape"] = (len(x),)
return RheoData(
x=x,
y=dy_dx,
x_units=data.x_units,
y_units=new_y_units,
domain=data.domain,
metadata=new_metadata,
validate=False,
)
def _inverse_transform(self, data: RheoData) -> RheoData:
"""Apply numerical integration (inverse of derivative).
Parameters
----------
data : RheoData
Derivative data
Returns
-------
RheoData
Integrated data (approximation of original)
"""
from scipy.integrate import cumulative_trapezoid as scipy_cumtrapz
from rheojax.core.data import RheoData
logger.debug("Starting numerical integration (inverse derivative)")
# Get data
x = data.x
dy_dx = data.y
# Convert to numpy for scipy
x_np = np.array(x) if isinstance(x, jnp.ndarray) else x
dy_dx_np = np.array(dy_dx) if isinstance(dy_dx, jnp.ndarray) else dy_dx
# Numerical integration (cumulative trapezoid)
y_integrated = scipy_cumtrapz(dy_dx_np, x_np, initial=0)
logger.debug(
"Integration completed",
n_points=len(y_integrated),
)
# Create metadata
new_metadata = (data.metadata or {}).copy()
new_metadata.update(
{"transform": "integral", "original_transform": "derivative"}
)
return RheoData(
x=x,
y=jnp.array(y_integrated),
x_units=data.x_units,
y_units="integrated",
domain=data.domain,
metadata=new_metadata,
validate=False,
)
[docs]
def estimate_noise_level(self, data: RheoData) -> float:
"""Estimate noise level in data.
This uses the median absolute deviation (MAD) of the second derivative
as a robust noise estimator.
Parameters
----------
data : RheoData
Input data
Returns
-------
float
Estimated noise standard deviation
"""
# Get data
x = data.x
y = data.y
# Convert to arrays
if not isinstance(x, jnp.ndarray):
x = jnp.array(x)
if not isinstance(y, jnp.ndarray):
y = jnp.array(y)
# R7-DERIV-001: Need at least 3 points for second derivative estimation
if len(x) < 3:
logger.warning(
"estimate_noise_level: need at least 3 data points, got %d",
len(x),
)
return 0.0
# Compute second derivative (amplifies noise)
d2y = jnp.gradient(jnp.gradient(y, x), x)
# MAD estimator
median = jnp.median(d2y)
mad = jnp.median(jnp.abs(d2y - median))
# Convert MAD to standard deviation (assumes Gaussian noise)
sigma = 1.4826 * mad
return float(sigma)
__all__ = ["SmoothDerivative"]