Pipeline API

This page documents the high-level Pipeline API for rheological analysis workflows.

Overview

The Pipeline API provides a fluent interface for chaining operations from data loading through model fitting and export. It’s designed for rapid analysis with minimal boilerplate code.

Core Components:

  1. Pipeline: Base fluent API with method chaining

  2. Specialized Workflows: Pre-configured pipelines for common tasks

  3. PipelineBuilder: Programmatic pipeline construction

  4. BatchPipeline: Process multiple datasets

Basic Pipeline

class rheojax.pipeline.Pipeline(data=None)[source]

Bases: object

Fluent API for rheological analysis workflows.

This class provides a chainable interface for loading data, applying transforms, fitting models, and generating outputs. All methods return self to enable method chaining.

data

Current RheoData state

steps

List of (operation, object) tuples for fitted models

history

List of (operation, details) tuples tracking all operations

_last_model

Last fitted model for convenience

Example

>>> pipeline = Pipeline()
>>> pipeline.load('data.csv').fit('maxwell').plot()
__init__(data=None)[source]

Initialize pipeline.

Parameters:

data (RheoData | None) – Optional initial RheoData. If None, must call load() first.

load(file_path, format='auto', *, test_mode=None, initial_test_mode=None, **kwargs)[source]

Load data from file.

Parameters:
  • file_path (str | Path) – Path to data file

  • format (str) – File format (‘auto’, ‘csv’, ‘excel’, ‘trios’, ‘hdf5’)

  • test_mode (str | None) – Optional rheological mode metadata to attach to the resulting RheoData (e.g., ‘relaxation’, ‘creep’, ‘oscillation’)

  • initial_test_mode (str | None) – Backwards-compatible alias for test_mode

  • **kwargs – Additional arguments passed to reader

Return type:

Pipeline

Returns:

self for method chaining

Raises:

Example

>>> pipeline = Pipeline().load('data.csv', x_col='time', y_col='stress')
transform(transform, **kwargs)[source]

Apply a transform to the data.

Parameters:
  • transform (str | BaseTransform) – Transform name (string) or Transform instance

  • **kwargs – Arguments passed to transform constructor (if string)

Return type:

Pipeline

Returns:

self for method chaining

Raises:

ValueError – If data not loaded or transform not found

Example

>>> pipeline.transform('smooth', window_size=5)
>>> # or with instance
>>> from rheojax.transforms import SmoothTransform
>>> pipeline.transform(SmoothTransform(window_size=5))
fit(model, method='auto', **fit_kwargs)[source]

Fit a model to the data.

Parameters:
  • model (str | BaseModel) – Model name (string) or Model instance

  • method (str) – Optimization method passed to model.fit() (‘nlsq’, ‘scipy’, ‘auto’). Default ‘auto’ lets the model choose.

  • **fit_kwargs – Additional arguments passed to optimizer

Return type:

Pipeline

Returns:

self for method chaining

Raises:

ValueError – If data not loaded or model not found

Example

>>> pipeline.fit('maxwell')
>>> # or with instance
>>> from rheojax.models.linear import Maxwell
>>> pipeline.fit(Maxwell())
predict(model=None, X=None)[source]

Generate predictions from fitted model.

Parameters:
  • model (BaseModel | None) – Model to use for prediction. If None, uses last fitted model.

  • X (ndarray | None) – Input data for prediction. If None, uses current data.x.

Return type:

RheoData

Returns:

RheoData with predictions

Raises:

ValueError – If no model has been fitted

Example

>>> predictions = pipeline.predict()
plot(show=True, style='default', include_prediction=False, **plot_kwargs)[source]

Plot current data state.

Parameters:
  • show (bool) – Whether to call plt.show()

  • style (str) – Plot style (‘default’, ‘publication’, ‘presentation’)

  • include_prediction (bool) – If True and model fitted, overlay predictions

  • **plot_kwargs – Additional arguments passed to plotting function

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.plot(style='publication')
save(file_path, format='hdf5', **kwargs)[source]

Save current data to file.

Parameters:
  • file_path (str | Path) – Output file path

  • format (str) – Output format (‘hdf5’, ‘excel’, ‘csv’)

  • **kwargs – Additional arguments passed to writer

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.save('output.hdf5')
save_figure(filepath, format=None, dpi=300, **kwargs)[source]

Save the most recent plot to file.

Convenience method for exporting plots with publication-quality defaults. Wraps rheojax.visualization.plotter.save_figure() to enable fluent API chaining.

Parameters:
  • filepath (str | Path) – Output file path. Format inferred from extension if not specified.

  • format (str | None) – Output format (‘pdf’, ‘svg’, ‘png’, ‘eps’). If None, inferred from filepath.

  • dpi (int) – Resolution for raster formats (PNG).

  • **kwargs (Any) – Additional arguments passed to save_figure(). See rheojax.visualization.plotter.save_figure() for details.

Returns:

self – Returns self to enable method chaining

Return type:

Pipeline

Raises:
  • ValueError – If no plot exists (plot() not called yet)

  • ValueError – If format cannot be inferred or is unsupported

  • OSError – If filepath directory doesn’t exist

Examples

Basic usage with method chaining:

>>> pipeline = Pipeline()
>>> pipeline.load('data.csv').fit('maxwell').plot().save_figure('result.pdf')

Save multiple formats:

>>> pipeline.plot(style='publication')
>>> pipeline.save_figure('figure.pdf')
>>> pipeline.save_figure('figure.png', dpi=600)
>>> pipeline.save_figure('figure.svg', transparent=True)

Explicit format:

>>> pipeline.plot().save_figure('output', format='pdf')

See also

plot

Generate plot with automatic type selection

rheojax.visualization.plotter.save_figure

Core export function

Notes

This method saves the most recent plot generated by plot(). If you call plot() multiple times, only the last figure is saved. To save multiple plots, call save_figure() after each plot() call.

The figure is stored internally by plot() and retrieved by save_figure().

fit_bayesian(model=None, seed=None, **bayesian_kwargs)[source]

Run Bayesian (NUTS) inference on current data.

Uses the last fitted model (or a new one) with NLSQ warm-start.

Parameters:
  • model (str | BaseModel | None) – Model name, instance, or None to reuse last fitted model.

  • seed (int | None) – Random seed for reproducibility (default: 0).

  • **bayesian_kwargs – Arguments forwarded to model.fit_bayesian() (num_warmup, num_samples, num_chains, target_accept_prob, etc.)

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.fit('maxwell').fit_bayesian(seed=42, num_warmup=1000)
plot_fit(confidence=0.95, show_residuals=True, show_uncertainty=True, show=True, style='default', **kwargs)[source]

Plot NLSQ fit with uncertainty band and residuals.

Requires a prior call to fit(). Uses FitPlotter internally.

Parameters:
  • confidence (float) – Confidence level for uncertainty band (default: 0.95).

  • show_residuals (bool) – If True, add residuals subplot.

  • show_uncertainty (bool) – If True and covariance available, show band.

  • show (bool) – Whether to call plt.show() (default: True).

  • style (str) – Plot style (‘default’, ‘publication’, ‘presentation’).

  • **kwargs – Additional arguments forwarded to FitPlotter.plot_nlsq().

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.fit('maxwell').plot_fit(confidence=0.95)
plot_bayesian(credible_level=0.95, max_draws=500, show_nlsq_overlay=False, show_residuals=False, show=True, style='default', **kwargs)[source]

