"""RheoData class - JAX-native rheological data container.
This module provides the RheoData abstraction for rheological data that supports
both NumPy and JAX arrays with additional features for rheological analysis.
"""
from __future__ import annotations
import warnings
from dataclasses import InitVar, dataclass, field
from typing import TYPE_CHECKING, Any
import numpy as np
if TYPE_CHECKING: # pragma: no cover - typing helper only
import jax.numpy as jnp_typing
from rheojax.core.test_modes import TestModeEnum
else:
jnp_typing = np
from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
HAS_JAX = True
# Module-level logger
logger = get_logger(__name__)
type ArrayLike = np.ndarray | jnp_typing.ndarray | list | tuple
def _coerce_ndarray(data: ArrayLike | jnp_typing.ndarray | None) -> np.ndarray:
"""Convert any array-like input to a NumPy array for scalar ops."""
if data is None:
logger.error("Array data is None during conversion")
raise ValueError("Array data must be initialized before conversion")
if isinstance(data, np.ndarray):
return data
if HAS_JAX and isinstance(data, jnp.ndarray):
logger.debug(
"Converting JAX array to NumPy",
from_type="jax.ndarray",
to_type="np.ndarray",
)
return np.asarray(data)
logger.debug(
"Converting array-like to NumPy",
from_type=type(data).__name__,
to_type="np.ndarray",
)
return np.asarray(data)
[docs]
@dataclass
class RheoData:
"""JAX-native container for rheological data with NumPy/JAX array support.
This class provides a unified interface for rheological data that supports
both NumPy and JAX arrays with additional features needed for rheological
analysis including automatic test mode detection, data validation, and
domain-specific operations.
Attributes:
x: Independent variable data (e.g., time, frequency)
y: Dependent variable data (e.g., stress, strain, modulus)
x_units: Units for x-axis data
y_units: Units for y-axis data
domain: Data domain ('time' or 'frequency')
metadata: Dictionary of additional metadata
validate: Whether to validate data on creation
"""
x: ArrayLike | None = None
y: ArrayLike | None = None
x_units: str | None = None
y_units: str | None = None
domain: str = "time"
# Optional explicit test mode passed during initialization (e.g., relaxation/creep/oscillation)
initial_test_mode: InitVar[str | None] = None
metadata: dict[str, Any] = field(default_factory=dict)
validate: bool = True
_explicit_test_mode: str | None = field(default=None, repr=False, init=False)
_jax_cache: RheoData | None = field(default=None, repr=False, init=False)
[docs]
def __post_init__(self, initial_test_mode: str | None):
"""Initialize and validate RheoData."""
logger.debug(
"Creating RheoData",
domain=self.domain,
test_mode=initial_test_mode,
validate=self.validate,
)
# Normalize metadata container (defensive — callers may pass None explicitly)
if self.metadata is None:
self.metadata = {}
# Persist explicitly provided test mode into metadata and internal cache
if initial_test_mode is not None:
self._explicit_test_mode = initial_test_mode
self.metadata["test_mode"] = initial_test_mode
self.metadata["detected_test_mode"] = initial_test_mode
elif self.metadata and "test_mode" in self.metadata:
self._explicit_test_mode = self.metadata.get("test_mode")
# R12-B-012: also populate "detected_test_mode" in the elif path so
# that callers which read metadata["detected_test_mode"] (e.g. test
# assertions and the GUI) see the same value regardless of whether
# the mode was set at construction time or pre-populated in metadata.
self.metadata["detected_test_mode"] = self._explicit_test_mode
if self.x is None or self.y is None:
logger.error("x and y data must be provided")
raise ValueError("x and y data must be provided")
# Convert to arrays
self.x = self._ensure_array(self.x)
self.y = self._ensure_array(self.y)
x_array = _coerce_ndarray(self.x)
y_array = _coerce_ndarray(self.y)
# Log creation details after array conversion
logger.debug(
"RheoData arrays created",
x_shape=x_array.shape,
y_shape=y_array.shape,
x_dtype=str(x_array.dtype),
y_dtype=str(y_array.dtype),
)
# Validate shapes — allow (N,) x with (N,2) y for DMTA/GMM complex data
if x_array.shape[0] != y_array.shape[0]:
logger.error(
"Shape mismatch between x and y data",
x_shape=x_array.shape,
y_shape=y_array.shape,
)
raise ValueError(
f"x and y must have the same first dimension. "
f"Got x: {x_array.shape}, y: {y_array.shape}"
)
# Validate data if requested
if self.validate:
self._validate_data()
[docs]
def __setattr__(self, name: str, value: object) -> None:
"""Invalidate JAX cache when x or y data is reassigned."""
super().__setattr__(name, value)
if name in ("x", "y") and hasattr(self, "_jax_cache"):
super().__setattr__("_jax_cache", None)
def _ensure_array(self, data: ArrayLike) -> np.ndarray:
"""Ensure data is a proper array."""
if isinstance(data, (np.ndarray, jnp.ndarray)):
return data
elif isinstance(data, (list, tuple)):
logger.debug(
"Converting list/tuple to array", from_type=type(data).__name__
)
return np.array(data)
else:
logger.debug("Converting to array", from_type=type(data).__name__)
return np.array(data)
def _validate_data(self):
"""Validate data for common issues."""
logger.debug(
"Validating data",
checks=["nan", "finite", "monotonic", "negative_frequency"],
)
# Check for NaN values first (NaN is also non-finite)
if isinstance(self.x, np.ndarray):
if np.any(np.isnan(self.x)):
logger.error("x data contains NaN values")
raise ValueError("x data contains NaN values")
if not np.all(np.isfinite(self.x)):
logger.error("x data contains non-finite values")
raise ValueError("x data contains non-finite values")
elif isinstance(self.x, jnp.ndarray):
if bool(jnp.any(jnp.isnan(self.x))):
logger.error("x data contains NaN values")
raise ValueError("x data contains NaN values")
if not bool(jnp.all(jnp.isfinite(self.x))):
logger.error("x data contains non-finite values")
raise ValueError("x data contains non-finite values")
if isinstance(self.y, np.ndarray):
# Note: For complex arrays, np.isnan() returns True if EITHER
# real or imaginary part is NaN. This is intentional — partial
# NaN in complex modulus data is not physically meaningful.
if np.any(np.isnan(self.y)):
logger.error("y data contains NaN values")
raise ValueError("y data contains NaN values")
if not np.all(np.isfinite(self.y)):
logger.error("y data contains non-finite values")
raise ValueError("y data contains non-finite values")
elif isinstance(self.y, jnp.ndarray):
if bool(jnp.any(jnp.isnan(self.y))):
logger.error("y data contains NaN values")
raise ValueError("y data contains NaN values")
if not bool(jnp.all(jnp.isfinite(self.y))):
logger.error("y data contains non-finite values")
raise ValueError("y data contains non-finite values")
# Check for monotonic x-axis
if len(self.x) > 1:
if isinstance(self.x, np.ndarray):
diffs = np.diff(self.x)
if not (np.all(diffs > 0) or np.all(diffs < 0)):
logger.debug("x-axis is not monotonic")
warnings.warn("x-axis is not monotonic", UserWarning, stacklevel=2)
elif isinstance(self.x, jnp.ndarray):
diffs = jnp.diff(self.x)
is_increasing = bool(jnp.all(diffs > 0))
is_decreasing = bool(jnp.all(diffs < 0))
if not (is_increasing or is_decreasing):
logger.debug("x-axis is not monotonic")
warnings.warn("x-axis is not monotonic", UserWarning, stacklevel=2)
# Check for negative values in frequency domain
if self.domain == "frequency":
if isinstance(self.y, np.ndarray):
if np.any(np.real(self.y) < 0):
logger.debug("y data contains negative values in frequency domain")
warnings.warn(
"y data contains negative values in frequency domain",
UserWarning,
stacklevel=2,
)
elif isinstance(self.y, jnp.ndarray):
if bool(jnp.any(jnp.real(self.y) < 0)):
logger.debug("y data contains negative values in frequency domain")
warnings.warn(
"y data contains negative values in frequency domain",
UserWarning,
stacklevel=2,
)
logger.debug("Data validation completed successfully")
[docs]
def to_jax(self) -> RheoData:
"""Convert arrays to JAX arrays.
Returns cached result on subsequent calls — invalidated if x or y
are reassigned.
Returns:
New RheoData with JAX arrays
"""
if self._jax_cache is not None:
return self._jax_cache
logger.debug(
"Converting RheoData to JAX arrays", from_type="numpy", to_type="jax"
)
y_dtype = jnp.complex128 if np.iscomplexobj(self.y) else jnp.float64
result = RheoData(
x=jnp.array(self.x, dtype=jnp.float64),
y=jnp.array(self.y, dtype=y_dtype),
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
self._jax_cache = result
return result
[docs]
def to_numpy(self) -> RheoData:
"""Convert arrays to NumPy arrays.
Uses np.asarray() for zero-copy conversion when possible, providing
10-30% memory savings for large arrays (>100k points).
Returns:
New RheoData with NumPy arrays
"""
logger.debug(
"Converting RheoData to NumPy arrays", from_type="jax", to_type="numpy"
)
# Use asarray for zero-copy when array is already NumPy-compatible
# Preserve dtype (handles both float64 and complex128)
x_np = np.asarray(self.x)
y_np = np.asarray(self.y)
return RheoData(
x=x_np,
y=y_np,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
[docs]
def copy(self) -> RheoData:
"""Create a copy of the RheoData.
Returns:
Copy of the RheoData instance
"""
logger.debug("Creating copy of RheoData")
return RheoData(
x=self.x.copy() if hasattr(self.x, "copy") else self.x,
y=self.y.copy() if hasattr(self.y, "copy") else self.y,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary representation.
Returns:
Dictionary with data and metadata
"""
logger.debug("Converting RheoData to dictionary")
x_data = self.x.tolist() if hasattr(self.x, "tolist") else list(self.x)
y_arr = np.asarray(self.y)
data_dict: dict[str, Any] = {
"x": x_data,
"x_units": self.x_units,
"y_units": self.y_units,
"domain": self.domain,
"metadata": self.metadata,
}
if np.iscomplexobj(y_arr):
data_dict["y_real"] = np.real(y_arr).tolist()
data_dict["y_imag"] = np.imag(y_arr).tolist()
else:
data_dict["y"] = y_arr.tolist()
if self._explicit_test_mode is not None:
data_dict["test_mode"] = self._explicit_test_mode
return data_dict
[docs]
@classmethod
def from_dict(cls, data_dict: dict[str, Any]) -> RheoData:
"""Create from dictionary representation.
Args:
data_dict: Dictionary with data and metadata
Returns:
RheoData instance
"""
logger.debug("Creating RheoData from dictionary")
metadata = data_dict.get("metadata", {}) or {}
test_mode = data_dict.get("test_mode")
if "y_real" in data_dict and "y_imag" in data_dict:
y = np.array(data_dict["y_real"]) + 1j * np.array(data_dict["y_imag"])
else:
y = np.array(data_dict["y"])
return cls(
x=np.array(data_dict["x"]),
y=y,
x_units=data_dict.get("x_units"),
y_units=data_dict.get("y_units"),
domain=data_dict.get("domain", "time"),
metadata=dict(metadata),
initial_test_mode=test_mode,
validate=False,
)
# NumPy-like interface
@property
def shape(self) -> tuple:
"""Shape of the y data."""
return _coerce_ndarray(self.y).shape
@property
def ndim(self) -> int:
"""Number of dimensions of y data."""
return _coerce_ndarray(self.y).ndim
@property
def size(self) -> int:
"""Size of y data."""
return int(_coerce_ndarray(self.y).size)
@property
def dtype(self):
"""Data type of y data."""
return _coerce_ndarray(self.y).dtype
@property
def is_complex(self) -> bool:
"""Check if y data is complex."""
return np.iscomplexobj(_coerce_ndarray(self.y))
@property
def modulus(self) -> np.ndarray | None:
"""Get modulus of complex data."""
if self.is_complex:
return np.abs(self.y)
return None
@property
def phase(self) -> np.ndarray | None:
"""Get phase of complex data."""
if self.is_complex:
return np.angle(self.y)
return None
@property
def y_real(self) -> np.ndarray | jnp_typing.ndarray:
"""Get real component of y data.
For complex modulus data (G* = G' + i·G''), this returns the storage
modulus (G'). For real data, returns y unchanged.
Returns:
Real component of y data (G' for complex modulus)
Example:
>>> data = read_trios('frequency_sweep.txt') # Returns complex G*
>>> G_prime = data[0].y_real # Storage modulus (G')
>>> plt.loglog(data[0].x, G_prime, label="G'")
"""
if self.is_complex:
if isinstance(self.y, jnp.ndarray):
return jnp.real(self.y)
return np.real(self.y)
return self.y
@property
def y_imag(self) -> np.ndarray | jnp_typing.ndarray:
"""Get imaginary component of y data.
For complex modulus data (G* = G' + i·G''), this returns the loss
modulus (G''). For real data, returns zeros.
Returns:
Imaginary component of y data (G'' for complex modulus)
Example:
>>> data = read_trios('frequency_sweep.txt') # Returns complex G*
>>> G_double_prime = data[0].y_imag # Loss modulus (G'')
>>> plt.loglog(data[0].x, G_double_prime, label='G"')
"""
if self.is_complex:
if isinstance(self.y, jnp.ndarray):
return jnp.imag(self.y)
return np.imag(self.y)
if isinstance(self.y, jnp.ndarray):
return jnp.zeros_like(self.y)
return np.zeros_like(self.y)
@property
def storage_modulus(self) -> np.ndarray | None:
"""Get storage modulus (G') from complex modulus data.
Alias for y_real that makes rheological intent explicit.
Returns:
Storage modulus (G') if data is complex, None otherwise
Example:
>>> data = read_trios('frequency_sweep.txt')
>>> G_prime = data[0].storage_modulus
"""
if self.is_complex:
return self.y_real
return None
@property
def loss_modulus(self) -> np.ndarray | None:
"""Get loss modulus (G'') from complex modulus data.
Alias for y_imag that makes rheological intent explicit.
Returns:
Loss modulus (G'') if data is complex, None otherwise
Example:
>>> data = read_trios('frequency_sweep.txt')
>>> G_double_prime = data[0].loss_modulus
"""
if self.is_complex:
return self.y_imag
return None
@property
def tan_delta(self) -> np.ndarray | None:
"""Get loss tangent (tan δ = G''/G') from complex modulus data.
The loss tangent quantifies the ratio of viscous to elastic response:
- tan δ < 1: Elastic-dominant (solid-like)
- tan δ > 1: Viscous-dominant (liquid-like)
- tan δ = 1: Equal elastic and viscous contributions
Returns:
Loss tangent (dimensionless) if data is complex, None otherwise
Example:
>>> data = read_trios('frequency_sweep.txt')
>>> tan_d = data[0].tan_delta
>>> print(f"Material type: {'solid-like' if tan_d.mean() < 1 else 'liquid-like'}")
"""
if self.is_complex:
G_prime = self.y_real
G_double_prime = self.y_imag
# Avoid division by zero
if isinstance(G_prime, jnp.ndarray):
return jnp.where(G_prime > 0, G_double_prime / G_prime, jnp.nan)
return np.where(G_prime > 0, G_double_prime / G_prime, np.nan)
return None
@property
def test_mode(self) -> str:
"""Automatically detect or retrieve test mode.
The test mode is detected based on data characteristics and cached
in a private field. If already detected, returns the cached value. If
explicitly set in metadata['test_mode'], returns that value.
Returns:
Test mode string (relaxation, creep, oscillation, rotation, unknown)
"""
# Prefer explicitly provided test mode
if self._explicit_test_mode is not None:
return self._explicit_test_mode
# R8-DATA-001: check private cache first, avoid shared metadata dict
_cached = getattr(self, "_detected_test_mode", None)
if _cached is not None:
return _cached
# Check if already set in metadata (explicit or previously detected)
if "test_mode" in self.metadata:
_raw = self.metadata["test_mode"]
try:
from rheojax.core.test_modes import TestMode
return TestMode(_raw.lower()).value if isinstance(_raw, str) else _raw
except (ValueError, AttributeError):
return _raw
# Lazy import to avoid circular dependency
from rheojax.core.test_modes import detect_test_mode
# Detect test mode
logger.debug("Detecting test mode from data characteristics")
mode = detect_test_mode(self)
# R8-DATA-001: cache in private field AND metadata for observability.
# Metadata cache allows downstream code to check detected mode without
# triggering re-detection, consistent with __post_init__ behavior.
self._detected_test_mode = mode
self.metadata["detected_test_mode"] = mode
logger.debug("Test mode detected", test_mode=mode)
return mode
@property
def deformation_mode(self) -> str:
"""Get deformation mode from metadata.
Returns 'shear' if not explicitly set. Possible values:
'shear', 'tension', 'bending', 'compression'.
"""
return self.metadata.get("deformation_mode", "shear")
@property
def storage_modulus_label(self) -> str:
"""Get appropriate storage modulus label based on deformation mode.
Returns "E'" for tensile/bending/compression, "G'" for shear.
"""
from rheojax.core.test_modes import DeformationMode
try:
dm = DeformationMode(self.deformation_mode)
return "E'" if dm.is_tensile() else "G'"
except ValueError:
return "G'"
@property
def loss_modulus_label(self) -> str:
"""Get appropriate loss modulus label based on deformation mode.
Returns 'E"' for tensile/bending/compression, 'G"' for shear.
"""
from rheojax.core.test_modes import DeformationMode
try:
dm = DeformationMode(self.deformation_mode)
return 'E"' if dm.is_tensile() else 'G"'
except ValueError:
return 'G"'
[docs]
def __getitem__(self, idx):
"""Support indexing and slicing."""
if isinstance(idx, (int, np.integer)):
return (self.x[idx], self.y[idx])
else:
logger.debug("Slicing RheoData", index_type=type(idx).__name__)
return RheoData(
x=self.x[idx],
y=self.y[idx],
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
[docs]
def __add__(self, other):
"""Add two RheoData objects or scalar."""
if isinstance(other, RheoData):
try:
_axes_equal = np.array_equal(np.asarray(self.x), np.asarray(other.x))
except Exception:
_axes_equal = bool(jnp.all(self.x == other.x))
if not _axes_equal:
logger.error("x-axes must match for addition")
raise ValueError("x-axes must match for addition")
return RheoData(
x=self.x,
y=self.y + other.y,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
else:
return RheoData(
x=self.x,
y=self.y + other,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
[docs]
def __sub__(self, other):
"""Subtract two RheoData objects or scalar."""
if isinstance(other, RheoData):
try:
_axes_equal = np.array_equal(np.asarray(self.x), np.asarray(other.x))
except Exception:
_axes_equal = bool(jnp.all(self.x == other.x))
if not _axes_equal:
logger.error("x-axes must match for subtraction")
raise ValueError("x-axes must match for subtraction")
return RheoData(
x=self.x,
y=self.y - other.y,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
else:
return RheoData(
x=self.x,
y=self.y - other,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
[docs]
def __mul__(self, other):
"""Multiply by scalar or another RheoData."""
if isinstance(other, RheoData):
try:
_axes_equal = np.array_equal(np.asarray(self.x), np.asarray(other.x))
except Exception:
_axes_equal = bool(jnp.all(self.x == other.x))
if not _axes_equal:
logger.error("x-axes must match for multiplication")
raise ValueError("x-axes must match for multiplication")
y_result = self.y * other.y
else:
y_result = self.y * other
return RheoData(
x=self.x,
y=y_result,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
# Data operations
[docs]
def interpolate(self, new_x: ArrayLike) -> RheoData:
"""Interpolate data to new x values.
Args:
new_x: New x values for interpolation
Returns:
Interpolated RheoData
"""
logger.debug(
"Interpolating data",
n_new_points=len(new_x) if hasattr(new_x, "__len__") else 1,
)
new_x = self._ensure_array(new_x)
if np.iscomplexobj(self.y):
# Complex data: interpolate real and imaginary parts separately.
# jnp.interp and np.interp do not support complex arrays — the
# imaginary part would be silently discarded.
if isinstance(self.x, jnp.ndarray):
new_y_real = jnp.interp(new_x, self.x, jnp.real(self.y))
new_y_imag = jnp.interp(new_x, self.x, jnp.imag(self.y))
else:
new_y_real = np.interp(new_x, self.x, np.real(self.y))
new_y_imag = np.interp(new_x, self.x, np.imag(self.y))
new_y = new_y_real + 1j * new_y_imag
elif isinstance(self.x, jnp.ndarray) or isinstance(self.y, jnp.ndarray):
# Use JAX interpolation
new_y = jnp.interp(new_x, self.x, self.y)
else:
# Use NumPy interpolation
new_y = np.interp(new_x, self.x, self.y)
return RheoData(
x=new_x,
y=new_y,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
[docs]
def resample(self, n_points: int) -> RheoData:
"""Resample data to specified number of points.
Args:
n_points: Number of points to resample to
Returns:
Resampled RheoData
"""
logger.debug("Resampling data", n_points=n_points, domain=self.domain)
x_array = _coerce_ndarray(self.x)
if self.domain == "frequency":
if x_array.min() <= 0:
raise ValueError(
f"Cannot resample in log-space: x contains non-positive values "
f"(min={float(x_array.min()):.3g}). Ensure all x > 0 for frequency domain."
)
new_x = np.logspace(
np.log10(x_array.min()), np.log10(x_array.max()), n_points
)
else:
# Linear-spaced for time domain
new_x = np.linspace(x_array.min(), x_array.max(), n_points)
return self.interpolate(new_x)
[docs]
def smooth(self, window_size: int = 5) -> RheoData:
"""Smooth data using moving average.
Args:
window_size: Size of smoothing window
Returns:
Smoothed RheoData
"""
if window_size % 2 == 0:
window_size += 1 # Make odd for symmetric window
logger.debug("Smoothing data", window_size=window_size)
# Simple moving average
kernel = np.ones(window_size) / window_size
if np.iscomplexobj(self.y):
# Complex data: convolve real and imaginary parts separately.
# jnp.convolve and np.convolve may not handle complex arrays correctly.
if isinstance(self.y, jnp.ndarray):
smoothed_real = jnp.convolve(jnp.real(self.y), kernel, mode="same")
smoothed_imag = jnp.convolve(jnp.imag(self.y), kernel, mode="same")
else:
smoothed_real = np.convolve(np.real(self.y), kernel, mode="same")
smoothed_imag = np.convolve(np.imag(self.y), kernel, mode="same")
smoothed_y = smoothed_real + 1j * smoothed_imag
elif isinstance(self.y, jnp.ndarray):
# Use JAX convolution
smoothed_y = jnp.convolve(self.y, kernel, mode="same")
else:
# Use NumPy convolution
smoothed_y = np.convolve(self.y, kernel, mode="same")
return RheoData(
x=self.x,
y=smoothed_y,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
[docs]
def derivative(self) -> RheoData:
"""Compute numerical derivative.
Returns:
RheoData with derivative values
"""
logger.debug("Computing numerical derivative")
if np.iscomplexobj(self.y):
if isinstance(self.y, jnp.ndarray):
dy_dx = jnp.gradient(jnp.real(self.y), self.x) + 1j * jnp.gradient(
jnp.imag(self.y), self.x
)
else:
dy_dx = np.gradient(np.real(self.y), self.x) + 1j * np.gradient(
np.imag(self.y), self.x
)
elif isinstance(self.x, jnp.ndarray) or isinstance(self.y, jnp.ndarray):
dy_dx = jnp.gradient(self.y, self.x)
else:
dy_dx = np.gradient(self.y, self.x)
return RheoData(
x=self.x,
y=dy_dx,
x_units=self.x_units,
y_units=(
f"d({self.y_units})/d({self.x_units})"
if self.y_units and self.x_units
else None
),
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
[docs]
def integral(self) -> RheoData:
"""Compute numerical integral.
Returns:
RheoData with integrated values
"""
logger.debug("Computing numerical integral")
if isinstance(self.x, jnp.ndarray) or isinstance(self.y, jnp.ndarray):
# JAX has no cumulative_trapezoid; compute manually via
# trapezoidal rule: I[0]=0, I[k] = I[k-1] + (y[k-1]+y[k])/2 * dx[k]
dx = jnp.diff(self.x)
avg_y = (self.y[:-1] + self.y[1:]) / 2.0 # type: ignore[operator]
integrated = jnp.concatenate(
[jnp.zeros(1, dtype=self.y.dtype), jnp.cumsum(avg_y * dx)] # type: ignore[union-attr]
)
else:
# Use NumPy/SciPy cumulative trapezoid
from scipy.integrate import cumulative_trapezoid
integrated = cumulative_trapezoid(self.y, self.x, initial=0)
return RheoData(
x=self.x,
y=integrated,
x_units=self.x_units,
y_units=(
f"∫{self.y_units}·d{self.x_units}"
if self.y_units and self.x_units
else None
),
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
# Domain conversion placeholders
[docs]
def to_frequency_domain(self) -> RheoData:
"""Convert time domain data to frequency domain.
Returns:
Frequency domain RheoData
"""
if self.domain != "time":
logger.debug("Data is already in frequency domain")
warnings.warn(
"Data is already in frequency domain", UserWarning, stacklevel=2
)
return self.copy()
logger.error("Frequency domain conversion not yet implemented")
# This would use FFT transform when implemented
raise NotImplementedError(
"Frequency domain conversion will be implemented with transforms"
)
[docs]
def to_time_domain(self) -> RheoData:
"""Convert frequency domain data to time domain.
Returns:
Time domain RheoData
"""
if self.domain != "frequency":
logger.debug("Data is already in time domain")
warnings.warn("Data is already in time domain", UserWarning, stacklevel=2)
return self.copy()
logger.error("Time domain conversion not yet implemented")
# This would use inverse FFT transform when implemented
raise NotImplementedError(
"Time domain conversion will be implemented with transforms"
)
# Data slicing methods
[docs]
def slice(self, start: float | None = None, end: float | None = None) -> RheoData:
"""Slice data between x values.
Args:
start: Start x value
end: End x value
Returns:
Sliced RheoData
"""
logger.debug("Slicing data by x range", start=start, end=end)
# Use np.asarray only for mask computation — preserves JAX or NumPy array type
x_np = np.asarray(self.x)
mask = np.ones_like(x_np, dtype=bool)
if start is not None:
mask &= x_np >= start
if end is not None:
mask &= x_np <= end
sliced_x = self.x[mask]
sliced_y = self.y[mask]
logger.debug("Sliced data", original_size=len(x_np), new_size=len(sliced_x))
result = RheoData(
x=sliced_x,
y=sliced_y,
x_units=self.x_units,
y_units=self.y_units,
domain=self.domain,
initial_test_mode=self._explicit_test_mode,
metadata=self.metadata.copy(),
validate=False,
)
if len(result.x) == 0:
import warnings as _warnings
_warnings.warn(
f"RheoData.slice(start={start}, end={end}) produced an empty result.",
stacklevel=2,
)
return result