Source code for rheojax.pipeline.bayesian

"""Specialized pipeline for Bayesian workflows.

This module provides the BayesianPipeline class for orchestrating the complete
NLSQ → NumPyro NUTS workflow with a fluent API.

Example:
    >>> from rheojax.pipeline.bayesian import BayesianPipeline
    >>> pipeline = BayesianPipeline()
    >>> result = (pipeline
    ...     .load('data.csv')
    ...     .fit_nlsq('maxwell')
    ...     .fit_bayesian(num_samples=2000)
    ...     .plot_posterior()
    ...     .save('results.hdf5'))
"""

from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd

from rheojax.core.arviz_utils import import_arviz
from rheojax.core.base import BaseModel
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.registry import ModelRegistry
from rheojax.logging import get_logger, log_bayesian, log_fit
from rheojax.pipeline.base import Pipeline

# Safe JAX import (verifies NLSQ was imported first)
jax, jnp = safe_import_jax()

logger = get_logger(__name__)


[docs] class BayesianPipeline(Pipeline): """Specialized pipeline for Bayesian rheological analysis workflows. This class extends the base Pipeline to provide a fluent API for the NLSQ → NumPyro NUTS workflow. It supports: - NLSQ optimization for fast point estimation - Bayesian inference with automatic warm-start from NLSQ - Convergence diagnostics (R-hat, ESS, divergences) - Posterior visualization (distributions and trace plots) All methods return self to enable method chaining. Attributes: data: Current RheoData state (inherited from Pipeline) _last_model: Last fitted model (inherited from Pipeline) _nlsq_result: Stored NLSQ optimization result _bayesian_result: Stored Bayesian inference result _diagnostics: Stored convergence diagnostics Example: >>> pipeline = BayesianPipeline() >>> pipeline.load('data.csv') \\ ... .fit_nlsq('maxwell') \\ ... .fit_bayesian(num_samples=2000) \\ ... .plot_posterior() \\ ... .save('results.hdf5') """
[docs] def __init__(self, data=None): """Initialize Bayesian pipeline. Args: data: Optional initial RheoData. If None, must call load() first. """ super().__init__(data=data) self._nlsq_result = None self._bayesian_result = None self._diagnostics = None logger.debug("BayesianPipeline initialized", has_data=data is not None)
[docs] def fit_nlsq(self, model: str | BaseModel, **nlsq_kwargs) -> BayesianPipeline: """Fit model using NLSQ optimization for point estimation. This method performs fast GPU-accelerated nonlinear least squares optimization to obtain point estimates of model parameters. The optimization result is stored for potential warm-starting of Bayesian inference. Args: model: Model name (string) or Model instance to fit **nlsq_kwargs: Additional arguments passed to NLSQ optimizer (e.g., max_iter, ftol, xtol, gtol) Returns: self for method chaining Raises: ValueError: If data not loaded Note: This method writes resolved ``deformation_mode``, ``poisson_ratio``, and ``test_mode`` back to ``self.data.metadata`` so that a subsequent ``fit_bayesian()`` call inherits these settings without the caller having to repeat them. Example: >>> pipeline.fit_nlsq('maxwell') >>> # or with instance >>> from rheojax.models import Maxwell >>> pipeline.fit_nlsq(Maxwell(), max_iter=1000) """ if self.data is None: logger.error("No data loaded for NLSQ fit") raise ValueError("No data loaded. Call load() first.") # Create model if string if isinstance(model, str): model_obj = ModelRegistry.create(model) model_name = model else: model_obj = model model_name = model_obj.__class__.__name__ # Fit using model's fit method (uses NLSQ by default) X = self.data.x y = self.data.y # Convert to numpy for fitting if isinstance(X, jnp.ndarray): X = np.array(X) if isinstance(y, jnp.ndarray): y = np.array(y) logger.debug( "Starting NLSQ fit", model=model_name, data_shape=X.shape, # type: ignore[union-attr] ) with log_fit( logger, model=model_name, data_shape=X.shape, # type: ignore[union-attr] test_mode=( self.data.metadata.get("test_mode", "unknown") if hasattr(self.data, "metadata") and self.data.metadata is not None else "unknown" ), ) as ctx: # BP-002: auto-propagate test_mode and deformation_mode from data metadata if hasattr(self, "data") and self.data is not None: _meta = getattr(self.data, "metadata", None) if _meta is not None: if "test_mode" not in nlsq_kwargs: _tm = _meta.get("test_mode") if _tm is not None: nlsq_kwargs["test_mode"] = _tm # R9-PIPE-DMT: propagate deformation_mode for DMTA data if "deformation_mode" not in nlsq_kwargs: _dm = _meta.get("deformation_mode") if _dm is not None: nlsq_kwargs["deformation_mode"] = _dm if "poisson_ratio" not in nlsq_kwargs: _pr = _meta.get("poisson_ratio") if _pr is not None: nlsq_kwargs["poisson_ratio"] = _pr # Remove 'method' from nlsq_kwargs to prevent "multiple values" # TypeError since we explicitly pass method="nlsq" below. nlsq_kwargs.pop("method", None) model_obj.fit(X, y, method="nlsq", **nlsq_kwargs) r_squared = model_obj.score(X, y) ctx["r_squared"] = r_squared ctx["n_parameters"] = len(model_obj.parameters) # R10-PIPE-BAY-001: Write resolved deformation_mode and poisson_ratio back # to data.metadata so fit_bayesian() can propagate them without the caller # having to repeat these kwargs on the Bayesian call. # R12-E-005 (part): ensure metadata dict exists before writing test_mode. if self.data is not None and self.data.metadata is None: self.data.metadata = {} if self.data is not None and self.data.metadata is not None: _dm_resolved = nlsq_kwargs.get("deformation_mode") if _dm_resolved is None: _dm_resolved = self.data.metadata.get("deformation_mode") if _dm_resolved is not None: self.data.metadata["deformation_mode"] = _dm_resolved _pr_resolved = nlsq_kwargs.get("poisson_ratio") if _pr_resolved is None: _pr_resolved = self.data.metadata.get("poisson_ratio") if _pr_resolved is not None: self.data.metadata["poisson_ratio"] = _pr_resolved # R12-E-005: write resolved test_mode back to metadata so # fit_bayesian() reads the correct mode without the caller # having to repeat the kwarg. _resolved_tm = getattr(model_obj, "_test_mode", None) if _resolved_tm is not None: self.data.metadata["test_mode"] = ( _resolved_tm.value if hasattr(_resolved_tm, "value") else str(_resolved_tm) ) # Store fitted model self._last_model = model_obj self.steps.append(("fit_nlsq", model_obj)) self.history.append(("fit_nlsq", model_name, r_squared)) # Store NLSQ result from model self._nlsq_result = model_obj.get_nlsq_result() logger.info( "NLSQ fit completed", model=model_name, r_squared=r_squared, ) return self
[docs] def fit_bayesian( # type: ignore[override] self, num_samples: int = 2000, num_warmup: int = 1000, num_chains: int = 4, **nuts_kwargs, ) -> BayesianPipeline: """Perform Bayesian inference using NumPyro NUTS sampler. This method runs NUTS (No-U-Turn Sampler) for Bayesian parameter estimation. If a model has been previously fitted with fit_nlsq(), the NLSQ point estimates are automatically used for warm-starting the sampler, leading to faster convergence. Multi-chain sampling is enabled by default (num_chains=4) to provide reliable convergence diagnostics (R-hat, ESS) and parallel execution on multi-GPU systems. Args: num_samples: Number of posterior samples per chain (default: 2000) num_warmup: Number of warmup/burn-in iterations (default: 1000) num_chains: Number of MCMC chains (default: 4). Multiple chains enable proper R-hat computation and parallel execution. Chain method is auto-selected: 'parallel' on multi-GPU, 'vectorized' on single GPU/CPU. **nuts_kwargs: Additional arguments passed to NUTS sampler (e.g., target_accept_prob, chain_method) Returns: self for method chaining Raises: ValueError: If no model has been fitted with fit_nlsq() Example: >>> pipeline.fit_nlsq('maxwell').fit_bayesian(num_samples=2000) >>> # With custom NUTS parameters >>> pipeline.fit_bayesian( ... num_samples=3000, ... num_warmup=1500, ... num_chains=4, ... target_accept_prob=0.9 ... ) """ if self._last_model is None: logger.error("No model fitted before Bayesian inference") raise ValueError( "No model fitted. Call fit_nlsq() first to fit a model " "before running Bayesian inference." ) if self.data is None: logger.error("No data loaded for Bayesian inference") raise ValueError("No data loaded. Call load() first.") # Get data X = self.data.x y = self.data.y # Convert to numpy if isinstance(X, jnp.ndarray): X = np.array(X) if isinstance(y, jnp.ndarray): y = np.array(y) # Extract initial values from NLSQ fit for warm-start initial_values = None if self._last_model.fitted_: initial_values = { name: v for name in self._last_model.parameters if (v := self._last_model.parameters.get_value(name)) is not None } logger.debug( "Using NLSQ warm-start", n_initial_values=len(initial_values), ) # Get test_mode and deformation_mode from data metadata if available. # Convert test_mode string to TestMode enum for model_function dispatch. test_mode = None deformation_mode = None poisson_ratio = None if hasattr(self.data, "metadata") and self.data.metadata is not None: test_mode = self.data.metadata.get("test_mode") # R9-PIPE-DMT: propagate deformation_mode for DMTA data deformation_mode = self.data.metadata.get("deformation_mode") _pr = self.data.metadata.get("poisson_ratio") if _pr is not None: poisson_ratio = _pr # test_mode is passed as-is to fit_bayesian(), which handles # str → TestMode conversion internally (bayesian.py:1093-1094). model_name = self._last_model.__class__.__name__ logger.debug( "Starting Bayesian inference", model=model_name, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, test_mode=test_mode, ) # Run Bayesian inference with multi-chain parallelization with log_bayesian( logger, model=model_name, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, ) as ctx: _bay_kwargs: dict = { "test_mode": test_mode, "num_warmup": num_warmup, "num_samples": num_samples, "num_chains": num_chains, "initial_values": initial_values, } if deformation_mode is not None: _bay_kwargs["deformation_mode"] = deformation_mode if poisson_ratio is not None: _bay_kwargs["poisson_ratio"] = poisson_ratio _bay_kwargs.update(nuts_kwargs) result = self._last_model.fit_bayesian(X, y, **_bay_kwargs) # Add diagnostics to context _div = result.diagnostics.get("divergences") ctx["divergences"] = _div if _div is not None else 0 if "r_hat" in result.diagnostics: # R13-BAY-PIPE-001: Filter NaN values before computing # aggregate — a single failed parameter diagnostic should # not make the summary metric NaN. r_hat_values = [ v for v in result.diagnostics["r_hat"].values() if np.isfinite(v) ] if r_hat_values: ctx["r_hat_max"] = max(r_hat_values) else: ctx["r_hat_max"] = None logger.warning("All R-hat values are NaN — diagnostics invalid") if "ess" in result.diagnostics: ess_values = [ v for v in result.diagnostics["ess"].values() if np.isfinite(v) ] if ess_values: ctx["ess_min"] = min(ess_values) else: ctx["ess_min"] = None logger.warning("All ESS values are NaN — diagnostics invalid") # Store results self._bayesian_result = result self._diagnostics = result.diagnostics # Add to history _div_hist = result.diagnostics.get("divergences") self.history.append( ( "fit_bayesian", num_samples, num_warmup, _div_hist if _div_hist is not None else 0, ) ) logger.info( "Bayesian inference completed", model=model_name, divergences=_div_hist if _div_hist is not None else 0, num_samples=num_samples, num_chains=num_chains, ) return self
[docs] def get_diagnostics(self) -> dict[str, Any]: """Get convergence diagnostics from Bayesian inference. Returns diagnostics including R-hat (Gelman-Rubin statistic), effective sample size (ESS), and number of divergent transitions. Returns: Dictionary with diagnostic information: - r_hat: R-hat for each parameter (dict) - ess: Effective sample size for each parameter (dict) - divergences: Number of divergent transitions (int) Raises: ValueError: If Bayesian inference has not been run Example: >>> diagnostics = pipeline.get_diagnostics() >>> print(f"R-hat: {diagnostics['r_hat']}") >>> print(f"ESS: {diagnostics['ess']}") >>> print(f"Divergences: {diagnostics['divergences']}") """ if self._bayesian_result is None: logger.error("No Bayesian result available for diagnostics") raise ValueError("No Bayesian result available. Call fit_bayesian() first.") logger.debug("Retrieving diagnostics", n_params=len(self._diagnostics)) # Return a copy to prevent callers from mutating internal state import copy as _copy return _copy.deepcopy(self._diagnostics)
[docs] def get_posterior_summary(self) -> pd.DataFrame: """Get formatted posterior summary statistics. Returns a pandas DataFrame with summary statistics for each parameter including mean, standard deviation, median, and quantiles (5%, 25%, 75%, 95%). Returns: DataFrame with parameters as rows and statistics as columns Raises: ValueError: If Bayesian inference has not been run Example: >>> summary = pipeline.get_posterior_summary() >>> print(summary) mean std median q05 q25 q75 q95 a 5.123 0.245 5.110 4.721 4.962 5.285 5.531 b 0.487 0.032 0.485 0.435 0.465 0.509 0.542 """ if self._bayesian_result is None: logger.error("No Bayesian result available for posterior summary") raise ValueError("No Bayesian result available. Call fit_bayesian() first.") # Convert summary dict to DataFrame summary_data = {} for param_name, stats in self._bayesian_result.summary.items(): summary_data[param_name] = stats df = pd.DataFrame(summary_data).T logger.debug("Posterior summary retrieved", n_parameters=len(df)) return df
[docs] def plot_posterior( self, param_name: str | None = None, show: bool = True, **plot_kwargs ) -> BayesianPipeline: """Plot posterior distributions. Generates histogram plots of posterior distributions for model parameters. If param_name is None, plots all parameters in separate subplots. Args: param_name: Name of specific parameter to plot. If None, plots all parameters (default: None) show: Whether to call plt.show() (default: True) **plot_kwargs: Additional arguments passed to matplotlib (e.g., bins, alpha, color) Returns: self for method chaining Raises: ValueError: If Bayesian inference has not been run Example: >>> # Plot all parameters >>> pipeline.plot_posterior() >>> # Plot specific parameter >>> pipeline.plot_posterior('a', bins=50, alpha=0.7) >>> # Plot without showing (for save_figure) >>> pipeline.plot_posterior(show=False).save_figure('posterior.pdf') """ if self._bayesian_result is None: logger.error("No Bayesian result available for posterior plot") raise ValueError("No Bayesian result available. Call fit_bayesian() first.") import matplotlib.pyplot as plt posterior_samples = self._bayesian_result.posterior_samples # Determine which parameters to plot if param_name is not None: if param_name not in posterior_samples: logger.error( "Parameter not found in posterior samples", param_name=param_name, available_params=list(posterior_samples.keys()), ) raise ValueError( f"Parameter '{param_name}' not found in posterior samples. " f"Available parameters: {list(posterior_samples.keys())}" ) params_to_plot = [param_name] else: params_to_plot = list(posterior_samples.keys()) logger.debug("Plotting posterior", params=params_to_plot) # Create subplots n_params = len(params_to_plot) n_cols = min(3, n_params) n_rows = (n_params + n_cols - 1) // n_cols fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows)) # Handle single parameter case if n_params == 1: axes = np.array([axes]) axes_flat = axes.flatten() if n_params > 1 else axes # Plot each parameter for idx, param in enumerate(params_to_plot): ax = axes_flat[idx] samples = posterior_samples[param] # Plot histogram (copy kwargs to avoid mutating caller's dict) _kwargs = plot_kwargs.copy() bins = _kwargs.pop("bins", 30) alpha = _kwargs.pop("alpha", 0.7) ax.hist(samples, bins=bins, alpha=alpha, **_kwargs) # Add summary statistics mean = self._bayesian_result.summary[param]["mean"] median = self._bayesian_result.summary[param]["median"] ax.axvline(mean, color="red", linestyle="--", linewidth=2, label="Mean") ax.axvline( median, color="blue", linestyle="--", linewidth=2, label="Median" ) ax.set_xlabel(f"{param}") ax.set_ylabel("Frequency") ax.set_title(f"Posterior: {param}") ax.legend() ax.grid(alpha=0.3) # Hide unused subplots for idx in range(n_params, len(axes_flat)): axes_flat[idx].set_visible(False) plt.tight_layout() # Store figure for save_figure() method self._current_figure = fig if show: plt.show() self.history.append(("plot_posterior", param_name if param_name else "all")) return self
[docs] def plot_trace( self, param_name: str | None = None, show: bool = True, **plot_kwargs ) -> BayesianPipeline: """Plot MCMC trace plots. Generates trace plots showing parameter values across MCMC iterations. Useful for diagnosing convergence issues. If param_name is None, plots all parameters. Args: param_name: Name of specific parameter to plot. If None, plots all parameters (default: None) show: Whether to call plt.show() (default: True) **plot_kwargs: Additional arguments passed to matplotlib (e.g., alpha, linewidth) Returns: self for method chaining Raises: ValueError: If Bayesian inference has not been run Example: >>> # Plot all trace plots >>> pipeline.plot_trace() >>> # Plot specific parameter >>> pipeline.plot_trace('a', alpha=0.5) >>> # Plot without showing (for save_figure) >>> pipeline.plot_trace(show=False).save_figure('trace.pdf') """ if self._bayesian_result is None: logger.error("No Bayesian result available for trace plot") raise ValueError("No Bayesian result available. Call fit_bayesian() first.") import matplotlib.pyplot as plt posterior_samples = self._bayesian_result.posterior_samples # Determine which parameters to plot if param_name is not None: if param_name not in posterior_samples: logger.error( "Parameter not found in posterior samples", param_name=param_name, available_params=list(posterior_samples.keys()), ) raise ValueError( f"Parameter '{param_name}' not found in posterior samples. " f"Available parameters: {list(posterior_samples.keys())}" ) params_to_plot = [param_name] else: params_to_plot = list(posterior_samples.keys()) logger.debug("Plotting trace", params=params_to_plot) # Create subplots n_params = len(params_to_plot) fig, axes = plt.subplots(n_params, 1, figsize=(10, 3 * n_params)) # Handle single parameter case if n_params == 1: axes = [axes] # Plot each parameter for idx, param in enumerate(params_to_plot): ax = axes[idx] samples = posterior_samples[param] # Plot trace _kwargs = plot_kwargs.copy() alpha = _kwargs.pop("alpha", 0.7) ax.plot(samples, alpha=alpha, **_kwargs) # Add mean line mean = self._bayesian_result.summary[param]["mean"] ax.axhline(mean, color="red", linestyle="--", linewidth=2, label="Mean") ax.set_xlabel("Iteration") ax.set_ylabel(f"{param}") ax.set_title(f"Trace: {param}") ax.legend() ax.grid(alpha=0.3) plt.tight_layout() # Store figure for save_figure() method self._current_figure = fig if show: plt.show() self.history.append(("plot_trace", param_name if param_name else "all")) return self
def _get_inference_data(self) -> Any: """Get or create ArviZ InferenceData from Bayesian result. Helper method that retrieves the InferenceData object from the BayesianResult, converting it on first access. The InferenceData is cached for subsequent calls. Returns: ArviZ InferenceData object Raises: ValueError: If Bayesian inference has not been run ImportError: If arviz is not installed Example: >>> idata = pipeline._get_inference_data() """ if self._bayesian_result is None: logger.error("No Bayesian result available for InferenceData conversion") raise ValueError("No Bayesian result available. Call fit_bayesian() first.") logger.debug("Converting Bayesian result to InferenceData") return self._bayesian_result.to_inference_data()
[docs] def plot_pair( self, var_names: list[str] | None = None, kind: str = "scatter", divergences: bool = True, show: bool = True, **plot_kwargs, ) -> BayesianPipeline: """Plot pairwise relationships between parameters (pair plot). Creates a matrix of scatter or KDE plots showing correlations between parameters. This is critical for identifying parameter dependencies, non-identifiability issues, and understanding the joint posterior structure. Divergent transitions are highlighted by default to identify problematic posterior geometry. Args: var_names: List of parameter names to plot. If None, plots all parameters (default: None) kind: Type of pair plot - "scatter", "kde", or "hexbin" (default: "scatter") divergences: Whether to highlight divergent transitions in red (default: True). Useful for identifying problematic regions. show: Whether to call plt.show() (default: True) **plot_kwargs: Additional arguments passed to arviz.plot_pair() (e.g., marginals, point_estimate_marker_style) Returns: self for method chaining Raises: ValueError: If Bayesian inference has not been run ImportError: If arviz is not installed Example: >>> # Plot all parameters with divergences highlighted >>> pipeline.plot_pair() >>> >>> # Plot specific parameters as KDE >>> pipeline.plot_pair(var_names=["G0", "eta"], kind="kde") >>> >>> # Save without showing >>> pipeline.plot_pair(show=False).save_figure("pair.pdf") Note: Pair plots are essential for diagnosing: - Parameter correlations (indicates non-identifiability) - Funnel geometry (divergences concentrated in specific regions) - Multimodal posteriors (multiple clusters) """ logger.debug( "Creating pair plot", var_names=var_names, kind=kind, divergences=divergences, ) try: az = import_arviz(required=("plot_pair",)) except ImportError as exc: logger.error("ArviZ not installed for pair plot", exc_info=True) raise ImportError( "ArviZ is required for pair plots. Install it with: pip install arviz" ) from exc # Get InferenceData idata = self._get_inference_data() # Filter degenerate parameters to prevent ArviZ KDE crashes import numpy as np if ( var_names is None and hasattr(idata, "posterior") and idata.posterior is not None ): all_vars = list(idata.posterior.data_vars) var_names = [ v for v in all_vars if np.ptp(idata.posterior[v].values.ravel()) > 1e-10 ] if not var_names: logger.warning( "All posterior parameters are degenerate, skipping pair plot" ) return self # Create pair plot axes = az.plot_pair( idata, var_names=var_names, kind=kind, divergences=divergences, **plot_kwargs, ) # Extract figure from axes import matplotlib.pyplot as plt if hasattr(axes, "figure"): fig = axes.figure elif hasattr(axes, "ravel"): fig = axes.ravel()[0].figure else: fig = plt.gcf() # Store figure for save_figure() method self._current_figure = fig if show: plt.show() self.history.append(("plot_pair", var_names if var_names else "all")) return self
[docs] def plot_forest( self, var_names: list[str] | None = None, combined: bool = True, hdi_prob: float = 0.95, show: bool = True, **plot_kwargs, ) -> BayesianPipeline: """Plot forest plot with credible intervals for parameters. Creates a forest plot showing parameter estimates with credible intervals (highest density intervals). Excellent for comparing parameter magnitudes and uncertainties at a glance. Each parameter is shown as a point estimate with error bars representing the credible interval. Args: var_names: List of parameter names to plot. If None, plots all parameters (default: None) combined: Whether to combine multiple chains (default: True) hdi_prob: Probability mass for credible interval (default: 0.95). Common values: 0.68 (1σ), 0.95 (2σ), 0.997 (3σ) show: Whether to call plt.show() (default: True) **plot_kwargs: Additional arguments passed to arviz.plot_forest() (e.g., rope, ref_val, colors) Returns: self for method chaining Raises: ValueError: If Bayesian inference has not been run ImportError: If arviz is not installed Example: >>> # Plot all parameters with 95% CI >>> pipeline.plot_forest() >>> >>> # Plot specific parameters with 68% CI >>> pipeline.plot_forest(var_names=["G0", "eta"], hdi_prob=0.68) >>> >>> # Save without showing >>> pipeline.plot_forest(show=False).save_figure("forest.pdf") Note: Forest plots are useful for: - Quickly comparing parameter magnitudes - Assessing parameter uncertainty - Identifying parameters with poor estimation (wide intervals) """ logger.debug( "Creating forest plot", var_names=var_names, combined=combined, hdi_prob=hdi_prob, ) try: az = import_arviz(required=("plot_forest",)) except ImportError as exc: logger.error("ArviZ not installed for forest plot", exc_info=True) raise ImportError( "ArviZ is required for forest plots. Install it with: pip install arviz" ) from exc # Get InferenceData idata = self._get_inference_data() # Filter degenerate parameters to prevent ArviZ KDE crashes import numpy as np if ( var_names is None and hasattr(idata, "posterior") and idata.posterior is not None ): all_vars = list(idata.posterior.data_vars) var_names = [ v for v in all_vars if np.ptp(idata.posterior[v].values.ravel()) > 1e-10 ] if not var_names: logger.warning( "All posterior parameters are degenerate, skipping forest plot" ) return self # Create forest plot axes = az.plot_forest( idata, var_names=var_names, combined=combined, hdi_prob=hdi_prob, **plot_kwargs, ) # Extract figure from axes import matplotlib.pyplot as plt if hasattr(axes, "figure"): fig = axes.figure elif isinstance(axes, np.ndarray): fig = axes.ravel()[0].figure else: fig = plt.gcf() # Store figure for save_figure() method self._current_figure = fig if show: plt.show() self.history.append(("plot_forest", var_names if var_names else "all")) return self
[docs] def plot_energy(self, show: bool = True, **plot_kwargs) -> BayesianPipeline: """Plot NUTS energy diagnostic plot. Creates an energy plot showing the distribution of energy transitions during NUTS sampling. This is a NUTS-specific diagnostic that helps identify problematic posterior geometry such as heavy tails, funnels, or multimodal distributions. Energy transitions that differ between the marginal and transition distributions indicate sampling problems. Args: show: Whether to call plt.show() (default: True) **plot_kwargs: Additional arguments passed to arviz.plot_energy() Returns: self for method chaining Raises: ValueError: If Bayesian inference has not been run ImportError: If arviz is not installed Example: >>> # Plot energy diagnostic >>> pipeline.plot_energy() >>> >>> # Save without showing >>> pipeline.plot_energy(show=False).save_figure("energy.pdf") Note: Energy diagnostics help identify: - Heavy-tailed posteriors (energy dist has fat tails) - Funnel geometry (energy varies dramatically) - Problematic parameterizations Good NUTS sampling shows similar marginal and transition energy distributions. """ logger.debug("Creating energy plot") try: az = import_arviz(required=("plot_energy",)) except ImportError as exc: logger.error("ArviZ not installed for energy plot", exc_info=True) raise ImportError( "ArviZ is required for energy plots. Install it with: pip install arviz" ) from exc # Get InferenceData idata = self._get_inference_data() sample_stats = getattr(idata, "sample_stats", None) if sample_stats is None or not hasattr(sample_stats, "energy"): logger.error("Energy diagnostic missing from InferenceData") raise RuntimeError( "Energy diagnostic is missing from InferenceData.sample_stats. " "Ensure NumPyro was run with NUTS and that energy/potential_energy " "fields are available for conversion to ArviZ." ) # Create energy plot axes = az.plot_energy(idata, **plot_kwargs) # Extract figure from axes import matplotlib.pyplot as plt if hasattr(axes, "figure"): fig = axes.figure elif isinstance(axes, np.ndarray): fig = axes.ravel()[0].figure else: fig = plt.gcf() # Store figure for save_figure() method self._current_figure = fig if show: plt.show() self.history.append(("plot_energy", None)) return self
[docs] def plot_autocorr( self, var_names: list[str] | None = None, combined: bool = False, show: bool = True, **plot_kwargs, ) -> BayesianPipeline: """Plot autocorrelation diagnostic for MCMC mixing. Creates autocorrelation plots showing how correlated consecutive samples are in the MCMC chain. High autocorrelation indicates poor mixing and suggests more samples are needed for reliable inference. Ideally, autocorrelation should decay quickly to zero. Args: var_names: List of parameter names to plot. If None, plots all parameters (default: None) combined: Whether to combine multiple chains (default: False) show: Whether to call plt.show() (default: True) **plot_kwargs: Additional arguments passed to arviz.plot_autocorr() (e.g., max_lag) Returns: self for method chaining Raises: ValueError: If Bayesian inference has not been run ImportError: If arviz is not installed Example: >>> # Plot autocorrelation for all parameters >>> pipeline.plot_autocorr() >>> >>> # Plot specific parameters with longer lag >>> pipeline.plot_autocorr(var_names=["G0"], max_lag=100) >>> >>> # Save without showing >>> pipeline.plot_autocorr(show=False).save_figure("autocorr.pdf") Note: Autocorrelation diagnostics help identify: - Poor mixing (high autocorrelation persists) - Need for more samples (ESS will be low) - Chain length adequacy Goal: autocorrelation drops to ~0 within a few dozen lags. """ logger.debug( "Creating autocorrelation plot", var_names=var_names, combined=combined, ) try: az = import_arviz(required=("plot_autocorr",)) except ImportError as exc: logger.error("ArviZ not installed for autocorrelation plot", exc_info=True) raise ImportError( "ArviZ is required for autocorrelation plots. Install it with: pip install arviz" ) from exc # Get InferenceData idata = self._get_inference_data() # Create autocorrelation plot axes = az.plot_autocorr( idata, var_names=var_names, combined=combined, **plot_kwargs, ) # Extract figure from axes import matplotlib.pyplot as plt if hasattr(axes, "figure"): fig = axes.figure elif isinstance(axes, np.ndarray): fig = axes.ravel()[0].figure else: fig = plt.gcf() # Store figure for save_figure() method self._current_figure = fig if show: plt.show() self.history.append(("plot_autocorr", var_names if var_names else "all")) return self
[docs] def plot_rank( self, var_names: list[str] | None = None, show: bool = True, **plot_kwargs, ) -> BayesianPipeline: """Plot rank plot for convergence diagnostics. Creates rank plots (also called rank histograms or rank-normalization plots) which are a modern alternative to trace plots for diagnosing convergence. A uniform rank distribution across chains indicates good mixing and convergence. Non-uniformity suggests convergence problems. Args: var_names: List of parameter names to plot. If None, plots all parameters (default: None) show: Whether to call plt.show() (default: True) **plot_kwargs: Additional arguments passed to arviz.plot_rank() Returns: self for method chaining Raises: ValueError: If Bayesian inference has not been run ImportError: If arviz is not installed Example: >>> # Plot rank diagnostic for all parameters >>> pipeline.plot_rank() >>> >>> # Plot specific parameters >>> pipeline.plot_rank(var_names=["G0", "eta"]) >>> >>> # Save without showing >>> pipeline.plot_rank(show=False).save_figure("rank.pdf") Note: Rank plots help identify: - Non-convergence (non-uniform rank distribution) - Chain sticking (vertical bands) - Insufficient mixing (patterns in ranks) Goal: uniform histogram across all bins. """ logger.debug("Creating rank plot", var_names=var_names) try: az = import_arviz(required=("plot_rank",)) except ImportError as exc: logger.error("ArviZ not installed for rank plot", exc_info=True) raise ImportError( "ArviZ is required for rank plots. Install it with: pip install arviz" ) from exc # Get InferenceData idata = self._get_inference_data() # Create rank plot axes = az.plot_rank( idata, var_names=var_names, **plot_kwargs, ) # Extract figure from axes import matplotlib.pyplot as plt if hasattr(axes, "figure"): fig = axes.figure elif isinstance(axes, np.ndarray): fig = axes.ravel()[0].figure else: fig = plt.gcf() # Store figure for save_figure() method self._current_figure = fig if show: plt.show() self.history.append(("plot_rank", var_names if var_names else "all")) return self
[docs] def plot_ess( self, var_names: list[str] | None = None, kind: str = "local", show: bool = True, **plot_kwargs, ) -> BayesianPipeline: """Plot effective sample size (ESS) diagnostic. Creates a plot showing the effective sample size for each parameter, which quantifies how many independent samples the MCMC chain is equivalent to. Low ESS indicates high autocorrelation and suggests more samples are needed. ESS values should ideally be > 400. Args: var_names: List of parameter names to plot. If None, plots all parameters (default: None) kind: Type of ESS plot - "local", "quantile", or "evolution" (default: "local") show: Whether to call plt.show() (default: True) **plot_kwargs: Additional arguments passed to arviz.plot_ess() (e.g., min_ess) Returns: self for method chaining Raises: ValueError: If Bayesian inference has not been run ImportError: If arviz is not installed Example: >>> # Plot ESS for all parameters >>> pipeline.plot_ess() >>> >>> # Plot quantile ESS >>> pipeline.plot_ess(kind="quantile") >>> >>> # Save without showing >>> pipeline.plot_ess(show=False).save_figure("ess.pdf") Note: ESS diagnostics help assess: - Sampling efficiency (ESS / total samples) - Which parameters need more sampling - Overall chain quality Goal: ESS > 400 for bulk and tail estimates. """ logger.debug( "Creating ESS plot", var_names=var_names, kind=kind, ) try: az = import_arviz(required=("plot_ess",)) except ImportError as exc: logger.error("ArviZ not installed for ESS plot", exc_info=True) raise ImportError( "ArviZ is required for ESS plots. Install it with: pip install arviz" ) from exc # Get InferenceData idata = self._get_inference_data() # Create ESS plot axes = az.plot_ess( idata, var_names=var_names, kind=kind, **plot_kwargs, ) # Extract figure from axes import matplotlib.pyplot as plt if hasattr(axes, "figure"): fig = axes.figure elif isinstance(axes, np.ndarray): fig = axes.ravel()[0].figure else: fig = plt.gcf() # Store figure for save_figure() method self._current_figure = fig if show: plt.show() self.history.append(("plot_ess", var_names if var_names else "all")) return self
[docs] def reset(self) -> BayesianPipeline: """Reset pipeline to initial state. Clears all data, models, and results including NLSQ and Bayesian inference results. Returns: self for method chaining Example: >>> pipeline.reset() """ super().reset() self._nlsq_result = None self._bayesian_result = None self._diagnostics = None logger.debug("BayesianPipeline reset") return self
[docs] def __repr__(self) -> str: """String representation of Bayesian pipeline.""" n_steps = len(self.history) has_data = self.data is not None has_model = self._last_model is not None has_nlsq = self._nlsq_result is not None has_bayesian = self._bayesian_result is not None return ( f"BayesianPipeline(steps={n_steps}, " f"has_data={has_data}, " f"has_model={has_model}, " f"has_nlsq={has_nlsq}, " f"has_bayesian={has_bayesian})" )
__all__ = ["BayesianPipeline"]