Pipeline API¶
This page documents the high-level Pipeline API for rheological analysis workflows.
Overview¶
The Pipeline API provides a fluent interface for chaining operations from data loading through model fitting and export. It’s designed for rapid analysis with minimal boilerplate code.
Core Components:
Pipeline: Base fluent API with method chaining
Specialized Workflows: Pre-configured pipelines for common tasks
PipelineBuilder: Programmatic pipeline construction
BatchPipeline: Process multiple datasets
Basic Pipeline¶
- class rheojax.pipeline.Pipeline(data=None)[source]¶
Bases:
objectFluent API for rheological analysis workflows.
This class provides a chainable interface for loading data, applying transforms, fitting models, and generating outputs. All methods return self to enable method chaining.
- data¶
Current RheoData state
- steps¶
List of (operation, object) tuples for fitted models
- history¶
List of (operation, details) tuples tracking all operations
- _last_model¶
Last fitted model for convenience
Example
>>> pipeline = Pipeline() >>> pipeline.load('data.csv').fit('maxwell').plot()
- load(file_path, format='auto', *, test_mode=None, initial_test_mode=None, **kwargs)[source]¶
Load data from file.
- Parameters:
format (
str) – File format (‘auto’, ‘csv’, ‘excel’, ‘trios’, ‘hdf5’)test_mode (
str|None) – Optional rheological mode metadata to attach to the resulting RheoData (e.g., ‘relaxation’, ‘creep’, ‘oscillation’)initial_test_mode (
str|None) – Backwards-compatible alias for test_mode**kwargs – Additional arguments passed to reader
- Return type:
- Returns:
self for method chaining
- Raises:
FileNotFoundError – If file doesn’t exist
ValueError – If file format not recognized
Example
>>> pipeline = Pipeline().load('data.csv', x_col='time', y_col='stress')
- transform(transform, **kwargs)[source]¶
Apply a transform to the data.
- Parameters:
transform (
str|BaseTransform) – Transform name (string) or Transform instance**kwargs – Arguments passed to transform constructor (if string)
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If data not loaded or transform not found
Example
>>> pipeline.transform('smooth', window_size=5) >>> # or with instance >>> from rheojax.transforms import SmoothTransform >>> pipeline.transform(SmoothTransform(window_size=5))
- fit(model, method='auto', **fit_kwargs)[source]¶
Fit a model to the data.
- Parameters:
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If data not loaded or model not found
Example
>>> pipeline.fit('maxwell') >>> # or with instance >>> from rheojax.models.linear import Maxwell >>> pipeline.fit(Maxwell())
- predict(model=None, X=None)[source]¶
Generate predictions from fitted model.
- Parameters:
- Return type:
- Returns:
RheoData with predictions
- Raises:
ValueError – If no model has been fitted
Example
>>> predictions = pipeline.predict()
- plot(show=True, style='default', include_prediction=False, **plot_kwargs)[source]¶
Plot current data state.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.plot(style='publication')
- save(file_path, format='hdf5', **kwargs)[source]¶
Save current data to file.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.save('output.hdf5')
- save_figure(filepath, format=None, dpi=300, **kwargs)[source]¶
Save the most recent plot to file.
Convenience method for exporting plots with publication-quality defaults. Wraps rheojax.visualization.plotter.save_figure() to enable fluent API chaining.
- Parameters:
filepath (
str|Path) – Output file path. Format inferred from extension if not specified.format (
str|None) – Output format (‘pdf’, ‘svg’, ‘png’, ‘eps’). If None, inferred from filepath.dpi (
int) – Resolution for raster formats (PNG).**kwargs (
Any) – Additional arguments passed to save_figure(). See rheojax.visualization.plotter.save_figure() for details.
- Returns:
self – Returns self to enable method chaining
- Return type:
- Raises:
ValueError – If no plot exists (plot() not called yet)
ValueError – If format cannot be inferred or is unsupported
OSError – If filepath directory doesn’t exist
Examples
Basic usage with method chaining:
>>> pipeline = Pipeline() >>> pipeline.load('data.csv').fit('maxwell').plot().save_figure('result.pdf')
Save multiple formats:
>>> pipeline.plot(style='publication') >>> pipeline.save_figure('figure.pdf') >>> pipeline.save_figure('figure.png', dpi=600) >>> pipeline.save_figure('figure.svg', transparent=True)
Explicit format:
>>> pipeline.plot().save_figure('output', format='pdf')
See also
plotGenerate plot with automatic type selection
rheojax.visualization.plotter.save_figureCore export function
Notes
This method saves the most recent plot generated by plot(). If you call plot() multiple times, only the last figure is saved. To save multiple plots, call save_figure() after each plot() call.
The figure is stored internally by plot() and retrieved by save_figure().
- fit_bayesian(model=None, seed=None, **bayesian_kwargs)[source]¶
Run Bayesian (NUTS) inference on current data.
Uses the last fitted model (or a new one) with NLSQ warm-start.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.fit('maxwell').fit_bayesian(seed=42, num_warmup=1000)
- plot_fit(confidence=0.95, show_residuals=True, show_uncertainty=True, show=True, style='default', **kwargs)[source]¶
Plot NLSQ fit with uncertainty band and residuals.
Requires a prior call to fit(). Uses FitPlotter internally.
- Parameters:
confidence (
float) – Confidence level for uncertainty band (default: 0.95).show_residuals (
bool) – If True, add residuals subplot.show_uncertainty (
bool) – If True and covariance available, show band.show (
bool) – Whether to call plt.show() (default: True).style (
str) – Plot style (‘default’, ‘publication’, ‘presentation’).**kwargs – Additional arguments forwarded to FitPlotter.plot_nlsq().
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.fit('maxwell').plot_fit(confidence=0.95)
- plot_bayesian(credible_level=0.95, max_draws=500, show_nlsq_overlay=False, show_residuals=False, show=True, style='default', **kwargs)[source]¶
Plot Bayesian posterior predictive with credible interval.
Requires a prior call to fit_bayesian().
- Parameters:
credible_level (
float) – Credible interval level (default: 0.95).max_draws (
int) – Maximum posterior draws for band computation.show_nlsq_overlay (
bool) – If True, overlay NLSQ fit for comparison.show_residuals (
bool) – If True, add residuals subplot.show (
bool) – Whether to call plt.show() (default: True).style (
str) – Plot style.**kwargs – Additional arguments forwarded to FitPlotter.plot_bayesian().
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.fit('maxwell').fit_bayesian(seed=42).plot_bayesian()
- plot_diagnostics(output_dir=None, style='default', prefix='mcmc', formats=('pdf', 'png'), dpi=300, **kwargs)[source]¶
Generate ArviZ MCMC diagnostic suite (6 plots).
Requires a prior call to fit_bayesian().
- Parameters:
output_dir (
str|Path|None) – Directory for saving plots. If None, displays only.style (
str) – Plot style.prefix (
str) – Filename prefix for saved plots.formats (
tuple[str,...]) – Output formats (default: (‘pdf’, ‘png’)).dpi (
int) – Resolution for raster formats.**kwargs – Additional arguments forwarded to generate_diagnostic_suite().
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.fit_bayesian(seed=42).plot_diagnostics(output_dir='./diag')
- plot_transform(transform_name=None, show_intermediate=True, show=True, style='default', **kwargs)[source]¶
Plot the result of a previously applied transform.
Uses TransformPlotter for per-transform layout dispatch.
- Parameters:
transform_name (
str|None) – Name of the transform to plot. If None, uses the most recently applied transform.show_intermediate (
bool) – Whether to show before/after comparison.show (
bool) – Whether to call plt.show() (default: True).style (
str) – Plot style.**kwargs – Additional arguments forwarded to TransformPlotter.
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.transform('mastercurve', reference_temp=25.0).plot_transform()
- get_result()[source]¶
Get current data state.
- Return type:
- Returns:
Current RheoData
Example
>>> data = pipeline.get_result()
- get_history()[source]¶
Get pipeline execution history.
Example
>>> history = pipeline.get_history() >>> for step in history: ... print(step)
- get_last_model()[source]¶
Get the last fitted model.
- Return type:
BaseModel|None- Returns:
Last fitted BaseModel or None
Example
>>> model = pipeline.get_last_model() >>> params = model.get_params()
- get_all_models()[source]¶
Get all fitted models from pipeline.
- Return type:
list[BaseModel]- Returns:
List of all fitted models
Example
>>> models = pipeline.get_all_models()
- get_fitted_parameters()[source]¶
Get fitted parameters from the last model as a dictionary.
This is a convenience method that extracts parameter values from the last fitted model’s ParameterSet.
- Return type:
- Returns:
Dictionary mapping parameter names to their fitted values
- Raises:
ValueError – If no model has been fitted yet
Example
>>> pipeline = Pipeline() >>> pipeline.load('data.csv').fit('maxwell') >>> params = pipeline.get_fitted_parameters() >>> print(params) # {'G0': 100000.0, 'eta': 1000.0} >>> G0 = params['G0']
- compare_models(models, criterion='aic', **fit_kwargs)[source]¶
Compare multiple models on the current data.
Fits each model and ranks by information criterion. The best model becomes
_last_modeland is appended tosteps.- Parameters:
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If no data is loaded.
Example
>>> pipeline.load('data.csv').compare_models(['maxwell', 'zener'])
- get_fit_result()[source]¶
Construct a FitResult from the last fitted model.
- Return type:
- Returns:
FitResult with model metadata, fitted parameters, and statistics.
- Raises:
ValueError – If no model has been fitted.
Example
>>> result = pipeline.load('data.csv').fit('maxwell').get_fit_result() >>> print(result.summary())
- clone()[source]¶
Create a copy of the pipeline.
- Return type:
- Returns:
New Pipeline with copied data and history
Example
>>> pipeline2 = pipeline.clone()
- reset()[source]¶
Reset pipeline to initial state.
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.reset()
- export(output, format='auto', *, include_data=True, include_figures=True, include_diagnostics=True, figure_formats=('pdf', 'png'), figure_dpi=300, **kwargs)[source]¶
Export the full analysis to a directory or file.
This bundles data, parameters, statistics, figures, transform results, and Bayesian diagnostics into a single export.
- Parameters:
output (
str|Path) – Output path. If a directory (no extension or trailing /), exports as structured directory. If .xlsx, exports Excel.format (
str) – Export format (‘auto’, ‘directory’, ‘excel’). ‘auto’ infers from the output path extension.include_data (
bool) – Save raw and transformed data files.include_figures (
bool) – Save generated matplotlib figures.include_diagnostics (
bool) – Save MCMC diagnostic plots.figure_formats (
tuple[str,...]) – Formats for figure files (default: (‘pdf’, ‘png’)).figure_dpi (
int) – Resolution for raster figures (default: 300).**kwargs – Additional arguments forwarded to the exporter.
- Return type:
- Returns:
self for method chaining
Example
>>> pipeline.load('data.csv').fit('maxwell').plot_fit().export('./results') >>> pipeline.export('report.xlsx')
Description: Core pipeline class providing fluent method chaining for rheological analysis workflows.
Example - Basic Usage:
from rheojax.pipeline import Pipeline
# Create pipeline and chain operations
results = (Pipeline()
.load('data.txt') # Load data
.transform('smooth', window=11) # Smooth noisy data
.fit('maxwell') # Fit model
.plot(show=True) # Visualize
.get_results()) # Retrieve results
print(f"R^2 = {results['r2']:.4f}")
print(f"Parameters: {results['parameters']}")
Key Methods:
load(source, format='auto', **kwargs)¶
Load data from file or RheoData object.
Parameters:
source(str or RheoData): File path or data objectformat(str): File format - ‘auto’, ‘trios’, ‘csv’, ‘excel’**kwargs: Format-specific arguments
Returns: self (for chaining)
Example:
# Auto-detect format
pipeline = Pipeline().load('data.txt')
# Explicit format
pipeline = Pipeline().load('data.csv', format='csv',
x_col='frequency', y_col='modulus')
# From RheoData object
from rheojax.core import RheoData
data = RheoData(x=freq, y=modulus, ...)
pipeline = Pipeline().load(data)
transform(name, **params)¶
Apply data transform.
Parameters:
name(str): Transform name - ‘smooth’, ‘fft’, ‘mastercurve’, etc.**params: Transform-specific parameters
Returns: self (for chaining)
Example:
# Single transform
pipeline = (Pipeline()
.load('data.txt')
.transform('smooth', method='savgol', window=11))
# Multiple transforms (chained)
pipeline = (Pipeline()
.load('data.txt')
.transform('smooth', window=11)
.transform('fft', window='hann'))
fit(model, initial_params=None, bounds=None, **kwargs)¶
Fit rheological model to data.
Parameters:
model(str or BaseModel): Model name or instanceinitial_params(dict, optional): Initial parameter valuesbounds(dict, optional): Parameter bounds**kwargs: Optimization options
Returns: self (for chaining)
Example:
# By name
pipeline = Pipeline().load('data.txt').fit('maxwell')
# With initial parameters
pipeline = (Pipeline()
.load('data.txt')
.fit('maxwell',
initial_params={'G_s': 1e5, 'eta_s': 1e3},
bounds={'G_s': (1e3, 1e7), 'eta_s': (1e1, 1e5)}))
# Multiple models (comparison)
pipeline = (Pipeline()
.load('data.txt')
.fit(['maxwell', 'zener', 'springpot']))
plot(show=False, save=None, style='default', **kwargs)¶
Visualize data and model fit.
Parameters:
show(bool): Display interactive plotsave(str, optional): Save to filestyle(str): Plot style - ‘default’, ‘publication’, ‘presentation’**kwargs: Plotting options
Returns: self (for chaining)
Example:
# Show plot
pipeline.plot(show=True)
# Save to file
pipeline.plot(save='fit_result.png', dpi=300)
# Custom style
pipeline.plot(show=True, style='publication',
include_residuals=True, title='Maxwell Fit')
save(filepath, format='hdf5', **kwargs)¶
Export results to file.
Parameters:
filepath(str): Output file pathformat(str): Format - ‘hdf5’, ‘excel’, ‘csv’**kwargs: Format-specific options
Returns: self (for chaining)
Example:
# HDF5 (full fidelity)
pipeline.save('results.hdf5')
# Excel report
pipeline.save('report.xlsx', format='excel', include_plots=True)
get_results()¶
Retrieve analysis results as dictionary.
Returns: dict with keys:
'parameters': Fitted parameter values'r2': R^2 score'rmse': Root mean squared error'predictions': Model predictions'residuals': Fit residuals'data': Original RheoData object'model': Fitted model instance
Example:
results = pipeline.get_results()
print(f"R^2 = {results['r2']:.4f}")
print(f"RMSE = {results['rmse']:.2e}")
print(f"Parameters:")
for name, value in results['parameters'].items():
print(f" {name} = {value:.4e}")
Specialized Workflows¶
MastercurvePipeline¶
- class rheojax.pipeline.MastercurvePipeline(reference_temp=298.15)[source]¶
Bases:
PipelinePipeline for time-temperature superposition analysis.
This pipeline automates the construction of mastercurves from multi-temperature rheological data using horizontal shift factors.
- reference_temp¶
Reference temperature for mastercurve
- shift_factors¶
Dictionary of temperature -> shift factor
Example
>>> pipeline = MastercurvePipeline(reference_temp=298.15) >>> pipeline.run(file_paths, temperatures) >>> mastercurve = pipeline.get_result()
- __init__(reference_temp=298.15)[source]¶
Initialize mastercurve pipeline.
- Parameters:
reference_temp (
float) – Reference temperature in Kelvin (default: 298.15 K)
- run(file_paths, temperatures, format='auto', parallel_io=True, **load_kwargs)[source]¶
Execute mastercurve workflow.
- Parameters:
file_paths (
list[str]) – List of data file paths (one per temperature)temperatures (
list[float]) – List of temperatures (in Kelvin)format (
str) – File format for loadingparallel_io (
bool) – Whether to load files in parallel (default True)**load_kwargs – Additional arguments passed to load (e.g., x_col, y_col)
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If file_paths and temperatures have different lengths
Description: Pre-configured pipeline for time-temperature superposition analysis.
Example:
from rheojax.pipeline import MastercurvePipeline
# Create mastercurve pipeline
mc_pipeline = MastercurvePipeline(
reference_temp=50, # Reference temperature ( degC)
method='wlf', # 'wlf' or 'arrhenius'
optimize=True # Optimize WLF/Arrhenius parameters
)
# Load and process multi-temperature data
files = ['data_25C.txt', 'data_40C.txt', 'data_55C.txt', 'data_70C.txt']
temperatures = [25, 40, 55, 70]
results = mc_pipeline.run(files, temperatures)
# Access mastercurve results
mastercurve = results['mastercurve']
shift_factors = results['shift_factors']
wlf_params = results['wlf_parameters']
print(f"WLF C1 = {wlf_params['C1']:.2f}")
print(f"WLF C2 = {wlf_params['C2']:.2f} K")
# Fit model to mastercurve
mc_pipeline.fit('fractional_maxwell_gel')
mc_pipeline.plot(show=True, style='publication')
Key Methods:
run(files, temperatures): Create mastercurve from filesfit(model): Fit model to mastercurveget_shift_factors(): Get temperature shift factorsget_wlf_parameters(): Get fitted WLF C1, C2
ModelComparisonPipeline¶
- class rheojax.pipeline.ModelComparisonPipeline(models)[source]¶
Bases:
PipelinePipeline for comparing multiple models on the same data.
This pipeline fits multiple models to the same dataset and computes comparison metrics (RMSE, R², AIC, etc.).
- models¶
List of model names to compare
- results¶
Dictionary of model_name -> metrics
Example
>>> pipeline = ModelComparisonPipeline(['maxwell', 'zener', 'springpot']) >>> pipeline.run(data) >>> best = pipeline.get_best_model() >>> print(pipeline.get_comparison_table())
- run(data, parallel=False, n_workers=None, **fit_kwargs)[source]¶
Fit multiple models and compare.
- Parameters:
- Return type:
- Returns:
self for method chaining
- get_best_model(metric='rmse', minimize=True)[source]¶
Return name of best-fitting model.
- Parameters:
- Return type:
- Returns:
Name of best model
Example
>>> best = pipeline.get_best_model(metric='aic')
- get_comparison_table()[source]¶
Get comparison table of all models.
Example
>>> table = pipeline.get_comparison_table() >>> for model, metrics in table.items(): ... print(f"{model}: R²={metrics['r_squared']:.4f}")
Description: Systematically compare multiple models on the same dataset.
Example:
from rheojax.pipeline import ModelComparisonPipeline
# Models to compare
models = ['maxwell', 'zener', 'fractional_maxwell_gel',
'fractional_kelvin_voigt', 'springpot']
# Create comparison pipeline
comparison = ModelComparisonPipeline(models)
# Load data and run comparison
comparison.load('data.txt')
comparison.run()
# Get comparison table
results = comparison.get_results()
comparison_table = results['comparison']
# Print comparison
print("\\nModel Comparison:")
print(f"{'Model':<30} {'R^2':<10} {'RMSE':<12} {'AIC':<12}")
print("-" * 70)
for row in comparison_table:
print(f"{row['model']:<30} {row['r2']:<10.4f} "
f"{row['rmse']:<12.2e} {row['aic']:<12.1f}")
# Get best model
best = comparison.get_best_model(criterion='aic') # 'aic', 'bic', 'r2'
print(f"\\nBest model (AIC): {best['name']}")
# Visualize comparison
comparison.plot_comparison(show=True)
comparison.plot_ranking(criterion='aic', show=True)
Key Methods:
run(): Fit all modelsget_best_model(criterion): Select best by AIC, BIC, or R^2plot_comparison(): Multi-panel plot of all modelsplot_ranking(): Bar chart ranking by criterion
CreepToRelaxationPipeline¶
- class rheojax.pipeline.CreepToRelaxationPipeline(data=None)[source]¶
Bases:
PipelineConvert creep compliance data to relaxation modulus.
This pipeline performs the numerical conversion from J(t) to G(t) using regularized numerical inversion techniques.
Example
>>> pipeline = CreepToRelaxationPipeline() >>> pipeline.run(creep_data) >>> relaxation_data = pipeline.get_result()
- run(creep_data, method='approximate')[source]¶
Execute conversion workflow.
- Parameters:
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If input is not creep data
Description: Convert creep compliance J(t) to relaxation modulus G(t).
Example:
from rheojax.pipeline import CreepToRelaxationPipeline
converter = CreepToRelaxationPipeline(
method='integration', # 'integration' or 'approximate'
regularization=0.01 # Regularization parameter
)
converter.load('creep_data.txt')
relaxation_data = converter.convert()
# Fit model to relaxation data
converter.fit('maxwell')
converter.plot(show=True)
FrequencyToTimePipeline¶
- class rheojax.pipeline.FrequencyToTimePipeline(data=None)[source]¶
Bases:
PipelineConvert frequency domain data to time domain.
This pipeline converts dynamic modulus G*(ω) to relaxation modulus G(t) using Fourier transform techniques.
Example
>>> pipeline = FrequencyToTimePipeline() >>> pipeline.run(frequency_data) >>> time_data = pipeline.get_result()
Description: Convert frequency-domain data to time-domain via inverse FFT.
Example:
from rheojax.pipeline import FrequencyToTimePipeline
ft_pipeline = FrequencyToTimePipeline(
method='inverse_fft', # 'inverse_fft' or 'analytical'
time_range=(1e-3, 1e3), # Time range (s)
n_points=200 # Number of time points
)
ft_pipeline.load('frequency_sweep.txt')
time_data = ft_pipeline.convert()
ft_pipeline.plot(show=True)
BayesianPipeline¶
- class rheojax.pipeline.bayesian.BayesianPipeline(data=None)[source]¶
Bases:
PipelineSpecialized pipeline for Bayesian rheological analysis workflows.
This class extends the base Pipeline to provide a fluent API for the NLSQ → NumPyro NUTS workflow. It supports: - NLSQ optimization for fast point estimation - Bayesian inference with automatic warm-start from NLSQ - Convergence diagnostics (R-hat, ESS, divergences) - Posterior visualization (distributions and trace plots)
All methods return self to enable method chaining.
- data¶
Current RheoData state (inherited from Pipeline)
- _last_model¶
Last fitted model (inherited from Pipeline)
- _nlsq_result¶
Stored NLSQ optimization result
- _bayesian_result¶
Stored Bayesian inference result
- _diagnostics¶
Stored convergence diagnostics
Example
>>> pipeline = BayesianPipeline() >>> pipeline.load('data.csv') \ ... .fit_nlsq('maxwell') \ ... .fit_bayesian(num_samples=2000) \ ... .plot_posterior() \ ... .save('results.hdf5')
- __init__(data=None)[source]¶
Initialize Bayesian pipeline.
- Parameters:
data – Optional initial RheoData. If None, must call load() first.
- fit_nlsq(model, **nlsq_kwargs)[source]¶
Fit model using NLSQ optimization for point estimation.
This method performs fast GPU-accelerated nonlinear least squares optimization to obtain point estimates of model parameters. The optimization result is stored for potential warm-starting of Bayesian inference.
- Parameters:
model (
str|BaseModel) – Model name (string) or Model instance to fit**nlsq_kwargs – Additional arguments passed to NLSQ optimizer (e.g., max_iter, ftol, xtol, gtol)
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If data not loaded
Note
This method writes resolved
deformation_mode,poisson_ratio, andtest_modeback toself.data.metadataso that a subsequentfit_bayesian()call inherits these settings without the caller having to repeat them.Example
>>> pipeline.fit_nlsq('maxwell') >>> # or with instance >>> from rheojax.models import Maxwell >>> pipeline.fit_nlsq(Maxwell(), max_iter=1000)
- fit_bayesian(num_samples=2000, num_warmup=1000, num_chains=4, **nuts_kwargs)[source]¶
Perform Bayesian inference using NumPyro NUTS sampler.
This method runs NUTS (No-U-Turn Sampler) for Bayesian parameter estimation. If a model has been previously fitted with fit_nlsq(), the NLSQ point estimates are automatically used for warm-starting the sampler, leading to faster convergence.
Multi-chain sampling is enabled by default (num_chains=4) to provide reliable convergence diagnostics (R-hat, ESS) and parallel execution on multi-GPU systems.
- Parameters:
num_samples (
int) – Number of posterior samples per chain (default: 2000)num_warmup (
int) – Number of warmup/burn-in iterations (default: 1000)num_chains (
int) – Number of MCMC chains (default: 4). Multiple chains enable proper R-hat computation and parallel execution. Chain method is auto-selected: ‘parallel’ on multi-GPU, ‘vectorized’ on single GPU/CPU.**nuts_kwargs – Additional arguments passed to NUTS sampler (e.g., target_accept_prob, chain_method)
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If no model has been fitted with fit_nlsq()
Example
>>> pipeline.fit_nlsq('maxwell').fit_bayesian(num_samples=2000) >>> # With custom NUTS parameters >>> pipeline.fit_bayesian( ... num_samples=3000, ... num_warmup=1500, ... num_chains=4, ... target_accept_prob=0.9 ... )
- get_diagnostics()[source]¶
Get convergence diagnostics from Bayesian inference.
Returns diagnostics including R-hat (Gelman-Rubin statistic), effective sample size (ESS), and number of divergent transitions.
- Returns:
r_hat: R-hat for each parameter (dict)
ess: Effective sample size for each parameter (dict)
divergences: Number of divergent transitions (int)
- Return type:
- Raises:
ValueError – If Bayesian inference has not been run
Example
>>> diagnostics = pipeline.get_diagnostics() >>> print(f"R-hat: {diagnostics['r_hat']}") >>> print(f"ESS: {diagnostics['ess']}") >>> print(f"Divergences: {diagnostics['divergences']}")
- get_posterior_summary()[source]¶
Get formatted posterior summary statistics.
Returns a pandas DataFrame with summary statistics for each parameter including mean, standard deviation, median, and quantiles (5%, 25%, 75%, 95%).
- Return type:
DataFrame- Returns:
DataFrame with parameters as rows and statistics as columns
- Raises:
ValueError – If Bayesian inference has not been run
Example
>>> summary = pipeline.get_posterior_summary() >>> print(summary) mean std median q05 q25 q75 q95 a 5.123 0.245 5.110 4.721 4.962 5.285 5.531 b 0.487 0.032 0.485 0.435 0.465 0.509 0.542
- plot_posterior(param_name=None, show=True, **plot_kwargs)[source]¶
Plot posterior distributions.
Generates histogram plots of posterior distributions for model parameters. If param_name is None, plots all parameters in separate subplots.
- Parameters:
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If Bayesian inference has not been run
Example
>>> # Plot all parameters >>> pipeline.plot_posterior() >>> # Plot specific parameter >>> pipeline.plot_posterior('a', bins=50, alpha=0.7) >>> # Plot without showing (for save_figure) >>> pipeline.plot_posterior(show=False).save_figure('posterior.pdf')
- plot_trace(param_name=None, show=True, **plot_kwargs)[source]¶
Plot MCMC trace plots.
Generates trace plots showing parameter values across MCMC iterations. Useful for diagnosing convergence issues. If param_name is None, plots all parameters.
- Parameters:
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If Bayesian inference has not been run
Example
>>> # Plot all trace plots >>> pipeline.plot_trace() >>> # Plot specific parameter >>> pipeline.plot_trace('a', alpha=0.5) >>> # Plot without showing (for save_figure) >>> pipeline.plot_trace(show=False).save_figure('trace.pdf')
- plot_pair(var_names=None, kind='scatter', divergences=True, show=True, **plot_kwargs)[source]¶
Plot pairwise relationships between parameters (pair plot).
Creates a matrix of scatter or KDE plots showing correlations between parameters. This is critical for identifying parameter dependencies, non-identifiability issues, and understanding the joint posterior structure. Divergent transitions are highlighted by default to identify problematic posterior geometry.
- Parameters:
var_names (
list[str] |None) – List of parameter names to plot. If None, plots all parameters (default: None)kind (
str) – Type of pair plot - “scatter”, “kde”, or “hexbin” (default: “scatter”)divergences (
bool) – Whether to highlight divergent transitions in red (default: True). Useful for identifying problematic regions.show (
bool) – Whether to call plt.show() (default: True)**plot_kwargs – Additional arguments passed to arviz.plot_pair() (e.g., marginals, point_estimate_marker_style)
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If Bayesian inference has not been run
ImportError – If arviz is not installed
Example
>>> # Plot all parameters with divergences highlighted >>> pipeline.plot_pair() >>> >>> # Plot specific parameters as KDE >>> pipeline.plot_pair(var_names=["G0", "eta"], kind="kde") >>> >>> # Save without showing >>> pipeline.plot_pair(show=False).save_figure("pair.pdf")
Note
Pair plots are essential for diagnosing: - Parameter correlations (indicates non-identifiability) - Funnel geometry (divergences concentrated in specific regions) - Multimodal posteriors (multiple clusters)
- plot_forest(var_names=None, combined=True, hdi_prob=0.95, show=True, **plot_kwargs)[source]¶
Plot forest plot with credible intervals for parameters.
Creates a forest plot showing parameter estimates with credible intervals (highest density intervals). Excellent for comparing parameter magnitudes and uncertainties at a glance. Each parameter is shown as a point estimate with error bars representing the credible interval.
- Parameters:
var_names (
list[str] |None) – List of parameter names to plot. If None, plots all parameters (default: None)combined (
bool) – Whether to combine multiple chains (default: True)hdi_prob (
float) – Probability mass for credible interval (default: 0.95). Common values: 0.68 (1σ), 0.95 (2σ), 0.997 (3σ)show (
bool) – Whether to call plt.show() (default: True)**plot_kwargs – Additional arguments passed to arviz.plot_forest() (e.g., rope, ref_val, colors)
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If Bayesian inference has not been run
ImportError – If arviz is not installed
Example
>>> # Plot all parameters with 95% CI >>> pipeline.plot_forest() >>> >>> # Plot specific parameters with 68% CI >>> pipeline.plot_forest(var_names=["G0", "eta"], hdi_prob=0.68) >>> >>> # Save without showing >>> pipeline.plot_forest(show=False).save_figure("forest.pdf")
Note
Forest plots are useful for: - Quickly comparing parameter magnitudes - Assessing parameter uncertainty - Identifying parameters with poor estimation (wide intervals)
- plot_energy(show=True, **plot_kwargs)[source]¶
Plot NUTS energy diagnostic plot.
Creates an energy plot showing the distribution of energy transitions during NUTS sampling. This is a NUTS-specific diagnostic that helps identify problematic posterior geometry such as heavy tails, funnels, or multimodal distributions. Energy transitions that differ between the marginal and transition distributions indicate sampling problems.
- Parameters:
show (
bool) – Whether to call plt.show() (default: True)**plot_kwargs – Additional arguments passed to arviz.plot_energy()
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If Bayesian inference has not been run
ImportError – If arviz is not installed
Example
>>> # Plot energy diagnostic >>> pipeline.plot_energy() >>> >>> # Save without showing >>> pipeline.plot_energy(show=False).save_figure("energy.pdf")
Note
Energy diagnostics help identify: - Heavy-tailed posteriors (energy dist has fat tails) - Funnel geometry (energy varies dramatically) - Problematic parameterizations Good NUTS sampling shows similar marginal and transition energy distributions.
- plot_autocorr(var_names=None, combined=False, show=True, **plot_kwargs)[source]¶
Plot autocorrelation diagnostic for MCMC mixing.
Creates autocorrelation plots showing how correlated consecutive samples are in the MCMC chain. High autocorrelation indicates poor mixing and suggests more samples are needed for reliable inference. Ideally, autocorrelation should decay quickly to zero.
- Parameters:
var_names (
list[str] |None) – List of parameter names to plot. If None, plots all parameters (default: None)combined (
bool) – Whether to combine multiple chains (default: False)show (
bool) – Whether to call plt.show() (default: True)**plot_kwargs – Additional arguments passed to arviz.plot_autocorr() (e.g., max_lag)
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If Bayesian inference has not been run
ImportError – If arviz is not installed
Example
>>> # Plot autocorrelation for all parameters >>> pipeline.plot_autocorr() >>> >>> # Plot specific parameters with longer lag >>> pipeline.plot_autocorr(var_names=["G0"], max_lag=100) >>> >>> # Save without showing >>> pipeline.plot_autocorr(show=False).save_figure("autocorr.pdf")
Note
Autocorrelation diagnostics help identify: - Poor mixing (high autocorrelation persists) - Need for more samples (ESS will be low) - Chain length adequacy Goal: autocorrelation drops to ~0 within a few dozen lags.
- plot_rank(var_names=None, show=True, **plot_kwargs)[source]¶
Plot rank plot for convergence diagnostics.
Creates rank plots (also called rank histograms or rank-normalization plots) which are a modern alternative to trace plots for diagnosing convergence. A uniform rank distribution across chains indicates good mixing and convergence. Non-uniformity suggests convergence problems.
- Parameters:
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If Bayesian inference has not been run
ImportError – If arviz is not installed
Example
>>> # Plot rank diagnostic for all parameters >>> pipeline.plot_rank() >>> >>> # Plot specific parameters >>> pipeline.plot_rank(var_names=["G0", "eta"]) >>> >>> # Save without showing >>> pipeline.plot_rank(show=False).save_figure("rank.pdf")
Note
Rank plots help identify: - Non-convergence (non-uniform rank distribution) - Chain sticking (vertical bands) - Insufficient mixing (patterns in ranks) Goal: uniform histogram across all bins.
- plot_ess(var_names=None, kind='local', show=True, **plot_kwargs)[source]¶
Plot effective sample size (ESS) diagnostic.
Creates a plot showing the effective sample size for each parameter, which quantifies how many independent samples the MCMC chain is equivalent to. Low ESS indicates high autocorrelation and suggests more samples are needed. ESS values should ideally be > 400.
- Parameters:
var_names (
list[str] |None) – List of parameter names to plot. If None, plots all parameters (default: None)kind (
str) – Type of ESS plot - “local”, “quantile”, or “evolution” (default: “local”)show (
bool) – Whether to call plt.show() (default: True)**plot_kwargs – Additional arguments passed to arviz.plot_ess() (e.g., min_ess)
- Return type:
- Returns:
self for method chaining
- Raises:
ValueError – If Bayesian inference has not been run
ImportError – If arviz is not installed
Example
>>> # Plot ESS for all parameters >>> pipeline.plot_ess() >>> >>> # Plot quantile ESS >>> pipeline.plot_ess(kind="quantile") >>> >>> # Save without showing >>> pipeline.plot_ess(show=False).save_figure("ess.pdf")
Note
ESS diagnostics help assess: - Sampling efficiency (ESS / total samples) - Which parameters need more sampling - Overall chain quality Goal: ESS > 400 for bulk and tail estimates.
Description: Specialized pipeline for Bayesian rheological analysis with NLSQ -> NUTS workflow.
Key Features:
NLSQ optimization for fast point estimation
Automatic warm-start for NumPyro NUTS sampling
Comprehensive ArviZ diagnostics (6 plot types)
Fluent API for method chaining
Convergence monitoring (R-hat, ESS, divergences)
Example - Complete Bayesian Workflow:
from rheojax.pipeline.bayesian import BayesianPipeline
# Create and execute pipeline
pipeline = (BayesianPipeline()
.load('data.csv', x_col='time', y_col='stress')
.fit_nlsq('maxwell') # Fast point estimate
.fit_bayesian(num_samples=2000, # NUTS with warm-start
num_warmup=1000)
.plot_posterior() # Posterior distributions
.plot_trace() # MCMC trace plots
.save('results.hdf5')) # Export results
# Access results
summary = pipeline.get_posterior_summary()
diagnostics = pipeline.get_diagnostics()
intervals = pipeline.get_credible_intervals()
Example - ArviZ Diagnostic Suite:
# Comprehensive MCMC quality assessment
(pipeline
.plot_pair(divergences=True) # Parameter correlations with divergences
.plot_forest(hdi_prob=0.95) # Credible intervals comparison
.plot_energy() # NUTS energy diagnostic
.plot_autocorr() # Mixing diagnostic
.plot_rank() # Convergence diagnostic
.plot_ess(kind='local')) # Effective sample size
# Convert to ArviZ InferenceData for advanced analysis
idata = pipeline._bayesian_result.to_inference_data()
import arviz as az
az.summary(idata)
Key Methods:
fit_nlsq(model_name, **kwargs): NLSQ optimization for point estimationfit_bayesian(num_samples, num_warmup, **kwargs): NumPyro NUTS sampling with warm-startplot_posterior(**kwargs): Plot posterior distributionsplot_trace(**kwargs): Plot MCMC trace diagnosticsplot_pair(**kwargs): Plot parameter correlations (ArviZ)plot_forest(**kwargs): Plot credible intervals (ArviZ)plot_energy(**kwargs): Plot NUTS energy diagnostic (ArviZ)plot_autocorr(**kwargs): Plot autocorrelation (ArviZ)plot_rank(**kwargs): Plot rank diagnostic (ArviZ)plot_ess(**kwargs): Plot effective sample size (ArviZ)get_posterior_summary(): Get posterior summary statisticsget_diagnostics(): Get convergence diagnostics (R-hat, ESS)get_credible_intervals(credibility=0.95): Get credible intervals
Pipeline Builder¶
- class rheojax.pipeline.PipelineBuilder[source]¶
Bases:
objectBuild and validate pipelines programmatically.
This class provides a fluent API for constructing pipelines with validation of step order and dependencies.
Example
>>> builder = PipelineBuilder() >>> builder.add_load_step('data.csv') >>> builder.add_fit_step('maxwell') >>> pipeline = builder.build()
- add_load_step(file_path, format='auto', **kwargs)[source]¶
Add data loading step.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> builder.add_load_step('data.csv', x_col='time', y_col='stress')
- add_transform_step(transform_name, **kwargs)[source]¶
Add transform step.
- Parameters:
transform_name (
str) – Name of transform to apply**kwargs – Arguments for transform constructor
- Return type:
- Returns:
self for method chaining
Example
>>> builder.add_transform_step('smooth', window_size=5)
- add_fit_step(model_name, method='auto', use_jax=True, **kwargs)[source]¶
Add model fitting step.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> builder.add_fit_step('maxwell')
- add_predict_step(store_as=None, **kwargs)[source]¶
Add prediction step.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> builder.add_predict_step(store_as='prediction')
- add_plot_step(show=False, style='default', **kwargs)[source]¶
Add plotting step.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> builder.add_plot_step(style='publication', show=True)
- add_bayesian_step(num_warmup=1000, num_samples=2000, num_chains=4, seed=0, warm_start=True, **kwargs)[source]¶
Add Bayesian inference step (NUTS sampling).
- Parameters:
num_warmup (
int) – Number of warmup iterations per chainnum_samples (
int) – Number of posterior samples per chainnum_chains (
int) – Number of MCMC chainsseed (
int) – Random seed for reproducibilitywarm_start (
bool) – Whether to use NLSQ results as initial values**kwargs – Additional arguments for fit_bayesian()
- Return type:
- Returns:
self for method chaining
Example
>>> builder.add_bayesian_step(num_warmup=500, num_samples=1000)
- add_export_step(output_path, format='auto', **kwargs)[source]¶
Add analysis export step.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> builder.add_export_step('./results', format='directory')
- add_save_step(file_path, format='hdf5', **kwargs)[source]¶
Add data saving step.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> builder.add_save_step('output.hdf5')
- build(validate=True)[source]¶
Build and optionally validate pipeline.
- Parameters:
validate (
bool) – Whether to validate pipeline structure- Return type:
- Returns:
Constructed Pipeline instance
- Raises:
ValueError – If validation fails
Example
>>> pipeline = builder.build()
- clear()[source]¶
Clear all steps.
- Return type:
- Returns:
self for method chaining
Example
>>> builder.clear()
Description: Programmatic pipeline construction for complex custom workflows.
Example - Basic Builder:
from rheojax.pipeline import PipelineBuilder
# Build custom pipeline
builder = PipelineBuilder()
builder.add_load_step('data.txt', format='auto')
builder.add_transform_step('smooth', method='savgol', window=11)
builder.add_transform_step('fft', window='hann')
builder.add_fit_step('maxwell', initial_params={'G_s': 1e5})
builder.add_plot_step(show=False, save='result.png')
builder.add_save_step('result.hdf5')
# Build and execute
pipeline = builder.build()
results = pipeline.execute()
Example - Conditional Logic:
builder = PipelineBuilder()
builder.add_load_step('data.txt')
# Conditional transform
builder.add_conditional_step(
condition=lambda state: state['data'].metadata.get('noisy', False),
true_step=('transform', {'name': 'smooth', 'window': 11}),
false_step=None # Skip if not noisy
)
builder.add_fit_step('maxwell')
pipeline = builder.build()
results = pipeline.execute()
Key Methods:
add_load_step(source, **kwargs): Add data loading stepadd_transform_step(name, **params): Add transform stepadd_fit_step(model, **kwargs): Add model fitting stepadd_plot_step(**kwargs): Add visualization stepadd_save_step(filepath, **kwargs): Add export stepadd_conditional_step(condition, true_step, false_step): Add conditional logicbuild(): Build pipelineexecute(): Execute built pipeline
Batch Processing¶
- class rheojax.pipeline.BatchPipeline(template_pipeline=None)[source]¶
Bases:
objectApply pipeline to multiple datasets.
This class enables batch processing of multiple data files with the same pipeline configuration, collecting results for analysis.
- template_pipeline¶
Template Pipeline to apply to each dataset
- results¶
List of (file_path, result, metrics) tuples
Example
>>> template = Pipeline().fit('maxwell') >>> batch = BatchPipeline(template) >>> batch.process_files(['data1.csv', 'data2.csv'])
- set_template(pipeline)[source]¶
Set template pipeline.
- Parameters:
pipeline (
Pipeline) – Pipeline to use as template- Return type:
- Returns:
self for method chaining
- process_files(file_paths, format='auto', parallel=False, parallel_io=True, n_workers=None, **load_kwargs)[source]¶
Process multiple files with the pipeline.
- Parameters:
file_paths (
Iterable[str|Path]) – List of file paths to processformat (
str) – File format for loadingparallel (
bool) – Whether to use parallel processing for the full pipeline. Default False: JAX JIT cache is not thread-safe with concurrent ThreadPoolExecutor. Set True only for I/O-bound pipelines without JAX JIT calls (e.g., loading + simple numpy transforms).parallel_io (
bool) – Whether to load files in parallel using threads. Default True: file I/O is thread-safe and benefits from parallelism. Loading phase runs in threads, pipeline replay runs sequentially.n_workers (
int|None) – Number of parallel workers (default: min(4, cpu_count))**load_kwargs – Additional arguments for data loading
- Return type:
- Returns:
self for method chaining
Note
During replay, protocol-specific kwargs (gamma_dot, sigma_init, lam_init, sigma_0, lam_0, gamma_0, omega_laos, n_cycles, points_per_cycle) are stripped from the template’s fit kwargs because they are data-dependent and should not be reused across datasets. DMTA kwargs (deformation_mode, poisson_ratio) and solver settings (method) are preserved.
Example
>>> batch.process_files(['data1.csv', 'data2.csv']) >>> # Parallel mode (use with caution — JAX JIT not thread-safe): >>> batch.process_files(['data1.csv', 'data2.csv'], parallel=True)
- process_directory(directory, pattern='*.csv', recursive=False, **kwargs)[source]¶
Process all files in directory matching pattern.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> batch.process_directory('data/', pattern='*.csv')
- get_results()[source]¶
Get all processing results.
- Return type:
- Returns:
List of (file_path, result_data, metrics) tuples
Example
>>> results = batch.get_results() >>> for path, data, metrics in results: ... print(f"{path}: R²={metrics.get('r_squared', 0):.4f}")
- get_errors()[source]¶
Get processing errors.
Example
>>> errors = batch.get_errors() >>> for path, error in errors: ... print(f"Error in {path}: {error}")
- get_summary_dataframe()[source]¶
Get summary DataFrame of all results.
- Return type:
DataFrame- Returns:
DataFrame with file paths and metrics
Example
>>> df = batch.get_summary_dataframe() >>> print(df)
- export_summary(output_path, format='excel')[source]¶
Export summary of batch results.
- Parameters:
- Return type:
- Returns:
self for method chaining
Example
>>> batch.export_summary('summary.xlsx')
- apply_filter(filter_fn)[source]¶
Filter results based on custom criteria.
- Parameters:
filter_fn (
Callable[[Path,RheoData,dict[str,Any]],bool]) – Function that takes (file_path, data, metrics) and returns True to keep the result- Return type:
- Returns:
self for method chaining
Example
>>> # Keep only results with R² > 0.9 >>> batch.apply_filter(lambda p, d, m: m.get('r_squared', 0) > 0.9)
- get_statistics()[source]¶
Get statistics across all results.
Example
>>> stats = batch.get_statistics() >>> print(f"Mean R²: {stats['mean_r_squared']:.4f}")
Description: Process multiple datasets with the same workflow in parallel.
Example - Basic Batch:
from rheojax.pipeline import Pipeline, BatchPipeline
# Define template pipeline
template = (Pipeline()
.transform('smooth', window=11)
.fit('maxwell')
.plot(save='${filename}_fit.png') # ${filename} replaced per file
.save('${filename}_results.hdf5'))
# Create batch processor
batch = BatchPipeline(template)
# Process directory
batch.process_directory('data/', pattern='*.txt')
# Get all results
all_results = batch.get_all_results()
# Export summary
batch.export_summary('batch_summary.xlsx')
Example - Parallel Processing:
# Use multiple cores
batch = BatchPipeline(template, n_jobs=4) # 4 parallel workers
# Process with progress bar
batch.process_directory('data/', pattern='*.txt',
progress_bar=True)
# Process specific files
files = ['sample1.txt', 'sample2.txt', 'sample3.txt']
batch.process_files(files)
Key Methods:
process_directory(path, pattern): Process all matching files in directoryprocess_files(file_list): Process specific filesget_all_results(): Retrieve results from all filesexport_summary(filepath): Export comparison tableget_failed_files(): Get list of failed processing attempts
Parameters:
template(Pipeline): Template pipeline to applyn_jobs(int): Number of parallel workers (-1 = all cores)fail_on_error(bool): Raise exception on first error (default: False)progress_bar(bool): Show progress bar (default: False)
Error Handling¶
Pipeline Error Management¶
pipeline = (Pipeline()
.load('data.txt')
.fit('maxwell', fail_on_error=False)) # Don't raise exception
# Check for errors
if pipeline.has_errors():
errors = pipeline.get_errors()
print(f"Errors encountered: {errors}")
else:
results = pipeline.get_results()
Pipeline Validation¶
Validate before execution:
pipeline = (Pipeline()
.load('data.txt')
.fit('maxwell'))
# Validate pipeline
is_valid, messages = pipeline.validate()
if is_valid:
results = pipeline.execute()
else:
print(f"Validation failed: {messages}")
Debug Mode¶
Enable debugging output:
# Enable debug logging
pipeline = Pipeline(debug=True)
# Or set verbosity
pipeline = Pipeline(verbose=2) # 0=silent, 1=info, 2=debug
# Inspect pipeline state
state = pipeline.get_state()
print(f"Current step: {state['current_step']}")
print(f"Data loaded: {state['data_loaded']}")
print(f"Model fitted: {state['model_fitted']}")
Best Practices¶
Method Chaining Style¶
Recommended (readable, clean):
results = (Pipeline()
.load('data.txt')
.transform('smooth', window=11)
.fit('maxwell')
.plot(show=True)
.get_results())
Acceptable (for debugging):
pipeline = Pipeline()
pipeline.load('data.txt')
pipeline.transform('smooth', window=11)
pipeline.fit('maxwell')
pipeline.plot(show=True)
results = pipeline.get_results()
Error Recovery¶
# Try multiple models until one succeeds
models = ['maxwell', 'zener', 'fractional_maxwell_gel']
for model_name in models:
try:
results = (Pipeline()
.load('data.txt')
.fit(model_name)
.get_results())
print(f"Success with {model_name}")
break
except Exception as e:
print(f"{model_name} failed: {e}")
continue
Performance Optimization¶
# Cache intermediate results
pipeline = Pipeline(cache=True)
# Process in chunks for large batches
batch = BatchPipeline(template, n_jobs=-1)
batch.process_directory('data/', chunk_size=10)
See Also¶
/user_guide/pipeline_api - Comprehensive pipeline tutorial
/user_guide/modular_api - Low-level API for custom control
Models API - Model API reference
Transforms API - Transform API reference
rheojax.core.base.BaseModel- Base model classrheojax.core.base.BaseTransform- Base transform class