Plot Bayesian posterior predictive with credible interval.

Requires a prior call to fit_bayesian().

Parameters:
  • credible_level (float) – Credible interval level (default: 0.95).

  • max_draws (int) – Maximum posterior draws for band computation.

  • show_nlsq_overlay (bool) – If True, overlay NLSQ fit for comparison.

  • show_residuals (bool) – If True, add residuals subplot.

  • show (bool) – Whether to call plt.show() (default: True).

  • style (str) – Plot style.

  • **kwargs – Additional arguments forwarded to FitPlotter.plot_bayesian().

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.fit('maxwell').fit_bayesian(seed=42).plot_bayesian()
plot_diagnostics(output_dir=None, style='default', prefix='mcmc', formats=('pdf', 'png'), dpi=300, **kwargs)[source]

Generate ArviZ MCMC diagnostic suite (6 plots).

Requires a prior call to fit_bayesian().

Parameters:
  • output_dir (str | Path | None) – Directory for saving plots. If None, displays only.

  • style (str) – Plot style.

  • prefix (str) – Filename prefix for saved plots.

  • formats (tuple[str, ...]) – Output formats (default: (‘pdf’, ‘png’)).

  • dpi (int) – Resolution for raster formats.

  • **kwargs – Additional arguments forwarded to generate_diagnostic_suite().

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.fit_bayesian(seed=42).plot_diagnostics(output_dir='./diag')
plot_transform(transform_name=None, show_intermediate=True, show=True, style='default', **kwargs)[source]

Plot the result of a previously applied transform.

Uses TransformPlotter for per-transform layout dispatch.

Parameters:
  • transform_name (str | None) – Name of the transform to plot. If None, uses the most recently applied transform.

  • show_intermediate (bool) – Whether to show before/after comparison.

  • show (bool) – Whether to call plt.show() (default: True).

  • style (str) – Plot style.

  • **kwargs – Additional arguments forwarded to TransformPlotter.

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.transform('mastercurve', reference_temp=25.0).plot_transform()
get_result()[source]

Get current data state.

Return type:

RheoData

Returns:

Current RheoData

Example

>>> data = pipeline.get_result()
get_history()[source]

Get pipeline execution history.

Return type:

list[tuple[Any, ...]]

Returns:

List of (operation, details) tuples

Example

>>> history = pipeline.get_history()
>>> for step in history:
...     print(step)
get_last_model()[source]

Get the last fitted model.

Return type:

BaseModel | None

Returns:

Last fitted BaseModel or None

Example

>>> model = pipeline.get_last_model()
>>> params = model.get_params()
get_all_models()[source]

Get all fitted models from pipeline.

Return type:

list[BaseModel]

Returns:

List of all fitted models

Example

>>> models = pipeline.get_all_models()
get_fitted_parameters()[source]

Get fitted parameters from the last model as a dictionary.

This is a convenience method that extracts parameter values from the last fitted model’s ParameterSet.

Return type:

dict[str, float]

Returns:

Dictionary mapping parameter names to their fitted values

Raises:

ValueError – If no model has been fitted yet

Example

>>> pipeline = Pipeline()
>>> pipeline.load('data.csv').fit('maxwell')
>>> params = pipeline.get_fitted_parameters()
>>> print(params)  # {'G0': 100000.0, 'eta': 1000.0}
>>> G0 = params['G0']
compare_models(models, criterion='aic', **fit_kwargs)[source]

Compare multiple models on the current data.

Fits each model and ranks by information criterion. The best model becomes _last_model and is appended to steps.

Parameters:
  • models (list[str | BaseModel]) – List of model names (strings) or BaseModel instances.

  • criterion (str) – Ranking criterion (‘aic’, ‘aicc’, ‘bic’).

  • **fit_kwargs – Extra kwargs forwarded to each model.fit() call.

Return type:

Pipeline

Returns:

self for method chaining

Raises:

ValueError – If no data is loaded.

Example

>>> pipeline.load('data.csv').compare_models(['maxwell', 'zener'])
get_fit_result()[source]

Construct a FitResult from the last fitted model.

Return type:

Any

Returns:

FitResult with model metadata, fitted parameters, and statistics.

Raises:

ValueError – If no model has been fitted.

Example

>>> result = pipeline.load('data.csv').fit('maxwell').get_fit_result()
>>> print(result.summary())
clone()[source]

Create a copy of the pipeline.

Return type:

Pipeline

Returns:

New Pipeline with copied data and history

Example

>>> pipeline2 = pipeline.clone()
reset()[source]

Reset pipeline to initial state.

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.reset()
export(output, format='auto', *, include_data=True, include_figures=True, include_diagnostics=True, figure_formats=('pdf', 'png'), figure_dpi=300, **kwargs)[source]

Export the full analysis to a directory or file.

This bundles data, parameters, statistics, figures, transform results, and Bayesian diagnostics into a single export.

Parameters:
  • output (str | Path) – Output path. If a directory (no extension or trailing /), exports as structured directory. If .xlsx, exports Excel.

  • format (str) – Export format (‘auto’, ‘directory’, ‘excel’). ‘auto’ infers from the output path extension.

  • include_data (bool) – Save raw and transformed data files.

  • include_figures (bool) – Save generated matplotlib figures.

  • include_diagnostics (bool) – Save MCMC diagnostic plots.

  • figure_formats (tuple[str, ...]) – Formats for figure files (default: (‘pdf’, ‘png’)).

  • figure_dpi (int) – Resolution for raster figures (default: 300).

  • **kwargs – Additional arguments forwarded to the exporter.

Return type:

Pipeline

Returns:

self for method chaining

Example

>>> pipeline.load('data.csv').fit('maxwell').plot_fit().export('./results')
>>> pipeline.export('report.xlsx')
__repr__()[source]

String representation of pipeline.

Return type:

str

Description: Core pipeline class providing fluent method chaining for rheological analysis workflows.

Example - Basic Usage:

from rheojax.pipeline import Pipeline

# Create pipeline and chain operations
results = (Pipeline()
    .load('data.txt')                   # Load data
    .transform('smooth', window=11)     # Smooth noisy data
    .fit('maxwell')                     # Fit model
    .plot(show=True)                    # Visualize
    .get_results())                     # Retrieve results

print(f"R^2 = {results['r2']:.4f}")
print(f"Parameters: {results['parameters']}")

Key Methods:

load(source, format='auto', **kwargs)

Load data from file or RheoData object.

Parameters:

  • source (str or RheoData): File path or data object

  • format (str): File format - ‘auto’, ‘trios’, ‘csv’, ‘excel’

  • **kwargs: Format-specific arguments

Returns: self (for chaining)

Example:

# Auto-detect format
pipeline = Pipeline().load('data.txt')

# Explicit format
pipeline = Pipeline().load('data.csv', format='csv',
                            x_col='frequency', y_col='modulus')

# From RheoData object
from rheojax.core import RheoData
data = RheoData(x=freq, y=modulus, ...)
pipeline = Pipeline().load(data)

