Source code for rheojax.core.data

"""RheoData class - JAX-native rheological data container.

This module provides the RheoData abstraction for rheological data that supports
both NumPy and JAX arrays with additional features for rheological analysis.
"""

from __future__ import annotations

import warnings
from dataclasses import InitVar, dataclass, field
from typing import TYPE_CHECKING, Any

import numpy as np

if TYPE_CHECKING:  # pragma: no cover - typing helper only
    import jax.numpy as jnp_typing

    from rheojax.core.test_modes import TestModeEnum
else:
    jnp_typing = np

from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
HAS_JAX = True

# Module-level logger
logger = get_logger(__name__)


type ArrayLike = np.ndarray | jnp_typing.ndarray | list | tuple


def _coerce_ndarray(data: ArrayLike | jnp_typing.ndarray | None) -> np.ndarray:
    """Convert any array-like input to a NumPy array for scalar ops."""
    if data is None:
        logger.error("Array data is None during conversion")
        raise ValueError("Array data must be initialized before conversion")
    if isinstance(data, np.ndarray):
        return data
    if HAS_JAX and isinstance(data, jnp.ndarray):
        logger.debug(
            "Converting JAX array to NumPy",
            from_type="jax.ndarray",
            to_type="np.ndarray",
        )
        return np.asarray(data)
    logger.debug(
        "Converting array-like to NumPy",
        from_type=type(data).__name__,
        to_type="np.ndarray",
    )
    return np.asarray(data)


