Source code for rheojax.io.readers.excel_reader

"""Excel file reader for rheological data."""

from __future__ import annotations

import warnings
from pathlib import Path
from typing import Any

import numpy as np

from rheojax.core.data import RheoData
from rheojax.io.readers._utils import (
    VALID_TEST_MODES,
    VALID_TRANSFORMS,
    construct_complex_modulus,
    detect_deformation_mode_from_columns,
    detect_domain,
    detect_test_mode_from_columns,
    extract_unit_from_header,
    validate_transform,
)
from rheojax.logging import get_logger

logger = get_logger(__name__)


[docs] def load_excel( filepath: str | Path, x_col: str | int, y_col: str | int | None = None, *, y_cols: list[str | int] | None = None, sheet: str | int = 0, x_units: str | None = None, y_units: str | None = None, domain: str | None = None, test_mode: str | None = None, deformation_mode: str | None = None, temperature: float | None = None, metadata: dict | None = None, intended_transform: str | None = None, column_mapping: dict[str, str] | None = None, strain_amplitude: float | None = None, angular_frequency: float | None = None, applied_stress: float | None = None, shear_rate: float | None = None, reference_gamma_dot: float | None = None, header: int | None = 0, **kwargs, ) -> RheoData: """Load data from Excel file into RheoData. Args: filepath: Path to Excel file (.xlsx or .xls). x_col: Column name or index for x-axis data. y_col: Column name or index for y-axis data (single column). Mutually exclusive with y_cols. y_cols: List of two column names/indices for complex modulus [G', G'']. First column is storage modulus (G'), second is loss modulus (G''). Mutually exclusive with y_col. sheet: Sheet name or index (default: 0 - first sheet). x_units: Units for x-axis (auto-detected from header if None). y_units: Units for y-axis (auto-detected from header if None). domain: Data domain ('time' or 'frequency', auto-detected if None). test_mode: Test mode ('relaxation', 'creep', 'oscillation', 'rotation'). Auto-detected if None. deformation_mode: Deformation mode ('shear', 'tension', 'bending', 'compression'). Auto-detected from column names if None. If 'tension'/'bending'/'compression', sets metadata for DMTA support. temperature: Temperature in Kelvin for TTS workflows. metadata: Additional metadata dict to merge. intended_transform: Transform type for metadata validation. One of 'mastercurve', 'srfs', 'owchirp', 'spp', 'fft', 'mutation', 'derivative'. column_mapping: Optional dict mapping original column names to new names. Applied immediately after reading, before any column lookup. Example: {"t": "time", "sigma": "stress"}. strain_amplitude: Strain amplitude (gamma_0) stored in metadata as ``gamma_0``. Used for LAOS/oscillation protocols. angular_frequency: Angular frequency (omega) stored in metadata as ``omega``. Used for oscillation protocols. applied_stress: Applied stress stored in metadata as ``sigma_applied``. Used for creep protocols. shear_rate: Shear rate stored in metadata as ``gamma_dot``. Used for flow/startup protocols. reference_gamma_dot: Reference shear rate stored in metadata as ``reference_gamma_dot``. Used for dimensionless flow analysis. header: Row number for column headers (None if no header). **kwargs: Additional arguments passed to pandas.read_excel. Returns: RheoData object with populated fields. Raises: FileNotFoundError: If file doesn't exist. ImportError: If pandas or openpyxl not installed. KeyError: If specified columns or sheet don't exist. ValueError: If data cannot be parsed, y_cols has wrong length, or both y_col and y_cols are provided. Warnings: UserWarning: If intended_transform metadata is missing. UserWarning: If domain incompatible with intended_transform. UserWarning: If test_mode conflicts with intended_transform. Example: >>> # Simple creep data from specific sheet >>> data = load_excel( ... "data.xlsx", ... x_col="time (s)", ... y_col="J(t) (1/Pa)", ... sheet="Creep Test", ... ) >>> # Flow curve with explicit test mode >>> data = load_excel( ... "flow_curve.xlsx", ... x_col=0, ... y_col=1, ... test_mode='rotation', ... x_units='1/s', ... y_units='Pa·s', ... ) >>> # Complex modulus from Excel >>> data = load_excel( ... "frequency_sweep.xlsx", ... x_col="omega (rad/s)", ... y_cols=["G' (Pa)", "G'' (Pa)"], ... intended_transform='mastercurve', ... temperature=298.15, ... ) """ try: import pandas as pd except ImportError as exc: logger.error("pandas not installed for Excel reading", exc_info=True) raise ImportError( "pandas is required for Excel reading. Install with: pip install pandas openpyxl" ) from exc filepath = Path(filepath) logger.info("Opening file", filepath=str(filepath)) if not filepath.exists(): logger.error("File not found", filepath=str(filepath)) raise FileNotFoundError(f"File not found: {filepath}") # Validate y_col / y_cols mutual exclusivity if y_col is not None and y_cols is not None: raise ValueError("Cannot specify both y_col and y_cols. Use one or the other.") if y_col is None and y_cols is None: raise ValueError("Must specify either y_col or y_cols.") if y_cols is not None and len(y_cols) != 2: raise ValueError( f"y_cols must contain exactly 2 columns [G', G'']. Got {len(y_cols)} columns." ) # Validate test_mode if provided if test_mode is not None and test_mode.lower() not in VALID_TEST_MODES: raise ValueError( f"Invalid test_mode '{test_mode}'. " f"Valid options: {sorted(VALID_TEST_MODES)}" ) # Validate intended_transform if provided if ( intended_transform is not None and intended_transform.lower() not in VALID_TRANSFORMS ): raise ValueError( f"Invalid intended_transform '{intended_transform}'. " f"Valid options: {sorted(VALID_TRANSFORMS)}" ) # Build list of columns to load (memory optimization for wide files) # Only use usecols when all column specifiers are strings (not indices) # Skip when column_mapping is provided — file columns differ from target names usecols = None if column_mapping is not None: pass # Cannot use usecols with column_mapping (file has pre-rename names) elif isinstance(x_col, str): cols_needed = [x_col] if y_col is not None and isinstance(y_col, str): cols_needed.append(y_col) elif y_cols is not None: cols_needed.extend([c for c in y_cols if isinstance(c, str)]) # Only set usecols if all columns are strings if len(cols_needed) == (1 + (1 if y_col is not None else len(y_cols or []))): usecols = cols_needed # Read Excel file try: logger.debug("Reading Excel file", sheet=sheet) df = pd.read_excel( filepath, sheet_name=sheet, header=header, usecols=usecols, **kwargs ) except (ImportError, KeyError): raise except Exception as e: logger.error( "Failed to parse Excel file", filepath=str(filepath), exc_info=True ) raise ValueError(f"Error reading Excel columns: {e}") from e logger.debug("Excel file read successfully", n_rows=len(df), n_cols=len(df.columns)) # Apply column renaming if provided if column_mapping is not None: df = df.rename(columns=column_mapping) logger.debug("Applied column_mapping", mapping=column_mapping) # Get column headers for detection x_header = _get_column_header(df, x_col) # Extract x data try: x_data = _get_column_data(df, x_col) except (KeyError, IndexError) as e: logger.error("X column not found", x_col=x_col, exc_info=True) raise KeyError(f"X column not found: {e}") from e # Extract y data (single column or complex modulus) is_complex = y_cols is not None if is_complex: if y_cols is None: # pragma: no cover — guarded by is_complex raise ValueError("y_cols must not be None for complex data") y_headers = [_get_column_header(df, col) for col in y_cols] try: g_prime_data = _get_column_data(df, y_cols[0]) g_double_prime_data = _get_column_data(df, y_cols[1]) except (KeyError, IndexError) as e: logger.error("Y column not found", y_cols=y_cols, exc_info=True) raise KeyError(f"Y column not found: {e}") from e # Convert to float arrays before constructing complex modulus g_prime_data = np.array(g_prime_data, dtype=float) g_double_prime_data = np.array(g_double_prime_data, dtype=float) y_data = construct_complex_modulus(g_prime_data, g_double_prime_data) logger.debug("Constructed complex modulus from G' and G''") else: if y_col is None: # pragma: no cover — guarded by is_complex raise ValueError("y_col must not be None for real data") y_headers = [_get_column_header(df, y_col)] try: y_data = _get_column_data(df, y_col) except (KeyError, IndexError) as e: logger.error("Y column not found", y_col=y_col, exc_info=True) raise KeyError(f"Y column not found: {e}") from e # Convert to numpy arrays and handle NaN x_data = np.array(x_data, dtype=float) if not is_complex: y_data = np.array(y_data, dtype=float) # Remove non-finite values (NaN and ±inf) in single pass. # np.isfinite covers both NaN and inf, preventing RheoData's isfinite # validation from raising a confusing ValueError on instrument artefacts. if is_complex: valid_idx = np.flatnonzero( np.isfinite(x_data) & np.isfinite(y_data.real) & np.isfinite(y_data.imag) ) else: valid_idx = np.flatnonzero(np.isfinite(x_data) & np.isfinite(y_data)) n_dropped = len(x_data) - len(valid_idx) if n_dropped > 0: logger.warning( "Dropped non-finite (NaN/Inf) rows during loading", n_dropped=n_dropped, n_total=len(x_data), ) x_data = x_data[valid_idx] if y_data.ndim > 1: y_data = y_data[valid_idx, :] else: y_data = y_data[valid_idx] if len(x_data) == 0: logger.error( "No valid data points after removing NaN values", filepath=str(filepath) ) raise ValueError("No valid data points after removing NaN values") logger.debug("Data points after NaN removal", n_points=len(x_data)) # Auto-extract units from headers if not provided if x_units is None: _, x_units = extract_unit_from_header(x_header) if y_units is None: # Use first y column header for units _, y_units = extract_unit_from_header(y_headers[0]) # Auto-detect domain if not provided if domain is None: domain = detect_domain(x_header, x_units, y_headers) logger.debug("Auto-detected domain", domain=domain) # Auto-detect test mode if not provided detected_test_mode = None if test_mode is None: detected_test_mode = detect_test_mode_from_columns( x_header, y_headers, x_units, y_units ) # If y_cols provided, default to oscillation if detected_test_mode is None and is_complex: detected_test_mode = "oscillation" logger.debug("Auto-detected test mode", test_mode=detected_test_mode) else: detected_test_mode = test_mode.lower() # Build source metadata source_metadata = { "source_file": filepath.name, "file_type": "excel", "sheet": sheet, "x_column": x_col, "y_column": y_cols if is_complex else y_col, } # Merge with user metadata final_metadata: dict[str, Any] = {**source_metadata} if metadata: final_metadata.update(metadata) # Add temperature if provided if temperature is not None: final_metadata["temperature"] = temperature # Store protocol metadata if strain_amplitude is not None: final_metadata["gamma_0"] = strain_amplitude if angular_frequency is not None: final_metadata["omega"] = angular_frequency if applied_stress is not None: final_metadata["sigma_applied"] = applied_stress if shear_rate is not None: final_metadata["gamma_dot"] = shear_rate if reference_gamma_dot is not None: final_metadata["reference_gamma_dot"] = reference_gamma_dot # Add intended_transform if provided if intended_transform is not None: final_metadata["intended_transform"] = intended_transform.lower() # Validate transform requirements and emit warnings warning_messages = validate_transform( intended_transform, domain, final_metadata, detected_test_mode, ) for msg in warning_messages: warnings.warn(msg, UserWarning, stacklevel=2) # Auto-detect deformation mode from y column names if not provided if deformation_mode is None: detected_deformation = detect_deformation_mode_from_columns(y_headers, y_units) if detected_deformation is not None: deformation_mode = detected_deformation logger.debug( "Auto-detected deformation mode", deformation_mode=deformation_mode ) # Store deformation mode in metadata for BaseModel.fit() auto-detection if deformation_mode is not None: final_metadata["deformation_mode"] = deformation_mode logger.info( "File parsed", filepath=str(filepath), n_records=len(x_data), test_mode=detected_test_mode, domain=domain, deformation_mode=deformation_mode, ) return RheoData( x=x_data, y=y_data, x_units=x_units, y_units=y_units, domain=domain, initial_test_mode=detected_test_mode, metadata=final_metadata, validate=True, )
def _get_column_header(df, col: str | int) -> str: """Get column header string from DataFrame.""" if isinstance(col, str): return col return str(df.columns[col]) def _get_column_data(df, col: str | int): """Get column data from DataFrame.""" if isinstance(col, str): return df[col].values return df.iloc[:, col].values