transform(name, **params)

Apply data transform.

Parameters:

  • name (str): Transform name - ‘smooth’, ‘fft’, ‘mastercurve’, etc.

  • **params: Transform-specific parameters

Returns: self (for chaining)

Example:

# Single transform
pipeline = (Pipeline()
    .load('data.txt')
    .transform('smooth', method='savgol', window=11))

# Multiple transforms (chained)
pipeline = (Pipeline()
    .load('data.txt')
    .transform('smooth', window=11)
    .transform('fft', window='hann'))

fit(model, initial_params=None, bounds=None, **kwargs)

Fit rheological model to data.

Parameters:

  • model (str or BaseModel): Model name or instance

  • initial_params (dict, optional): Initial parameter values

  • bounds (dict, optional): Parameter bounds

  • **kwargs: Optimization options

Returns: self (for chaining)

Example:

# By name
pipeline = Pipeline().load('data.txt').fit('maxwell')

# With initial parameters
pipeline = (Pipeline()
    .load('data.txt')
    .fit('maxwell',
         initial_params={'G_s': 1e5, 'eta_s': 1e3},
         bounds={'G_s': (1e3, 1e7), 'eta_s': (1e1, 1e5)}))

# Multiple models (comparison)
pipeline = (Pipeline()
    .load('data.txt')
    .fit(['maxwell', 'zener', 'springpot']))

plot(show=False, save=None, style='default', **kwargs)

Visualize data and model fit.

Parameters:

  • show (bool): Display interactive plot

  • save (str, optional): Save to file

  • style (str): Plot style - ‘default’, ‘publication’, ‘presentation’

  • **kwargs: Plotting options

Returns: self (for chaining)

Example:

# Show plot
pipeline.plot(show=True)

# Save to file
pipeline.plot(save='fit_result.png', dpi=300)

# Custom style
pipeline.plot(show=True, style='publication',
              include_residuals=True, title='Maxwell Fit')

save(filepath, format='hdf5', **kwargs)

Export results to file.

Parameters:

  • filepath (str): Output file path

  • format (str): Format - ‘hdf5’, ‘excel’, ‘csv’

  • **kwargs: Format-specific options

Returns: self (for chaining)

Example:

# HDF5 (full fidelity)
pipeline.save('results.hdf5')

# Excel report
pipeline.save('report.xlsx', format='excel', include_plots=True)

get_results()

Retrieve analysis results as dictionary.

Returns: dict with keys:

  • 'parameters': Fitted parameter values

  • 'r2': R^2 score

  • 'rmse': Root mean squared error

  • 'predictions': Model predictions

  • 'residuals': Fit residuals

  • 'data': Original RheoData object

  • 'model': Fitted model instance

Example:

results = pipeline.get_results()

print(f"R^2 = {results['r2']:.4f}")
print(f"RMSE = {results['rmse']:.2e}")
print(f"Parameters:")
for name, value in results['parameters'].items():
    print(f"  {name} = {value:.4e}")

Specialized Workflows

MastercurvePipeline

class rheojax.pipeline.MastercurvePipeline(reference_temp=298.15)[source]

Bases: Pipeline

Pipeline for time-temperature superposition analysis.

This pipeline automates the construction of mastercurves from multi-temperature rheological data using horizontal shift factors.

reference_temp

Reference temperature for mastercurve

shift_factors

Dictionary of temperature -> shift factor

Example

>>> pipeline = MastercurvePipeline(reference_temp=298.15)
>>> pipeline.run(file_paths, temperatures)
>>> mastercurve = pipeline.get_result()
__init__(reference_temp=298.15)[source]

Initialize mastercurve pipeline.

Parameters:

reference_temp (float) – Reference temperature in Kelvin (default: 298.15 K)

run(file_paths, temperatures, format='auto', parallel_io=True, **load_kwargs)[source]

Execute mastercurve workflow.

Parameters:
  • file_paths (list[str]) – List of data file paths (one per temperature)

  • temperatures (list[float]) – List of temperatures (in Kelvin)

  • format (str) – File format for loading

  • parallel_io (bool) – Whether to load files in parallel (default True)

  • **load_kwargs – Additional arguments passed to load (e.g., x_col, y_col)

Return type:

MastercurvePipeline

Returns:

self for method chaining

Raises:

ValueError – If file_paths and temperatures have different lengths

get_shift_factors()[source]

Get computed shift factors.

Return type:

dict[float, float]

Returns:

Dictionary mapping temperature to shift factor

Description: Pre-configured pipeline for time-temperature superposition analysis.

Example:

from rheojax.pipeline import MastercurvePipeline

# Create mastercurve pipeline
mc_pipeline = MastercurvePipeline(
    reference_temp=50,      # Reference temperature ( degC)
    method='wlf',           # 'wlf' or 'arrhenius'
    optimize=True           # Optimize WLF/Arrhenius parameters
)

# Load and process multi-temperature data
files = ['data_25C.txt', 'data_40C.txt', 'data_55C.txt', 'data_70C.txt']
temperatures = [25, 40, 55, 70]

results = mc_pipeline.run(files, temperatures)

# Access mastercurve results
mastercurve = results['mastercurve']
shift_factors = results['shift_factors']
wlf_params = results['wlf_parameters']

print(f"WLF C1 = {wlf_params['C1']:.2f}")
print(f"WLF C2 = {wlf_params['C2']:.2f} K")

# Fit model to mastercurve
mc_pipeline.fit('fractional_maxwell_gel')
mc_pipeline.plot(show=True, style='publication')

Key Methods:

  • run(files, temperatures): Create mastercurve from files

  • fit(model): Fit model to mastercurve

  • get_shift_factors(): Get temperature shift factors

  • get_wlf_parameters(): Get fitted WLF C1, C2

ModelComparisonPipeline

class rheojax.pipeline.ModelComparisonPipeline(models)[source]

Bases: Pipeline

Pipeline for comparing multiple models on the same data.

This pipeline fits multiple models to the same dataset and computes comparison metrics (RMSE, R², AIC, etc.).

models

List of model names to compare

results

Dictionary of model_name -> metrics

Example

>>> pipeline = ModelComparisonPipeline(['maxwell', 'zener', 'springpot'])
>>> pipeline.run(data)
>>> best = pipeline.get_best_model()
>>> print(pipeline.get_comparison_table())
__init__(models)[source]

Initialize model comparison pipeline.

Parameters:

models (list[str]) – List of model names to compare

run(data, parallel=False, n_workers=None, **fit_kwargs)[source]

Fit multiple models and compare.

Parameters:
  • data (RheoData) – RheoData to fit

  • parallel (bool) – Whether to fit models in parallel subprocesses. Each model gets its own process with independent JIT cache.

  • n_workers (int | None) – Number of parallel workers (default: auto)

  • **fit_kwargs – Additional arguments passed to fit

Return type:

ModelComparisonPipeline

Returns:

self for method chaining

get_best_model(metric='rmse', minimize=True)[source]

Return name of best-fitting model.

Parameters:
  • metric (str) – Metric to use for comparison (‘rmse’, ‘r_squared’, ‘aic’, ‘bic’)

  • minimize (bool) – If True, lower values are better (e.g., RMSE, AIC, BIC)

