Source code for rheojax.io.readers._validation

"""Protocol-aware validation for rheological data loaded from files.

This module provides opt-in quality checks that can be run after loading data
with any reader. Checks are protocol-specific (relaxation, creep, oscillation,
rotation/flow_curve, startup) and emit :class:`RheoJaxValidationWarning` for
each issue found.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import TYPE_CHECKING

import numpy as np

from rheojax.io._exceptions import RheoJaxValidationWarning
from rheojax.io.readers._utils import validate_transform
from rheojax.logging import get_logger

if TYPE_CHECKING:
    from rheojax.core.data import RheoData

logger = get_logger(__name__)

__all__ = ["LoaderReport", "validate_protocol"]


# =============================================================================
# Report dataclass
# =============================================================================


@dataclass
class LoaderReport:
    """Summary of issues and metadata collected during loading/validation.

    Attributes:
        warnings: Non-fatal data quality messages.
        errors: Fatal issues that prevent reliable analysis.
        skipped_rows: Number of rows discarded during parsing.
        protocol_inferred: True when test mode was inferred (not explicit).
        units_converted: Mapping of field -> original unit for converted values.
        quality_flags: Named boolean flags for downstream consumers.
    """

    warnings: list[str] = field(default_factory=list)
    errors: list[str] = field(default_factory=list)
    skipped_rows: int = 0  # populated by caller if NaN rows were dropped during loading
    protocol_inferred: bool = False
    units_converted: dict[str, str] = field(default_factory=dict)
    quality_flags: dict[str, bool] = field(default_factory=dict)


# =============================================================================
# Helpers
# =============================================================================


def _infer_protocol(data: RheoData) -> tuple[str | None, bool]:
    """Return (protocol, was_inferred).

    Checks ``data.initial_test_mode`` first (set as ``_explicit_test_mode``
    internally), then falls back to ``metadata["detected_test_mode"]``.
    """
    # initial_test_mode is an InitVar; after __post_init__ it is stored in
    # _explicit_test_mode AND metadata["detected_test_mode"].
    explicit = getattr(data, "_explicit_test_mode", None)
    if explicit is not None:
        return explicit, False

    detected = data.metadata.get("detected_test_mode")
    if detected is not None:
        return detected, True

    # Last resort: "test_mode" key (some readers set this without the
    # "detected_" prefix)
    fallback = data.metadata.get("test_mode")
    if fallback is not None:
        return fallback, True

    return None, True


# =============================================================================
# Per-protocol checks
# =============================================================================


def _check_relaxation(data: RheoData, report: LoaderReport) -> None:
    """Quality checks for relaxation modulus G(t) data."""
    y = np.asarray(data.y)
    if np.iscomplexobj(y):
        y = np.abs(y)

    if y.ndim == 2:
        y = y[:, 0]

    # Check for monotonic decay: majority of diffs should be <= 0
    if len(y) >= 3:
        diffs = np.diff(y)
        n_increasing = int(np.sum(diffs > 0))
        fraction_increasing = n_increasing / len(diffs)
        if fraction_increasing > 0.3:
            msg = (
                f"Relaxation data is not monotonically decaying: "
                f"{n_increasing}/{len(diffs)} steps are increasing "
                f"({fraction_increasing:.1%}). Data may be noisy or "
                f"incorrectly labelled."
            )
            report.warnings.append(msg)
            report.quality_flags["monotonic_decay"] = False
            warnings.warn(msg, RheoJaxValidationWarning, stacklevel=3)
            logger.debug(
                "Relaxation monotonicity check failed",
                n_increasing=n_increasing,
                total_steps=len(diffs),
            )
        else:
            report.quality_flags["monotonic_decay"] = True

    # Check that t[0] is not too large relative to the total time range
    x = np.asarray(data.x)
    if len(x) >= 2:
        t_range = float(x[-1] - x[0])
        t_start = float(x[0])
        if t_range > 0 and t_start / t_range > 0.5:
            msg = (
                f"Relaxation data starts at t={t_start:.3g} which is "
                f"{t_start / t_range:.1%} of the total time range "
                f"({t_range:.3g}). Early transient may be missing."
            )
            report.warnings.append(msg)
            report.quality_flags["early_transient_present"] = False
            warnings.warn(msg, RheoJaxValidationWarning, stacklevel=3)
            logger.debug(
                "Relaxation start-time check failed",
                t_start=t_start,
                t_range=t_range,
            )
        else:
            report.quality_flags["early_transient_present"] = True


def _check_creep(data: RheoData, report: LoaderReport) -> None:
    """Quality checks for creep compliance J(t) data."""
    meta = data.metadata
    has_stress = any(k in meta for k in ("sigma_applied", "sigma_0", "stress_applied"))
    if not has_stress:
        msg = (
            "Creep data is missing applied stress metadata. "
            "Expected 'sigma_applied' or 'sigma_0' in metadata for "
            "accurate compliance scaling."
        )
        report.warnings.append(msg)
        report.quality_flags["sigma_metadata_present"] = False
        warnings.warn(msg, RheoJaxValidationWarning, stacklevel=4)
        logger.debug(
            "Creep sigma metadata check failed", metadata_keys=list(meta.keys())
        )
    else:
        report.quality_flags["sigma_metadata_present"] = True


def _check_oscillation(data: RheoData, report: LoaderReport) -> None:
    """Quality checks for oscillatory (SAOS/MAOS) data."""
    x = np.asarray(data.x)
    if len(x) >= 2:
        x_pos = x[x > 0]
        if len(x_pos) >= 2:
            decades = float(np.log10(x_pos.max() / x_pos.min()))
            if decades < 2.0:
                msg = (
                    f"Oscillation frequency range spans only {decades:.2f} decades "
                    f"(min={x_pos.min():.3g}, max={x_pos.max():.3g} rad/s). "
                    f"At least 2 decades are recommended for reliable fitting."
                )
                report.warnings.append(msg)
                report.quality_flags["frequency_range_sufficient"] = False
                warnings.warn(msg, RheoJaxValidationWarning, stacklevel=3)
                logger.debug(
                    "Oscillation frequency range check failed",
                    decades=decades,
                    omega_min=float(x_pos.min()),
                    omega_max=float(x_pos.max()),
                )
            else:
                report.quality_flags["frequency_range_sufficient"] = True
        else:
            msg = "Oscillation data has fewer than 2 positive frequency points."
            report.warnings.append(msg)
            report.quality_flags["frequency_range_sufficient"] = False
            warnings.warn(msg, RheoJaxValidationWarning, stacklevel=3)


def _check_rotation(data: RheoData, report: LoaderReport) -> None:
    """Quality checks for steady-state flow curve (rotation) data."""
    meta = data.metadata
    has_rate = any(k in meta for k in ("gamma_dot", "shear_rate", "applied_shear_rate"))
    if not has_rate:
        msg = (
            "Flow curve (rotation) data is missing shear rate metadata. "
            "Expected 'gamma_dot' or 'shear_rate' in metadata."
        )
        report.warnings.append(msg)
        report.quality_flags["shear_rate_metadata_present"] = False
        warnings.warn(msg, RheoJaxValidationWarning, stacklevel=4)
        logger.debug(
            "Rotation shear-rate metadata check failed",
            metadata_keys=list(meta.keys()),
        )
    else:
        report.quality_flags["shear_rate_metadata_present"] = True


def _check_startup(data: RheoData, report: LoaderReport) -> None:
    """Quality checks for startup-of-flow data."""
    meta = data.metadata
    has_rate = any(k in meta for k in ("gamma_dot", "shear_rate"))
    if not has_rate:
        msg = (
            "Startup data is missing applied shear rate metadata. "
            "Expected 'gamma_dot' or 'shear_rate' in metadata."
        )
        report.warnings.append(msg)
        report.quality_flags["shear_rate_metadata_present"] = False
        warnings.warn(msg, RheoJaxValidationWarning, stacklevel=4)
        logger.debug(
            "Startup shear-rate metadata check failed",
            metadata_keys=list(meta.keys()),
        )
    else:
        report.quality_flags["shear_rate_metadata_present"] = True


# =============================================================================
# Public API
# =============================================================================

_PROTOCOL_CHECKERS = {
    "relaxation": _check_relaxation,
    "creep": _check_creep,
    "oscillation": _check_oscillation,
    "rotation": _check_rotation,
    "flow_curve": _check_rotation,  # alias
    "startup": _check_startup,
}


[docs] def validate_protocol( data: RheoData, intended_transform: str | None = None, ) -> LoaderReport: """Run protocol-aware quality checks on loaded rheological data. Infers the test protocol from ``data.initial_test_mode`` or ``data.metadata["detected_test_mode"]`` and performs protocol-specific quality checks. A :class:`RheoJaxValidationWarning` is emitted for every issue found so that callers using ``warnings.filterwarnings`` can control visibility. Args: data: Loaded rheological data container. intended_transform: Optional transform name (e.g. ``"mastercurve"``, ``"owchirp"``) to validate transform compatibility in addition to protocol checks. Returns: :class:`LoaderReport` with all findings. """ report = LoaderReport() # Guard: empty data x = np.asarray(data.x) if x.size == 0: msg = "Data is empty (zero points). No validation performed." report.errors.append(msg) logger.debug("validate_protocol: empty data, skipping all checks") return report # Infer protocol protocol, was_inferred = _infer_protocol(data) report.protocol_inferred = was_inferred logger.debug( "validate_protocol: protocol resolved", protocol=protocol, inferred=was_inferred, ) if protocol is None: msg = ( "Could not determine test protocol from data. " "Pass 'initial_test_mode' to the reader for reliable validation." ) report.warnings.append(msg) warnings.warn(msg, RheoJaxValidationWarning, stacklevel=2) else: checker = _PROTOCOL_CHECKERS.get(protocol) if checker is not None: checker(data, report) else: msg = ( f"Unknown protocol '{protocol}'. " f"Supported: {sorted(_PROTOCOL_CHECKERS.keys())}" ) report.warnings.append(msg) warnings.warn(msg, RheoJaxValidationWarning, stacklevel=2) # Optional transform validation if intended_transform is not None: transform_warnings = validate_transform( intended_transform=intended_transform, domain=data.domain, metadata=data.metadata, test_mode=protocol, ) for tw in transform_warnings: report.warnings.append(tw) warnings.warn(tw, RheoJaxValidationWarning, stacklevel=2) logger.debug( "validate_protocol: complete", protocol=protocol, n_warnings=len(report.warnings), n_errors=len(report.errors), quality_flags=report.quality_flags, ) return report