"""Batch processing pipeline for multiple datasets.
This module provides utilities for applying the same pipeline to
multiple datasets efficiently, with parallel processing support.
Example:
>>> from rheojax.pipeline import Pipeline, BatchPipeline
>>> template = Pipeline().fit('maxwell').plot()
>>> batch = BatchPipeline(template)
>>> batch.process_directory('data/', pattern='*.csv')
>>> batch.export_summary('summary.xlsx')
"""
from __future__ import annotations
import copy
import warnings
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from rheojax.core.data import RheoData
from rheojax.logging import get_logger, log_fit, log_pipeline_stage
from rheojax.pipeline.base import Pipeline
logger = get_logger(__name__)
[docs]
class BatchPipeline:
"""Apply pipeline to multiple datasets.
This class enables batch processing of multiple data files with
the same pipeline configuration, collecting results for analysis.
Attributes:
template_pipeline: Template Pipeline to apply to each dataset
results: List of (file_path, result, metrics) tuples
Example:
>>> template = Pipeline().fit('maxwell')
>>> batch = BatchPipeline(template)
>>> batch.process_files(['data1.csv', 'data2.csv'])
"""
[docs]
def __init__(self, template_pipeline: Pipeline | None = None):
"""Initialize batch pipeline.
Args:
template_pipeline: Template Pipeline to clone for each file.
If None, must be set before processing.
"""
self.template_pipeline = template_pipeline
self.results: list[tuple[Path, RheoData, dict[str, Any]]] = []
self.errors: list[tuple[Path, Exception]] = []
logger.debug(
"BatchPipeline initialized",
has_template=template_pipeline is not None,
)
[docs]
def set_template(self, pipeline: Pipeline) -> BatchPipeline:
"""Set template pipeline.
Args:
pipeline: Pipeline to use as template
Returns:
self for method chaining
"""
self.template_pipeline = pipeline
logger.debug("Template pipeline set", pipeline_type=type(pipeline).__name__)
return self
[docs]
def process_files(
self,
file_paths: Iterable[str | Path],
format: str = "auto",
parallel: bool = False,
parallel_io: bool = True,
n_workers: int | None = None,
**load_kwargs,
) -> BatchPipeline:
"""Process multiple files with the pipeline.
Args:
file_paths: List of file paths to process
format: File format for loading
parallel: Whether to use parallel processing for the full pipeline.
Default False: JAX JIT cache is not thread-safe with concurrent
ThreadPoolExecutor. Set True only for I/O-bound pipelines without
JAX JIT calls (e.g., loading + simple numpy transforms).
parallel_io: Whether to load files in parallel using threads.
Default True: file I/O is thread-safe and benefits from
parallelism. Loading phase runs in threads, pipeline replay
runs sequentially.
n_workers: Number of parallel workers (default: min(4, cpu_count))
**load_kwargs: Additional arguments for data loading
Returns:
self for method chaining
Note:
During replay, protocol-specific kwargs (gamma_dot, sigma_init,
lam_init, sigma_0, lam_0, gamma_0, omega_laos, n_cycles,
points_per_cycle) are stripped from the template's fit kwargs
because they are data-dependent and should not be reused
across datasets. DMTA kwargs (deformation_mode, poisson_ratio)
and solver settings (method) are preserved.
Example:
>>> batch.process_files(['data1.csv', 'data2.csv'])
>>> # Parallel mode (use with caution — JAX JIT not thread-safe):
>>> batch.process_files(['data1.csv', 'data2.csv'], parallel=True)
"""
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
if self.template_pipeline is None:
logger.error("No template pipeline set")
raise ValueError("No template pipeline set. Call set_template() first.")
normalized_paths = [Path(p) for p in file_paths]
if not normalized_paths:
logger.debug("No files to process")
return self
logger.info(
"Starting batch processing",
n_files=len(normalized_paths),
parallel=parallel,
n_workers=n_workers if parallel else 1,
)
if parallel:
if parallel_io:
logger.debug(
"parallel_io is ignored when parallel=True "
"(full pipeline runs in threads, including I/O)"
)
import warnings as _batch_warnings
has_fit_steps = any(
step_action in ("fit", "fit_nlsq")
for step_action, _ in self.template_pipeline.steps
)
if has_fit_steps:
_batch_warnings.warn(
"parallel=True with a fitting pipeline may cause JAX JIT compilation "
"races between threads. Set parallel=False for pipelines that call "
"model.fit().",
UserWarning,
stacklevel=2,
)
# Parallel processing with ThreadPoolExecutor
if n_workers is None:
n_workers = min(4, os.cpu_count() or 1)
def process_one(file_path):
try:
logger.debug("Processing file", filepath=str(file_path))
result, metrics = self._process_file(
file_path, format=format, **load_kwargs
)
logger.debug(
"File processed successfully",
filepath=str(file_path),
n_points=len(result.x) if result else 0,
)
return (file_path, result, metrics, None)
except Exception as e:
logger.error(
"Failed to process file",
filepath=str(file_path),
error_type=type(e).__name__,
error_message=str(e),
exc_info=True,
)
return (file_path, None, None, e)
# NOTE: This uses concurrent.futures.ThreadPoolExecutor (not Qt threads).
# Designed for headless/pipeline use only. If called from the GUI,
# the calling thread blocks at as_completed(). Use WorkerPool for
# GUI integration.
with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = {
executor.submit(process_one, fp): fp for fp in normalized_paths
}
for future in as_completed(futures):
file_path, result, metrics, error = future.result()
if error is None:
self.results.append((file_path, result, metrics))
else:
self.errors.append((file_path, error))
warnings.warn(
f"Failed to process {file_path}: {error}", stacklevel=2
)
else:
# Phase 1: Optionally pre-load files in parallel (I/O only, thread-safe)
preloaded: dict[Path, RheoData] = {}
if parallel_io and len(normalized_paths) > 1:
io_workers = n_workers or min(len(normalized_paths), 8)
preloaded = self._parallel_preload(
normalized_paths,
format=format,
n_workers=io_workers,
**load_kwargs,
)
# Phase 2: Sequential pipeline replay (JAX-safe)
for file_path in normalized_paths:
try:
logger.debug("Processing file", filepath=str(file_path))
result, metrics = self._process_file(
file_path,
format=format,
preloaded_data=preloaded.get(file_path),
**load_kwargs,
)
self.results.append((file_path, result, metrics))
logger.debug(
"File processed successfully",
filepath=str(file_path),
n_points=len(result.x) if result else 0,
)
except Exception as e:
self.errors.append((file_path, e))
logger.error(
"Failed to process file",
filepath=str(file_path),
error_type=type(e).__name__,
error_message=str(e),
exc_info=True,
)
warnings.warn(f"Failed to process {file_path}: {e}", stacklevel=2)
logger.info(
"Batch processing completed",
n_success=len(self.results),
n_errors=len(self.errors),
)
return self
[docs]
def process_directory(
self,
directory: str | Path,
pattern: str = "*.csv",
recursive: bool = False,
**kwargs,
) -> BatchPipeline:
"""Process all files in directory matching pattern.
Args:
directory: Directory path
pattern: File pattern (e.g., '*.csv', '*.xlsx')
recursive: Whether to search recursively
**kwargs: Additional arguments passed to process_files
Returns:
self for method chaining
Example:
>>> batch.process_directory('data/', pattern='*.csv')
"""
directory_path = Path(directory)
logger.debug(
"Scanning directory",
directory=str(directory_path),
pattern=pattern,
recursive=recursive,
)
if not directory_path.exists():
logger.error("Directory not found", directory=str(directory))
raise FileNotFoundError(f"Directory not found: {directory}")
if recursive:
file_paths = list(directory_path.rglob(pattern))
else:
file_paths = list(directory_path.glob(pattern))
logger.debug(
"Directory scan completed",
directory=str(directory_path),
n_files_found=len(file_paths),
)
if not file_paths:
logger.warning(
"No files matching pattern found",
directory=str(directory),
pattern=pattern,
)
warnings.warn(
f"No files matching '{pattern}' found in {directory}", stacklevel=2
)
return self
return self.process_files(file_paths, **kwargs)
def _parallel_preload(
self,
file_paths: list[Path],
format: str = "auto",
n_workers: int = 8,
**load_kwargs,
) -> dict[Path, RheoData]:
"""Pre-load files in parallel using threads (I/O only, thread-safe).
Returns a dict mapping file_path -> RheoData for successfully loaded files.
Failures are logged but do not raise (handled later in _process_file).
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
from rheojax.io import auto_load
loaded: dict[Path, RheoData] = {}
def _load_one(
fp: Path,
) -> tuple[Path, RheoData | list[RheoData] | None, Exception | None]:
try:
data = auto_load(fp, format=format, **load_kwargs)
return (fp, data, None)
except Exception as e:
return (fp, None, e)
with ThreadPoolExecutor(max_workers=n_workers) as executor:
futures = {executor.submit(_load_one, fp): fp for fp in file_paths}
for future in as_completed(futures):
fp, data, err = future.result()
if err is None and data is not None:
loaded[fp] = data # type: ignore[assignment]
elif err is not None:
logger.debug(
"Parallel preload failed for file",
filepath=str(fp),
error=str(err),
)
logger.debug(
"Parallel preload completed",
n_loaded=len(loaded),
n_total=len(file_paths),
)
return loaded
def _process_file(
self,
file_path: Path,
format: str = "auto",
preloaded_data: RheoData | None = None,
**load_kwargs,
) -> tuple[RheoData, dict[str, Any]]:
"""Process single file with pipeline.
Args:
file_path: Path to file
format: File format
preloaded_data: Pre-loaded RheoData (skips I/O if provided)
**load_kwargs: Additional load arguments
Returns:
Tuple of (result_data, metrics)
"""
# Clone template pipeline
pipeline = self._clone_pipeline(self.template_pipeline)
path = Path(file_path)
# R11-BATCH-001: Clear template-copied steps before load+replay to avoid
# duplicates. Reset before load() so the load step itself is not mixed
# with stale template steps.
pipeline.steps = []
pipeline._last_model = None
# Load data — use preloaded_data if available (from parallel I/O phase)
if preloaded_data is not None:
pipeline.data = preloaded_data
else:
with log_pipeline_stage(logger, "load", filepath=str(path)):
pipeline.load(path, format=format, **load_kwargs)
# R12-E-006: pre-initialize metrics so transform replay errors can be
# recorded inside the loop below before fit metrics are appended.
metrics: dict[str, Any] = {}
# R10-BATCH-001: Replay template steps on the newly loaded data.
# Steps are recorded as ("fit", model_obj) or ("transform", transform_obj)
# tuples. For each step we create a fresh model/transform of the same class
# and re-fit/re-transform on the new dataset, preserving fit kwargs that were
# stored in _last_fit_kwargs by the model itself.
fit_kwargs_replay: dict[str, Any] = {}
for step_action, step_obj in self.template_pipeline.steps:
if step_action in ("fit", "fit_nlsq"):
model_cls = type(step_obj)
new_model = model_cls()
X = np.asarray(pipeline.data.x)
y = np.asarray(pipeline.data.y)
_lfk = getattr(step_obj, "_last_fit_kwargs", None)
fit_kwargs_replay = dict(_lfk) if _lfk is not None else {}
# Strip internal tracking keys and protocol-specific kwargs
# that should not be replayed from the template to new datasets.
_batch_strip_keys = {
# NOTE: "method" is intentionally NOT stripped — ODE models
# that require method="scipy" must preserve this in replay.
"gamma_dot",
"sigma_init",
"lam_init",
"sigma_0",
"lam_0",
"gamma_0",
"omega_laos",
"n_cycles",
"points_per_cycle",
}
for _k in _batch_strip_keys:
fit_kwargs_replay.pop(_k, None)
# R12-E-003: forward deformation_mode and poisson_ratio from
# the template model so DMTA fits are replayed correctly.
_deformation_mode = getattr(step_obj, "_deformation_mode", None)
if _deformation_mode is not None:
fit_kwargs_replay.setdefault("deformation_mode", _deformation_mode)
_poisson_ratio = getattr(step_obj, "_poisson_ratio", None)
if _poisson_ratio is not None:
fit_kwargs_replay.setdefault("poisson_ratio", _poisson_ratio)
new_model.fit(X, y, **fit_kwargs_replay)
pipeline._last_model = new_model
pipeline.steps.append((step_action, new_model))
logger.debug(
"Replayed fit step",
model=model_cls.__name__,
filepath=str(path),
)
elif step_action == "transform":
# Re-apply the transform to the pipeline's current data.
# SYS-08: use shallow copy for stateless transforms to avoid
# cloning large internal buffers. Stateless transforms expose a
# `stateless` class attribute (or instance attribute) set to True.
# Transforms with fitted state (e.g. Mastercurve shift_factors)
# keep deepcopy to preserve template params across datasets.
try:
transform_cls = type(step_obj)
if getattr(step_obj, "stateless", False):
new_transform = copy.copy(step_obj)
else:
new_transform = copy.deepcopy(step_obj)
transform_result = new_transform.transform(pipeline.data)
# Handle transforms that return (data, extra) tuples
if isinstance(transform_result, tuple):
pipeline.data = transform_result[0]
else:
pipeline.data = transform_result
# Propagate test_mode from data metadata into replay kwargs
# so that a subsequent fit step picks it up correctly.
if pipeline.data is not None and hasattr(pipeline.data, "metadata"):
_tm = (pipeline.data.metadata or {}).get("test_mode")
if _tm is not None and "test_mode" not in fit_kwargs_replay:
fit_kwargs_replay["test_mode"] = _tm
pipeline.steps.append((step_action, new_transform))
logger.debug(
"Replayed transform step",
transform=transform_cls.__name__,
filepath=str(path),
)
except Exception as _te:
# R12-E-006: elevate to ERROR — downstream fit uses unprocessed data
logger.error(
"Transform replay failed; skipping — downstream fit uses unprocessed data",
transform=type(step_obj).__name__,
error=str(_te),
)
metrics["transform_replay_failed"] = True
elif step_action == "fit_bayesian":
# Replay Bayesian inference on the newly fitted model.
if pipeline._last_model is None:
logger.warning(
"Skipping fit_bayesian step — no prior fit available",
filepath=str(path),
)
continue
try:
X = np.asarray(pipeline.data.x)
if pipeline.data.y is None:
raise ValueError(
f"Cannot replay fit_bayesian: data.y is None for {path}"
)
y = np.asarray(pipeline.data.y)
_bayes_kwargs: dict[str, Any] = {}
# Forward test_mode from fit replay
if "test_mode" in fit_kwargs_replay:
_bayes_kwargs["test_mode"] = fit_kwargs_replay["test_mode"]
# Forward deformation_mode/poisson_ratio
_dm = fit_kwargs_replay.get("deformation_mode")
if _dm is not None:
_bayes_kwargs["deformation_mode"] = _dm
_pr = fit_kwargs_replay.get("poisson_ratio")
if _pr is not None:
_bayes_kwargs["poisson_ratio"] = _pr
# Carry Bayesian sampling kwargs from the template model.
# These are stored by Pipeline.fit_bayesian() on the model
# as _last_bayesian_kwargs (separate from _last_fit_kwargs
# which only holds protocol kwargs from NLSQ).
_template_bayes = getattr(step_obj, "_last_bayesian_kwargs", None)
if _template_bayes is not None:
for _bk in (
"num_warmup",
"num_samples",
"num_chains",
"seed",
"target_accept_prob",
):
if _bk in _template_bayes:
_bayes_kwargs.setdefault(_bk, _template_bayes[_bk])
bayes_result = pipeline._last_model.fit_bayesian(
X, y, **_bayes_kwargs
)
pipeline._last_bayesian_result = bayes_result
pipeline.steps.append((step_action, pipeline._last_model))
metrics["bayesian_completed"] = True
logger.debug(
"Replayed fit_bayesian step",
model=type(pipeline._last_model).__name__,
filepath=str(path),
)
except Exception as _be:
logger.error(
"Bayesian replay failed; skipping",
model=type(pipeline._last_model).__name__,
error=str(_be),
)
metrics["bayesian_replay_failed"] = True
elif step_action == "export":
# Replay export step for each processed file.
try:
export_config = step_obj if isinstance(step_obj, dict) else {}
_out_path = export_config.get("output_path", "")
_fmt = export_config.get("format", "directory")
per_file_out = None
if _out_path:
# Create per-file output subdirectory to avoid collisions
per_file_out = Path(_out_path) / path.stem
pipeline.export(
str(per_file_out),
format=_fmt,
)
metrics["export_path"] = str(per_file_out)
logger.debug(
"Replayed export step",
filepath=str(path),
output=str(per_file_out) if _out_path else "(no output path)",
)
except Exception as _ee:
logger.error(
"Export replay failed; skipping",
error=str(_ee),
)
metrics["export_replay_failed"] = True
else:
logger.warning(
"Unknown step action in batch replay; skipping",
step_action=step_action,
filepath=str(path),
)
metrics.setdefault("unknown_steps_skipped", []).append(step_action)
result = pipeline.get_result()
# Compute metrics if model was fitted
if pipeline._last_model is not None:
model = pipeline._last_model
X = np.asarray(result.x)
y = np.asarray(result.y)
with log_fit(
logger,
model=model.__class__.__name__,
data_shape=X.shape,
) as ctx:
metrics["r_squared"] = model.score(X, y)
metrics["parameters"] = model.get_params()
metrics["model"] = model.__class__.__name__
# Calculate RMSE
# R8-PIPE-005: handle complex oscillation data in RMSE
y_pred = model.predict(X)
residuals = np.asarray(y) - np.asarray(y_pred)
metrics["rmse"] = float(np.sqrt(np.mean(np.abs(residuals) ** 2)))
ctx["r_squared"] = metrics["r_squared"]
ctx["rmse"] = metrics["rmse"]
return result, metrics
def _clone_pipeline(self, pipeline: Pipeline) -> Pipeline:
"""Clone pipeline for independent execution.
SYS-07: Lightweight structural clone — calls Pipeline.__init__ (so new
fields added to __init__ are automatically included) then copies only
the history list from the template. _process_file immediately resets
pipeline.steps, pipeline._last_model, and pipeline.data, so only a
valid fresh Pipeline instance with inherited history is needed.
Args:
pipeline: Pipeline to clone
Returns:
New Pipeline instance with a clean state (no data, no model)
"""
clone = Pipeline() # uses __init__ defaults — future-proof
clone.history = list(pipeline.history)
return clone
[docs]
def get_results(self) -> list[tuple[Path, RheoData, dict[str, Any]]]:
"""Get all processing results.
Returns:
List of (file_path, result_data, metrics) tuples
Example:
>>> results = batch.get_results()
>>> for path, data, metrics in results:
... print(f"{path}: R²={metrics.get('r_squared', 0):.4f}")
"""
return self.results.copy()
[docs]
def get_errors(self) -> list[tuple[Path, Exception]]:
"""Get processing errors.
Returns:
List of (file_path, exception) tuples
Example:
>>> errors = batch.get_errors()
>>> for path, error in errors:
... print(f"Error in {path}: {error}")
"""
return self.errors.copy()
[docs]
def get_summary_dataframe(self) -> pd.DataFrame:
"""Get summary DataFrame of all results.
Returns:
DataFrame with file paths and metrics
Example:
>>> df = batch.get_summary_dataframe()
>>> print(df)
"""
if not self.results:
return pd.DataFrame()
summary_data: list[dict[str, Any]] = []
for file_path, result, metrics in self.results:
path_obj = Path(file_path)
row = {
"file_path": str(path_obj),
"file_name": path_obj.name,
"n_points": len(result.x) if result.x is not None else 0,
}
row.update(metrics)
summary_data.append(row)
return pd.DataFrame(summary_data)
[docs]
def export_summary(
self, output_path: str | Path, format: str = "excel"
) -> BatchPipeline:
"""Export summary of batch results.
Args:
output_path: Output file path
format: Output format ('excel', 'csv')
Returns:
self for method chaining
Example:
>>> batch.export_summary('summary.xlsx')
"""
df = self.get_summary_dataframe()
if df.empty:
logger.warning("No results to export")
warnings.warn("No results to export", stacklevel=2)
return self
output_path = Path(output_path)
logger.info(
"Exporting batch summary",
output_path=str(output_path),
format=format,
n_results=len(df),
)
if format == "excel":
df.to_excel(output_path, index=False)
elif format == "csv":
df.to_csv(output_path, index=False)
else:
logger.error("Unknown export format", format=format)
raise ValueError(f"Unknown format: {format}")
logger.debug("Export completed", output_path=str(output_path))
return self
[docs]
def apply_filter(
self, filter_fn: Callable[[Path, RheoData, dict[str, Any]], bool]
) -> BatchPipeline:
"""Filter results based on custom criteria.
Args:
filter_fn: Function that takes (file_path, data, metrics) and
returns True to keep the result
Returns:
self for method chaining
Example:
>>> # Keep only results with R² > 0.9
>>> batch.apply_filter(lambda p, d, m: m.get('r_squared', 0) > 0.9)
"""
original_count = len(self.results)
self.results = [
(path, data, metrics)
for path, data, metrics in self.results
if filter_fn(path, data, metrics)
]
logger.debug(
"Filter applied",
original_count=original_count,
filtered_count=len(self.results),
removed_count=original_count - len(self.results),
)
return self
[docs]
def get_statistics(self) -> dict[str, Any]:
"""Get statistics across all results.
Returns:
Dictionary with summary statistics
Example:
>>> stats = batch.get_statistics()
>>> print(f"Mean R²: {stats['mean_r_squared']:.4f}")
"""
if not self.results:
return {}
# Collect metrics
r_squared_values = []
rmse_values = []
for _, _, metrics in self.results:
if "r_squared" in metrics:
r_squared_values.append(metrics["r_squared"])
if "rmse" in metrics:
rmse_values.append(metrics["rmse"])
stats = {
"total_files": len(self.results),
"total_errors": len(self.errors),
"success_rate": (
len(self.results) / (len(self.results) + len(self.errors))
if (len(self.results) + len(self.errors)) > 0
else 0
),
}
if r_squared_values:
stats.update(
{
"mean_r_squared": float(np.mean(r_squared_values)),
"std_r_squared": float(np.std(r_squared_values)),
"min_r_squared": float(np.min(r_squared_values)),
"max_r_squared": float(np.max(r_squared_values)),
}
)
if rmse_values:
stats.update(
{
"mean_rmse": float(np.mean(rmse_values)),
"std_rmse": float(np.std(rmse_values)),
"min_rmse": float(np.min(rmse_values)),
"max_rmse": float(np.max(rmse_values)),
}
)
return stats
[docs]
def clear(self) -> BatchPipeline:
"""Clear all results and errors.
Returns:
self for method chaining
"""
n_results = len(self.results)
n_errors = len(self.errors)
self.results.clear()
self.errors.clear()
logger.debug(
"BatchPipeline cleared",
cleared_results=n_results,
cleared_errors=n_errors,
)
return self
[docs]
def __len__(self) -> int:
"""Get number of processed results."""
return len(self.results)
[docs]
def __repr__(self) -> str:
"""String representation."""
return (
f"BatchPipeline(results={len(self.results)}, " f"errors={len(self.errors)})"
)
__all__ = ["BatchPipeline"]