Return type:

str

Returns:

Name of best model

Example

>>> best = pipeline.get_best_model(metric='aic')
get_comparison_table()[source]

Get comparison table of all models.

Return type:

dict[str, dict[str, float]]

Returns:

Dictionary of model_name -> metrics

Example

>>> table = pipeline.get_comparison_table()
>>> for model, metrics in table.items():
...     print(f"{model}: R²={metrics['r_squared']:.4f}")
get_model_result(model_name)[source]

Get detailed results for a specific model.

Parameters:

model_name (str) – Name of the model

Return type:

dict[str, Any]

Returns:

Dictionary with model, parameters, and metrics

Example

>>> result = pipeline.get_model_result('maxwell')
>>> params = result['parameters']

Description: Systematically compare multiple models on the same dataset.

Example:

from rheojax.pipeline import ModelComparisonPipeline

# Models to compare
models = ['maxwell', 'zener', 'fractional_maxwell_gel',
          'fractional_kelvin_voigt', 'springpot']

# Create comparison pipeline
comparison = ModelComparisonPipeline(models)

# Load data and run comparison
comparison.load('data.txt')
comparison.run()

# Get comparison table
results = comparison.get_results()
comparison_table = results['comparison']

# Print comparison
print("\\nModel Comparison:")
print(f"{'Model':<30} {'R^2':<10} {'RMSE':<12} {'AIC':<12}")
print("-" * 70)
for row in comparison_table:
    print(f"{row['model']:<30} {row['r2']:<10.4f} "
          f"{row['rmse']:<12.2e} {row['aic']:<12.1f}")

# Get best model
best = comparison.get_best_model(criterion='aic')  # 'aic', 'bic', 'r2'
print(f"\\nBest model (AIC): {best['name']}")

# Visualize comparison
comparison.plot_comparison(show=True)
comparison.plot_ranking(criterion='aic', show=True)

Key Methods:

  • run(): Fit all models

  • get_best_model(criterion): Select best by AIC, BIC, or R^2

  • plot_comparison(): Multi-panel plot of all models

  • plot_ranking(): Bar chart ranking by criterion

CreepToRelaxationPipeline

class rheojax.pipeline.CreepToRelaxationPipeline(data=None)[source]

Bases: Pipeline

Convert creep compliance data to relaxation modulus.

This pipeline performs the numerical conversion from J(t) to G(t) using regularized numerical inversion techniques.

Example

>>> pipeline = CreepToRelaxationPipeline()
>>> pipeline.run(creep_data)
>>> relaxation_data = pipeline.get_result()
run(creep_data, method='approximate')[source]

Execute conversion workflow.

Parameters:
  • creep_data (RheoData) – RheoData with creep compliance J(t)

  • method (str) – Conversion method (‘approximate’, ‘exact’)

Return type:

CreepToRelaxationPipeline

Returns:

self for method chaining

Raises:

ValueError – If input is not creep data

Description: Convert creep compliance J(t) to relaxation modulus G(t).

Example:

from rheojax.pipeline import CreepToRelaxationPipeline

converter = CreepToRelaxationPipeline(
    method='integration',    # 'integration' or 'approximate'
    regularization=0.01      # Regularization parameter
)

converter.load('creep_data.txt')
relaxation_data = converter.convert()

# Fit model to relaxation data
converter.fit('maxwell')
converter.plot(show=True)

FrequencyToTimePipeline

class rheojax.pipeline.FrequencyToTimePipeline(data=None)[source]

Bases: Pipeline

Convert frequency domain data to time domain.

This pipeline converts dynamic modulus G*(ω) to relaxation modulus G(t) using Fourier transform techniques.

Example

>>> pipeline = FrequencyToTimePipeline()
>>> pipeline.run(frequency_data)
>>> time_data = pipeline.get_result()
run(frequency_data, time_range=None, n_points=100)[source]

Execute frequency to time conversion.

Parameters:
  • frequency_data (RheoData) – RheoData in frequency domain

  • time_range (tuple | None) – Optional (t_min, t_max) for time range

  • n_points (int) – Number of time points to generate

Return type:

FrequencyToTimePipeline

Returns:

self for method chaining

Description: Convert frequency-domain data to time-domain via inverse FFT.

Example:

from rheojax.pipeline import FrequencyToTimePipeline

ft_pipeline = FrequencyToTimePipeline(
    method='inverse_fft',     # 'inverse_fft' or 'analytical'
    time_range=(1e-3, 1e3),   # Time range (s)
    n_points=200              # Number of time points
)

ft_pipeline.load('frequency_sweep.txt')
time_data = ft_pipeline.convert()
ft_pipeline.plot(show=True)

BayesianPipeline

class rheojax.pipeline.bayesian.BayesianPipeline(data=None)[source]

Bases: Pipeline

Specialized pipeline for Bayesian rheological analysis workflows.

This class extends the base Pipeline to provide a fluent API for the NLSQ → NumPyro NUTS workflow. It supports: - NLSQ optimization for fast point estimation - Bayesian inference with automatic warm-start from NLSQ - Convergence diagnostics (R-hat, ESS, divergences) - Posterior visualization (distributions and trace plots)

All methods return self to enable method chaining.

data

Current RheoData state (inherited from Pipeline)

_last_model

Last fitted model (inherited from Pipeline)

_nlsq_result

Stored NLSQ optimization result

_bayesian_result

Stored Bayesian inference result

_diagnostics

Stored convergence diagnostics

Example

>>> pipeline = BayesianPipeline()
>>> pipeline.load('data.csv') \
...     .fit_nlsq('maxwell') \
...     .fit_bayesian(num_samples=2000) \
...     .plot_posterior() \
...     .save('results.hdf5')
__init__(data=None)[source]

Initialize Bayesian pipeline.

Parameters:

data – Optional initial RheoData. If None, must call load() first.

fit_nlsq(model, **nlsq_kwargs)[source]

Fit model using NLSQ optimization for point estimation.

This method performs fast GPU-accelerated nonlinear least squares optimization to obtain point estimates of model parameters. The optimization result is stored for potential warm-starting of Bayesian inference.

Parameters:
  • model (str | BaseModel) – Model name (string) or Model instance to fit

  • **nlsq_kwargs – Additional arguments passed to NLSQ optimizer (e.g., max_iter, ftol, xtol, gtol)

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

ValueError – If data not loaded

Note

This method writes resolved deformation_mode, poisson_ratio, and test_mode back to self.data.metadata so that a subsequent fit_bayesian() call inherits these settings without the caller having to repeat them.

Example

>>> pipeline.fit_nlsq('maxwell')
>>> # or with instance
>>> from rheojax.models import Maxwell
>>> pipeline.fit_nlsq(Maxwell(), max_iter=1000)
fit_bayesian(num_samples=2000, num_warmup=1000, num_chains=4, **nuts_kwargs)[source]

Perform Bayesian inference using NumPyro NUTS sampler.

This method runs NUTS (No-U-Turn Sampler) for Bayesian parameter estimation. If a model has been previously fitted with fit_nlsq(), the NLSQ point estimates are automatically used for warm-starting the sampler, leading to faster convergence.