[docs] @dataclass class RheoData: """JAX-native container for rheological data with NumPy/JAX array support. This class provides a unified interface for rheological data that supports both NumPy and JAX arrays with additional features needed for rheological analysis including automatic test mode detection, data validation, and domain-specific operations. Attributes: x: Independent variable data (e.g., time, frequency) y: Dependent variable data (e.g., stress, strain, modulus) x_units: Units for x-axis data y_units: Units for y-axis data domain: Data domain ('time' or 'frequency') metadata: Dictionary of additional metadata validate: Whether to validate data on creation """ x: ArrayLike | None = None y: ArrayLike | None = None x_units: str | None = None y_units: str | None = None domain: str = "time" # Optional explicit test mode passed during initialization (e.g., relaxation/creep/oscillation) initial_test_mode: InitVar[str | None] = None metadata: dict[str, Any] = field(default_factory=dict) validate: bool = True _explicit_test_mode: str | None = field(default=None, repr=False, init=False) _jax_cache: RheoData | None = field(default=None, repr=False, init=False)
[docs] def __post_init__(self, initial_test_mode: str | None): """Initialize and validate RheoData.""" logger.debug( "Creating RheoData", domain=self.domain, test_mode=initial_test_mode, validate=self.validate, ) # Normalize metadata container (defensive — callers may pass None explicitly) if self.metadata is None: self.metadata = {} # Persist explicitly provided test mode into metadata and internal cache if initial_test_mode is not None: self._explicit_test_mode = initial_test_mode self.metadata["test_mode"] = initial_test_mode self.metadata["detected_test_mode"] = initial_test_mode elif self.metadata and "test_mode" in self.metadata: self._explicit_test_mode = self.metadata.get("test_mode") # R12-B-012: also populate "detected_test_mode" in the elif path so # that callers which read metadata["detected_test_mode"] (e.g. test # assertions and the GUI) see the same value regardless of whether # the mode was set at construction time or pre-populated in metadata. self.metadata["detected_test_mode"] = self._explicit_test_mode if self.x is None or self.y is None: logger.error("x and y data must be provided") raise ValueError("x and y data must be provided") # Convert to arrays self.x = self._ensure_array(self.x) self.y = self._ensure_array(self.y) x_array = _coerce_ndarray(self.x) y_array = _coerce_ndarray(self.y) # Log creation details after array conversion logger.debug( "RheoData arrays created", x_shape=x_array.shape, y_shape=y_array.shape, x_dtype=str(x_array.dtype), y_dtype=str(y_array.dtype), ) # Validate shapes — allow (N,) x with (N,2) y for DMTA/GMM complex data if x_array.shape[0] != y_array.shape[0]: logger.error( "Shape mismatch between x and y data", x_shape=x_array.shape, y_shape=y_array.shape, ) raise ValueError( f"x and y must have the same first dimension. " f"Got x: {x_array.shape}, y: {y_array.shape}" ) # Validate data if requested if self.validate: self._validate_data()
[docs] def __setattr__(self, name: str, value: object) -> None: """Invalidate JAX cache when x or y data is reassigned.""" super().__setattr__(name, value) if name in ("x", "y") and hasattr(self, "_jax_cache"): super().__setattr__("_jax_cache", None)
def _ensure_array(self, data: ArrayLike) -> np.ndarray: """Ensure data is a proper array.""" if isinstance(data, (np.ndarray, jnp.ndarray)): return data elif isinstance(data, (list, tuple)): logger.debug( "Converting list/tuple to array", from_type=type(data).__name__ ) return np.array(data) else: logger.debug("Converting to array", from_type=type(data).__name__) return np.array(data) def _validate_data(self): """Validate data for common issues.""" logger.debug( "Validating data", checks=["nan", "finite", "monotonic", "negative_frequency"], ) # Check for NaN values first (NaN is also non-finite) if isinstance(self.x, np.ndarray): if np.any(np.isnan(self.x)): logger.error("x data contains NaN values") raise ValueError("x data contains NaN values") if not np.all(np.isfinite(self.x)): logger.error("x data contains non-finite values") raise ValueError("x data contains non-finite values") elif isinstance(self.x, jnp.ndarray): if bool(jnp.any(jnp.isnan(self.x))): logger.error("x data contains NaN values") raise ValueError("x data contains NaN values") if not bool(jnp.all(jnp.isfinite(self.x))): logger.error("x data contains non-finite values") raise ValueError("x data contains non-finite values") if isinstance(self.y, np.ndarray): # Note: For complex arrays, np.isnan() returns True if EITHER # real or imaginary part is NaN. This is intentional — partial # NaN in complex modulus data is not physically meaningful. if np.any(np.isnan(self.y)): logger.error("y data contains NaN values") raise ValueError("y data contains NaN values") if not np.all(np.isfinite(self.y)): logger.error("y data contains non-finite values") raise ValueError("y data contains non-finite values") elif isinstance(self.y, jnp.ndarray): if bool(jnp.any(jnp.isnan(self.y))): logger.error("y data contains NaN values") raise ValueError("y data contains NaN values") if not bool(jnp.all(jnp.isfinite(self.y))): logger.error("y data contains non-finite values") raise ValueError("y data contains non-finite values") # Check for monotonic x-axis if len(self.x) > 1: if isinstance(self.x, np.ndarray): diffs = np.diff(self.x) if not (np.all(diffs > 0) or np.all(diffs < 0)): logger.debug("x-axis is not monotonic") warnings.warn("x-axis is not monotonic", UserWarning, stacklevel=2) elif isinstance(self.x, jnp.ndarray): diffs = jnp.diff(self.x) is_increasing = bool(jnp.all(diffs > 0)) is_decreasing = bool(jnp.all(diffs < 0)) if not (is_increasing or is_decreasing): logger.debug("x-axis is not monotonic") warnings.warn("x-axis is not monotonic", UserWarning, stacklevel=2) # Check for negative values in frequency domain if self.domain == "frequency": if isinstance(self.y, np.ndarray): if np.any(np.real(self.y) < 0): logger.debug("y data contains negative values in frequency domain") warnings.warn( "y data contains negative values in frequency domain", UserWarning, stacklevel=2, ) elif isinstance(self.y, jnp.ndarray): if bool(jnp.any(jnp.real(self.y) < 0)): logger.debug("y data contains negative values in frequency domain") warnings.warn( "y data contains negative values in frequency domain", UserWarning, stacklevel=2, ) logger.debug("Data validation completed successfully")
[docs] def to_jax(self) -> RheoData: """Convert arrays to JAX arrays. Returns cached result on subsequent calls — invalidated if x or y are reassigned. Returns: New RheoData with JAX arrays """ if self._jax_cache is not None: return self._jax_cache logger.debug( "Converting RheoData to JAX arrays", from_type="numpy", to_type="jax" ) y_dtype = jnp.complex128 if np.iscomplexobj(self.y) else jnp.float64 result = RheoData( x=jnp.array(self.x, dtype=jnp.float64), y=jnp.array(self.y, dtype=y_dtype), x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, ) self._jax_cache = result return result
[docs] def to_numpy(self) -> RheoData: """Convert arrays to NumPy arrays. Uses np.asarray() for zero-copy conversion when possible, providing 10-30% memory savings for large arrays (>100k points). Returns: New RheoData with NumPy arrays """ logger.debug( "Converting RheoData to NumPy arrays", from_type="jax", to_type="numpy" ) # Use asarray for zero-copy when array is already NumPy-compatible # Preserve dtype (handles both float64 and complex128) x_np = np.asarray(self.x) y_np = np.asarray(self.y) return RheoData( x=x_np, y=y_np, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
[docs] def copy(self) -> RheoData: """Create a copy of the RheoData. Returns: Copy of the RheoData instance """ logger.debug("Creating copy of RheoData") return RheoData( x=self.x.copy() if hasattr(self.x, "copy") else self.x, y=self.y.copy() if hasattr(self.y, "copy") else self.y, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
[docs] def update_metadata(self, metadata: dict[str, Any]): """Update metadata dictionary. Args: metadata: Dictionary of metadata to add/update """ logger.debug("Updating metadata", keys=list(metadata.keys())) # R10-DATA-001: Invalidate auto-detection cache when test_mode is updated. # Otherwise get_test_mode() continues returning the stale detected value # even after the caller explicitly sets metadata["test_mode"]. if "test_mode" in metadata: self._detected_test_mode: TestModeEnum | None = None self._explicit_test_mode = metadata["test_mode"] # R11-DATA-001: sync self.metadata.pop("detected_test_mode", None) # R11-DATA-002: clear stale self.metadata.update(metadata) # Invalidate JAX cache since metadata snapshot is now stale if hasattr(self, "_jax_cache"): super().__setattr__("_jax_cache", None)
[docs] def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation. Returns: Dictionary with data and metadata """ logger.debug("Converting RheoData to dictionary") x_data = self.x.tolist() if hasattr(self.x, "tolist") else list(self.x) y_arr = np.asarray(self.y) data_dict: dict[str, Any] = { "x": x_data, "x_units": self.x_units, "y_units": self.y_units, "domain": self.domain, "metadata": self.metadata, } if np.iscomplexobj(y_arr): data_dict["y_real"] = np.real(y_arr).tolist() data_dict["y_imag"] = np.imag(y_arr).tolist() else: data_dict["y"] = y_arr.tolist() if self._explicit_test_mode is not None: data_dict["test_mode"] = self._explicit_test_mode return data_dict
[docs] @classmethod def from_dict(cls, data_dict: dict[str, Any]) -> RheoData: """Create from dictionary representation. Args: data_dict: Dictionary with data and metadata Returns: RheoData instance """ logger.debug("Creating RheoData from dictionary") metadata = data_dict.get("metadata", {}) or {} test_mode = data_dict.get("test_mode") if "y_real" in data_dict and "y_imag" in data_dict: y = np.array(data_dict["y_real"]) + 1j * np.array(data_dict["y_imag"]) else: y = np.array(data_dict["y"]) return cls( x=np.array(data_dict["x"]), y=y, x_units=data_dict.get("x_units"), y_units=data_dict.get("y_units"), domain=data_dict.get("domain", "time"), metadata=dict(metadata), initial_test_mode=test_mode, validate=False, )
# NumPy-like interface @property def shape(self) -> tuple: """Shape of the y data.""" return _coerce_ndarray(self.y).shape @property def ndim(self) -> int: """Number of dimensions of y data.""" return _coerce_ndarray(self.y).ndim @property def size(self) -> int: """Size of y data.""" return int(_coerce_ndarray(self.y).size) @property def dtype(self): """Data type of y data.""" return _coerce_ndarray(self.y).dtype @property def is_complex(self) -> bool: """Check if y data is complex.""" return np.iscomplexobj(_coerce_ndarray(self.y)) @property def modulus(self) -> np.ndarray | None: """Get modulus of complex data.""" if self.is_complex: return np.abs(self.y) return None @property def phase(self) -> np.ndarray | None: """Get phase of complex data.""" if self.is_complex: return np.angle(self.y) return None @property def y_real(self) -> np.ndarray | jnp_typing.ndarray: """Get real component of y data. For complex modulus data (G* = G' + i·G''), this returns the storage modulus (G'). For real data, returns y unchanged. Returns: Real component of y data (G' for complex modulus) Example: >>> data = read_trios('frequency_sweep.txt') # Returns complex G* >>> G_prime = data[0].y_real # Storage modulus (G') >>> plt.loglog(data[0].x, G_prime, label="G'") """ if self.is_complex: if isinstance(self.y, jnp.ndarray): return jnp.real(self.y) return np.real(self.y) return self.y @property def y_imag(self) -> np.ndarray | jnp_typing.ndarray: """Get imaginary component of y data. For complex modulus data (G* = G' + i·G''), this returns the loss modulus (G''). For real data, returns zeros. Returns: Imaginary component of y data (G'' for complex modulus) Example: >>> data = read_trios('frequency_sweep.txt') # Returns complex G* >>> G_double_prime = data[0].y_imag # Loss modulus (G'') >>> plt.loglog(data[0].x, G_double_prime, label='G"') """ if self.is_complex: if isinstance(self.y, jnp.ndarray): return jnp.imag(self.y) return np.imag(self.y) if isinstance(self.y, jnp.ndarray): return jnp.zeros_like(self.y) return np.zeros_like(self.y) @property def storage_modulus(self) -> np.ndarray | None: """Get storage modulus (G') from complex modulus data. Alias for y_real that makes rheological intent explicit. Returns: Storage modulus (G') if data is complex, None otherwise Example: >>> data = read_trios('frequency_sweep.txt') >>> G_prime = data[0].storage_modulus """ if self.is_complex: return self.y_real return None @property def loss_modulus(self) -> np.ndarray | None: """Get loss modulus (G'') from complex modulus data. Alias for y_imag that makes rheological intent explicit. Returns: Loss modulus (G'') if data is complex, None otherwise Example: >>> data = read_trios('frequency_sweep.txt') >>> G_double_prime = data[0].loss_modulus """ if self.is_complex: return self.y_imag return None @property def tan_delta(self) -> np.ndarray | None: """Get loss tangent (tan δ = G''/G') from complex modulus data. The loss tangent quantifies the ratio of viscous to elastic response: - tan δ < 1: Elastic-dominant (solid-like) - tan δ > 1: Viscous-dominant (liquid-like) - tan δ = 1: Equal elastic and viscous contributions Returns: Loss tangent (dimensionless) if data is complex, None otherwise Example: >>> data = read_trios('frequency_sweep.txt') >>> tan_d = data[0].tan_delta >>> print(f"Material type: {'solid-like' if tan_d.mean() < 1 else 'liquid-like'}") """ if self.is_complex: G_prime = self.y_real G_double_prime = self.y_imag # Avoid division by zero if isinstance(G_prime, jnp.ndarray): return jnp.where(G_prime > 0, G_double_prime / G_prime, jnp.nan) return np.where(G_prime > 0, G_double_prime / G_prime, np.nan) return None @property def test_mode(self) -> str: """Automatically detect or retrieve test mode. The test mode is detected based on data characteristics and cached in a private field. If already detected, returns the cached value. If explicitly set in metadata['test_mode'], returns that value. Returns: Test mode string (relaxation, creep, oscillation, rotation, unknown) """ # Prefer explicitly provided test mode if self._explicit_test_mode is not None: return self._explicit_test_mode # R8-DATA-001: check private cache first, avoid shared metadata dict _cached = getattr(self, "_detected_test_mode", None) if _cached is not None: return _cached # Check if already set in metadata (explicit or previously detected) if "test_mode" in self.metadata: _raw = self.metadata["test_mode"] try: from rheojax.core.test_modes import TestMode return TestMode(_raw.lower()).value if isinstance(_raw, str) else _raw except (ValueError, AttributeError): return _raw # Lazy import to avoid circular dependency from rheojax.core.test_modes import detect_test_mode # Detect test mode logger.debug("Detecting test mode from data characteristics") mode = detect_test_mode(self) # R8-DATA-001: cache in private field AND metadata for observability. # Metadata cache allows downstream code to check detected mode without # triggering re-detection, consistent with __post_init__ behavior. self._detected_test_mode = mode self.metadata["detected_test_mode"] = mode logger.debug("Test mode detected", test_mode=mode) return mode @property def deformation_mode(self) -> str: """Get deformation mode from metadata. Returns 'shear' if not explicitly set. Possible values: 'shear', 'tension', 'bending', 'compression'. """ return self.metadata.get("deformation_mode", "shear") @property def storage_modulus_label(self) -> str: """Get appropriate storage modulus label based on deformation mode. Returns "E'" for tensile/bending/compression, "G'" for shear. """ from rheojax.core.test_modes import DeformationMode try: dm = DeformationMode(self.deformation_mode) return "E'" if dm.is_tensile() else "G'" except ValueError: return "G'" @property def loss_modulus_label(self) -> str: """Get appropriate loss modulus label based on deformation mode. Returns 'E"' for tensile/bending/compression, 'G"' for shear. """ from rheojax.core.test_modes import DeformationMode try: dm = DeformationMode(self.deformation_mode) return 'E"' if dm.is_tensile() else 'G"' except ValueError: return 'G"'
[docs] def __getitem__(self, idx): """Support indexing and slicing.""" if isinstance(idx, (int, np.integer)): return (self.x[idx], self.y[idx]) else: logger.debug("Slicing RheoData", index_type=type(idx).__name__) return RheoData( x=self.x[idx], y=self.y[idx], x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
[docs] def __add__(self, other): """Add two RheoData objects or scalar.""" if isinstance(other, RheoData): try: _axes_equal = np.array_equal(np.asarray(self.x), np.asarray(other.x)) except Exception: _axes_equal = bool(jnp.all(self.x == other.x)) if not _axes_equal: logger.error("x-axes must match for addition") raise ValueError("x-axes must match for addition") return RheoData( x=self.x, y=self.y + other.y, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, ) else: return RheoData( x=self.x, y=self.y + other, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
[docs] def __sub__(self, other): """Subtract two RheoData objects or scalar.""" if isinstance(other, RheoData): try: _axes_equal = np.array_equal(np.asarray(self.x), np.asarray(other.x)) except Exception: _axes_equal = bool(jnp.all(self.x == other.x)) if not _axes_equal: logger.error("x-axes must match for subtraction") raise ValueError("x-axes must match for subtraction") return RheoData( x=self.x, y=self.y - other.y, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, ) else: return RheoData( x=self.x, y=self.y - other, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
[docs] def __mul__(self, other): """Multiply by scalar or another RheoData.""" if isinstance(other, RheoData): try: _axes_equal = np.array_equal(np.asarray(self.x), np.asarray(other.x)) except Exception: _axes_equal = bool(jnp.all(self.x == other.x)) if not _axes_equal: logger.error("x-axes must match for multiplication") raise ValueError("x-axes must match for multiplication") y_result = self.y * other.y else: y_result = self.y * other return RheoData( x=self.x, y=y_result, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
# Data operations
[docs] def interpolate(self, new_x: ArrayLike) -> RheoData: """Interpolate data to new x values. Args: new_x: New x values for interpolation Returns: Interpolated RheoData """ logger.debug( "Interpolating data", n_new_points=len(new_x) if hasattr(new_x, "__len__") else 1, ) new_x = self._ensure_array(new_x) if np.iscomplexobj(self.y): # Complex data: interpolate real and imaginary parts separately. # jnp.interp and np.interp do not support complex arrays — the # imaginary part would be silently discarded. if isinstance(self.x, jnp.ndarray): new_y_real = jnp.interp(new_x, self.x, jnp.real(self.y)) new_y_imag = jnp.interp(new_x, self.x, jnp.imag(self.y)) else: new_y_real = np.interp(new_x, self.x, np.real(self.y)) new_y_imag = np.interp(new_x, self.x, np.imag(self.y)) new_y = new_y_real + 1j * new_y_imag elif isinstance(self.x, jnp.ndarray) or isinstance(self.y, jnp.ndarray): # Use JAX interpolation new_y = jnp.interp(new_x, self.x, self.y) else: # Use NumPy interpolation new_y = np.interp(new_x, self.x, self.y) return RheoData( x=new_x, y=new_y, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
[docs] def resample(self, n_points: int) -> RheoData: """Resample data to specified number of points. Args: n_points: Number of points to resample to Returns: Resampled RheoData """ logger.debug("Resampling data", n_points=n_points, domain=self.domain) x_array = _coerce_ndarray(self.x) if self.domain == "frequency": if x_array.min() <= 0: raise ValueError( f"Cannot resample in log-space: x contains non-positive values " f"(min={float(x_array.min()):.3g}). Ensure all x > 0 for frequency domain." ) new_x = np.logspace( np.log10(x_array.min()), np.log10(x_array.max()), n_points ) else: # Linear-spaced for time domain new_x = np.linspace(x_array.min(), x_array.max(), n_points) return self.interpolate(new_x)
[docs] def smooth(self, window_size: int = 5) -> RheoData: """Smooth data using moving average. Args: window_size: Size of smoothing window Returns: Smoothed RheoData """ if window_size % 2 == 0: window_size += 1 # Make odd for symmetric window logger.debug("Smoothing data", window_size=window_size) # Simple moving average kernel = np.ones(window_size) / window_size if np.iscomplexobj(self.y): # Complex data: convolve real and imaginary parts separately. # jnp.convolve and np.convolve may not handle complex arrays correctly. if isinstance(self.y, jnp.ndarray): smoothed_real = jnp.convolve(jnp.real(self.y), kernel, mode="same") smoothed_imag = jnp.convolve(jnp.imag(self.y), kernel, mode="same") else: smoothed_real = np.convolve(np.real(self.y), kernel, mode="same") smoothed_imag = np.convolve(np.imag(self.y), kernel, mode="same") smoothed_y = smoothed_real + 1j * smoothed_imag elif isinstance(self.y, jnp.ndarray): # Use JAX convolution smoothed_y = jnp.convolve(self.y, kernel, mode="same") else: # Use NumPy convolution smoothed_y = np.convolve(self.y, kernel, mode="same") return RheoData( x=self.x, y=smoothed_y, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
[docs] def derivative(self) -> RheoData: """Compute numerical derivative. Returns: RheoData with derivative values """ logger.debug("Computing numerical derivative") if np.iscomplexobj(self.y): if isinstance(self.y, jnp.ndarray): dy_dx = jnp.gradient(jnp.real(self.y), self.x) + 1j * jnp.gradient( jnp.imag(self.y), self.x ) else: dy_dx = np.gradient(np.real(self.y), self.x) + 1j * np.gradient( np.imag(self.y), self.x ) elif isinstance(self.x, jnp.ndarray) or isinstance(self.y, jnp.ndarray): dy_dx = jnp.gradient(self.y, self.x) else: dy_dx = np.gradient(self.y, self.x) return RheoData( x=self.x, y=dy_dx, x_units=self.x_units, y_units=( f"d({self.y_units})/d({self.x_units})" if self.y_units and self.x_units else None ), domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
[docs] def integral(self) -> RheoData: """Compute numerical integral. Returns: RheoData with integrated values """ logger.debug("Computing numerical integral") if isinstance(self.x, jnp.ndarray) or isinstance(self.y, jnp.ndarray): # JAX has no cumulative_trapezoid; compute manually via # trapezoidal rule: I[0]=0, I[k] = I[k-1] + (y[k-1]+y[k])/2 * dx[k] dx = jnp.diff(self.x) avg_y = (self.y[:-1] + self.y[1:]) / 2.0 # type: ignore[operator] integrated = jnp.concatenate( [jnp.zeros(1, dtype=self.y.dtype), jnp.cumsum(avg_y * dx)] # type: ignore[union-attr] ) else: # Use NumPy/SciPy cumulative trapezoid from scipy.integrate import cumulative_trapezoid integrated = cumulative_trapezoid(self.y, self.x, initial=0) return RheoData( x=self.x, y=integrated, x_units=self.x_units, y_units=( f"∫{self.y_units}·d{self.x_units}" if self.y_units and self.x_units else None ), domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, )
# Domain conversion placeholders
[docs] def to_frequency_domain(self) -> RheoData: """Convert time domain data to frequency domain. Returns: Frequency domain RheoData """ if self.domain != "time": logger.debug("Data is already in frequency domain") warnings.warn( "Data is already in frequency domain", UserWarning, stacklevel=2 ) return self.copy() logger.error("Frequency domain conversion not yet implemented") # This would use FFT transform when implemented raise NotImplementedError( "Frequency domain conversion will be implemented with transforms" )
[docs] def to_time_domain(self) -> RheoData: """Convert frequency domain data to time domain. Returns: Time domain RheoData """ if self.domain != "frequency": logger.debug("Data is already in time domain") warnings.warn("Data is already in time domain", UserWarning, stacklevel=2) return self.copy() logger.error("Time domain conversion not yet implemented") # This would use inverse FFT transform when implemented raise NotImplementedError( "Time domain conversion will be implemented with transforms" )
# Data slicing methods
[docs] def slice(self, start: float | None = None, end: float | None = None) -> RheoData: """Slice data between x values. Args: start: Start x value end: End x value Returns: Sliced RheoData """ logger.debug("Slicing data by x range", start=start, end=end) # Use np.asarray only for mask computation — preserves JAX or NumPy array type x_np = np.asarray(self.x) mask = np.ones_like(x_np, dtype=bool) if start is not None: mask &= x_np >= start if end is not None: mask &= x_np <= end sliced_x = self.x[mask] sliced_y = self.y[mask] logger.debug("Sliced data", original_size=len(x_np), new_size=len(sliced_x)) result = RheoData( x=sliced_x, y=sliced_y, x_units=self.x_units, y_units=self.y_units, domain=self.domain, initial_test_mode=self._explicit_test_mode, metadata=self.metadata.copy(), validate=False, ) if len(result.x) == 0: import warnings as _warnings _warnings.warn( f"RheoData.slice(start={start}, end={end}) produced an empty result.", stacklevel=2, ) return result