"""Base pipeline class for fluent API workflows.
This module provides the core Pipeline class that enables intuitive method
chaining for common rheological analysis workflows.
Example:
>>> from rheojax.pipeline import Pipeline
>>> pipeline = Pipeline()
>>> result = (pipeline
... .load('data.csv')
... .transform('smooth', window_size=5)
... .fit('maxwell')
... .plot()
... .save('result.hdf5')
... .get_result())
"""
from __future__ import annotations
import copy
import uuid
import warnings
from pathlib import Path
from typing import Any
import numpy as np
from rheojax.core.base import BaseModel, BaseTransform
from rheojax.core.data import RheoData
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry, TransformRegistry
from rheojax.logging import get_logger, log_pipeline_stage
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
# Module-level logger
logger = get_logger(__name__)
def _is_jax_array(x: Any) -> bool:
"""Robust check for JAX arrays across JAX versions."""
return hasattr(x, "devices") and not isinstance(x, np.ndarray)
[docs]
class Pipeline:
"""Fluent API for rheological analysis workflows.
This class provides a chainable interface for loading data, applying
transforms, fitting models, and generating outputs. All methods return
self to enable method chaining.
Attributes:
data: Current RheoData state
steps: List of (operation, object) tuples for fitted models
history: List of (operation, details) tuples tracking all operations
_last_model: Last fitted model for convenience
Example:
>>> pipeline = Pipeline()
>>> pipeline.load('data.csv').fit('maxwell').plot()
"""
[docs]
def __init__(self, data: RheoData | None = None):
"""Initialize pipeline.
Args:
data: Optional initial RheoData. If None, must call load() first.
"""
self.data = data
self.steps: list[tuple[str, Any]] = []
self.history: list[tuple[Any, ...]] = []
self._last_model: BaseModel | None = None
self._last_fit_result: Any = None
self._last_bayesian_result: Any = None
self._transform_results: dict[str, tuple[Any, RheoData | None]] = {}
self._last_transform_name: str | None = None
self._current_figure: Any = None
self._diagnostic_results: Any = None
self._last_comparison: Any = None
self._id = str(uuid.uuid4())[:8]
logger.debug(
"Pipeline initialized",
pipeline_id=self._id,
has_initial_data=data is not None,
)
[docs]
def load(
self,
file_path: str | Path,
format: str = "auto",
*,
test_mode: str | None = None,
initial_test_mode: str | None = None,
**kwargs,
) -> Pipeline:
"""Load data from file.
Args:
file_path: Path to data file
format: File format ('auto', 'csv', 'excel', 'trios', 'hdf5')
test_mode: Optional rheological mode metadata to attach to the
resulting RheoData (e.g., 'relaxation', 'creep', 'oscillation')
initial_test_mode: Backwards-compatible alias for test_mode
**kwargs: Additional arguments passed to reader
Returns:
self for method chaining
Raises:
FileNotFoundError: If file doesn't exist
ValueError: If file format not recognized
Example:
>>> pipeline = Pipeline().load('data.csv', x_col='time', y_col='stress')
"""
from rheojax.io import auto_load
path = Path(file_path)
explicit_mode = test_mode if test_mode is not None else initial_test_mode
with log_pipeline_stage(
logger, "load", pipeline_id=self._id, file_path=str(path), format=format
) as ctx:
try:
if format == "auto":
result = auto_load(path, **kwargs)
else:
# Format-specific loading
if format == "csv":
from rheojax.io import load_csv
result = load_csv(path, **kwargs)
elif format == "excel":
from rheojax.io import load_excel
result = load_excel(path, **kwargs)
elif format == "trios":
from rheojax.io import load_trios
result = load_trios(path, **kwargs)
elif format == "hdf5":
from rheojax.io import load_hdf5
result = load_hdf5(path, **kwargs)
elif format == "npz":
from rheojax.io.writers.npz_writer import load_npz
result = load_npz(path)
else:
raise ValueError(f"Unknown format: {format}")
# Handle multiple segments (for TRIOS)
if isinstance(result, list):
if len(result) == 1:
self.data = result[0]
else:
warnings.warn(
f"Loaded {len(result)} segments. Using first segment.",
stacklevel=2,
)
self.data = result[0]
ctx["n_segments"] = len(result)
else:
self.data = result
self._apply_test_mode_metadata(self.data, explicit_mode)
ctx["n_points"] = len(self.data.x) if self.data else 0
ctx["test_mode"] = explicit_mode
except Exception as e:
logger.error(
"Failed to load data",
pipeline_id=self._id,
file_path=str(path),
format=format,
error=str(e),
exc_info=True,
)
raise
self.history.append(("load", str(path), format))
return self
def _apply_test_mode_metadata(
self, data: RheoData | None, mode: str | None
) -> None:
"""Attach explicit test mode information to loaded data."""
if data is None or mode is None:
return
if data.metadata is None:
data.metadata = {}
data.metadata["test_mode"] = mode
data.metadata.setdefault("detected_test_mode", mode)
# Persist explicit annotation for downstream helpers that rely on it
if hasattr(data, "_explicit_test_mode"):
data._explicit_test_mode = mode
[docs]
def fit(
self,
model: str | BaseModel,
method: str = "auto",
**fit_kwargs,
) -> Pipeline:
"""Fit a model to the data.
Args:
model: Model name (string) or Model instance
method: Optimization method passed to model.fit() ('nlsq', 'scipy', 'auto').
Default 'auto' lets the model choose.
**fit_kwargs: Additional arguments passed to optimizer
Returns:
self for method chaining
Raises:
ValueError: If data not loaded or model not found
Example:
>>> pipeline.fit('maxwell')
>>> # or with instance
>>> from rheojax.models.linear import Maxwell
>>> pipeline.fit(Maxwell())
"""
if self.data is None:
raise ValueError("No data loaded. Call load() first.")
# Create model if string
if isinstance(model, str):
model_obj = ModelRegistry.create(model)
model_name = model
else:
model_obj = model
model_name = model_obj.__class__.__name__
logger.debug(
"Creating model for fitting",
pipeline_id=self._id,
model=model_name,
)
# Fit using model's fit method
X = self.data.x
y = self.data.y
# Convert to numpy for fitting — np.asarray is zero-copy for CPU arrays
if _is_jax_array(X):
X = np.asarray(X)
if _is_jax_array(y):
y = np.asarray(y)
with log_pipeline_stage(
logger,
"fit",
pipeline_id=self._id,
model=model_name,
data_shape=X.shape, # type: ignore[union-attr]
) as ctx:
try:
# PB-001: auto-propagate test_mode from loaded data metadata
if hasattr(self, "data") and self.data is not None:
_meta = getattr(self.data, "metadata", None)
if _meta is not None:
if "test_mode" not in fit_kwargs:
_tm = _meta.get("test_mode")
if _tm is not None:
fit_kwargs["test_mode"] = _tm
# R9-PIPE-DMT: propagate deformation_mode for DMTA data
if "deformation_mode" not in fit_kwargs:
_dm = _meta.get("deformation_mode")
if _dm is not None:
fit_kwargs["deformation_mode"] = _dm
if "poisson_ratio" not in fit_kwargs:
_pr = _meta.get("poisson_ratio")
if _pr is not None:
fit_kwargs["poisson_ratio"] = _pr
# R12-E-001: forward method kwarg to model.fit()
fit_kwargs["method"] = method
model_obj.fit(X, y, **fit_kwargs)
self._last_model = model_obj
self._last_fit_result = None # Lazily built by get_fit_result()
self.steps.append(("fit", model_obj))
try:
score = model_obj.score(X, y)
except Exception:
score = float("nan")
ctx["r_squared"] = score
self.history.append(("fit", model_name, score))
except Exception as e:
logger.error(
"Model fitting failed",
pipeline_id=self._id,
model=model_name,
error=str(e),
exc_info=True,
)
raise
return self
[docs]
def predict(
self, model: BaseModel | None = None, X: np.ndarray | None = None
) -> RheoData:
"""Generate predictions from fitted model.
Args:
model: Model to use for prediction. If None, uses last fitted model.
X: Input data for prediction. If None, uses current data.x.
Returns:
RheoData with predictions
Raises:
ValueError: If no model has been fitted
Example:
>>> predictions = pipeline.predict()
"""
if model is None:
model = self._last_model
if model is None:
raise ValueError("No model fitted. Call fit() first.")
if X is None:
if self.data is None:
raise ValueError("No data available for prediction.")
X = self.data.x
# Convert to numpy for prediction — np.asarray is zero-copy for CPU arrays
if _is_jax_array(X):
X = np.asarray(X)
logger.debug(
"Generating predictions",
pipeline_id=self._id,
model=model.__class__.__name__,
n_points=len(X),
)
predictions = model.predict(X)
return RheoData(
x=X,
y=predictions,
x_units=self.data.x_units if self.data else None,
y_units=self.data.y_units if self.data else None,
domain=self.data.domain if self.data else "time",
metadata={
**(
self.data.metadata
if (self.data and self.data.metadata is not None)
else {}
),
"type": "prediction",
"model": model.__class__.__name__,
},
validate=False,
)
[docs]
def plot(
self,
show: bool = True,
style: str = "default",
include_prediction: bool = False,
**plot_kwargs,
) -> Pipeline:
"""Plot current data state.
Args:
show: Whether to call plt.show()
style: Plot style ('default', 'publication', 'presentation')
include_prediction: If True and model fitted, overlay predictions
**plot_kwargs: Additional arguments passed to plotting function
Returns:
self for method chaining
Example:
>>> pipeline.plot(style='publication')
"""
if self.data is None:
raise ValueError("No data loaded. Call load() first.")
with log_pipeline_stage(
logger,
"plot",
pipeline_id=self._id,
style=style,
include_prediction=include_prediction,
) as ctx:
from rheojax.visualization.plotter import plot_rheo_data
fig, ax = plot_rheo_data(self.data, style=style, **plot_kwargs)
# Optionally overlay predictions
if include_prediction and self._last_model is not None:
predictions = self.predict()
import matplotlib.pyplot as plt
# Get the axes (handle both single and multiple axes)
if isinstance(ax, np.ndarray):
ax_plot = ax[0]
else:
ax_plot = ax
ax_plot.plot(
predictions.x,
predictions.y,
"--",
label="Model Prediction",
linewidth=2,
)
ax_plot.legend()
ctx["prediction_overlay"] = True
if show:
import matplotlib.pyplot as plt
plt.show()
# Store figure for save_figure() method
self._current_figure = fig
self.history.append(("plot", style))
return self
[docs]
def save(self, file_path: str | Path, format: str = "hdf5", **kwargs) -> Pipeline:
"""Save current data to file.
Args:
file_path: Output file path
format: Output format ('hdf5', 'excel', 'csv')
**kwargs: Additional arguments passed to writer
Returns:
self for method chaining
Example:
>>> pipeline.save('output.hdf5')
"""
if self.data is None:
raise ValueError("No data to save. Call load() first.")
path = Path(file_path)
# R12-E-007: include fitted model parameters in data metadata before saving
if self.steps:
_last_fit_steps = [s for s in self.steps if s[0] in ("fit", "fit_nlsq")]
if _last_fit_steps:
_fit_model = _last_fit_steps[-1][1]
if hasattr(_fit_model, "parameters"):
if self.data.metadata is None:
self.data.metadata = {}
for _pname in _fit_model.parameters.keys():
try:
self.data.metadata[f"fitted_{_pname}"] = float(
_fit_model.parameters.get_value(_pname)
)
except (TypeError, ValueError):
pass
self.data.metadata["fitted_model"] = type(_fit_model).__name__
with log_pipeline_stage(
logger,
"save",
pipeline_id=self._id,
file_path=str(path),
format=format,
) as ctx:
try:
if format == "hdf5":
from rheojax.io import save_hdf5
save_hdf5(self.data, path, **kwargs)
elif format == "excel":
from rheojax.io import save_excel
# R13-PIPE-XLS-001: The "parameters" key should contain
# actual model parameters (name → value), not metadata
# labels. Data metadata (units, domain) goes into a
# separate "fit_quality" dict for the Fit Quality sheet.
parameters: dict[str, Any] = {}
_last_fit_steps = [
s for s in self.steps if s[0] in ("fit", "fit_nlsq")
]
if _last_fit_steps:
_fit_model = _last_fit_steps[-1][1]
if hasattr(_fit_model, "parameters"):
for _pname in _fit_model.parameters.keys():
try:
parameters[_pname] = float(
_fit_model.parameters.get_value(_pname)
)
except (TypeError, ValueError):
pass
parameters["model"] = type(_fit_model).__name__
fit_quality: dict[str, Any] = {}
if self.data.x_units:
fit_quality["x_units"] = self.data.x_units
if self.data.y_units:
fit_quality["y_units"] = self.data.y_units
if self.data.domain:
fit_quality["domain"] = self.data.domain
excel_payload: dict[str, Any] = {
"x": np.array(self.data.x),
"predictions": np.array(self.data.y),
}
if parameters:
excel_payload["parameters"] = parameters
if fit_quality:
excel_payload["fit_quality"] = fit_quality
# R13-PIPE-XLS-002: propagate deformation_mode for
# correct E'/G' column labels in the Predictions sheet.
if self.data.metadata:
_dm = self.data.metadata.get("deformation_mode")
if _dm is not None:
excel_payload["deformation_mode"] = _dm
save_excel(excel_payload, path, **kwargs)
elif format == "csv":
# R13-PIPE-CSV-001: Handle complex and 2D y arrays
# in CSV export. Complex y is split into real/imag
# columns; 2D y is split into numbered columns.
import pandas as pd
x_arr = np.array(self.data.x)
y_arr = np.array(self.data.y)
if np.iscomplexobj(y_arr):
df = pd.DataFrame(
{
"x": x_arr,
"y_real": np.real(y_arr),
"y_imag": np.imag(y_arr),
}
)
elif y_arr.ndim == 2:
cols: dict[str, Any] = {"x": x_arr}
for ci in range(y_arr.shape[1]):
cols[f"y_{ci}"] = y_arr[:, ci]
df = pd.DataFrame(cols)
else:
df = pd.DataFrame({"x": x_arr, "y": y_arr})
df.to_csv(path, index=False, **kwargs)
else:
raise ValueError(f"Unknown format: {format}")
ctx["n_points"] = len(self.data.x)
except Exception as e:
logger.error(
"Failed to save data",
pipeline_id=self._id,
file_path=str(path),
format=format,
error=str(e),
exc_info=True,
)
raise
self.history.append(("save", str(path), format))
return self
[docs]
def fit_bayesian(
self,
model: str | BaseModel | None = None,
seed: int | None = None,
**bayesian_kwargs,
) -> Pipeline:
"""Run Bayesian (NUTS) inference on current data.
Uses the last fitted model (or a new one) with NLSQ warm-start.
Args:
model: Model name, instance, or None to reuse last fitted model.
seed: Random seed for reproducibility (default: 0).
**bayesian_kwargs: Arguments forwarded to model.fit_bayesian()
(num_warmup, num_samples, num_chains, target_accept_prob, etc.)
Returns:
self for method chaining
Example:
>>> pipeline.fit('maxwell').fit_bayesian(seed=42, num_warmup=1000)
"""
if self.data is None:
raise ValueError("No data loaded. Call load() first.")
# Resolve model
if model is not None:
if isinstance(model, str):
model_obj = ModelRegistry.create(model)
else:
model_obj = model
elif self._last_model is not None:
model_obj = self._last_model
else:
raise ValueError("No model available. Call fit() first or provide a model.")
X = self.data.x
y = self.data.y
if _is_jax_array(X):
X = np.asarray(X)
if _is_jax_array(y):
y = np.asarray(y)
# PIPE-WARM-001: strip pipeline-level `warm_start` kwarg — it must not
# be forwarded to model.fit_bayesian(), which passes **nuts_kwargs
# straight to NUTS(). Passing warm_start=True to NUTS causes TypeError.
#
# When warm_start is False, build initial_values from bounds midpoints
# so the sampler starts from prior-like values rather than NLSQ estimates.
# We avoid model_obj.__class__() which crashes for models with required
# constructor args (GeneralizedMaxwell(n_modes), STZ(variant), etc.).
use_warm_start = bayesian_kwargs.pop("warm_start", True)
if not use_warm_start:
import math # noqa: PLC0415
midpoint_values: dict[str, float] = {}
for name in model_obj.parameters.keys():
param = model_obj.parameters[name]
lo, hi = param.bounds
if lo is not None and hi is not None and lo > 0 and hi > 0:
midpoint_values[name] = math.sqrt(lo * hi)
elif lo is not None and hi is not None:
midpoint_values[name] = (lo + hi) / 2.0
else:
midpoint_values[name] = 1.0
bayesian_kwargs.setdefault("initial_values", midpoint_values)
# Auto-propagate metadata
# Use explicit `is not None` guards — truthy check swallows falsy-but-valid
# values such as test_mode="" or deformation_mode="shear" (empty string).
_meta = getattr(self.data, "metadata", None) or {}
if "test_mode" not in bayesian_kwargs:
_tm = _meta.get("test_mode")
if _tm is not None:
bayesian_kwargs["test_mode"] = _tm
if "deformation_mode" not in bayesian_kwargs:
_dm = _meta.get("deformation_mode")
if _dm is not None:
bayesian_kwargs["deformation_mode"] = _dm
if "poisson_ratio" not in bayesian_kwargs:
_pr = _meta.get("poisson_ratio")
if _pr is not None:
bayesian_kwargs["poisson_ratio"] = _pr
if seed is not None:
bayesian_kwargs["seed"] = seed
with log_pipeline_stage(
logger,
"fit_bayesian",
pipeline_id=self._id,
model=model_obj.__class__.__name__,
) as ctx:
try:
result = model_obj.fit_bayesian(X, y, **bayesian_kwargs)
self._last_bayesian_result = result
self._last_model = model_obj
# Store sampling kwargs on the model so BatchPipeline can
# replay with the same configuration. _last_fit_kwargs only
# contains protocol kwargs from NLSQ — Bayesian sampling
# params (num_warmup, num_samples, num_chains, seed) are
# consumed by NumPyro and never stored there.
_sampling_keys = {
"num_warmup",
"num_samples",
"num_chains",
"seed",
"target_accept_prob",
}
model_obj._last_bayesian_kwargs = {
k: v for k, v in bayesian_kwargs.items() if k in _sampling_keys
}
self.steps.append(("fit_bayesian", model_obj))
self.history.append(("fit_bayesian", model_obj.__class__.__name__))
ctx["num_samples"] = getattr(result, "num_samples", None)
ctx["num_chains"] = getattr(result, "num_chains", None)
except Exception as e:
logger.error(
"Bayesian fitting failed",
pipeline_id=self._id,
error=str(e),
exc_info=True,
)
raise
return self
[docs]
def plot_fit(
self,
confidence: float = 0.95,
show_residuals: bool = True,
show_uncertainty: bool = True,
show: bool = True,
style: str = "default",
**kwargs,
) -> Pipeline:
"""Plot NLSQ fit with uncertainty band and residuals.
Requires a prior call to fit(). Uses FitPlotter internally.
Args:
confidence: Confidence level for uncertainty band (default: 0.95).
show_residuals: If True, add residuals subplot.
show_uncertainty: If True and covariance available, show band.
show: Whether to call plt.show() (default: True).
style: Plot style ('default', 'publication', 'presentation').
**kwargs: Additional arguments forwarded to FitPlotter.plot_nlsq().
Returns:
self for method chaining
Example:
>>> pipeline.fit('maxwell').plot_fit(confidence=0.95)
"""
if self._last_model is None:
raise ValueError("No model fitted. Call fit() first.")
if self.data is None:
raise ValueError("No data loaded. Call load() first.")
from rheojax.visualization.fit_plotter import FitPlotter
fit_result = self.get_fit_result()
plotter = FitPlotter()
# np.asarray is zero-copy for CPU-backed arrays (JAX or numpy)
X = np.asarray(self.data.x)
y = np.asarray(self.data.y)
# Forward deformation_mode from metadata
_meta = getattr(self.data, "metadata", None) or {}
if "deformation_mode" not in kwargs:
dm = _meta.get("deformation_mode")
if dm is not None:
kwargs["deformation_mode"] = dm
if "test_mode" not in kwargs:
tm = _meta.get("test_mode")
if tm is not None:
kwargs["test_mode"] = tm
fig, axes = plotter.plot_nlsq(
X,
y,
fit_result,
self._last_model,
confidence=confidence,
show_residuals=show_residuals,
show_uncertainty=show_uncertainty,
style=style,
**kwargs,
)
self._current_figure = fig
if show:
import matplotlib.pyplot as plt
plt.show()
self.history.append(("plot_fit", style))
return self
[docs]
def plot_bayesian(
self,
credible_level: float = 0.95,
max_draws: int = 500,
show_nlsq_overlay: bool = False,
show_residuals: bool = False,
show: bool = True,
style: str = "default",
**kwargs,
) -> Pipeline:
"""Plot Bayesian posterior predictive with credible interval.
Requires a prior call to fit_bayesian().
Args:
credible_level: Credible interval level (default: 0.95).
max_draws: Maximum posterior draws for band computation.
show_nlsq_overlay: If True, overlay NLSQ fit for comparison.
show_residuals: If True, add residuals subplot.
show: Whether to call plt.show() (default: True).
style: Plot style.
**kwargs: Additional arguments forwarded to FitPlotter.plot_bayesian().
Returns:
self for method chaining
Example:
>>> pipeline.fit('maxwell').fit_bayesian(seed=42).plot_bayesian()
"""
if self._last_bayesian_result is None:
raise ValueError("No Bayesian result available. Call fit_bayesian() first.")
if self._last_model is None:
raise ValueError("No model available.")
if self.data is None:
raise ValueError("No data loaded.")
from rheojax.visualization.fit_plotter import FitPlotter
plotter = FitPlotter()
# np.asarray is zero-copy for CPU-backed arrays (JAX or numpy)
X = np.asarray(self.data.x)
y = np.asarray(self.data.y)
# Forward metadata
_meta = getattr(self.data, "metadata", None) or {}
if "deformation_mode" not in kwargs:
dm = _meta.get("deformation_mode")
if dm is not None:
kwargs["deformation_mode"] = dm
if "test_mode" not in kwargs:
tm = _meta.get("test_mode")
if tm is not None:
kwargs["test_mode"] = tm
fit_result = None
if show_nlsq_overlay:
try:
fit_result = self.get_fit_result()
except ValueError:
pass
fig, axes = plotter.plot_bayesian(
X,
y,
self._last_bayesian_result,
self._last_model,
credible_level=credible_level,
max_draws=max_draws,
show_nlsq_overlay=show_nlsq_overlay,
fit_result=fit_result,
show_residuals=show_residuals,
style=style,
**kwargs,
)
self._current_figure = fig
if show:
import matplotlib.pyplot as plt
plt.show()
self.history.append(("plot_bayesian", style))
return self
[docs]
def plot_diagnostics(
self,
output_dir: str | Path | None = None,
style: str = "default",
prefix: str = "mcmc",
formats: tuple[str, ...] = ("pdf", "png"),
dpi: int = 300,
**kwargs,
) -> Pipeline:
"""Generate ArviZ MCMC diagnostic suite (6 plots).
Requires a prior call to fit_bayesian().
Args:
output_dir: Directory for saving plots. If None, displays only.
style: Plot style.
prefix: Filename prefix for saved plots.
formats: Output formats (default: ('pdf', 'png')).
dpi: Resolution for raster formats.
**kwargs: Additional arguments forwarded to generate_diagnostic_suite().
Returns:
self for method chaining
Example:
>>> pipeline.fit_bayesian(seed=42).plot_diagnostics(output_dir='./diag')
"""
if self._last_bayesian_result is None:
raise ValueError("No Bayesian result available. Call fit_bayesian() first.")
from rheojax.visualization.fit_plotter import generate_diagnostic_suite
result = generate_diagnostic_suite(
self._last_bayesian_result,
style=style,
output_dir=output_dir,
prefix=prefix,
formats=formats,
dpi=dpi,
**kwargs,
)
self._diagnostic_results = result
# Expose the first diagnostic figure for save_figure() chaining.
# generate_diagnostic_suite returns dict[str, Figure | Path].
if isinstance(result, dict):
for fig_or_path in result.values():
if hasattr(fig_or_path, "savefig"):
self._current_figure = fig_or_path
break
self.history.append(("plot_diagnostics", str(output_dir)))
return self
[docs]
def get_result(self) -> RheoData:
"""Get current data state.
Returns:
Current RheoData
Example:
>>> data = pipeline.get_result()
"""
if self.data is None:
raise ValueError("No data available. Call load() first.")
return self.data
[docs]
def get_history(self) -> list[tuple[Any, ...]]:
"""Get pipeline execution history.
Returns:
List of (operation, details) tuples
Example:
>>> history = pipeline.get_history()
>>> for step in history:
... print(step)
"""
return self.history.copy()
[docs]
def get_last_model(self) -> BaseModel | None:
"""Get the last fitted model.
Returns:
Last fitted BaseModel or None
Example:
>>> model = pipeline.get_last_model()
>>> params = model.get_params()
"""
return self._last_model
[docs]
def get_all_models(self) -> list[BaseModel]:
"""Get all fitted models from pipeline.
Returns:
List of all fitted models
Example:
>>> models = pipeline.get_all_models()
"""
return [step[1] for step in self.steps if step[0] in ("fit", "fit_nlsq")]
[docs]
def get_fitted_parameters(self) -> dict[str, float]:
"""Get fitted parameters from the last model as a dictionary.
This is a convenience method that extracts parameter values from
the last fitted model's ParameterSet.
Returns:
Dictionary mapping parameter names to their fitted values
Raises:
ValueError: If no model has been fitted yet
Example:
>>> pipeline = Pipeline()
>>> pipeline.load('data.csv').fit('maxwell')
>>> params = pipeline.get_fitted_parameters()
>>> print(params) # {'G0': 100000.0, 'eta': 1000.0}
>>> G0 = params['G0']
"""
if self._last_model is None:
raise ValueError("No model fitted. Call fit() first.")
# Extract all parameter values from the model's ParameterSet
return {
name: self._last_model.parameters.get_value(name)
for name in self._last_model.parameters.keys()
}
[docs]
def compare_models(
self,
models: list[str | BaseModel],
criterion: str = "aic",
**fit_kwargs,
) -> Pipeline:
"""Compare multiple models on the current data.
Fits each model and ranks by information criterion. The best model
becomes ``_last_model`` and is appended to ``steps``.
Args:
models: List of model names (strings) or BaseModel instances.
criterion: Ranking criterion ('aic', 'aicc', 'bic').
**fit_kwargs: Extra kwargs forwarded to each ``model.fit()`` call.
Returns:
self for method chaining
Raises:
ValueError: If no data is loaded.
Example:
>>> pipeline.load('data.csv').compare_models(['maxwell', 'zener'])
"""
if self.data is None:
raise ValueError("No data loaded. Call load() first.")
from rheojax.utils.model_selection import compare_models as _compare
X = self.data.x
y = self.data.y
if _is_jax_array(X):
X = np.asarray(X)
if _is_jax_array(y):
y = np.asarray(y)
# Auto-propagate metadata
# Use explicit `is not None` guards — truthy check swallows falsy-but-valid
# values such as test_mode="" or deformation_mode="shear".
_meta = getattr(self.data, "metadata", None) or {}
if "test_mode" not in fit_kwargs:
_tm = _meta.get("test_mode")
if _tm is not None:
fit_kwargs["test_mode"] = _tm
if "deformation_mode" not in fit_kwargs:
_dm = _meta.get("deformation_mode")
if _dm is not None:
fit_kwargs["deformation_mode"] = _dm
if "poisson_ratio" not in fit_kwargs:
_pr = _meta.get("poisson_ratio")
if _pr is not None:
fit_kwargs["poisson_ratio"] = _pr
test_mode = fit_kwargs.pop("test_mode", None)
comparison = _compare(
X,
y,
models=models,
test_mode=test_mode,
criterion=criterion,
**fit_kwargs,
)
self._last_comparison = comparison
self.history.append(("compare_models", comparison.best_model, criterion))
# Set the best model as _last_model if available — reuse the
# already-fitted instance from compare_models() instead of re-fitting.
if comparison.results:
best_fr = next(
(
r
for r in comparison.results
if r.model_name == comparison.best_model
),
None,
)
fitted_model = getattr(best_fr, "_fitted_model", None) if best_fr else None
if fitted_model is not None:
self._last_model = fitted_model
self.steps.append(("compare_models", fitted_model))
else:
logger.warning(
"Best model FitResult has no attached fitted model",
model=comparison.best_model,
)
return self
[docs]
def get_fit_result(self) -> Any:
"""Construct a FitResult from the last fitted model.
Returns:
FitResult with model metadata, fitted parameters, and statistics.
Raises:
ValueError: If no model has been fitted.
Example:
>>> result = pipeline.load('data.csv').fit('maxwell').get_fit_result()
>>> print(result.summary())
"""
if self._last_model is None:
raise ValueError("No model fitted. Call fit() first.")
from rheojax.utils.model_selection import build_fit_result
X = self.data.x if self.data is not None else None
y = self.data.y if self.data is not None else None
test_mode = None
if self.data is not None:
_meta = getattr(self.data, "metadata", None) or {}
test_mode = _meta.get("test_mode")
return build_fit_result(
self._last_model,
X,
y,
test_mode=test_mode,
)
[docs]
def clone(self) -> Pipeline:
"""Create a copy of the pipeline.
Returns:
New Pipeline with copied data and history
Example:
>>> pipeline2 = pipeline.clone()
"""
new_pipeline = Pipeline(data=self.data.copy() if self.data else None)
new_pipeline.steps = copy.deepcopy(self.steps)
new_pipeline.history = self.history.copy()
new_pipeline._last_model = (
copy.deepcopy(self._last_model) if self._last_model is not None else None
)
logger.debug(
"Pipeline cloned",
original_id=self._id,
new_id=new_pipeline._id,
)
return new_pipeline
[docs]
def reset(self) -> Pipeline:
"""Reset pipeline to initial state.
Returns:
self for method chaining
Example:
>>> pipeline.reset()
"""
logger.debug("Pipeline reset", pipeline_id=self._id)
self.data = None
self.steps = []
self.history = []
self._last_model = None
self._last_fit_result = None
self._last_bayesian_result = None
self._transform_results = {}
self._last_transform_name = None
self._current_figure = None
self._diagnostic_results = None
self._last_comparison = None
return self
[docs]
def export(
self,
output: str | Path,
format: str = "auto",
*,
include_data: bool = True,
include_figures: bool = True,
include_diagnostics: bool = True,
figure_formats: tuple[str, ...] = ("pdf", "png"),
figure_dpi: int = 300,
**kwargs,
) -> Pipeline:
"""Export the full analysis to a directory or file.
This bundles data, parameters, statistics, figures, transform results,
and Bayesian diagnostics into a single export.
Args:
output: Output path. If a directory (no extension or trailing /),
exports as structured directory. If .xlsx, exports Excel.
format: Export format ('auto', 'directory', 'excel').
'auto' infers from the output path extension.
include_data: Save raw and transformed data files.
include_figures: Save generated matplotlib figures.
include_diagnostics: Save MCMC diagnostic plots.
figure_formats: Formats for figure files (default: ('pdf', 'png')).
figure_dpi: Resolution for raster figures (default: 300).
**kwargs: Additional arguments forwarded to the exporter.
Returns:
self for method chaining
Example:
>>> pipeline.load('data.csv').fit('maxwell').plot_fit().export('./results')
>>> pipeline.export('report.xlsx')
"""
from rheojax.io.analysis_exporter import AnalysisExporter
output_path = Path(output)
exporter = AnalysisExporter(
figure_formats=figure_formats,
figure_dpi=figure_dpi,
)
# Determine format
if format == "auto":
if output_path.suffix.lower() == ".xlsx":
format = "excel"
else:
format = "directory"
with log_pipeline_stage(
logger,
"export",
pipeline_id=self._id,
output=str(output_path),
format=format,
) as ctx:
try:
if format == "directory":
exporter.export_directory(
self,
output_path,
include_data=include_data,
include_figures=include_figures,
include_diagnostics=include_diagnostics,
**kwargs,
)
elif format == "excel":
exporter.export_excel(
self,
output_path,
include_plots=include_figures,
**kwargs,
)
else:
raise ValueError(
f"Unknown export format: {format}. Use 'directory' or 'excel'."
)
ctx["format"] = format
except Exception as e:
logger.error(
"Export failed",
pipeline_id=self._id,
output=str(output_path),
error=str(e),
exc_info=True,
)
raise
self.steps.append(
("export", {"output_path": str(output_path), "format": format})
)
self.history.append(("export", str(output_path), format))
return self
[docs]
def __repr__(self) -> str:
"""String representation of pipeline."""
n_steps = len(self.history)
has_data = self.data is not None
has_model = self._last_model is not None
return f"Pipeline(steps={n_steps}, has_data={has_data}, has_model={has_model})"
__all__ = ["Pipeline"]