Multi-chain sampling is enabled by default (num_chains=4) to provide reliable convergence diagnostics (R-hat, ESS) and parallel execution on multi-GPU systems.

Parameters:
  • num_samples (int) – Number of posterior samples per chain (default: 2000)

  • num_warmup (int) – Number of warmup/burn-in iterations (default: 1000)

  • num_chains (int) – Number of MCMC chains (default: 4). Multiple chains enable proper R-hat computation and parallel execution. Chain method is auto-selected: ‘parallel’ on multi-GPU, ‘vectorized’ on single GPU/CPU.

  • **nuts_kwargs – Additional arguments passed to NUTS sampler (e.g., target_accept_prob, chain_method)

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

ValueError – If no model has been fitted with fit_nlsq()

Example

>>> pipeline.fit_nlsq('maxwell').fit_bayesian(num_samples=2000)
>>> # With custom NUTS parameters
>>> pipeline.fit_bayesian(
...     num_samples=3000,
...     num_warmup=1500,
...     num_chains=4,
...     target_accept_prob=0.9
... )
get_diagnostics()[source]

Get convergence diagnostics from Bayesian inference.

Returns diagnostics including R-hat (Gelman-Rubin statistic), effective sample size (ESS), and number of divergent transitions.

Returns:

  • r_hat: R-hat for each parameter (dict)

  • ess: Effective sample size for each parameter (dict)

  • divergences: Number of divergent transitions (int)

Return type:

dict[str, Any]

Raises:

ValueError – If Bayesian inference has not been run

Example

>>> diagnostics = pipeline.get_diagnostics()
>>> print(f"R-hat: {diagnostics['r_hat']}")
>>> print(f"ESS: {diagnostics['ess']}")
>>> print(f"Divergences: {diagnostics['divergences']}")
get_posterior_summary()[source]

Get formatted posterior summary statistics.

Returns a pandas DataFrame with summary statistics for each parameter including mean, standard deviation, median, and quantiles (5%, 25%, 75%, 95%).

Return type:

DataFrame

Returns:

DataFrame with parameters as rows and statistics as columns

Raises:

ValueError – If Bayesian inference has not been run

Example

>>> summary = pipeline.get_posterior_summary()
>>> print(summary)
         mean       std    median       q05       q25       q75       q95
a    5.123   0.245     5.110     4.721     4.962     5.285     5.531
b    0.487   0.032     0.485     0.435     0.465     0.509     0.542
plot_posterior(param_name=None, show=True, **plot_kwargs)[source]

Plot posterior distributions.

Generates histogram plots of posterior distributions for model parameters. If param_name is None, plots all parameters in separate subplots.

Parameters:
  • param_name (str | None) – Name of specific parameter to plot. If None, plots all parameters (default: None)

  • show (bool) – Whether to call plt.show() (default: True)

  • **plot_kwargs – Additional arguments passed to matplotlib (e.g., bins, alpha, color)

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

ValueError – If Bayesian inference has not been run

Example

>>> # Plot all parameters
>>> pipeline.plot_posterior()
>>> # Plot specific parameter
>>> pipeline.plot_posterior('a', bins=50, alpha=0.7)
>>> # Plot without showing (for save_figure)
>>> pipeline.plot_posterior(show=False).save_figure('posterior.pdf')
plot_trace(param_name=None, show=True, **plot_kwargs)[source]

Plot MCMC trace plots.

Generates trace plots showing parameter values across MCMC iterations. Useful for diagnosing convergence issues. If param_name is None, plots all parameters.

Parameters:
  • param_name (str | None) – Name of specific parameter to plot. If None, plots all parameters (default: None)

  • show (bool) – Whether to call plt.show() (default: True)

  • **plot_kwargs – Additional arguments passed to matplotlib (e.g., alpha, linewidth)

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

ValueError – If Bayesian inference has not been run

Example

>>> # Plot all trace plots
>>> pipeline.plot_trace()
>>> # Plot specific parameter
>>> pipeline.plot_trace('a', alpha=0.5)
>>> # Plot without showing (for save_figure)
>>> pipeline.plot_trace(show=False).save_figure('trace.pdf')
plot_pair(var_names=None, kind='scatter', divergences=True, show=True, **plot_kwargs)[source]

Plot pairwise relationships between parameters (pair plot).

Creates a matrix of scatter or KDE plots showing correlations between parameters. This is critical for identifying parameter dependencies, non-identifiability issues, and understanding the joint posterior structure. Divergent transitions are highlighted by default to identify problematic posterior geometry.

Parameters:
  • var_names (list[str] | None) – List of parameter names to plot. If None, plots all parameters (default: None)

  • kind (str) – Type of pair plot - “scatter”, “kde”, or “hexbin” (default: “scatter”)

  • divergences (bool) – Whether to highlight divergent transitions in red (default: True). Useful for identifying problematic regions.

  • show (bool) – Whether to call plt.show() (default: True)

  • **plot_kwargs – Additional arguments passed to arviz.plot_pair() (e.g., marginals, point_estimate_marker_style)

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

Example

>>> # Plot all parameters with divergences highlighted
>>> pipeline.plot_pair()
>>>
>>> # Plot specific parameters as KDE
>>> pipeline.plot_pair(var_names=["G0", "eta"], kind="kde")
>>>
>>> # Save without showing
>>> pipeline.plot_pair(show=False).save_figure("pair.pdf")

Note

Pair plots are essential for diagnosing: - Parameter correlations (indicates non-identifiability) - Funnel geometry (divergences concentrated in specific regions) - Multimodal posteriors (multiple clusters)

plot_forest(var_names=None, combined=True, hdi_prob=0.95, show=True, **plot_kwargs)[source]

Plot forest plot with credible intervals for parameters.

Creates a forest plot showing parameter estimates with credible intervals (highest density intervals). Excellent for comparing parameter magnitudes and uncertainties at a glance. Each parameter is shown as a point estimate with error bars representing the credible interval.

Parameters:
  • var_names (list[str] | None) – List of parameter names to plot. If None, plots all parameters (default: None)

  • combined (bool) – Whether to combine multiple chains (default: True)

  • hdi_prob (float) – Probability mass for credible interval (default: 0.95). Common values: 0.68 (1σ), 0.95 (2σ), 0.997 (3σ)

  • show (bool) – Whether to call plt.show() (default: True)

  • **plot_kwargs – Additional arguments passed to arviz.plot_forest() (e.g., rope, ref_val, colors)

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

Example

>>> # Plot all parameters with 95% CI
>>> pipeline.plot_forest()
>>>
>>> # Plot specific parameters with 68% CI
>>> pipeline.plot_forest(var_names=["G0", "eta"], hdi_prob=0.68)
>>>
>>> # Save without showing
>>> pipeline.plot_forest(show=False).save_figure("forest.pdf")

Note

Forest plots are useful for: - Quickly comparing parameter magnitudes - Assessing parameter uncertainty - Identifying parameters with poor estimation (wide intervals)

plot_energy(show=True, **plot_kwargs)[source]

Plot NUTS energy diagnostic plot.

