Source code for rheojax.transforms.smooth_derivative

"""Smooth noise-robust numerical differentiation for rheological data.

This module provides noise-robust differentiation using Savitzky-Golay filtering
and other smoothing techniques, essential for converting between rheological
functions (e.g., creep compliance → relaxation modulus).
"""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np
from scipy.signal import savgol_filter

from rheojax.core.base import BaseTransform
from rheojax.core.inventory import TransformType
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import TransformRegistry
from rheojax.logging import get_logger, log_transform

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()

# Module logger
logger = get_logger(__name__)

if TYPE_CHECKING:
    import jax.numpy as jnp_typing

    from rheojax.core.data import RheoData
else:  # pragma: no cover - typing fallback when JAX unavailable
    jnp_typing = np

type JaxArray = jnp_typing.ndarray


DerivativeMethod = Literal["savgol", "finite_diff", "spline", "total_variation"]


[docs] @TransformRegistry.register("smooth_derivative", type=TransformType.PROCESSING) class SmoothDerivative(BaseTransform): """Smooth noise-robust numerical differentiation. This transform computes derivatives of noisy rheological data using regularization techniques to suppress noise amplification. Multiple methods are available: 1. Savitzky-Golay: Fits local polynomials and computes analytical derivatives 2. Finite Difference: Simple finite differences with optional smoothing 3. Spline: Fits smoothing splines and computes derivatives 4. Total Variation: Regularized differentiation minimizing total variation Savitzky-Golay is recommended for most applications as it preserves peak positions better than simple smoothing while providing good noise suppression. Common use cases: - Creep compliance J(t) → relaxation modulus G(t) (via numerical inversion) - Storage modulus G'(ω) → loss modulus G"(ω) via Kramers-Kronig - Time-derivative of strain in controlled-strain experiments Parameters ---------- method : DerivativeMethod, default='savgol' Differentiation method window_length : int, default=11 Window length for Savitzky-Golay or smoothing (must be odd) polyorder : int, default=3 Polynomial order for Savitzky-Golay (must be < window_length) deriv : int, default=1 Order of derivative (1, 2, 3, ...) smooth_before : bool, default=False Apply additional smoothing before differentiation smooth_after : bool, default=False Apply additional smoothing after differentiation Examples -------- >>> from rheojax.core.data import RheoData >>> from rheojax.transforms.smooth_derivative import SmoothDerivative >>> >>> # Create noisy creep compliance data >>> t = jnp.linspace(0.1, 10, 100) >>> J_t = t + 0.1 * jnp.random.normal(size=len(t)) # Noisy linear creep >>> data = RheoData(x=t, y=J_t, domain='time') >>> >>> # Compute smooth derivative >>> deriv = SmoothDerivative(window_length=11, polyorder=3) >>> dJ_dt = deriv.transform(data) >>> >>> # For higher-order derivatives >>> deriv2 = SmoothDerivative(window_length=15, polyorder=4, deriv=2) >>> d2J_dt2 = deriv2.transform(data) """
[docs] def __init__( self, method: DerivativeMethod = "savgol", window_length: int = 11, polyorder: int = 3, deriv: int = 1, smooth_before: bool = False, smooth_after: bool = False, smooth_window: int = 5, ): """Initialize Smooth Derivative transform. Parameters ---------- method : DerivativeMethod Differentiation method window_length : int Window length (must be odd) polyorder : int Polynomial order for Savitzky-Golay deriv : int Derivative order smooth_before : bool Smooth before differentiation smooth_after : bool Smooth after differentiation smooth_window : int Smoothing window size """ super().__init__() self.method = method self.window_length = window_length self.polyorder = polyorder self.deriv = deriv self.smooth_before = smooth_before self.smooth_after = smooth_after self.smooth_window = smooth_window # Validate parameters if self.window_length % 2 == 0: raise ValueError("window_length must be odd") if self.polyorder >= self.window_length: raise ValueError("polyorder must be less than window_length") if self.deriv < 1: raise ValueError("deriv must be at least 1")
def _smooth_data(self, y: JaxArray, window: int) -> JaxArray: """Apply moving average smoothing. Parameters ---------- y : jnp.ndarray Data to smooth window : int Window size Returns ------- jnp.ndarray Smoothed data """ if window % 2 == 0: window += 1 kernel = jnp.ones(window) / window smoothed = jnp.convolve(y, kernel, mode="same") return smoothed def _savgol_derivative(self, x: JaxArray, y: JaxArray) -> JaxArray: """Compute derivative using Savitzky-Golay filter. Parameters ---------- x : jnp.ndarray Independent variable y : jnp.ndarray Dependent variable Returns ------- jnp.ndarray Derivative dy/dx """ # Convert to numpy for scipy x_np = np.array(x) if isinstance(x, jnp.ndarray) else x y_np = np.array(y) if isinstance(y, jnp.ndarray) else y # Guard: need at least 2 points to compute dx[0], and window_length points # for Savitzky-Golay. Fall back to finite differences for small arrays. if len(x_np) < 2: raise ValueError( f"SmoothDerivative requires at least 2 data points, got {len(x_np)}" ) if len(x_np) < self.window_length: logger.warning( "Data length < window_length; falling back to finite differences", data_length=len(x_np), window_length=self.window_length, ) return self._finite_diff_derivative(x, y) # Check if uniformly spaced dx = np.diff(x_np) is_uniform = np.allclose(dx, dx[0], rtol=1e-5) if is_uniform: # Use scipy's savgol_filter directly delta = dx[0] # R7-DERIV-002: Guard against zero spacing (duplicate x values). # Cannot fall back to finite differences either since jnp.gradient # also divides by dx — raise ValueError instead. if abs(delta) < 1e-30: raise ValueError( "SmoothDerivative: uniform spacing is near-zero " f"(delta={delta:.2e}). Data likely contains duplicate x " "values. Remove duplicates before computing derivatives." ) dy_dx = savgol_filter( y_np, window_length=self.window_length, polyorder=self.polyorder, deriv=self.deriv, delta=delta, ) else: # Non-uniform spacing: use derivative of fitted polynomial # This is more complex - use finite difference as fallback dy_dx = self._finite_diff_derivative(x, y) return jnp.array(dy_dx) def _finite_diff_derivative(self, x: JaxArray, y: JaxArray) -> JaxArray: """Compute derivative using finite differences. Parameters ---------- x : jnp.ndarray Independent variable y : jnp.ndarray Dependent variable Returns ------- jnp.ndarray Derivative dy/dx """ if self.deriv == 1: # First derivative using central differences dy_dx = jnp.gradient(y, x) elif self.deriv == 2: # Second derivative dy_dx_1 = jnp.gradient(y, x) dy_dx = jnp.gradient(dy_dx_1, x) else: # Higher-order derivatives (recursive) dy_dx = y for _ in range(self.deriv): dy_dx = jnp.gradient(dy_dx, x) return dy_dx def _spline_derivative(self, x: JaxArray, y: JaxArray) -> JaxArray: """Compute derivative using JIT-safe cubic splines via interpax. Parameters ---------- x : jnp.ndarray Independent variable y : jnp.ndarray Dependent variable Returns ------- jnp.ndarray Derivative dy/dx """ from interpax import CubicSpline # Ensure JAX arrays x_jax = jnp.asarray(x) y_jax = jnp.asarray(y) # Sort data if needed (interpax requires sorted x) sort_idx = jnp.argsort(x_jax) x_sorted = x_jax[sort_idx] y_sorted = y_jax[sort_idx] # Fit cubic spline (JIT-compatible) spline = CubicSpline(x_sorted, y_sorted) # Compute derivative at original x points # interpax splines have .derivative() method deriv_spline = spline.derivative(nu=self.deriv) dy_dx = deriv_spline(x_sorted) # Unsort if needed unsort_idx = jnp.argsort(sort_idx) dy_dx = dy_dx[unsort_idx] return dy_dx def _total_variation_derivative( self, x: JaxArray, y: JaxArray, alpha: float = 0.1 ) -> JaxArray: """Compute derivative with total variation regularization. This minimizes: ||y - integral(u)||² + α * TV(u) where u = dy/dx is the derivative. .. warning:: Currently uses finite differences with smoothing as an approximation, not true TV-regularised differentiation. Parameters ---------- x : jnp.ndarray Independent variable y : jnp.ndarray Dependent variable alpha : float Regularization parameter Returns ------- jnp.ndarray Derivative dy/dx """ # TODO: Implement proper TV-regularised differentiation (Chambolle's algorithm) warnings.warn( "The 'total_variation' method currently uses finite differences with " "smoothing as an approximation. For true TV-regularised differentiation, " "use an external solver.", stacklevel=2, ) dy_dx = self._finite_diff_derivative(x, y) # Apply TV denoising to the derivative # (Requires convex optimization - use simple smoothing for now) dy_dx = self._smooth_data(dy_dx, self.smooth_window) return dy_dx def _transform(self, data: RheoData) -> RheoData: """Compute smooth derivative of data. Parameters ---------- data : RheoData Input data Returns ------- RheoData Derivative data """ from rheojax.core.data import RheoData input_shape = (len(data.x),) if hasattr(data.x, "__len__") else (1,) # type: ignore[arg-type] with log_transform( logger, "smooth_derivative", input_shape=input_shape, method=self.method, derivative_order=self.deriv, window_length=self.window_length, ) as ctx: # Get data x = data.x y = data.y # Convert to JAX arrays if not isinstance(x, jnp.ndarray): x = jnp.array(x) if not isinstance(y, jnp.ndarray): y = jnp.array(y) logger.debug( "Processing derivative input", n_points=len(x), dtype=str(y.dtype), ) # Handle complex data if jnp.iscomplexobj(y): logger.debug("Taking real part of complex signal") y = jnp.real(y) # Pre-smoothing if requested if self.smooth_before: logger.debug("Applying pre-smoothing", window=self.smooth_window) y = self._smooth_data(y, self.smooth_window) # Compute derivative based on method logger.debug("Computing derivative", method=self.method) if self.method == "savgol": dy_dx = self._savgol_derivative(x, y) elif self.method == "finite_diff": dy_dx = self._finite_diff_derivative(x, y) elif self.method == "spline": dy_dx = self._spline_derivative(x, y) elif self.method == "total_variation": dy_dx = self._total_variation_derivative(x, y) else: logger.error("Unknown differentiation method", method=self.method) # type: ignore[unreachable] raise ValueError(f"Unknown method: {self.method}") # Post-smoothing if requested if self.smooth_after: logger.debug("Applying post-smoothing", window=self.smooth_window) dy_dx = self._smooth_data(dy_dx, self.smooth_window) # Create new y_units if data.y_units and data.x_units: if self.deriv == 1: new_y_units = f"d({data.y_units})/d({data.x_units})" else: new_y_units = ( f"d^{self.deriv}({data.y_units})/d({data.x_units})^{self.deriv}" ) else: new_y_units = f"derivative_order_{self.deriv}" # Create metadata new_metadata = (data.metadata or {}).copy() new_metadata.update( { "transform": "derivative", "method": self.method, "derivative_order": self.deriv, "window_length": self.window_length, "polyorder": self.polyorder, } ) ctx["output_shape"] = (len(x),) return RheoData( x=x, y=dy_dx, x_units=data.x_units, y_units=new_y_units, domain=data.domain, metadata=new_metadata, validate=False, ) def _inverse_transform(self, data: RheoData) -> RheoData: """Apply numerical integration (inverse of derivative). Parameters ---------- data : RheoData Derivative data Returns ------- RheoData Integrated data (approximation of original) """ from scipy.integrate import cumulative_trapezoid as scipy_cumtrapz from rheojax.core.data import RheoData logger.debug("Starting numerical integration (inverse derivative)") # Get data x = data.x dy_dx = data.y # Convert to numpy for scipy x_np = np.array(x) if isinstance(x, jnp.ndarray) else x dy_dx_np = np.array(dy_dx) if isinstance(dy_dx, jnp.ndarray) else dy_dx # Numerical integration (cumulative trapezoid) y_integrated = scipy_cumtrapz(dy_dx_np, x_np, initial=0) logger.debug( "Integration completed", n_points=len(y_integrated), ) # Create metadata new_metadata = (data.metadata or {}).copy() new_metadata.update( {"transform": "integral", "original_transform": "derivative"} ) return RheoData( x=x, y=jnp.array(y_integrated), x_units=data.x_units, y_units="integrated", domain=data.domain, metadata=new_metadata, validate=False, )
[docs] def estimate_noise_level(self, data: RheoData) -> float: """Estimate noise level in data. This uses the median absolute deviation (MAD) of the second derivative as a robust noise estimator. Parameters ---------- data : RheoData Input data Returns ------- float Estimated noise standard deviation """ # Get data x = data.x y = data.y # Convert to arrays if not isinstance(x, jnp.ndarray): x = jnp.array(x) if not isinstance(y, jnp.ndarray): y = jnp.array(y) # R7-DERIV-001: Need at least 3 points for second derivative estimation if len(x) < 3: logger.warning( "estimate_noise_level: need at least 3 data points, got %d", len(x), ) return 0.0 # Compute second derivative (amplifies noise) d2y = jnp.gradient(jnp.gradient(y, x), x) # MAD estimator median = jnp.median(d2y) mad = jnp.median(jnp.abs(d2y - median)) # Convert MAD to standard deviation (assumes Gaussian noise) sigma = 1.4826 * mad return float(sigma)
__all__ = ["SmoothDerivative"]