"""NumPy .npz writer/reader for RheoData objects."""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import numpy as np
from rheojax.core.data import RheoData
from rheojax.io.json_encoder import NumpyJSONEncoder as _NumpyEncoder
from rheojax.logging import get_logger
logger = get_logger(__name__)
__all__ = ["save_npz", "load_npz", "save_fit_result_npz"]
def _encode_str(s: str | None) -> np.ndarray:
"""Encode a string (or None) as a uint8 byte array for npz storage."""
encoded = (s or "").encode("utf-8")
return np.frombuffer(encoded, dtype=np.uint8)
def _decode_str(arr: np.ndarray) -> str | None:
"""Decode a uint8 byte array back to a string (None if empty)."""
s = arr.tobytes().decode("utf-8")
return s if s else None
[docs]
def save_npz(
data: RheoData,
filepath: str | Path,
compressed: bool = True,
) -> None:
"""Save a RheoData object to a NumPy .npz archive.
Strings and metadata are stored as UTF-8 encoded uint8 byte arrays —
no pickle is used.
Args:
data: RheoData object to save.
filepath: Destination path (np.savez appends .npz if not present).
compressed: If True (default), use np.savez_compressed. If False,
use np.savez (larger file, faster write).
Raises:
OSError: If the file cannot be written.
"""
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
# Serialise metadata to JSON bytes (handles numpy types via custom encoder)
metadata_bytes = json.dumps(
data.metadata or {}, cls=_NumpyEncoder, allow_nan=True
).encode("utf-8")
arrays: dict[str, np.ndarray] = {
"x": np.asarray(data.x),
"y": np.asarray(data.y),
"_metadata_json": np.frombuffer(metadata_bytes, dtype=np.uint8),
"_x_units": _encode_str(data.x_units),
"_y_units": _encode_str(data.y_units),
"_domain": _encode_str(data.domain),
"_initial_test_mode": _encode_str(data._explicit_test_mode),
}
save_fn = np.savez_compressed if compressed else np.savez
save_fn(filepath, **arrays) # type: ignore[arg-type]
logger.info(
"Saved RheoData to npz",
filepath=str(filepath),
compressed=compressed,
n_points=len(data.x), # type: ignore[arg-type]
)
def save_fit_result_npz(
result: Any,
filepath: str | Path,
compressed: bool = True,
) -> None:
"""Save a FitResult to a NumPy .npz archive (no pickle, safe serialization).
Stores all fields as numpy arrays and UTF-8 encoded strings.
Args:
result: A FitResult instance (from rheojax.core.fit_result).
filepath: Destination path.
compressed: Use compressed npz (default: True).
"""
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
# Build arrays dict — all values are numpy arrays (no pickle used)
arrays: dict[str, np.ndarray] = {
"_model_name": _encode_str(result.model_name or ""),
"_model_class_name": _encode_str(result.model_class_name or ""),
"_protocol": _encode_str(result.protocol or ""),
"_n_params": np.array([result.n_params]),
"_timestamp": _encode_str(result.timestamp or ""),
}
# Parameters as JSON-encoded string (safe serialization)
param_names = list(result.params.keys())
param_values = np.array([result.params[n] for n in param_names], dtype=np.float64)
arrays["_param_names"] = _encode_str(json.dumps(param_names))
arrays["_param_values"] = param_values
# Units as JSON-encoded string
if result.params_units:
arrays["_params_units"] = _encode_str(json.dumps(result.params_units))
# Fitted curve
if result.fitted_curve is not None:
arrays["fitted_curve"] = np.asarray(result.fitted_curve)
# Input data
if result.X is not None:
arrays["input_x"] = np.asarray(result.X)
if result.y is not None:
arrays["input_y"] = np.asarray(result.y)
# Statistics as JSON-encoded string
stats = {}
for attr_name in ("r_squared", "aic", "bic", "rmse", "mae"):
val = getattr(result, attr_name, None)
if val is not None:
stats[attr_name] = float(val)
if stats:
arrays["_stats"] = _encode_str(json.dumps(stats))
save_fn = np.savez_compressed if compressed else np.savez
save_fn(filepath, **arrays) # type: ignore[arg-type]
logger.info(
"Saved FitResult to npz",
filepath=str(filepath),
model_name=result.model_name,
)
[docs]
def load_npz(filepath: str | Path) -> RheoData:
"""Load a RheoData object from a NumPy .npz archive.
Args:
filepath: Path to the .npz file (with or without .npz extension).
Returns:
Reconstructed RheoData object.
Raises:
FileNotFoundError: If the file does not exist.
ValueError: If the file is not a valid RheoData npz archive.
"""
filepath = Path(filepath)
if not filepath.exists():
with_suffix = filepath.with_suffix(".npz")
if with_suffix.exists():
filepath = with_suffix
else:
raise FileNotFoundError(f"File not found: {filepath}")
try:
npz = np.load(filepath, allow_pickle=False)
except Exception as e:
raise ValueError(f"Failed to load npz archive: {filepath}: {e}") from e
x = npz["x"]
y = npz["y"]
# Validate array shape compatibility before constructing RheoData
if x.ndim >= 1 and y.ndim >= 1 and len(x) != len(y):
raise ValueError(
f"Corrupt npz archive: x has {len(x)} points but y has {len(y)}"
)
# Parse metadata from UTF-8 bytes
try:
metadata: dict = json.loads(npz["_metadata_json"].tobytes().decode("utf-8"))
except (json.JSONDecodeError, KeyError, UnicodeDecodeError):
logger.warning("Could not parse metadata JSON from npz, using empty dict")
metadata = {}
x_units = _decode_str(npz["_x_units"]) if "_x_units" in npz else None
y_units = _decode_str(npz["_y_units"]) if "_y_units" in npz else None
domain = _decode_str(npz["_domain"]) if "_domain" in npz else None
initial_test_mode = (
_decode_str(npz["_initial_test_mode"]) if "_initial_test_mode" in npz else None
)
logger.info(
"Loaded RheoData from npz",
filepath=str(filepath),
n_points=len(x),
domain=domain,
initial_test_mode=initial_test_mode,
)
return RheoData(
x=x,
y=y,
x_units=x_units,
y_units=y_units,
domain=domain or "time",
initial_test_mode=initial_test_mode,
metadata=metadata,
validate=True,
)