Creates an energy plot showing the distribution of energy transitions during NUTS sampling. This is a NUTS-specific diagnostic that helps identify problematic posterior geometry such as heavy tails, funnels, or multimodal distributions. Energy transitions that differ between the marginal and transition distributions indicate sampling problems.

Parameters:
  • show (bool) – Whether to call plt.show() (default: True)

  • **plot_kwargs – Additional arguments passed to arviz.plot_energy()

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

Example

>>> # Plot energy diagnostic
>>> pipeline.plot_energy()
>>>
>>> # Save without showing
>>> pipeline.plot_energy(show=False).save_figure("energy.pdf")

Note

Energy diagnostics help identify: - Heavy-tailed posteriors (energy dist has fat tails) - Funnel geometry (energy varies dramatically) - Problematic parameterizations Good NUTS sampling shows similar marginal and transition energy distributions.

plot_autocorr(var_names=None, combined=False, show=True, **plot_kwargs)[source]

Plot autocorrelation diagnostic for MCMC mixing.

Creates autocorrelation plots showing how correlated consecutive samples are in the MCMC chain. High autocorrelation indicates poor mixing and suggests more samples are needed for reliable inference. Ideally, autocorrelation should decay quickly to zero.

Parameters:
  • var_names (list[str] | None) – List of parameter names to plot. If None, plots all parameters (default: None)

  • combined (bool) – Whether to combine multiple chains (default: False)

  • show (bool) – Whether to call plt.show() (default: True)

  • **plot_kwargs – Additional arguments passed to arviz.plot_autocorr() (e.g., max_lag)

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

Example

>>> # Plot autocorrelation for all parameters
>>> pipeline.plot_autocorr()
>>>
>>> # Plot specific parameters with longer lag
>>> pipeline.plot_autocorr(var_names=["G0"], max_lag=100)
>>>
>>> # Save without showing
>>> pipeline.plot_autocorr(show=False).save_figure("autocorr.pdf")

Note

Autocorrelation diagnostics help identify: - Poor mixing (high autocorrelation persists) - Need for more samples (ESS will be low) - Chain length adequacy Goal: autocorrelation drops to ~0 within a few dozen lags.

plot_rank(var_names=None, show=True, **plot_kwargs)[source]

Plot rank plot for convergence diagnostics.

Creates rank plots (also called rank histograms or rank-normalization plots) which are a modern alternative to trace plots for diagnosing convergence. A uniform rank distribution across chains indicates good mixing and convergence. Non-uniformity suggests convergence problems.

Parameters:
  • var_names (list[str] | None) – List of parameter names to plot. If None, plots all parameters (default: None)

  • show (bool) – Whether to call plt.show() (default: True)

  • **plot_kwargs – Additional arguments passed to arviz.plot_rank()

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

Example

>>> # Plot rank diagnostic for all parameters
>>> pipeline.plot_rank()
>>>
>>> # Plot specific parameters
>>> pipeline.plot_rank(var_names=["G0", "eta"])
>>>
>>> # Save without showing
>>> pipeline.plot_rank(show=False).save_figure("rank.pdf")

Note

Rank plots help identify: - Non-convergence (non-uniform rank distribution) - Chain sticking (vertical bands) - Insufficient mixing (patterns in ranks) Goal: uniform histogram across all bins.

plot_ess(var_names=None, kind='local', show=True, **plot_kwargs)[source]

Plot effective sample size (ESS) diagnostic.

Creates a plot showing the effective sample size for each parameter, which quantifies how many independent samples the MCMC chain is equivalent to. Low ESS indicates high autocorrelation and suggests more samples are needed. ESS values should ideally be > 400.

Parameters:
  • var_names (list[str] | None) – List of parameter names to plot. If None, plots all parameters (default: None)

  • kind (str) – Type of ESS plot - “local”, “quantile”, or “evolution” (default: “local”)

  • show (bool) – Whether to call plt.show() (default: True)

  • **plot_kwargs – Additional arguments passed to arviz.plot_ess() (e.g., min_ess)

Return type:

BayesianPipeline

Returns:

self for method chaining

Raises:

Example

>>> # Plot ESS for all parameters
>>> pipeline.plot_ess()
>>>
>>> # Plot quantile ESS
>>> pipeline.plot_ess(kind="quantile")
>>>
>>> # Save without showing
>>> pipeline.plot_ess(show=False).save_figure("ess.pdf")

Note

ESS diagnostics help assess: - Sampling efficiency (ESS / total samples) - Which parameters need more sampling - Overall chain quality Goal: ESS > 400 for bulk and tail estimates.

reset()[source]

Reset pipeline to initial state.

Clears all data, models, and results including NLSQ and Bayesian inference results.

Return type:

BayesianPipeline

Returns:

self for method chaining

Example

>>> pipeline.reset()
__repr__()[source]

String representation of Bayesian pipeline.

Return type:

str

Description: Specialized pipeline for Bayesian rheological analysis with NLSQ -> NUTS workflow.

Key Features:

  • NLSQ optimization for fast point estimation

  • Automatic warm-start for NumPyro NUTS sampling

  • Comprehensive ArviZ diagnostics (6 plot types)

  • Fluent API for method chaining

  • Convergence monitoring (R-hat, ESS, divergences)

Example - Complete Bayesian Workflow:

from rheojax.pipeline.bayesian import BayesianPipeline

# Create and execute pipeline
pipeline = (BayesianPipeline()
    .load('data.csv', x_col='time', y_col='stress')
    .fit_nlsq('maxwell')                    # Fast point estimate
    .fit_bayesian(num_samples=2000,         # NUTS with warm-start
                  num_warmup=1000)
    .plot_posterior()                       # Posterior distributions
    .plot_trace()                           # MCMC trace plots
    .save('results.hdf5'))                  # Export results

# Access results
summary = pipeline.get_posterior_summary()
diagnostics = pipeline.get_diagnostics()
intervals = pipeline.get_credible_intervals()

Example - ArviZ Diagnostic Suite:

# Comprehensive MCMC quality assessment
(pipeline
    .plot_pair(divergences=True)        # Parameter correlations with divergences
    .plot_forest(hdi_prob=0.95)         # Credible intervals comparison
    .plot_energy()                       # NUTS energy diagnostic
    .plot_autocorr()                     # Mixing diagnostic
    .plot_rank()                         # Convergence diagnostic
    .plot_ess(kind='local'))            # Effective sample size

# Convert to ArviZ InferenceData for advanced analysis
idata = pipeline._bayesian_result.to_inference_data()
import arviz as az
az.summary(idata)

Key Methods:

  • fit_nlsq(model_name, **kwargs): NLSQ optimization for point estimation

  • fit_bayesian(num_samples, num_warmup, **kwargs): NumPyro NUTS sampling with warm-start

  • plot_posterior(**kwargs): Plot posterior distributions

  • plot_trace(**kwargs): Plot MCMC trace diagnostics

  • plot_pair(**kwargs): Plot parameter correlations (ArviZ)

  • plot_forest(**kwargs): Plot credible intervals (ArviZ)

  • plot_energy(**kwargs): Plot NUTS energy diagnostic (ArviZ)

  • plot_autocorr(**kwargs): Plot autocorrelation (ArviZ)

  • plot_rank(**kwargs): Plot rank diagnostic (ArviZ)

  • plot_ess(**kwargs): Plot effective sample size (ArviZ)

  • get_posterior_summary(): Get posterior summary statistics

  • get_diagnostics(): Get convergence diagnostics (R-hat, ESS)

  • get_credible_intervals(credibility=0.95): Get credible intervals

