Source code for rheojax.io.writers.excel_writer

"""Excel writer for rheological data and results."""

from __future__ import annotations

import os
import re
import tempfile
from pathlib import Path
from typing import Any

import numpy as np

from rheojax.logging import get_logger, log_io

logger = get_logger(__name__)


def _to_python_scalar(value: Any) -> Any:
    """Convert JAX/numpy scalars to native Python types for Excel compatibility."""
    if isinstance(value, (int, float, str, bool)):
        return value
    if isinstance(value, np.generic):
        return value.item()
    # JAX arrays: check for .item() method (0-d arrays)
    if hasattr(value, "item") and hasattr(value, "shape"):
        try:
            if value.shape == () or value.size == 1:
                return value.item()
        except (TypeError, ValueError):
            pass
    return value


[docs] def save_excel( results: dict[str, Any], filepath: str | Path, include_plots: bool = False, **kwargs ) -> None: """Save results to Excel file for reporting. Creates an Excel workbook with multiple sheets for different result types: - Parameters sheet: Model parameters and values - Fit Quality sheet: R², RMSE, and other metrics - Predictions sheet: Model predictions - Residuals sheet: Fitting residuals Args: results: Dictionary containing results - 'parameters': dict of parameter names and values - 'fit_quality': dict of fit metrics (R2, RMSE, etc.) - 'predictions': array of model predictions (optional) - 'residuals': array of residuals (optional) filepath: Output file path (.xlsx) include_plots: Include embedded plots (requires matplotlib) **kwargs: Additional arguments Raises: ImportError: If pandas or openpyxl not installed ValueError: If results format is invalid IOError: If file cannot be written """ try: import pandas as pd except ImportError as exc: logger.error( "pandas import failed", error_type="ImportError", suggestion="pip install pandas openpyxl", exc_info=True, ) raise ImportError( "pandas is required for Excel writing. Install with: pip install pandas openpyxl" ) from exc filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) with log_io(logger, "write", filepath=str(filepath)) as ctx: sheets_written = [] if not results: tmp_fd, tmp_path = tempfile.mkstemp( dir=str(filepath.parent), suffix=".tmp.xlsx" ) try: os.close(tmp_fd) except OSError: os.unlink(tmp_path) raise try: with pd.ExcelWriter(tmp_path, engine="openpyxl") as writer: pd.DataFrame({"Info": ["No results to export"]}).to_excel( writer, sheet_name="Empty", index=False ) os.replace(tmp_path, str(filepath)) tmp_path = None # type: ignore[assignment] # prevent cleanup finally: if tmp_path is not None: try: os.unlink(tmp_path) except OSError: pass ctx["sheets_written"] = ["Empty"] ctx["num_sheets"] = 1 ctx["include_plots"] = include_plots return # Create Excel writer with atomic write (temp file + replace) tmp_fd, tmp_path = tempfile.mkstemp( dir=str(filepath.parent), suffix=".tmp.xlsx" ) try: os.close(tmp_fd) except OSError: os.unlink(tmp_path) raise try: with pd.ExcelWriter(tmp_path, engine="openpyxl") as writer: # Write parameters sheet if "parameters" in results: logger.debug( "Creating parameters dataframe", num_parameters=len(results["parameters"]), ) params_df = _create_parameters_dataframe(results["parameters"]) params_df.to_excel(writer, sheet_name="Parameters", index=False) sheets_written.append("Parameters") logger.debug("Parameters sheet written", rows=len(params_df)) # Write fit quality sheet if "fit_quality" in results: logger.debug( "Creating fit quality dataframe", num_metrics=len(results["fit_quality"]), ) quality_df = _create_quality_dataframe(results["fit_quality"]) quality_df.to_excel(writer, sheet_name="Fit Quality", index=False) sheets_written.append("Fit Quality") logger.debug("Fit Quality sheet written", rows=len(quality_df)) # Write predictions sheet if "predictions" in results: logger.debug( "Creating predictions dataframe", num_predictions=len(results["predictions"]), ) pred_df = _create_predictions_dataframe( results["predictions"], deformation_mode=results.get("deformation_mode"), ) pred_df.to_excel(writer, sheet_name="Predictions", index=False) sheets_written.append("Predictions") logger.debug("Predictions sheet written", rows=len(pred_df)) # Write residuals sheet if "residuals" in results: logger.debug( "Creating residuals dataframe", num_residuals=len(results["residuals"]), ) resid_df = _create_residuals_dataframe(results["residuals"]) resid_df.to_excel(writer, sheet_name="Residuals", index=False) sheets_written.append("Residuals") logger.debug("Residuals sheet written", rows=len(resid_df)) # Embed plots if requested if include_plots and "plots" in results: logger.debug( "Embedding plots", num_plots=len(results["plots"]), ) _embed_plots(writer, results["plots"]) sheets_written.extend( [f"Plot_{name[:25]}" for name in results["plots"].keys()] ) logger.debug( "Plots embedded", plot_names=list(results["plots"].keys()) ) os.replace(tmp_path, str(filepath)) tmp_path = None # type: ignore[assignment] # prevent cleanup finally: if tmp_path is not None: try: os.unlink(tmp_path) except OSError: pass ctx["sheets_written"] = sheets_written ctx["num_sheets"] = len(sheets_written) ctx["include_plots"] = include_plots
def _create_parameters_dataframe(parameters: dict[str, Any]) -> Any: """Create DataFrame for parameters. Args: parameters: Dictionary of parameter names and values Returns: pandas DataFrame """ import pandas as pd data = [] for name, value in parameters.items(): if isinstance(value, dict): # Parameter with additional info data.append( { "Parameter": name, "Value": _to_python_scalar(value.get("value", value)), "Units": value.get("units", ""), "Bounds": str(value.get("bounds", "")), } ) else: # Simple parameter value (may be JAX/numpy scalar) data.append( { "Parameter": name, "Value": _to_python_scalar(value), "Units": "", "Bounds": "", } ) return pd.DataFrame(data) def _create_quality_dataframe(fit_quality: dict[str, Any]) -> Any: """Create DataFrame for fit quality metrics. Args: fit_quality: Dictionary of fit quality metrics Returns: pandas DataFrame """ import pandas as pd data = [] for metric, value in fit_quality.items(): data.append( { "Metric": metric, "Value": _to_python_scalar(value), } ) return pd.DataFrame(data) def _create_predictions_dataframe( predictions: np.ndarray, deformation_mode: str | None = None, ) -> Any: """Create DataFrame for predictions. Handles complex arrays (G*=G'+iG'' or E*=E'+iE'') by splitting into separate columns. Column labels use the appropriate modulus prefix based on the deformation mode. Args: predictions: Array of predictions (real or complex) deformation_mode: Deformation mode ('tension', 'bending', 'compression', or None for shear). Controls column label prefix. Returns: pandas DataFrame """ import pandas as pd predictions = np.asarray(predictions) prefix = "E" if deformation_mode in {"tension", "bending", "compression"} else "G" if np.iscomplexobj(predictions): return pd.DataFrame( { "Index": np.arange(len(predictions)), f"{prefix}' (Storage)": np.real(predictions), f"{prefix}'' (Loss)": np.imag(predictions), } ) if predictions.ndim == 2: # GMM output returns (N, 2) real arrays [G'/E', G''/E''] col_names = [f"Component_{i}" for i in range(predictions.shape[1])] if predictions.shape[1] == 2: col_names = [f"{prefix}' (Storage)", f"{prefix}'' (Loss)"] df_dict = {"Index": np.arange(len(predictions))} for i, name in enumerate(col_names): df_dict[name] = predictions[:, i] return pd.DataFrame(df_dict) return pd.DataFrame( { "Index": np.arange(len(predictions)), "Prediction": predictions, } ) def _create_residuals_dataframe(residuals: np.ndarray) -> Any: """Create DataFrame for residuals. Handles complex arrays by splitting into real/imaginary components. Args: residuals: Array of residuals (real or complex) Returns: pandas DataFrame """ import pandas as pd residuals = np.asarray(residuals) if np.iscomplexobj(residuals): return pd.DataFrame( { "Index": np.arange(len(residuals)), "Residual (Real)": np.real(residuals), "Residual (Imag)": np.imag(residuals), } ) return pd.DataFrame( { "Index": np.arange(len(residuals)), "Residual": residuals, } ) def _embed_plots(writer: Any, plots: dict[str, Any]) -> None: """Embed plots in Excel workbook. Args: writer: ExcelWriter object plots: Dictionary of plot names and matplotlib figures Note: Requires openpyxl and matplotlib. """ import io try: from openpyxl.drawing.image import Image as XLImage except ImportError: logger.warning( "openpyxl.drawing.image not available — plots will not be embedded. " "Install openpyxl with: pip install openpyxl", reason="ImportError", ) return workbook = writer.book for plot_name, fig in plots.items(): # Create sheet for this plot sheet_name = re.sub( r"[\\/*?\[\]:]", "_", f"Plot_{plot_name[:25]}" ) # Excel sheet name limit logger.debug( "Embedding plot", plot_name=plot_name, sheet_name=sheet_name, ) if sheet_name not in workbook.sheetnames: workbook.create_sheet(sheet_name) sheet = workbook[sheet_name] # Save figure to bytes buffer buf = io.BytesIO() try: fig.savefig(buf, format="png", dpi=150, bbox_inches="tight") buf.seek(0) logger.debug( "Plot saved to buffer", plot_name=plot_name, buffer_size=buf.getbuffer().nbytes, ) # Create and add image to sheet img = XLImage(buf) img.anchor = "A1" sheet.add_image(img) finally: buf.close()