"""HDF5 writer for rheological data."""
from __future__ import annotations
import enum
import os
import tempfile
import warnings
from pathlib import Path
from typing import Any
import numpy as np
from rheojax.core.data import RheoData
from rheojax.io._exceptions import RheoJaxValidationWarning
from rheojax.logging import get_logger, log_io
logger = get_logger(__name__)
# Types that HDF5 can natively store as attributes
_HDF5_SCALAR_TYPES = (str, int, float, bool, np.integer, np.floating, np.bool_)
# Sentinel for None values in HDF5 attributes
_NONE_SENTINEL = "__rheojax_None__"
[docs]
def save_hdf5(
data: RheoData,
filepath: str | Path,
compression: bool = True,
compression_level: int = 4,
**kwargs,
) -> None:
"""Save RheoData to HDF5 file.
HDF5 is the recommended format for archiving rheological data. It provides:
- Efficient storage with compression
- Preservation of all metadata
- Fast read/write performance
- Cross-platform compatibility
Args:
data: RheoData object to save
filepath: Output file path
compression: Enable gzip compression (default: True)
compression_level: Compression level 0-9 (default: 4)
**kwargs: Additional arguments passed to h5py
Raises:
ImportError: If h5py not installed
ValueError: If data is invalid
IOError: If file cannot be written
"""
try:
import h5py
except ImportError as exc:
logger.error(
"h5py import failed",
error_type="ImportError",
suggestion="pip install h5py",
exc_info=True,
)
raise ImportError(
"h5py is required for HDF5 writing. Install with: pip install h5py"
) from exc
if not (0 <= compression_level <= 9):
raise ValueError(f"compression_level must be 0-9, got {compression_level}")
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
# Determine compression settings
compression_algorithm: str | None = None
compression_opts = None
if compression:
compression_algorithm = "gzip"
compression_opts = compression_level
logger.debug(
"Compression settings configured",
algorithm=compression_algorithm,
compression_level=compression_opts,
)
with log_io(logger, "write", filepath=str(filepath)) as ctx:
# Atomic write: write to a temp file in the same directory, then rename.
# This prevents corrupt files from interrupted writes.
tmp_fd = None
tmp_path = None
try:
tmp_fd, tmp_path = tempfile.mkstemp(dir=filepath.parent, suffix=".h5.tmp")
os.close(tmp_fd)
tmp_fd = None
with h5py.File(tmp_path, "w") as f:
# Store x and y data with explicit float64 preservation
x_arr = np.asarray(data.x, dtype=np.float64)
y_arr = np.asarray(data.y)
# Preserve complex dtype; ensure real arrays are float64
if not np.issubdtype(y_arr.dtype, np.complexfloating):
y_arr = np.asarray(y_arr, dtype=np.float64)
logger.debug(
"Writing data arrays",
x_shape=x_arr.shape,
x_dtype=str(x_arr.dtype),
y_shape=y_arr.shape,
y_dtype=str(y_arr.dtype),
compression=compression_algorithm,
)
f.create_dataset(
"x",
data=x_arr,
compression=compression_algorithm,
compression_opts=compression_opts,
)
f.create_dataset(
"y",
data=y_arr,
compression=compression_algorithm,
compression_opts=compression_opts,
)
# Store units as attributes
if data.x_units is not None:
f["x"].attrs["units"] = data.x_units
if data.y_units is not None:
f["y"].attrs["units"] = data.y_units
logger.debug(
"Units stored",
x_units=data.x_units,
y_units=data.y_units,
)
# Store domain
f.attrs["domain"] = data.domain
logger.debug("Domain stored", domain=data.domain)
# Store test_mode and deformation_mode as top-level attrs
# (belt-and-suspenders: also in metadata dict)
test_mode = data.test_mode
if test_mode is not None:
f.attrs["test_mode"] = str(test_mode)
deformation_mode = data.deformation_mode
if deformation_mode is not None:
f.attrs["deformation_mode"] = str(deformation_mode)
# Store metadata
if data.metadata:
metadata_group = f.create_group("metadata")
dropped = _write_metadata_recursive(metadata_group, data.metadata)
logger.debug(
"Metadata written",
metadata_keys=list(data.metadata.keys()),
)
if dropped:
logger.warning(
"Some metadata keys could not be serialized "
"and were dropped from the HDF5 file",
dropped_keys=dropped,
)
# Store rheojax version
try:
import rheojax
f.attrs["rheojax_version"] = rheojax.__version__
logger.debug("Version stored", rheojax_version=rheojax.__version__)
except ImportError:
pass
# Atomic rename: only overwrites target after successful write
os.replace(tmp_path, filepath)
tmp_path = None # Prevent cleanup since rename succeeded
finally:
# Clean up temp file if rename didn't happen (write failed)
if tmp_path is not None:
try:
os.unlink(tmp_path)
except OSError:
pass
ctx["data_points"] = len(data.x) # type: ignore[arg-type]
ctx["compression"] = compression
ctx["has_metadata"] = bool(data.metadata)
def _write_metadata_recursive(
group: Any,
metadata: dict[str, Any],
_path: str = "",
) -> list[str]:
"""Recursively write metadata to HDF5 group.
Args:
group: HDF5 group
metadata: Metadata dictionary
_path: Internal path prefix for logging (do not set externally)
Returns:
List of metadata key paths that could not be serialized.
"""
dropped_keys: list[str] = []
for key, value in metadata.items():
full_key = f"{_path}/{key}" if _path else key
if value is None:
# None is not HDF5-storable; store as sentinel string
group.attrs[key] = _NONE_SENTINEL
continue
# Convert enum values to their underlying Python type
# (h5py can't serialize str-enum subclasses directly)
if isinstance(value, enum.Enum):
value = value.value
if isinstance(value, dict):
subgroup = group.create_group(key)
dropped_keys.extend(
_write_metadata_recursive(subgroup, value, _path=full_key)
)
continue
if isinstance(value, (list, tuple)):
try:
if value and all(isinstance(v, str) for v in value):
import h5py
group.attrs.create(key, value, dtype=h5py.string_dtype())
else:
group.attrs[key] = np.array(value)
except (TypeError, ValueError):
# Lists of mixed types — fall back to string
group.attrs[key] = str(value)
logger.warning(
"Metadata key '%s' contains mixed-type list, stored as string representation",
full_key,
)
warnings.warn(
f"Metadata key '{full_key}' contains mixed-type list, "
f"stored as string representation",
RheoJaxValidationWarning,
stacklevel=4,
)
continue
if isinstance(value, np.ndarray):
# HDF5 attributes have a 64 KB size limit; store large arrays
# as datasets within the metadata group instead.
if value.nbytes > 60_000:
group.create_dataset(key, data=value)
else:
group.attrs[key] = value
continue
if isinstance(value, _HDF5_SCALAR_TYPES):
group.attrs[key] = value
continue
# Last resort: stringify
try:
group.attrs[key] = str(value)
logger.debug(
"Metadata stringified for storage",
key=full_key,
original_type=type(value).__name__,
)
except (TypeError, ValueError, OSError):
dropped_keys.append(full_key)
logger.warning(
"Could not serialize metadata key — value dropped",
key=full_key,
value_type=type(value).__name__,
)
return dropped_keys
[docs]
def save_fit_result_hdf5(
result: Any,
filepath: str | Path,
compression: bool = True,
compression_level: int = 4,
) -> None:
"""Save a FitResult to HDF5 file.
Stores model parameters, statistics, fitted curve, and metadata
in a structured HDF5 layout.
Args:
result: A FitResult instance (from rheojax.core.fit_result).
filepath: Output file path.
compression: Enable gzip compression (default: True).
compression_level: Compression level 0-9 (default: 4).
Raises:
ImportError: If h5py not installed.
"""
try:
import h5py
except ImportError as exc:
raise ImportError(
"h5py is required for HDF5 writing. Install with: pip install h5py"
) from exc
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
comp_algo: str | None = "gzip" if compression else None
comp_opts = compression_level if compression else None
# Atomic write: write to a temp file in the same directory, then rename.
tmp_fd = None
tmp_path = None
try:
tmp_fd, tmp_path = tempfile.mkstemp(dir=filepath.parent, suffix=".h5.tmp")
os.close(tmp_fd)
tmp_fd = None
with h5py.File(tmp_path, "w") as f:
f.attrs["rheojax_type"] = "FitResult"
f.attrs["model_name"] = result.model_name or ""
f.attrs["model_class_name"] = result.model_class_name or ""
f.attrs["protocol"] = result.protocol or ""
f.attrs["n_params"] = result.n_params
# Store scalar statistics
for attr_name in ("r_squared", "aic", "bic", "aicc", "rmse", "mae"):
val = getattr(result, attr_name, None)
if val is not None and np.isfinite(val):
f.attrs[attr_name] = float(val)
# Parameters
params_grp = f.create_group("params")
for name, value in result.params.items():
params_grp.attrs[name] = float(value)
# Parameter units
if result.params_units:
units_grp = f.create_group("params_units")
for name, unit in result.params_units.items():
units_grp.attrs[name] = str(unit)
# Fitted curve
if result.fitted_curve is not None:
arr = np.asarray(result.fitted_curve)
f.create_dataset(
"fitted_curve",
data=arr,
compression=comp_algo,
compression_opts=comp_opts,
)
# Input data
if result.X is not None:
f.create_dataset(
"input_x",
data=np.asarray(result.X),
compression=comp_algo,
compression_opts=comp_opts,
)
if result.y is not None:
f.create_dataset(
"input_y",
data=np.asarray(result.y),
compression=comp_algo,
compression_opts=comp_opts,
)
# Timestamp
if result.timestamp:
f.attrs["timestamp"] = result.timestamp
# Atomic rename
os.replace(tmp_path, filepath)
tmp_path = None
finally:
if tmp_path is not None:
try:
os.unlink(tmp_path)
except OSError:
pass
logger.info(
"Saved FitResult to HDF5",
filepath=str(filepath),
model_name=result.model_name,
)
[docs]
def load_hdf5(filepath: str | Path) -> RheoData:
"""Load RheoData from HDF5 file.
Args:
filepath: Path to HDF5 file
Returns:
RheoData object
Raises:
ImportError: If h5py not installed
FileNotFoundError: If file doesn't exist
ValueError: If file format is invalid
"""
try:
import h5py
except ImportError as exc:
logger.error(
"h5py import failed",
error_type="ImportError",
suggestion="pip install h5py",
exc_info=True,
)
raise ImportError(
"h5py is required for HDF5 reading. Install with: pip install h5py"
) from exc
filepath = Path(filepath)
if not filepath.exists():
logger.error(
"File not found",
filepath=str(filepath),
error_type="FileNotFoundError",
)
raise FileNotFoundError(f"File not found: {filepath}")
with log_io(logger, "read", filepath=str(filepath)) as ctx:
with h5py.File(filepath, "r") as f:
# Load data
x = f["x"][:]
y = f["y"][:]
logger.debug(
"Data arrays loaded",
x_shape=x.shape,
y_shape=y.shape,
)
# Load units
# R6-HDF5-001: h5py may return bytes instead of str on some
# platforms/backends. Decode to str for downstream compatibility.
x_units = f["x"].attrs.get("units", None)
if isinstance(x_units, (bytes, str)):
x_units = _safe_decode_hdf5_string(x_units)
y_units = f["y"].attrs.get("units", None)
if isinstance(y_units, (bytes, str)):
y_units = _safe_decode_hdf5_string(y_units)
logger.debug(
"Units loaded",
x_units=x_units,
y_units=y_units,
)
# Load domain
# R6-HDF5-002: Decode bytes for top-level string attrs.
domain = f.attrs.get("domain", "time")
if isinstance(domain, (bytes, str)):
domain = _safe_decode_hdf5_string(domain)
logger.debug("Domain loaded", domain=domain)
# Load metadata
metadata = {}
if "metadata" in f:
metadata = _read_metadata_recursive(f["metadata"])
logger.debug(
"Metadata loaded",
metadata_keys=list(metadata.keys()),
)
# Restore test_mode/deformation_mode from top-level attrs
# into metadata (belt-and-suspenders with metadata dict)
# R6-HDF5-003: Decode bytes for top-level string attrs.
test_mode = f.attrs.get("test_mode", None)
if test_mode is not None:
test_mode = _safe_decode_hdf5_string(test_mode)
if "test_mode" not in metadata:
metadata["test_mode"] = test_mode
deformation_mode = f.attrs.get("deformation_mode", None)
if deformation_mode is not None:
deformation_mode = _safe_decode_hdf5_string(deformation_mode)
if "deformation_mode" not in metadata:
metadata["deformation_mode"] = deformation_mode
ctx["data_points"] = len(x)
ctx["has_metadata"] = bool(metadata)
ctx["domain"] = domain
return RheoData(
x=x,
y=y,
x_units=x_units,
y_units=y_units,
domain=domain,
initial_test_mode=metadata.get("test_mode"),
metadata=metadata,
validate=True,
)
_MAX_HDF5_STRING_LEN = 4096 # Limit string attributes from untrusted HDF5 files
def _safe_decode_hdf5_string(
value: bytes | str, max_len: int = _MAX_HDF5_STRING_LEN
) -> str:
"""Decode and truncate a string value read from HDF5 attributes."""
if isinstance(value, bytes):
value = value.decode("utf-8", errors="replace")
if len(value) > max_len:
logger.warning(
"HDF5 string attribute truncated",
original_len=len(value),
max_len=max_len,
)
value = value[:max_len]
return value
def _read_metadata_recursive(group: Any) -> dict[str, Any]:
"""Recursively read metadata from HDF5 group.
Args:
group: HDF5 group
Returns:
Metadata dictionary
"""
metadata: dict[str, Any] = {}
# Read attributes
for key, value in group.attrs.items():
# h5py may return bytes instead of str on some platforms
if isinstance(value, bytes):
value = _safe_decode_hdf5_string(value)
elif isinstance(value, str):
value = _safe_decode_hdf5_string(value)
# R8-IO-003: decode numpy arrays of bytes from h5py string_dtype()
elif hasattr(value, "dtype") and hasattr(value, "tolist"):
try:
items = value.tolist()
if not isinstance(items, list):
# 0-d array: tolist() returns a scalar — unwrap to Python native type
if isinstance(items, (bytes, str)):
value = _safe_decode_hdf5_string(items)
else:
# int, float, bool, None — unwrap from 0-d numpy array
value = items
elif items and isinstance(items[0], bytes):
value = [
_safe_decode_hdf5_string(v) if isinstance(v, bytes) else str(v)
for v in items
]
elif items and isinstance(items[0], str):
value = [_safe_decode_hdf5_string(v) for v in items]
except (AttributeError, UnicodeDecodeError):
pass
# Restore None values from sentinel (backward-compatible with old "__None__")
if isinstance(value, str) and value in (_NONE_SENTINEL, "__None__"):
metadata[key] = None
else:
metadata[key] = value
# VIS-HDF-001: Move import to top of function (not inside for loop).
# Python caches modules so repeated imports are cheap, but placing import
# inside a loop is misleading and signals incomplete refactoring.
import h5py
# Read subgroups and datasets.
# HDF5-READ-001: attrs are loaded first; skip any dataset/subgroup whose
# name collides with an already-loaded attribute. Without this guard the
# dataset loop would silently overwrite the attribute value, corrupting
# metadata that was intentionally stored as a scalar attribute (e.g.
# test_mode, deformation_mode stored as belt-and-suspenders duplicates).
for key in group.keys():
if key in metadata:
# Attribute with the same name already loaded — skip the
# dataset/subgroup to preserve the attribute's value.
logger.debug(
"Skipping HDF5 dataset/subgroup — name collides with "
"previously loaded attribute; attribute value is kept",
key=key,
)
continue
if isinstance(group[key], h5py.Group):
metadata[key] = _read_metadata_recursive(group[key])
else:
raw = group[key][:]
# Sanitise string datasets the same way as attributes
if hasattr(raw, "dtype") and raw.dtype.kind in ("S", "U", "O"):
try:
items = raw.tolist()
if isinstance(items, (bytes, str)):
raw = _safe_decode_hdf5_string(items)
elif isinstance(items, list) and items:
raw = [
(
_safe_decode_hdf5_string(v)
if isinstance(v, (bytes, str))
else v
)
for v in items
]
except (AttributeError, UnicodeDecodeError):
pass
metadata[key] = raw
return metadata