Pipeline Builder

class rheojax.pipeline.PipelineBuilder[source]

Bases: object

Build and validate pipelines programmatically.

This class provides a fluent API for constructing pipelines with validation of step order and dependencies.

Example

>>> builder = PipelineBuilder()
>>> builder.add_load_step('data.csv')
>>> builder.add_fit_step('maxwell')
>>> pipeline = builder.build()
__init__()[source]

Initialize pipeline builder.

add_load_step(file_path, format='auto', **kwargs)[source]

Add data loading step.

Parameters:
  • file_path (str | Path) – Path to data file

  • format (str) – File format (‘auto’, ‘csv’, ‘excel’, etc.)

  • **kwargs – Additional arguments for loader

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.add_load_step('data.csv', x_col='time', y_col='stress')
add_transform_step(transform_name, **kwargs)[source]

Add transform step.

Parameters:
  • transform_name (str) – Name of transform to apply

  • **kwargs – Arguments for transform constructor

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.add_transform_step('smooth', window_size=5)
add_fit_step(model_name, method='auto', use_jax=True, **kwargs)[source]

Add model fitting step.

Parameters:
  • model_name (str) – Name of model to fit

  • method (str) – Optimization method

  • use_jax (bool) – Whether to use JAX gradients

  • **kwargs – Additional fit arguments

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.add_fit_step('maxwell')
add_predict_step(store_as=None, **kwargs)[source]

Add prediction step.

Parameters:
  • store_as (str | None) – Optional name to store prediction

  • **kwargs – Additional prediction arguments

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.add_predict_step(store_as='prediction')
add_plot_step(show=False, style='default', **kwargs)[source]

Add plotting step.

Parameters:
  • show (bool) – Whether to display plot

  • style (str) – Plot style

  • **kwargs – Additional plot arguments

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.add_plot_step(style='publication', show=True)
add_bayesian_step(num_warmup=1000, num_samples=2000, num_chains=4, seed=0, warm_start=True, **kwargs)[source]

Add Bayesian inference step (NUTS sampling).

Parameters:
  • num_warmup (int) – Number of warmup iterations per chain

  • num_samples (int) – Number of posterior samples per chain

  • num_chains (int) – Number of MCMC chains

  • seed (int) – Random seed for reproducibility

  • warm_start (bool) – Whether to use NLSQ results as initial values

  • **kwargs – Additional arguments for fit_bayesian()

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.add_bayesian_step(num_warmup=500, num_samples=1000)
add_export_step(output_path, format='auto', **kwargs)[source]

Add analysis export step.

Parameters:
  • output_path (str | Path) – Output directory or file path

  • format (str) – Export format (‘directory’, ‘excel’, ‘hdf5’, ‘auto’)

  • **kwargs – Additional arguments for Pipeline.export()

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.add_export_step('./results', format='directory')
add_save_step(file_path, format='hdf5', **kwargs)[source]

Add data saving step.

Parameters:
  • file_path (str | Path) – Output file path

  • format (str) – Output format

  • **kwargs – Additional save arguments

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.add_save_step('output.hdf5')
build(validate=True)[source]

Build and optionally validate pipeline.

Parameters:

validate (bool) – Whether to validate pipeline structure

Return type:

Pipeline

Returns:

Constructed Pipeline instance

Raises:

ValueError – If validation fails

Example

>>> pipeline = builder.build()
clear()[source]

Clear all steps.

Return type:

PipelineBuilder

Returns:

self for method chaining

Example

>>> builder.clear()
get_steps()[source]

Get current pipeline steps.

Return type:

list[tuple[str, dict[str, Any]]]

Returns:

List of (step_type, kwargs) tuples

Example

>>> steps = builder.get_steps()
__len__()[source]

Get number of steps.

Return type:

int

__repr__()[source]

String representation.

Return type:

str

Description: Programmatic pipeline construction for complex custom workflows.

Example - Basic Builder:

from rheojax.pipeline import PipelineBuilder

# Build custom pipeline
builder = PipelineBuilder()

builder.add_load_step('data.txt', format='auto')
builder.add_transform_step('smooth', method='savgol', window=11)
builder.add_transform_step('fft', window='hann')
builder.add_fit_step('maxwell', initial_params={'G_s': 1e5})
builder.add_plot_step(show=False, save='result.png')
builder.add_save_step('result.hdf5')

# Build and execute
pipeline = builder.build()
results = pipeline.execute()

Example - Conditional Logic:

builder = PipelineBuilder()

builder.add_load_step('data.txt')

# Conditional transform
builder.add_conditional_step(
    condition=lambda state: state['data'].metadata.get('noisy', False),
    true_step=('transform', {'name': 'smooth', 'window': 11}),
    false_step=None  # Skip if not noisy
)

builder.add_fit_step('maxwell')

pipeline = builder.build()
results = pipeline.execute()

Key Methods:

  • add_load_step(source, **kwargs): Add data loading step

  • add_transform_step(name, **params): Add transform step

  • add_fit_step(model, **kwargs): Add model fitting step

  • add_plot_step(**kwargs): Add visualization step

  • add_save_step(filepath, **kwargs): Add export step

  • add_conditional_step(condition, true_step, false_step): Add conditional logic

  • build(): Build pipeline

  • execute(): Execute built pipeline

Batch Processing

class rheojax.pipeline.BatchPipeline(template_pipeline=None)[source]

Bases: object

Apply pipeline to multiple datasets.

This class enables batch processing of multiple data files with the same pipeline configuration, collecting results for analysis.

template_pipeline

Template Pipeline to apply to each dataset

results

List of (file_path, result, metrics) tuples

Example

>>> template = Pipeline().fit('maxwell')
>>> batch = BatchPipeline(template)
>>> batch.process_files(['data1.csv', 'data2.csv'])
__init__(template_pipeline=None)[source]

Initialize batch pipeline.

Parameters:

template_pipeline (Pipeline | None) – Template Pipeline to clone for each file. If None, must be set before processing.

set_template(pipeline)[source]

Set template pipeline.

Parameters:

pipeline (Pipeline) – Pipeline to use as template

Return type:

BatchPipeline

Returns:

self for method chaining

process_files(file_paths, format='auto', parallel=False, parallel_io=True, n_workers=None, **load_kwargs)[source]

Process multiple files with the pipeline.

Parameters:
  • file_paths (Iterable[str | Path]) – List of file paths to process

  • format (str) – File format for loading

  • parallel (bool) – Whether to use parallel processing for the full pipeline. Default False: JAX JIT cache is not thread-safe with concurrent ThreadPoolExecutor. Set True only for I/O-bound pipelines without JAX JIT calls (e.g., loading + simple numpy transforms).

  • parallel_io (bool) – Whether to load files in parallel using threads. Default True: file I/O is thread-safe and benefits from parallelism. Loading phase runs in threads, pipeline replay runs sequentially.

  • n_workers (int | None) – Number of parallel workers (default: min(4, cpu_count))

  • **load_kwargs – Additional arguments for data loading

Return type:

BatchPipeline

Returns:

self for method chaining

Note

During replay, protocol-specific kwargs (gamma_dot, sigma_init, lam_init, sigma_0, lam_0, gamma_0, omega_laos, n_cycles, points_per_cycle) are stripped from the template’s fit kwargs because they are data-dependent and should not be reused across datasets. DMTA kwargs (deformation_mode, poisson_ratio) and solver settings (method) are preserved.

Example

>>> batch.process_files(['data1.csv', 'data2.csv'])
>>> # Parallel mode (use with caution — JAX JIT not thread-safe):
>>> batch.process_files(['data1.csv', 'data2.csv'], parallel=True)
process_directory(directory, pattern='*.csv', recursive=False, **kwargs)[source]

Process all files in directory matching pattern.

Parameters:
  • directory (str | Path) – Directory path

  • pattern (str) – File pattern (e.g., ‘.csv’, ‘.xlsx’)

  • recursive (bool) – Whether to search recursively

  • **kwargs – Additional arguments passed to process_files

Return type:

BatchPipeline

Returns:

self for method chaining

Example

>>> batch.process_directory('data/', pattern='*.csv')
get_results()[source]

Get all processing results.

Return type:

list[tuple[Path, RheoData, dict[str, Any]]]

Returns:

List of (file_path, result_data, metrics) tuples

Example

>>> results = batch.get_results()
>>> for path, data, metrics in results:
...     print(f"{path}: R²={metrics.get('r_squared', 0):.4f}")
get_errors()[source]

Get processing errors.

Return type:

list[tuple[Path, Exception]]

Returns:

List of (file_path, exception) tuples

Example

>>> errors = batch.get_errors()
>>> for path, error in errors:
...     print(f"Error in {path}: {error}")
get_summary_dataframe()[source]

Get summary DataFrame of all results.

Return type:

DataFrame

Returns:

DataFrame with file paths and metrics

Example

>>> df = batch.get_summary_dataframe()
>>> print(df)
export_summary(output_path, format='excel')[source]

Export summary of batch results.

Parameters:
  • output_path (str | Path) – Output file path

  • format (str) – Output format (‘excel’, ‘csv’)

Return type:

BatchPipeline

Returns:

self for method chaining

Example

>>> batch.export_summary('summary.xlsx')
apply_filter(filter_fn)[source]

Filter results based on custom criteria.

Parameters:

filter_fn (Callable[[Path, RheoData, dict[str, Any]], bool]) – Function that takes (file_path, data, metrics) and returns True to keep the result

Return type:

BatchPipeline

Returns:

self for method chaining

Example

>>> # Keep only results with R² > 0.9
>>> batch.apply_filter(lambda p, d, m: m.get('r_squared', 0) > 0.9)
get_statistics()[source]

Get statistics across all results.

Return type:

dict[str, Any]

Returns:

Dictionary with summary statistics

Example

>>> stats = batch.get_statistics()
>>> print(f"Mean R²: {stats['mean_r_squared']:.4f}")
clear()[source]

Clear all results and errors.

Return type:

BatchPipeline

Returns:

self for method chaining

__len__()[source]

Get number of processed results.

Return type:

int

__repr__()[source]

String representation.

Return type:

str

Description: Process multiple datasets with the same workflow in parallel.

Example - Basic Batch:

from rheojax.pipeline import Pipeline, BatchPipeline

# Define template pipeline
template = (Pipeline()
    .transform('smooth', window=11)
    .fit('maxwell')
    .plot(save='${filename}_fit.png')  # ${filename} replaced per file
    .save('${filename}_results.hdf5'))

# Create batch processor
batch = BatchPipeline(template)

# Process directory
batch.process_directory('data/', pattern='*.txt')

# Get all results
all_results = batch.get_all_results()

# Export summary
batch.export_summary('batch_summary.xlsx')

Example - Parallel Processing:

# Use multiple cores
batch = BatchPipeline(template, n_jobs=4)  # 4 parallel workers

# Process with progress bar
batch.process_directory('data/', pattern='*.txt',
                         progress_bar=True)

# Process specific files
files = ['sample1.txt', 'sample2.txt', 'sample3.txt']
batch.process_files(files)

Key Methods:

  • process_directory(path, pattern): Process all matching files in directory

  • process_files(file_list): Process specific files

  • get_all_results(): Retrieve results from all files

  • export_summary(filepath): Export comparison table

  • get_failed_files(): Get list of failed processing attempts

Parameters:

  • template (Pipeline): Template pipeline to apply

  • n_jobs (int): Number of parallel workers (-1 = all cores)

  • fail_on_error (bool): Raise exception on first error (default: False)

  • progress_bar (bool): Show progress bar (default: False)

Error Handling

Pipeline Error Management

pipeline = (Pipeline()
    .load('data.txt')
    .fit('maxwell', fail_on_error=False))  # Don't raise exception

# Check for errors
if pipeline.has_errors():
    errors = pipeline.get_errors()
    print(f"Errors encountered: {errors}")
else:
    results = pipeline.get_results()

Pipeline Validation

Validate before execution:

pipeline = (Pipeline()
    .load('data.txt')
    .fit('maxwell'))

# Validate pipeline
is_valid, messages = pipeline.validate()

if is_valid:
    results = pipeline.execute()
else:
    print(f"Validation failed: {messages}")

Debug Mode

Enable debugging output:

# Enable debug logging
pipeline = Pipeline(debug=True)

# Or set verbosity
pipeline = Pipeline(verbose=2)  # 0=silent, 1=info, 2=debug

# Inspect pipeline state
state = pipeline.get_state()
print(f"Current step: {state['current_step']}")
print(f"Data loaded: {state['data_loaded']}")
print(f"Model fitted: {state['model_fitted']}")

Best Practices

Method Chaining Style

Recommended (readable, clean):

results = (Pipeline()
    .load('data.txt')
    .transform('smooth', window=11)
    .fit('maxwell')
    .plot(show=True)
    .get_results())

Acceptable (for debugging):

pipeline = Pipeline()
pipeline.load('data.txt')
pipeline.transform('smooth', window=11)
pipeline.fit('maxwell')
pipeline.plot(show=True)
results = pipeline.get_results()

Error Recovery

# Try multiple models until one succeeds
models = ['maxwell', 'zener', 'fractional_maxwell_gel']

for model_name in models:
    try:
        results = (Pipeline()
            .load('data.txt')
            .fit(model_name)
            .get_results())
        print(f"Success with {model_name}")
        break
    except Exception as e:
        print(f"{model_name} failed: {e}")
        continue

Performance Optimization

# Cache intermediate results
pipeline = Pipeline(cache=True)

# Process in chunks for large batches
batch = BatchPipeline(template, n_jobs=-1)
batch.process_directory('data/', chunk_size=10)

See Also

  • /user_guide/pipeline_api - Comprehensive pipeline tutorial

  • /user_guide/modular_api - Low-level API for custom control

  • Models API - Model API reference

  • Transforms API - Transform API reference

  • rheojax.core.base.BaseModel - Base model class

  • rheojax.core.base.BaseTransform - Base transform class