Core Module (rheojax.core)

The core module provides fundamental data structures and abstractions for rheological analysis.

Data Container

RheoData

class rheojax.core.data.RheoData(x=None, y=None, x_units=None, y_units=None, domain='time', initial_test_mode=None, metadata=<factory>, validate=True)[source]

Bases: object

JAX-native container for rheological data with NumPy/JAX array support.

This class provides a unified interface for rheological data that supports both NumPy and JAX arrays with additional features needed for rheological analysis including automatic test mode detection, data validation, and domain-specific operations.

x

Independent variable data (e.g., time, frequency)

y

Dependent variable data (e.g., stress, strain, modulus)

x_units

Units for x-axis data

y_units

Units for y-axis data

domain

Data domain (‘time’ or ‘frequency’)

metadata

Dictionary of additional metadata

validate

Whether to validate data on creation

initial_test_mode: InitVar[str | None] = None
__post_init__(initial_test_mode)[source]

Initialize and validate RheoData.

__setattr__(name, value)[source]

Invalidate JAX cache when x or y data is reassigned.

Return type:

None

to_jax()[source]

Convert arrays to JAX arrays.

Returns cached result on subsequent calls — invalidated if x or y are reassigned.

Return type:

RheoData

Returns:

New RheoData with JAX arrays

to_numpy()[source]

Convert arrays to NumPy arrays.

Uses np.asarray() for zero-copy conversion when possible, providing 10-30% memory savings for large arrays (>100k points).

Return type:

RheoData

Returns:

New RheoData with NumPy arrays

copy()[source]

Create a copy of the RheoData.

Return type:

RheoData

Returns:

Copy of the RheoData instance

update_metadata(metadata)[source]

Update metadata dictionary.

Parameters:

metadata (dict[str, Any]) – Dictionary of metadata to add/update

to_dict()[source]

Convert to dictionary representation.

Return type:

dict[str, Any]

Returns:

Dictionary with data and metadata

classmethod from_dict(data_dict)[source]

Create from dictionary representation.

Parameters:

data_dict (dict[str, Any]) – Dictionary with data and metadata

Return type:

RheoData

Returns:

RheoData instance

property shape: tuple

Shape of the y data.

property ndim: int

Number of dimensions of y data.

property size: int

Size of y data.

property dtype

Data type of y data.

property is_complex: bool

Check if y data is complex.

property modulus: ndarray | None

Get modulus of complex data.

property phase: ndarray | None

Get phase of complex data.

property y_real: ndarray | Array

Get real component of y data.

For complex modulus data (G* = G’ + i·G’’), this returns the storage modulus (G’). For real data, returns y unchanged.

Returns:

Real component of y data (G’ for complex modulus)

Example

>>> data = read_trios('frequency_sweep.txt')  # Returns complex G*
>>> G_prime = data[0].y_real  # Storage modulus (G')
>>> plt.loglog(data[0].x, G_prime, label="G'")
property y_imag: ndarray | Array

Get imaginary component of y data.

For complex modulus data (G* = G’ + i·G’’), this returns the loss modulus (G’’). For real data, returns zeros.

Returns:

Imaginary component of y data (G’’ for complex modulus)

Example

>>> data = read_trios('frequency_sweep.txt')  # Returns complex G*
>>> G_double_prime = data[0].y_imag  # Loss modulus (G'')
>>> plt.loglog(data[0].x, G_double_prime, label='G"')
property storage_modulus: ndarray | None

Get storage modulus (G’) from complex modulus data.

Alias for y_real that makes rheological intent explicit.

Returns:

Storage modulus (G’) if data is complex, None otherwise

Example

>>> data = read_trios('frequency_sweep.txt')
>>> G_prime = data[0].storage_modulus
property loss_modulus: ndarray | None

Get loss modulus (G’’) from complex modulus data.

Alias for y_imag that makes rheological intent explicit.

Returns:

Loss modulus (G’’) if data is complex, None otherwise

Example

>>> data = read_trios('frequency_sweep.txt')
>>> G_double_prime = data[0].loss_modulus
property tan_delta: ndarray | None

Get loss tangent (tan δ = G’’/G’) from complex modulus data.

The loss tangent quantifies the ratio of viscous to elastic response: - tan δ < 1: Elastic-dominant (solid-like) - tan δ > 1: Viscous-dominant (liquid-like) - tan δ = 1: Equal elastic and viscous contributions

Returns:

Loss tangent (dimensionless) if data is complex, None otherwise

Example

>>> data = read_trios('frequency_sweep.txt')
>>> tan_d = data[0].tan_delta
>>> print(f"Material type: {'solid-like' if tan_d.mean() < 1 else 'liquid-like'}")
property test_mode: str

Automatically detect or retrieve test mode.

The test mode is detected based on data characteristics and cached in a private field. If already detected, returns the cached value. If explicitly set in metadata[‘test_mode’], returns that value.

Returns:

Test mode string (relaxation, creep, oscillation, rotation, unknown)

property deformation_mode: str

Get deformation mode from metadata.

Returns ‘shear’ if not explicitly set. Possible values: ‘shear’, ‘tension’, ‘bending’, ‘compression’.

property storage_modulus_label: str

Get appropriate storage modulus label based on deformation mode.

Returns “E’” for tensile/bending/compression, “G’” for shear.

property loss_modulus_label: str

Get appropriate loss modulus label based on deformation mode.

Returns ‘E”’ for tensile/bending/compression, ‘G”’ for shear.

__getitem__(idx)[source]

Support indexing and slicing.

__add__(other)[source]

Add two RheoData objects or scalar.

__sub__(other)[source]

Subtract two RheoData objects or scalar.

__mul__(other)[source]

Multiply by scalar or another RheoData.

interpolate(new_x)[source]

Interpolate data to new x values.

Parameters:

new_x (numpy.typing.ArrayLike) – New x values for interpolation

Return type:

RheoData

Returns:

Interpolated RheoData

resample(n_points)[source]

Resample data to specified number of points.

Parameters:

n_points (int) – Number of points to resample to

Return type:

RheoData

Returns:

Resampled RheoData

smooth(window_size=5)[source]

Smooth data using moving average.

Parameters:

window_size (int) – Size of smoothing window

Return type:

RheoData

Returns:

Smoothed RheoData

derivative()[source]

Compute numerical derivative.

Return type:

RheoData

Returns:

RheoData with derivative values

integral()[source]

Compute numerical integral.

Return type:

RheoData

Returns:

RheoData with integrated values

to_frequency_domain()[source]

Convert time domain data to frequency domain.

Return type:

RheoData

Returns:

Frequency domain RheoData

to_time_domain()[source]

Convert frequency domain data to time domain.

Return type:

RheoData

Returns:

Time domain RheoData

__init__(x=None, y=None, x_units=None, y_units=None, domain='time', initial_test_mode=None, metadata=<factory>, validate=True)
slice(start=None, end=None)[source]

Slice data between x values.

Parameters:
Return type:

RheoData

Returns:

Sliced RheoData

Base Classes

BaseModel

class rheojax.core.base.BaseModel[source]

Bases: BayesianMixin, ABC

Abstract base class for all rheological models.

This class defines the standard interface that all models must implement, supporting JAX arrays, scikit-learn style APIs, and Bayesian inference via NumPyro NUTS.

All models inherit Bayesian capabilities from BayesianMixin, including: - fit_bayesian(): Bayesian parameter estimation using NUTS - sample_prior(): Sample from prior distributions - get_credible_intervals(): Compute highest density intervals

The fit() method uses NLSQ optimization by default for fast point estimation, which can be used to warm-start Bayesian inference.

Abstract base class for all rheological models. Provides a consistent interface with support for scikit-learn style API and JAX arrays.

__init__()[source]

Initialize base model.

fit(X, y=None, method='nlsq', check_compatibility=False, use_log_residuals=None, use_multi_start=None, n_starts=5, perturb_factor=0.3, deformation_mode=None, poisson_ratio=0.5, auto_init=False, return_result=False, check_physics=False, uncertainty=None, **kwargs)[source]

Fit the model to data using NLSQ optimization.

This method uses NLSQ (GPU-accelerated nonlinear least squares) by default for fast point estimation. The optimization result is stored for potential warm-starting of Bayesian inference.

For very wide frequency ranges (>10 decades), multi-start optimization is automatically enabled to escape local minima.

Parameters:
  • X (numpy.typing.ArrayLike) – Input features

  • y (Optional[numpy.typing.ArrayLike]) – Target values

  • method (str) – Optimization method (‘nlsq’ by default for compatibility)

  • check_compatibility (bool) – Whether to check model-data compatibility before fitting. If True, warns when model may not be appropriate for data. Default is False for backward compatibility.

  • use_log_residuals (bool | None) – Whether to use log-space residuals for fitting. Recommended for wide frequency ranges (>8 decades) to prevent optimizer bias. If None (default), automatically detected based on data range. Explicit True/False overrides auto-detection.

  • use_multi_start (bool | None) – Whether to use multi-start optimization to escape local minima. Recommended for very wide ranges (>10 decades). If None (default), automatically enabled for >10 decades.

  • n_starts (int) – Number of random starts for multi-start optimization (default: 5)

  • perturb_factor (float) – Perturbation magnitude for multi-start random starts (default: 0.3). Parameters are perturbed by ± perturb_factor * (value or range). Larger values (0.7-0.9) explore wider parameter space.

  • auto_init (bool) – If True, calls auto_p0() to estimate initial parameters from data before running the optimizer (default: False).

  • return_result (bool) – If True, returns a FitResult instead of self. This intentionally breaks method chaining for workflows that need structured result objects (default: False).

  • check_physics (bool) – If True, runs post-fit physics validation and emits RheoJaxPhysicsWarning for any violations (default: False).

  • uncertainty (str | None) – Post-fit uncertainty method. "hessian" for fast Cramér-Rao bounds, "bootstrap" for residual bootstrap CIs, or None to skip (default: None).

  • **kwargs – Additional fitting options passed to _fit()

Return type:

BaseModel | Any

Returns:

self for method chaining (default), or FitResult if return_result=True.

Example

>>> model = Maxwell()
>>> model.fit(t, G_data)  # Uses NLSQ by default
>>> model.fit(t, G_data, method='nlsq', max_iter=1000)
>>> model.fit(t, G_data, check_compatibility=True)  # Check compatibility
>>> model.fit(omega, G_star, use_log_residuals=True)  # Force log-residuals
>>> model.fit(mastercurve, None, use_multi_start=True, n_starts=10)  # Multi-start
>>> result = model.fit(t, G_data, return_result=True)  # Structured result
>>> result = model.fit(t, G_data, auto_init=True, check_physics=True,
...                    return_result=True)  # Full pipeline
precompile(test_mode='relaxation', X=None, y=None)[source]

Precompile NLSQ residual functions to eliminate JIT cold-start.

Triggers JIT compilation by running a minimal fit (max_iter=1) with dummy data. The model parameters are reset afterwards so the model is left in its original state.

This is useful for interactive sessions or benchmarks where the ~870ms first-fit JIT overhead should be excluded.

Parameters:
  • test_mode (str) – Test mode to precompile for (default: ‘relaxation’).

  • X (Optional[numpy.typing.ArrayLike]) – Optional input data for shape inference. If None, uses a 10-point logspace array.

  • y (Optional[numpy.typing.ArrayLike]) – Optional output data. If None, generates ones matching X.

Return type:

float

Returns:

Compilation time in seconds.

Example

>>> model = Maxwell()
>>> t = model.precompile(test_mode='relaxation')
>>> print(f"Compiled in {t:.2f}s")
>>> model.fit(X, y)  # No JIT overhead
fit_bayesian(X, y=None, num_warmup=1000, num_samples=2000, num_chains=4, initial_values=None, test_mode=None, seed=None, deformation_mode=None, poisson_ratio=0.5, **nuts_kwargs)[source]

Perform Bayesian inference using NumPyro NUTS sampler.

This method delegates to BayesianMixin.fit_bayesian() to run NUTS sampling for Bayesian parameter estimation. If initial_values is not provided and the model has been previously fitted with fit(), the NLSQ point estimates are automatically used for warm-starting.

Multi-chain sampling is enabled by default (num_chains=4) to provide reliable convergence diagnostics (R-hat, ESS) and parallel execution on multi-GPU systems.

Parameters:
  • X (numpy.typing.ArrayLike) – Independent variable data (input features) or RheoData object

  • y (Optional[numpy.typing.ArrayLike]) – Dependent variable data (observations to fit). If X is RheoData, y is ignored and extracted from X.

  • num_warmup (int) – Number of warmup/burn-in iterations (default: 1000)

  • num_samples (int) – Number of posterior samples per chain (default: 2000)

  • num_chains (int) – Number of MCMC chains (default: 4). Multiple chains enable proper R-hat computation and parallel execution. Chain method is auto-selected: ‘parallel’ on multi-GPU, ‘vectorized’ on single GPU/CPU.

  • initial_values (dict[str, float] | None) – Optional dict of initial parameter values for warm-start. If None and model is fitted, uses NLSQ estimates.

  • test_mode (str | None) – Explicit test mode (e.g., ‘relaxation’, ‘creep’, ‘oscillation’). If None, inferred from RheoData.metadata[‘test_mode’] or defaults to ‘relaxation’. Overrides RheoData metadata if provided.

  • seed (int | None) – Random seed for reproducibility. If None, uses seed=0 for deterministic results. Set to different values for independent runs.

  • **nuts_kwargs – Additional arguments passed to NUTS sampler (e.g., target_accept_prob, chain_method)

Return type:

BayesianResult

Returns:

BayesianResult containing posterior samples, summary statistics, and convergence diagnostics (R-hat, ESS, divergences)

Example

>>> model = Maxwell()
>>> # Warm-start from NLSQ with explicit mode
>>> model.fit(t, G_data, test_mode='relaxation')  # NLSQ optimization
>>> result = model.fit_bayesian(t, G_data, test_mode='relaxation')
>>>
>>> # RheoData with embedded mode (recommended)
>>> rheo_data = RheoData(x=omega, y=G_star, metadata={'test_mode': 'oscillation'})
>>> result = model.fit_bayesian(rheo_data)
>>>
>>> # Or provide explicit initial values
>>> result = model.fit_bayesian(
...     t, G_data,
...     initial_values={'G0': 1e5, 'eta': 1e3},
...     test_mode='creep'
... )
predict(X, test_mode=None, deformation_mode=None, poisson_ratio=None, **kwargs)[source]

Make predictions.

Parameters:
  • X (numpy.typing.ArrayLike) – Input features

  • test_mode (str | None) – Optional test mode (‘oscillation’, ‘relaxation’, ‘creep’, ‘flow’). If provided, sets model’s test_mode before prediction. Useful for data generation without fitting.

  • deformation_mode (str | DeformationMode | None) – Optional deformation mode for output conversion. If None, uses the mode stored from fit(). If tensile, converts G* predictions to E* space.

  • poisson_ratio (float | None) – Poisson’s ratio for conversion. If None, uses value stored from fit() (default 0.5).

  • **kwargs – Additional arguments passed to the internal _predict method.

Return type:

numpy.typing.ArrayLike

Returns:

Model predictions (in E* space if deformation_mode is tensile)

fit_predict(X, y, **kwargs)[source]

Fit model and return predictions.

Parameters:
  • X (numpy.typing.ArrayLike) – Input features

  • y (numpy.typing.ArrayLike) – Target values

  • **kwargs – Additional fitting options

Return type:

numpy.typing.ArrayLike

Returns:

Model predictions on training data

get_nlsq_result()[source]

Get stored NLSQ optimization result.

Returns:

OptimizationResult from NLSQ fit, or None if not fitted

Example

>>> model.fit(t, G_data)
>>> result = model.get_nlsq_result()
>>> if result:
...     print(f"Converged: {result.success}")
property pcov_

Parameter covariance matrix from NLSQ fit.

Returns:

ndarray of shape (n_params, n_params), or None if not fitted

property popt_

Optimal parameter values from NLSQ fit.

Returns:

ndarray of shape (n_params,), or None if not fitted

get_parameter_uncertainties()[source]

Get standard errors for fitted parameters from NLSQ covariance.

Returns:

std_error}, or None if covariance unavailable

Return type:

dict of {param_name

get_bayesian_result()[source]

Get stored Bayesian inference result.

Return type:

BayesianResult | None

Returns:

BayesianResult from fit_bayesian(), or None if not run

Example

>>> model.fit_bayesian(t, G_data)
>>> result = model.get_bayesian_result()
>>> print(result.diagnostics['r_hat'])
get_params(deep=True)[source]

Get model parameters.

Parameters:

deep (bool) – If True, return parameters of sub-objects

Return type:

dict[str, Any]

Returns:

Dictionary of parameter names and values

set_params(**params)[source]

Set model parameters.

Parameters:

**params – Parameter names and values

Return type:

BaseModel

Returns:

self for method chaining

score(X, y)[source]

Compute model score (R² by default).

Parameters:
  • X (numpy.typing.ArrayLike) – Input features

  • y (numpy.typing.ArrayLike) – True target values

Return type:

float

Returns:

Model score (R² coefficient)

to_dict()[source]

Serialize model to dictionary.

Return type:

dict[str, Any]

Returns:

Dictionary representation of model

classmethod from_dict(data)[source]

Create model from dictionary.

Parameters:

data (dict[str, Any]) – Dictionary representation

Return type:

BaseModel

Returns:

Model instance

__repr__()[source]

String representation of model.

Return type:

str

BaseTransform

class rheojax.core.base.BaseTransform[source]

Bases: ABC

Abstract base class for all data transforms.

This class defines the standard interface that all transforms must implement, supporting JAX arrays and composable transformations.

Abstract base class for data transforms. Supports fit, transform, and inverse_transform operations with pipeline composition.

__init__()[source]

Initialize base transform.

transform(data)[source]

Transform the data.

Parameters:

data – Input data (RheoData or list[RheoData])

Returns:

Transformed data (RheoData or tuple[RheoData, dict])

inverse_transform(data)[source]

Apply inverse transformation.

Parameters:

data – Transformed data (RheoData)

Returns:

Original data (RheoData)

fit(data)[source]

Fit the transform to data (learn parameters if needed).

Parameters:

data – Training data (RheoData or list[RheoData])

Return type:

BaseTransform

Returns:

self for method chaining

fit_transform(data)[source]

Fit to data and transform it.

Parameters:

data – Input data (RheoData or list[RheoData])

Returns:

Transformed data (RheoData or tuple[RheoData, dict])

__add__(other)[source]

Compose transforms using + operator.

Parameters:

other (BaseTransform) – Another transform to compose

Return type:

TransformPipeline

Returns:

Pipeline of transforms

batch_transform(datasets)[source]

Transform multiple datasets sequentially.

Applies the transform to each dataset in order. Sequential execution is required because JAX JIT compilation is not thread-safe.

Parameters:

datasets (list) – Input datasets to transform.

Returns:

Transformed datasets, one per input.

Return type:

list

__repr__()[source]

String representation of transform.

Return type:

str

TransformPipeline

class rheojax.core.base.TransformPipeline(transforms)[source]

Bases: BaseTransform

Pipeline of multiple transforms applied sequentially.

Pipeline for composing multiple transforms that are applied sequentially.

__init__(transforms)[source]

Initialize transform pipeline.

Parameters:

transforms (list[BaseTransform]) – List of transforms to apply in order

fit(data)[source]

Fit all transforms in the pipeline.

Parameters:

data (numpy.typing.ArrayLike) – Training data

Return type:

TransformPipeline

Returns:

self for method chaining

__repr__()[source]

String representation of pipeline.

Return type:

str

Parameters

Parameter

class rheojax.core.parameters.Parameter(name, value=None, bounds=None, units=None, description=None, constraints=None)[source]

Bases: object

Single parameter with value, bounds, and metadata.

A Parameter represents a model parameter with support for bounds validation, units tracking, and constraint enforcement. Parameters can be used in both NLSQ optimization and Bayesian inference workflows.

name

Parameter identifier used for lookup and serialization.

value

Current parameter value (may be None if unset).

bounds

Lower and upper bounds as tuple (min, max).

units

Physical units string for display (e.g., “Pa”, “s”).

description

Human-readable description.

constraints

List of ParameterConstraint objects for validation.

Example

>>> param = Parameter("G0", value=1e5, bounds=(1e3, 1e9), units="Pa")
>>> param.value = 2e5  # Validated against bounds
>>> param.validate()
True

Single parameter with value, bounds, units, and constraints.

Attributes:

  • name (str) – Parameter name

  • value (float | None) – Current value

  • bounds (tuple[float, float] | None) – (min, max) bounds

  • units (str | None) – Physical units

  • description (str | None) – Parameter description

  • constraints (list[ParameterConstraint]) – List of constraints

__init__(name, value=None, bounds=None, units=None, description=None, constraints=None)[source]
name
units
description
constraints
prior: dict[str, Any] | None
property bounds: tuple[float, float] | None

Get parameter bounds.

property value: float | None

Get parameter value.

property was_clamped: bool

Return True if the last assignment clamped the value.

validate(value, context=None)[source]

Validate value against all constraints.

Parameters:
  • value (float) – Value to validate

  • context (dict[str, float] | None) – Context with other parameter values

Return type:

bool

Returns:

True if all constraints are satisfied

__hash__()[source]

Make Parameter hashable for use as dict keys.

Return type:

int

Returns:

Hash based on immutable identity attributes only

__eq__(other)[source]

Check equality with another Parameter.

Matches __hash__: identity-based on (name, bounds, units). Value is excluded because it changes during fitting while the parameter identity remains the same.

Parameters:

other (object) – Object to compare with

Return type:

bool

Returns:

True if parameters have the same identity

to_dict()[source]

Convert to dictionary representation.

Return type:

dict[str, Any]

classmethod from_dict(data)[source]

Create from dictionary representation.

Return type:

Parameter

ParameterSet

class rheojax.core.parameters.ParameterSet[source]

Bases: object

Collection of parameters for a model or transform.

A ParameterSet manages multiple Parameter objects with dict-like access, batch operations, and serialization support. It is the primary interface for working with model parameters in RheoJAX.

Key Features:
  • Dict-like access: params["G0"] or params.get("G0")

  • Batch operations: get_values(), set_values(), get_bounds()

  • Unpack helper: G0, eta = params.unpack("G0", "eta")

  • Serialization: to_dict() / from_dict() for JSON/HDF5

Example

>>> params = ParameterSet()
>>> params.add("G0", value=1e5, bounds=(1e3, 1e9), units="Pa")
>>> params.add("eta", value=1e3, bounds=(1e-3, 1e9), units="Pa*s")
>>> G0, eta = params.unpack("G0", "eta")
>>> print(f"G0={G0:.2e}, eta={eta:.2e}")
G0=1.00e+05, eta=1.00e+03

See also

Parameter: Individual parameter class. SharedParameterSet: For multi-model parameter sharing.

Collection of parameters for a model or transform.

__init__()[source]

Initialize empty parameter set.

add(name, value=None, bounds=None, units=None, description=None, constraints=None, overwrite=False)[source]

Add a parameter to the set.

Parameters:
  • name (str) – Parameter name

  • value (float | None) – Initial value

  • bounds (tuple[float, float] | None) – Value bounds (min, max)

  • units (str | None) – Parameter units

  • description (str | None) – Parameter description

  • constraints (list[ParameterConstraint] | None) – List of constraints

  • overwrite (bool) – If True, silently overwrite an existing parameter without emitting a warning. Default is False (warns on overwrite).

Return type:

Parameter

Returns:

The created Parameter object

get(name)[source]

Get a parameter by name.

Parameters:

name (str) – Parameter name

Return type:

Parameter | None

Returns:

Parameter object or None if not found

set_value(name, value)[source]

Set parameter value.

Parameters:
  • name (str) – Parameter name

  • value (float) – New value

Raises:
set_bounds(name, bounds)[source]

Set bounds for a parameter.

Parameters:
  • name (str) – Parameter name

  • bounds (tuple[float, float]) – Tuple of (min, max) values

Raises:
get_values()[source]

Get all parameter values as array.

Return type:

ndarray

Returns:

Array of parameter values in order

set_values(values)[source]

Set parameter values from array or dictionary.

Parameters:

values (Union[numpy.typing.ArrayLike, dict[str, float]]) – Array of values in order, or dict mapping names to values

Raises:

ValueError – If wrong number of values (array) or unknown parameter (dict)

update(values, *, strict=True)[source]

Apply a batch of name→value updates with optional failure tolerance.

Replacement for the for k, v in d.items(): try: set_value(k, v) except: logger.warning(...) pattern found in notebooks that mix parameters from different model schemas. Two failure modes are reported separately so the caller (or a schema-migration review) can tell “this key does not exist on this model” from “this value violates the constraints” without scanning ERROR-level logs.

Parameters:
  • values (dict[str, float]) – Mapping of parameter name → new value.

  • strict (bool) – When True (default), re-raises the first KeyError (unknown name) or ValueError (bad value) so calling code cannot silently drift out of the current schema. When False, collects every failure into the returned dict without logging at ERROR level — useful during migration to draft a single summary warning.

Return type:

dict[str, str]

Returns:

Dict of {name: reason} for entries that failed. Empty when all succeeded (including when values is empty).

Raises:
  • KeyError – (strict=True) if any name is unknown.

  • ValueError – (strict=True) if any value violates constraints.

get_bounds()[source]

Get bounds for all parameters.

Return type:

list[tuple[float | None, float | None]]

Returns:

List of (min, max) tuples

get_value(name)[source]

Get value of a specific parameter.

Parameters:

name (str) – Parameter name

Return type:

float | None

Returns:

Parameter value or None

unpack(*names)[source]

Extract multiple parameter values in a single call.

This method provides a concise way to extract several parameter values at once, reducing boilerplate in model implementations.

Parameters:

*names (str) – Parameter names to extract

Return type:

tuple[float | None, ...]

Returns:

Tuple of parameter values in the same order as requested. Returns None for parameters with None values.

Raises:

KeyError – If any parameter name is not found. The error message includes the missing name and lists available parameters.

Examples

Basic usage - extract multiple parameters in one line:

>>> params = ParameterSet()
>>> _ = params.add('x', value=1.5)
>>> _ = params.add('G0', value=100.0)
>>> _ = params.add('tau0', value=0.01)
>>> x, G0, tau0 = params.unpack('x', 'G0', 'tau0')
>>> x
1.5
>>> G0
100.0

Before (verbose):

x = params.get_value('x')
G0 = params.get_value('G0')
tau0 = params.get_value('tau0')

After (concise):

x, G0, tau0 = params.unpack('x', 'G0', 'tau0')
__len__()[source]

Number of parameters.

Return type:

int

__contains__(name)[source]

Check if parameter exists.

Return type:

bool

__iter__()[source]

Iterate over parameter names.

keys()[source]

Return an iterator over parameter names (dict-like interface).

Returns:

Iterator over parameter names in order

Examples

>>> params = ParameterSet()
>>> params.add('alpha', value=0.5)
>>> params.add('beta', value=1.0)
>>> list(params.keys())
['alpha', 'beta']
values()[source]

Return an iterator over Parameter objects (dict-like interface).

Returns:

Iterator over Parameter objects in order

Examples

>>> params = ParameterSet()
>>> params.add('alpha', value=0.5, units='')
>>> for param in params.values():
...     print(f"{param.name}: {param.value}")
alpha: 0.5
items()[source]

Return an iterator over (name, Parameter) tuples (dict-like interface).

Returns:

Iterator over (name, Parameter) tuples in order

Examples

>>> params = ParameterSet()
>>> params.add('alpha', value=0.5)
>>> for name, param in params.items():
...     print(f"{name}: {param.value}")
alpha: 0.5
__getitem__(key)[source]

Get parameter by name using subscript notation.

Parameters:

key (str) – Parameter name

Return type:

Parameter

Returns:

Parameter object

Raises:

KeyError – If parameter not found

Examples

>>> params = ParameterSet()
>>> params.add('alpha', value=0.5)
>>> param = params['alpha']  # Get parameter object
>>> value = params['alpha'].value  # Get value
__setitem__(key, value)[source]

Set parameter value using subscript notation.

Parameters:
  • key (str) – Parameter name

  • value (float | Parameter) – New value (float) or Parameter object

Raises:
  • KeyError – If parameter not found and value is float

  • ValueError – If value violates constraints

Examples

>>> params = ParameterSet()
>>> params.add('alpha', value=0.5, bounds=(0, 1))
>>> params['alpha'] = 0.7  # Set value
>>> # Or replace entire parameter:
>>> params['alpha'] = Parameter('alpha', value=0.8, bounds=(0, 1))
to_dict()[source]

Convert to dictionary representation.

Return type:

dict[str, dict[str, Any]]

classmethod from_dict(data)[source]

Create from dictionary representation.

Uses Parameter.from_dict() to preserve constraints (not just bounds).

Return type:

ParameterSet

ParameterConstraint

class rheojax.core.parameters.ParameterConstraint(type, min_value=None, max_value=None, value=None, relation=None, other_param=None, validator=None)[source]

Bases: object

Constraint on a parameter value.

Constraint on a parameter value.

Types:

  • "bounds": Min/max value bounds

  • "positive": Must be > 0

  • "integer": Must be an integer

  • "fixed": Fixed to specific value

  • "relative": Relative to another parameter

  • "custom": Custom validator function

type: str
min_value: float | None = None
max_value: float | None = None
value: float | None = None
relation: str | None = None
other_param: str | None = None
validator: Callable[[float], bool] | None = None
to_dict()[source]

Serialize constraint to a dictionary.

Return type:

dict[str, Any]

validate(value, context=None)[source]

Check if value satisfies the constraint.

Parameters:
  • value (float) – Value to validate

  • context (dict[str, float] | None) – Context with other parameter values (for relative constraints)

Return type:

bool

Returns:

True if constraint is satisfied

__init__(type, min_value=None, max_value=None, value=None, relation=None, other_param=None, validator=None)

SharedParameterSet

class rheojax.core.parameters.SharedParameterSet[source]

Bases: object

Manages parameters shared across multiple models.

Manages parameters shared across multiple models.

__init__()[source]

Initialize shared parameter set.

add_shared(name, value=None, bounds=None, units=None, constraints=None, group=None)[source]

Add a shared parameter.

Parameters:
  • name (str) – Parameter name

  • value (float | None) – Initial value

  • bounds (tuple[float, float] | None) – Value bounds

  • units (str | None) – Parameter units

  • constraints (list[ParameterConstraint] | None) – Parameter constraints

  • group (str | None) – Optional group name

Return type:

Parameter

Returns:

The created Parameter

link_model(model, param_name)[source]

Link a model to a shared parameter.

Parameters:
  • model (Any) – Model to link

  • param_name (str) – Name of shared parameter

link_parameter_set(param_set, param_name)[source]

Link a parameter set to a shared parameter.

Parameters:
  • param_set (ParameterSet) – ParameterSet to link

  • param_name (str) – Name of shared parameter

set_value(name, value)[source]

Set shared parameter value.

Parameters:
  • name (str) – Parameter name

  • value (float) – New value

Raises:

ValueError – If value violates constraints

get_value(name)[source]

Get shared parameter value.

Parameters:

name (str) – Parameter name

Return type:

float | None

Returns:

Parameter value or None

get_linked_models(param_name)[source]

Get models linked to a parameter.

Parameters:

param_name (str) – Parameter name

Return type:

list[Any]

Returns:

List of linked models

create_group(group_name, param_names)[source]

Create a parameter group.

Parameters:
  • group_name (str) – Name for the group

  • param_names (list[str]) – Parameter names to include

get_group(group_name)[source]

Get parameters in a group.

Parameters:

group_name (str) – Group name

Return type:

list[str]

Returns:

List of parameter names in group

__contains__(name)[source]

Check if shared parameter exists.

Return type:

bool

ParameterOptimizer

class rheojax.core.parameters.ParameterOptimizer(parameters, use_jax=False, track_history=False)[source]

Bases: object

Optimizer for parameter fitting.

Optimizer for parameter fitting with JAX gradient support.

__init__(parameters, use_jax=False, track_history=False)[source]

Initialize parameter optimizer.

Parameters:
  • parameters (ParameterSet) – ParameterSet to optimize

  • use_jax (bool) – Whether to use JAX for optimization

  • track_history (bool) – Whether to track optimization history

history: list[dict[str, Any]]
objective: Callable | None
constraints: list[Callable]
callback: Callable | None
property n_parameters: int

Number of parameters.

get_values()[source]

Get current parameter values.

Return type:

ndarray

get_bounds()[source]

Get parameter bounds.

Return type:

list[tuple[float | None, float | None]]

set_objective(objective)[source]

Set objective function to minimize.

Parameters:

objective (Callable) – Function that takes parameter values and returns scalar

evaluate(values)[source]

Evaluate objective at given values.

Parameters:

values (numpy.typing.ArrayLike) – Parameter values

Return type:

float

Returns:

Objective function value

compute_gradient(values)[source]

Compute gradient of objective.

Parameters:

values (numpy.typing.ArrayLike) – Parameter values

Return type:

ndarray

Returns:

Gradient vector

add_constraint(constraint)[source]

Add optimization constraint.

Parameters:

constraint (Callable) – Function that returns >= 0 for valid values

validate_constraints(values)[source]

Check if constraints are satisfied.

Parameters:

values (numpy.typing.ArrayLike) – Parameter values

Return type:

bool

Returns:

True if all constraints satisfied

set_callback(callback)[source]

Set optimization callback.

Parameters:

callback (Callable) – Function called after each iteration

step(values, iteration=None)[source]

Perform one optimization step.

Parameters:
  • values (numpy.typing.ArrayLike) – Current parameter values

  • iteration (int | None) – Current iteration number

get_history()[source]

Get optimization history.

Return type:

list[dict[str, Any]]

Returns:

List of history dictionaries

Test Modes

Test mode detection for rheological data.

This module provides automatic detection of rheological test modes based on data characteristics, units, and metadata.

class rheojax.core.test_modes.DeformationMode(*values)[source]

Bases: StrEnum

Deformation geometry mode for rheological measurements.

This enum classifies the type of mechanical deformation applied during a rheological measurement. Shear-based instruments (rotational rheometers) measure G*(w), while tensile/bending/compression instruments (DMTA/DMA) measure E*(w). The relationship is:

E*(w) = 2(1 + v) * G*(w)

where v is Poisson’s ratio of the material.

SHEAR = 'shear'
TENSION = 'tension'
BENDING = 'bending'
COMPRESSION = 'compression'
is_tensile()[source]

True if this deformation mode produces Young’s modulus E*.

Tension, bending, and compression geometries all measure E*, while shear measures G*.

Return type:

bool

class rheojax.core.test_modes.TestModeEnum(*values)[source]

Bases: StrEnum

Enumeration of rheological test modes.

Note: Named TestModeEnum (not TestMode) to avoid pytest collection warnings. Pytest treats classes starting with ‘Test’ and ending without ‘Enum’ as test classes.

Note on EPM/Flow protocols:
  • FLOW_CURVE: Steady-state stress vs shear rate (same physics as ROTATION)

  • STARTUP: Transient stress evolution at constant shear rate

  • ROTATION: Generic rotational/steady shear mode (legacy)

RELAXATION = 'relaxation'
CREEP = 'creep'
OSCILLATION = 'oscillation'
LAOS = 'laos'
ROTATION = 'rotation'
FLOW_CURVE = 'flow_curve'
STARTUP = 'startup'
UNKNOWN = 'unknown'
__str__()[source]

Return string representation.

Return type:

str

classmethod from_protocol(protocol)[source]

Convert inventory Protocol to TestModeEnum.

Return type:

TestModeEnum

to_protocol()[source]

Convert TestModeEnum to inventory Protocol (best effort).

Return type:

Protocol | None

rheojax.core.test_modes.RheoTestMode

alias of TestModeEnum

rheojax.core.test_modes.TestMode

alias of TestModeEnum

rheojax.core.test_modes.is_monotonic_increasing(data, strict=False, tolerance=1e-10, allow_fraction=0.1)[source]

Check if data is mostly monotonically increasing.

Parameters:
  • data (ndarray | Array) – Array to check

  • strict (bool) – If True, require strictly increasing (no equal values)

  • tolerance (float) – Relative tolerance based on data magnitude

  • allow_fraction (float) – Fraction of points allowed to violate monotonicity (0-1)

Return type:

bool

Returns:

True if data is mostly monotonically increasing

rheojax.core.test_modes.is_monotonic_decreasing(data, strict=False, tolerance=1e-10, allow_fraction=0.1)[source]

Check if data is mostly monotonically decreasing.

Parameters:
  • data (ndarray | Array) – Array to check

  • strict (bool) – If True, require strictly decreasing (no equal values)

  • tolerance (float) – Relative tolerance based on data magnitude

  • allow_fraction (float) – Fraction of points allowed to violate monotonicity (0-1)

Return type:

bool

Returns:

True if data is mostly monotonically decreasing

rheojax.core.test_modes.detect_test_mode(rheo_data)[source]

Detect rheological test mode from data characteristics.

The detection algorithm uses the following heuristics:

  1. Check metadata[‘test_mode’] if explicitly provided

  2. Check domain and units:

    • frequency domain with rad/s or Hz → OSCILLATION

    • time domain with 1/s or s^-1 x-units → ROTATION

  3. Check monotonicity for time-domain data:

    • monotonic decreasing → RELAXATION

    • monotonic increasing → CREEP

  4. Fall back to UNKNOWN if ambiguous

Parameters:

rheo_data (RheoData) – RheoData object to analyze

Return type:

TestModeEnum

Returns:

Detected TestMode

Raises:

ValueError – If rheo_data is invalid

rheojax.core.test_modes.validate_test_mode(test_mode)[source]

Validate and convert test mode to TestMode enum.

Parameters:

test_mode (str | TestModeEnum) – Test mode as string or TestMode enum

Return type:

TestModeEnum

Returns:

TestMode enum

Raises:

ValueError – If test_mode is invalid

rheojax.core.test_modes.get_compatible_test_modes(model_name)[source]

Get compatible test modes for a given model.

Queries the ModelRegistry to determine which test modes are supported by the specified model, using the Protocol-Driven Inventory System.

Parameters:

model_name (str) – Name of the rheological model

Return type:

list[TestModeEnum]

Returns:

List of compatible TestMode values

rheojax.core.test_modes.suggest_models_for_test_mode(test_mode)[source]

Suggest appropriate models for a given test mode.

Queries the ModelRegistry to find models compatible with the specified test mode using the Protocol-Driven Inventory System.

Parameters:

test_mode (TestModeEnum) – Detected test mode

Return type:

list[str]

Returns:

List of recommended model names

TestMode

rheojax.core.test_modes.TestMode

Enumeration of rheological test modes.

Values:

  • RELAXATION: Stress relaxation test

  • CREEP: Creep compliance test

  • OSCILLATION: Oscillatory (SAOS/LAOS) test

  • ROTATION: Steady shear (flow curve) test

  • FLOW_CURVE: Steady-state stress vs shear rate

  • STARTUP: Transient stress at constant shear rate

  • LAOS: Large Amplitude Oscillatory Shear

  • UNKNOWN: Unknown or ambiguous test type

alias of TestModeEnum

DeformationMode

class rheojax.core.test_modes.DeformationMode(*values)[source]

Bases: StrEnum

Deformation geometry mode for rheological measurements.

This enum classifies the type of mechanical deformation applied during a rheological measurement. Shear-based instruments (rotational rheometers) measure G*(w), while tensile/bending/compression instruments (DMTA/DMA) measure E*(w). The relationship is:

E*(w) = 2(1 + v) * G*(w)

where v is Poisson’s ratio of the material.

Deformation geometry for rheological measurements. Controls whether models work with shear modulus G* or Young’s modulus E*.

Values:

  • SHEAR: Rotational rheometer geometry (measures G*)

  • TENSION: DMTA/DMA tensile geometry (measures E*)

  • BENDING: DMTA/DMA bending geometry (measures E*)

  • COMPRESSION: DMTA/DMA compression geometry (measures E*)

Conversion:

\[E^*(\omega) = 2(1 + \nu) \, G^*(\omega)\]

where \(\nu\) is the Poisson’s ratio of the material.

Usage with models:

from rheojax.models import Maxwell

model = Maxwell()
model.fit(
    omega, E_star,
    test_mode='oscillation',
    deformation_mode='tension',
    poisson_ratio=0.5,  # rubber
)

See rheojax.utils.modulus_conversion for array-level conversion utilities.

SHEAR = 'shear'
TENSION = 'tension'
BENDING = 'bending'
COMPRESSION = 'compression'
is_tensile()[source]

True if this deformation mode produces Young’s modulus E*.

Tension, bending, and compression geometries all measure E*, while shear measures G*.

Return type:

bool

Functions

rheojax.core.test_modes.detect_test_mode(rheo_data)[source]

Detect rheological test mode from data characteristics.

The detection algorithm uses the following heuristics:

  1. Check metadata[‘test_mode’] if explicitly provided

  2. Check domain and units:

    • frequency domain with rad/s or Hz → OSCILLATION

    • time domain with 1/s or s^-1 x-units → ROTATION

  3. Check monotonicity for time-domain data:

    • monotonic decreasing → RELAXATION

    • monotonic increasing → CREEP

  4. Fall back to UNKNOWN if ambiguous

Parameters:

rheo_data (RheoData) – RheoData object to analyze

Return type:

TestModeEnum

Returns:

Detected TestMode

Raises:

ValueError – If rheo_data is invalid

rheojax.core.test_modes.validate_test_mode(test_mode)[source]

Validate and convert test mode to TestMode enum.

Parameters:

test_mode (str | TestModeEnum) – Test mode as string or TestMode enum

Return type:

TestModeEnum

Returns:

TestMode enum

Raises:

ValueError – If test_mode is invalid

rheojax.core.test_modes.is_monotonic_increasing(data, strict=False, tolerance=1e-10, allow_fraction=0.1)[source]

Check if data is mostly monotonically increasing.

Parameters:
  • data (ndarray | Array) – Array to check

  • strict (bool) – If True, require strictly increasing (no equal values)

  • tolerance (float) – Relative tolerance based on data magnitude

  • allow_fraction (float) – Fraction of points allowed to violate monotonicity (0-1)

Return type:

bool

Returns:

True if data is mostly monotonically increasing

rheojax.core.test_modes.is_monotonic_decreasing(data, strict=False, tolerance=1e-10, allow_fraction=0.1)[source]

Check if data is mostly monotonically decreasing.

Parameters:
  • data (ndarray | Array) – Array to check

  • strict (bool) – If True, require strictly decreasing (no equal values)

  • tolerance (float) – Relative tolerance based on data magnitude

  • allow_fraction (float) – Fraction of points allowed to violate monotonicity (0-1)

Return type:

bool

Returns:

True if data is mostly monotonically decreasing

rheojax.core.test_modes.get_compatible_test_modes(model_name)[source]

Get compatible test modes for a given model.

Queries the ModelRegistry to determine which test modes are supported by the specified model, using the Protocol-Driven Inventory System.

Parameters:

model_name (str) – Name of the rheological model

Return type:

list[TestModeEnum]

Returns:

List of compatible TestMode values

rheojax.core.test_modes.suggest_models_for_test_mode(test_mode)[source]

Suggest appropriate models for a given test mode.

Queries the ModelRegistry to find models compatible with the specified test mode using the Protocol-Driven Inventory System.

Parameters:

test_mode (TestModeEnum) – Detected test mode

Return type:

list[str]

Returns:

List of recommended model names

Bayesian Inference

The Bayesian inference module provides NumPyro NUTS sampling capabilities with NLSQ warm-start for all rheological models through the BayesianMixin class.

BayesianMixin

class rheojax.core.bayesian.BayesianMixin[source]

Bases: object

Mixin class providing Bayesian inference capabilities via NumPyro NUTS.

This mixin adds methods for Bayesian parameter estimation to any class that has a parameters attribute (ParameterSet). It implements: - NUTS sampling for posterior inference - Prior sampling for prior predictive checks - Credible interval computation (highest density intervals) - Convergence diagnostics (R-hat, ESS)

The mixin is designed to be composed with model classes, typically through BaseModel. All 20+ rheological models automatically inherit these capabilities when BaseModel is extended with BayesianMixin.

Requirements:
  • Class must have parameters attribute (ParameterSet)

  • Class must define model_function(X, params) method for predictions

  • Class must have X_data and y_data attributes when fitting

Example

>>> class MyModel(BayesianMixin):
...     def __init__(self):
...         self.parameters = ParameterSet()
...         self.parameters.add("a", bounds=(0, 10))
...         self.X_data = None
...         self.y_data = None
...
...     def model_function(self, X, params):
...         return params[0] * X
...
>>> model = MyModel()
>>> model.X_data = X
>>> model.y_data = y
>>> result = model.fit_bayesian(X, y)

Mixin class that adds Bayesian inference capabilities to models. Provides:

  • NLSQ -> NUTS warm-start workflow (2-5x faster convergence)

  • Automatic prior specification from parameter bounds

  • Credible interval calculation

  • Model function for NumPyro NUTS sampling

parameters: ParameterSet
sample_prior(num_samples=1000, seed=None)[source]

Sample from prior distributions over parameter bounds.

Samples from uniform prior distributions defined by parameter bounds. This is useful for prior predictive checks and understanding the prior’s influence on the posterior.

Parameters:
  • num_samples (int) – Number of samples to draw from prior (default: 1000)

  • seed (int | None) – Random seed for reproducibility (default: None)

Return type:

dict[str, ndarray]

Returns:

Dictionary mapping parameter names to arrays of prior samples. Each array has shape (num_samples,) and dtype float64.

Raises:

Example

>>> model = MyModel()
>>> prior_samples = model.sample_prior(num_samples=500, seed=42)
>>> print(prior_samples["a"].shape)  # (500,)
get_credible_intervals(posterior_samples, credibility=0.95)[source]

Compute highest density intervals (HDI) for posterior samples.

Computes the highest posterior density interval for each parameter, which is the shortest interval containing the specified probability mass. This is preferred over equal-tailed intervals for skewed distributions.

Parameters:
  • posterior_samples (dict[str, ndarray]) – Dictionary mapping parameter names to posterior sample arrays (from BayesianResult.posterior_samples)

  • credibility (float) – Probability mass to include in interval (default: 0.95) Common values: 0.68 (1 sigma), 0.95 (2 sigma), 0.997 (3 sigma)

Return type:

dict[str, tuple[float, float]]

Returns:

Dictionary mapping parameter names to (lower, upper) credible interval tuples. All values are float64.

Example

>>> result = model.fit_bayesian(X, y)
>>> intervals_95 = model.get_credible_intervals(
...     result.posterior_samples, credibility=0.95
... )
>>> print(f"95% CI for a: {intervals_95['a']}")
precompile_bayesian(X=None, y=None, test_mode=None, num_chains=4)[source]

Precompile NUTS kernel to eliminate JIT overhead in subsequent calls.

Triggers JIT compilation of the NumPyro model by running a minimal sampling (1 warmup, 1 sample). This caches the compiled kernel so that subsequent fit_bayesian() calls are 2-5x faster.

Parameters:
  • X (ndarray | RheoData | None) – Sample input data for determining array shapes. If None, uses a default 10-point linspace [0.01, 100].

  • y (ndarray | None) – Sample output data. If None, generates dummy data.

  • test_mode (str | TestModeEnum | None) – Test mode to precompile for. If None, defaults to ‘relaxation’.

Returns:

Compilation time in seconds.

Return type:

float

Example

>>> model = Maxwell()
>>> compile_time = model.precompile_bayesian(test_mode='relaxation')
>>> print(f"Compiled in {compile_time:.1f}s")
>>> # Now fit_bayesian() will be faster
>>> result = model.fit_bayesian(X, y)  # No compilation overhead
fit_bayesian(X, y=None, num_warmup=1000, num_samples=2000, num_chains=4, initial_values=None, test_mode=None, seed=None, **nuts_kwargs)[source]

Perform Bayesian inference using NumPyro NUTS sampler.

Runs NUTS (No-U-Turn Sampler) to obtain posterior samples for model parameters. Supports warm-starting from NLSQ point estimates for faster convergence. Uses uniform priors over parameter bounds.

Multi-chain sampling is enabled by default (num_chains=4) to provide reliable convergence diagnostics (R-hat, ESS) and parallel execution on multi-GPU systems.

CRITICAL: test_mode is captured in model_function closure to ensure correct posteriors for all test modes (relaxation, creep, oscillation).

Parameters:
  • X (ndarray | RheoData) – Independent variable data (input features) or RheoData object

  • y (ndarray | None) – Dependent variable data (observations to fit). If X is RheoData, y is ignored and extracted from X.

  • num_warmup (int) – Number of warmup/burn-in iterations (default: 1000)

  • num_samples (int) – Number of posterior samples per chain (default: 2000)

  • num_chains (int) – Number of MCMC chains (default: 4). Multiple chains enable proper R-hat computation and parallel execution.

  • initial_values (dict[str, float] | None) – Optional dict of initial parameter values for warm-start (e.g., from NLSQ). Keys are parameter names.

  • test_mode (str | TestModeEnum | None) – Explicit test mode (e.g., ‘relaxation’, ‘creep’, ‘oscillation’). If None, inferred from RheoData.metadata[‘test_mode’] or defaults to ‘relaxation’. Overrides RheoData metadata if provided.

  • seed (int | None) – Random seed for reproducibility. If None, uses seed=0 for deterministic results. Set to different values for independent runs.

  • **nuts_kwargs – Additional arguments passed to NUTS sampler (e.g., target_accept_prob, chain_method)

Return type:

BayesianResult

Returns:

BayesianResult containing posterior_samples, summary, diagnostics.

Example

>>> result = model.fit_bayesian(X, y, test_mode='oscillation')
>>> print(result.diagnostics["r_hat"])  # Should be < 1.01
>>>
>>> # For production: use num_chains=4 (default)
>>> result = model.fit_bayesian(X, y, num_chains=4)

BayesianResult

class rheojax.core.bayesian.BayesianResult(posterior_samples, summary, diagnostics, num_samples, num_chains, mcmc=None, model_comparison=<factory>, _inference_data=None, _inference_data_ll=None)[source]

Bases: object

Results from Bayesian inference with NUTS sampling.

This dataclass stores the complete results of NumPyro NUTS sampling, including posterior samples, summary statistics, convergence diagnostics, and placeholders for future model comparison metrics.

posterior_samples

Dictionary mapping parameter names to arrays of posterior samples (shape: [num_samples * num_chains, ]). All arrays are float64.

summary

Dictionary with summary statistics for each parameter. Contains nested dicts with ‘mean’, ‘std’, and quantiles.

diagnostics

Dictionary with convergence diagnostics including: - r_hat: Gelman-Rubin statistic for each parameter (dict) - ess: Effective sample size for each parameter (dict) - divergences: Number of divergent transitions (int)

num_samples

Number of posterior samples per chain (after warmup).

num_chains

Number of MCMC chains used in sampling.

mcmc

NumPyro MCMC object containing full sampling information including NUTS-specific diagnostics (energy, divergences, tree depth). Required for ArviZ visualization with full diagnostics.

model_comparison

Dictionary for model comparison metrics (WAIC, LOO). Currently a placeholder for future implementation.

_inference_data

Cached ArviZ InferenceData object. Automatically created on first access via to_inference_data(). Do not set manually.

Example

>>> result = model.fit_bayesian(X, y)
>>> print(result.summary["a"]["mean"])
>>> print(result.diagnostics["r_hat"]["a"])
>>> # Convert to ArviZ InferenceData for advanced plotting
>>> idata = result.to_inference_data()

Dataclass storing complete Bayesian inference results:

Attributes:

  • posterior_samples: Dict mapping parameter names to posterior samples (float64 arrays)

  • summary: Dict with summary statistics (mean, std, quantiles) for each parameter

  • diagnostics: Convergence diagnostics including R-hat, ESS, divergences

  • waic: WAIC model comparison metric (if computed)

  • loo: LOO cross-validation metric (if computed)

  • inference_data: ArviZ InferenceData object for advanced diagnostics

posterior_samples: dict[str, ndarray]
summary: dict[str, dict[str, float]]
diagnostics: DiagnosticsDict
num_samples: int
num_chains: int
mcmc: MCMC | None = None
model_comparison: dict[str, float]
__post_init__()[source]

Validate result after initialization.

to_inference_data(log_likelihood=False)[source]

Convert to ArviZ InferenceData format for advanced visualization.

Converts the NumPyro MCMC result to ArviZ InferenceData format, which enables access to ArviZ’s comprehensive plotting and diagnostic tools. The conversion preserves all NUTS-specific diagnostics including energy, divergences, and tree depth information.

The InferenceData object is cached after first conversion to avoid repeated conversion overhead. The log_likelihood=False and log_likelihood=True variants are cached independently.

Parameters:

log_likelihood (bool) – If True, compute pointwise log-likelihood for WAIC/LOO model comparison (az.waic(), az.loo()). This re-evaluates the model for all samples (~600-800ms slower). Default False for faster conversion when only plotting.

Returns:

  • posterior: Posterior samples for all parameters

  • sample_stats: NUTS diagnostics (energy, divergences, etc.)

  • log_likelihood: Only when log_likelihood=True

  • Additional groups as available from NumPyro

Return type:

Any

Raises:

Example

>>> result = model.fit_bayesian(X, y)
>>> idata = result.to_inference_data()  # Fast: no log-lik
>>> az.plot_trace(idata)
>>>
>>> # For model comparison (slower):
>>> idata_ll = result.to_inference_data(log_likelihood=True)
>>> az.waic(idata_ll)

Note

Requires arviz package: pip install arviz The MCMC object must be present (automatically stored by fit_bayesian).

__init__(posterior_samples, summary, diagnostics, num_samples, num_chains, mcmc=None, model_comparison=<factory>, _inference_data=None, _inference_data_ll=None)

JAX Configuration

JAX configuration and safe import mechanism for float64 precision.

This module provides utilities to ensure JAX operates in float64 mode by enforcing that NLSQ is imported before JAX and explicitly enabling float64.

NLSQ v0.2.1+ uses float32 by default with automatic precision fallback. RheoJAX explicitly enables float64 for numerical stability in rheological calculations.

Critical Configuration Steps:
  1. Import nlsq (required for GPU-accelerated optimization)

  2. Import JAX

  3. Enable float64: jax.config.update(“jax_enable_x64”, True)

Usage:

`python from rheojax.core.jax_config import safe_import_jax jax, jnp = safe_import_jax() `

This replaces direct JAX imports throughout the RheoJAX codebase.

rheojax.core.jax_config.suppress_glyph_warnings()[source]

Suppress matplotlib font glyph warnings.

These warnings are purely cosmetic — plots render correctly, the glyph is just displayed as a box or skipped. Common when using Unicode subscripts (e.g., σ₀, τ₀) with fonts that lack those glyphs. This is harmless for headless batch runs and provides no actionable information to users.

Call explicitly rather than relying on module-level side effects. safe_import_jax() calls this automatically.

Return type:

None

rheojax.core.jax_config.verify_float64()[source]

Verify that JAX is operating in float64 mode.

This function checks that JAX’s default dtype is float64. It should be called after JAX has been imported to validate the configuration.

Raises:

RuntimeError – If JAX is not in float64 mode.

Return type:

None

Example

>>> import nlsq
>>> import jax
>>> jax.config.update("jax_enable_x64", True)
>>> verify_float64()  # Validates float64 mode
rheojax.core.jax_config.safe_import_jax()[source]

Safely import JAX with float64 precision enforcement.

This function ensures that NLSQ has been imported before JAX and explicitly enables float64 precision. NLSQ v0.2.1+ uses float32 by default, so RheoJAX must explicitly configure JAX for float64.

It uses a thread-safe singleton pattern to cache validation results and avoid repeated checks.

Returns:

A tuple of (jax, jax.numpy) modules for use.

Return type:

tuple[Any, Any]

Raises:
  • ImportError – If NLSQ has not been imported before calling this function.

  • RuntimeError – If float64 mode cannot be enabled.

Example

>>> # Correct usage (NLSQ imported first at package level)
>>> import nlsq
>>> from rheojax.core.jax_config import safe_import_jax
>>> jax, jnp = safe_import_jax()
>>> arr = jnp.array([1.0, 2.0, 3.0])  # Operates in float64

Note

The rheojax package automatically imports NLSQ and configures JAX in __init__.py, so users don’t need to worry about configuration. This function is for internal use by RheoJAX modules.

rheojax.core.jax_config.lazy_import(module_name)[source]

Return a lazy proxy for module_name.

The real import is deferred until the first attribute access on the returned object. This is safe for modules that are only used inside method bodies (not at module-level scope for decorators, base classes, etc.).

Example:

diffrax = lazy_import("diffrax")
# ... later, inside a method ...
solver = diffrax.Tsit5()   # triggers ``import diffrax`` on first use
Return type:

_LazyModule

rheojax.core.jax_config.reset_validation()[source]

Reset validation state (for testing purposes only).

This function is intended for use in tests that need to simulate different import scenarios. It should not be used in production code.

Warning

This is not thread-safe and should only be used in single-threaded test environments.

Return type:

None

The JAX configuration module ensures float64 precision throughout the JAX stack by enforcing proper import order (NLSQ must be imported before JAX).

rheojax.core.jax_config.safe_import_jax()[source]

Safely import JAX with float64 precision enforcement.

This function ensures that NLSQ has been imported before JAX and explicitly enables float64 precision. NLSQ v0.2.1+ uses float32 by default, so RheoJAX must explicitly configure JAX for float64.

It uses a thread-safe singleton pattern to cache validation results and avoid repeated checks.

Returns:

A tuple of (jax, jax.numpy) modules for use.

Return type:

tuple[Any, Any]

Raises:
  • ImportError – If NLSQ has not been imported before calling this function.

  • RuntimeError – If float64 mode cannot be enabled.

Example

>>> # Correct usage (NLSQ imported first at package level)
>>> import nlsq
>>> from rheojax.core.jax_config import safe_import_jax
>>> jax, jnp = safe_import_jax()
>>> arr = jnp.array([1.0, 2.0, 3.0])  # Operates in float64

Note

The rheojax package automatically imports NLSQ and configures JAX in __init__.py, so users don’t need to worry about configuration. This function is for internal use by RheoJAX modules.

Safe JAX import that verifies NLSQ was imported first for float64 precision.

Usage:

# CORRECT - Always use in RheoJAX modules
from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax()

# INCORRECT - Never import JAX directly
import jax  # Will raise ImportError if NLSQ not imported first
rheojax.core.jax_config.verify_float64()[source]

Verify that JAX is operating in float64 mode.

This function checks that JAX’s default dtype is float64. It should be called after JAX has been imported to validate the configuration.

Raises:

RuntimeError – If JAX is not in float64 mode.

Return type:

None

Example

>>> import nlsq
>>> import jax
>>> jax.config.update("jax_enable_x64", True)
>>> verify_float64()  # Validates float64 mode

Verify JAX is operating in float64 mode. Raises exception if not.

Registry

Plugin registry system for models and transforms.

This module provides a registry system for discovering, registering, and managing models and transforms as plugins, enabling extensibility of the rheojax package.

class rheojax.core.registry.PluginType(*values)[source]

Bases: Enum

Types of plugins that can be registered.

MODEL = 'model'
TRANSFORM = 'transform'
class rheojax.core.registry.PluginInfo(name, plugin_class, plugin_type, metadata, doc=None, protocols=<factory>, deformation_modes=<factory>, transform_type=None)[source]

Bases: object

Information about a registered plugin.

name: str
plugin_class: type
plugin_type: PluginType
metadata: dict[str, Any]
doc: str | None = None
protocols: list[Protocol]
deformation_modes: list[DeformationMode]
transform_type: TransformType | None = None
__post_init__()[source]

Extract documentation from plugin class.

__init__(name, plugin_class, plugin_type, metadata, doc=None, protocols=<factory>, deformation_modes=<factory>, transform_type=None)
class rheojax.core.registry.Registry[source]

Bases: object

Central registry for models and transforms.

This class manages plugin registration, discovery, and retrieval for all models and transforms in the rheojax package.

static __new__(cls)[source]

Ensure singleton pattern (thread-safe).

classmethod get_instance()[source]

Get the singleton registry instance.

Return type:

Registry

Returns:

The global Registry instance

register(name, plugin_class, plugin_type, metadata=None, validate=False, force=False, protocols=None, deformation_modes=None, transform_type=None)[source]

Register a plugin in the registry.

Parameters:
  • name (str) – Unique name for the plugin

  • plugin_class (type) – The plugin class to register

  • plugin_type (PluginType | str) – Type of plugin (MODEL or TRANSFORM)

  • metadata (dict[str, Any] | None) – Optional metadata dictionary

  • validate (bool) – Whether to validate the plugin interface

  • force (bool) – Whether to overwrite existing registration

  • protocols (list[Protocol | str] | None) – List of supported protocols (for models)

  • transform_type (TransformType | str | None) – Type of transform (for transforms)

Raises:

ValueError – If plugin is already registered (and force=False) or invalid

get(name, plugin_type, raise_on_missing=False)[source]

Retrieve a registered plugin class.

Parameters:
  • name (str) – Name of the plugin

  • plugin_type (PluginType | str) – Type of plugin

  • raise_on_missing (bool) – Whether to raise error if not found

Return type:

type | None

Returns:

The plugin class, or None if not found

Raises:

KeyError – If plugin not found and raise_on_missing=True

get_info(name, plugin_type)[source]

Get full information about a registered plugin.

Parameters:
  • name (str) – Name of the plugin

  • plugin_type (PluginType | str) – Type of plugin

Return type:

PluginInfo | None

Returns:

PluginInfo object or None if not found

get_all_models()[source]

Get list of all registered model names.

Return type:

list[str]

Returns:

List of model names

get_all_transforms()[source]

Get list of all registered transform names.

Return type:

list[str]

Returns:

List of transform names

unregister(name, plugin_type)[source]

Remove a plugin from the registry.

Parameters:
  • name (str) – Name of the plugin to remove

  • plugin_type (PluginType | str) – Type of plugin

get_all()[source]

Get all registered plugins with their types.

Return type:

dict[str, tuple[type, PluginType]]

Returns:

Dictionary mapping plugin names to (class, type) tuples

clear()[source]

Clear all registered plugins.

__len__()[source]

Get total number of registered plugins.

Return type:

int

Returns:

Total count of models and transforms

__contains__(name)[source]

Check if a plugin name is registered.

Parameters:

name (str) – Plugin name to check

Return type:

bool

Returns:

True if registered as either model or transform

get_stats()[source]

Get registration statistics.

Return type:

dict[str, int]

Returns:

Dictionary with counts of registered plugins

discover(module_name)[source]

Discover and register plugins from a module.

Parameters:

module_name (str) – Name of the module to import and scan

discover_directory(path)[source]

Discover plugins in a directory.

Parameters:

path (str) – Path to directory containing plugin modules

create_instance(name, plugin_type, *args, **kwargs)[source]

Create an instance of a registered plugin.

Parameters:
  • name (str) – Name of the plugin

  • plugin_type (PluginType | str) – Type of plugin

  • *args – Positional arguments for plugin constructor

  • **kwargs – Keyword arguments for plugin constructor

Return type:

Any

Returns:

Instance of the plugin class

Raises:

KeyError – If plugin not found

find_compatible(protocol=None, deformation_mode=None, transform_type=None, **criteria)[source]

Find plugins matching certain criteria.

Parameters:
  • protocol (Protocol | str | None) – Filter models by supported protocol

  • deformation_mode (DeformationMode | str | None) – Filter models by supported deformation mode

  • transform_type (TransformType | str | None) – Filter transforms by type

  • **criteria – Additional criteria to match against plugin metadata

Return type:

list[str]

Returns:

List of plugin names matching all criteria

export_state()[source]

Export registry state for serialization.

Return type:

dict[str, Any]

Returns:

Dictionary representation of registry state

import_state(state)[source]

Import registry state from serialization.

Parameters:

state (dict[str, Any]) – Dictionary representation of registry state

inventory()[source]

Get full inventory of registered plugins.

Return type:

dict[str, Any]

Returns:

Dictionary with models (by protocol) and transforms (by type)

model(name=None, protocols=None, **metadata)[source]

Decorator for registering a model.

Parameters:
  • name (str | None) – Optional name for the model (uses class name if not provided)

  • protocols (list[Protocol | str] | None) – List of supported protocols

  • **metadata – Additional metadata for the model

Returns:

Decorator function

transform(name=None, transform_type=None, **metadata)[source]

Decorator for registering a transform.

Parameters:
  • name (str | None) – Optional name for the transform (uses class name if not provided)

  • transform_type (TransformType | str | None) – Type of transform

  • **metadata – Additional metadata for the transform

Returns:

Decorator function

class rheojax.core.registry.ModelRegistry[source]

Bases: object

Convenient interface for model registration and creation.

This class provides a simplified API specifically for models, delegating to the main Registry singleton.

Example

>>> @ModelRegistry.register('maxwell', protocols=['relaxation'])
>>> class Maxwell(BaseModel):
...     pass
>>>
>>> model = ModelRegistry.create('maxwell')
>>> models = ModelRegistry.find(protocol='relaxation')
classmethod register(name, protocols=None, deformation_modes=None, **metadata)[source]

Decorator for registering a model.

Parameters:
  • name (str) – Name for the model

  • protocols (list[Protocol | str] | None) – List of supported protocols

  • deformation_modes (list[DeformationMode | str] | None) – List of supported deformation modes (shear, tension, bending, compression). Models with oscillation protocol that work in G-space can support all 4 modes via automatic E*<->G* conversion.

  • **metadata – Additional metadata for the model

Returns:

Decorator function

Example

>>> @ModelRegistry.register('maxwell',
...     protocols=['relaxation', 'oscillation'],
...     deformation_modes=['shear', 'tension', 'bending', 'compression'])
>>> class Maxwell(BaseModel):
...     pass
classmethod create(name, *args, **kwargs)[source]

Create a model instance by name (factory method).

Parameters:
  • name (str) – Name of the model to create

  • *args – Positional arguments for model constructor

  • **kwargs – Keyword arguments for model constructor

Return type:

Any

Returns:

Instance of the model class

Raises:

KeyError – If model not found

Example

>>> model = ModelRegistry.create('maxwell')
classmethod list_models()[source]

List all registered model names (discovery).

Return type:

list[str]

Returns:

List of registered model names

Example

>>> models = ModelRegistry.list_models()
>>> print(models)
['maxwell', 'zener', 'springpot', ...]
classmethod find(protocol=None, deformation_mode=None, **criteria)[source]

Find models matching criteria.

Parameters:
  • protocol (Protocol | str | None) – Filter by supported protocol

  • deformation_mode (DeformationMode | str | None) – Filter by supported deformation mode

  • **criteria – Additional metadata criteria

Return type:

list[str]

Returns:

List of matching model names

classmethod get_info(name)[source]

Get information about a registered model.

Parameters:

name (str) – Name of the model

Return type:

PluginInfo | None

Returns:

PluginInfo object with model details

Example

>>> info = ModelRegistry.get_info('maxwell')
>>> print(info.protocols)
[<Protocol.RELAXATION: 'relaxation'>, ...]
classmethod for_protocol(protocol)[source]

Get all models supporting a given protocol.

Parameters:

protocol (Protocol | str) – Protocol to filter by (e.g. "relaxation").

Return type:

list[PluginInfo]

Returns:

List of PluginInfo objects for matching models.

classmethod compatible_models(data)[source]

Find models compatible with a RheoData instance.

Reads data.test_mode (maps to protocol) and data.metadata.get("deformation_mode") to filter models.

Parameters:

data (Any) – RheoData instance with test_mode metadata.

Return type:

list[PluginInfo]

Returns:

List of PluginInfo objects for compatible models.

classmethod model_info(name)[source]

Get aggregated model information including parameter metadata.

Temporarily instantiates the model to read parameter names, bounds, and units. Returns a ModelInfo dataclass.

Parameters:

name (str) – Registry name of the model.

Return type:

Any

Returns:

ModelInfo instance with full model metadata.

Raises:

KeyError – If model is not registered.

classmethod unregister(name)[source]

Unregister a model.

Parameters:

name (str) – Name of the model to remove

class rheojax.core.registry.TransformRegistry[source]

Bases: object

Convenient interface for transform registration and creation.

This class provides a simplified API specifically for transforms, delegating to the main Registry singleton.

Example

>>> @TransformRegistry.register('fft_analysis', type='spectral')
>>> class RheoAnalysis(BaseTransform):
...     pass
>>>
>>> transform = TransformRegistry.create('fft_analysis')
>>> transforms = TransformRegistry.find(type='spectral')
classmethod register(name, type=None, **metadata)[source]

Decorator for registering a transform.

Parameters:
  • name (str) – Name for the transform

  • type (TransformType | str | None) – Type of transform (TransformType)

  • **metadata – Additional metadata for the transform

Returns:

Decorator function

Example

>>> @TransformRegistry.register('fft_analysis', type='spectral')
>>> class RheoAnalysis(BaseTransform):
...     pass
classmethod create(name, *args, **kwargs)[source]

Create a transform instance by name (factory method).

Parameters:
  • name (str) – Name of the transform to create

  • *args – Positional arguments for transform constructor

  • **kwargs – Keyword arguments for transform constructor

Return type:

Any

Returns:

Instance of the transform class

Raises:

KeyError – If transform not found

Example

>>> transform = TransformRegistry.create('fft_analysis')
classmethod list_transforms()[source]

List all registered transform names (discovery).

Return type:

list[str]

Returns:

List of registered transform names

Example

>>> transforms = TransformRegistry.list_transforms()
>>> print(transforms)
['fft_analysis', 'mastercurve', 'owchirp', ...]
classmethod find(type=None, **criteria)[source]

Find transforms matching criteria.

Parameters:
  • type (TransformType | str | None) – Filter by transform type

  • **criteria – Additional metadata criteria

Return type:

list[str]

Returns:

List of matching transform names

classmethod get_info(name)[source]

Get information about a registered transform.

Parameters:

name (str) – Name of the transform

Return type:

PluginInfo | None

Returns:

PluginInfo object with transform details

Example

>>> info = TransformRegistry.get_info('fft_analysis')
>>> print(info.transform_type)
TransformType.SPECTRAL
classmethod unregister(name)[source]

Unregister a transform.

Parameters:

name (str) – Name of the transform to remove

The registry system provides a centralized way to discover and instantiate models and transforms by name, protocol, or deformation mode.

Inventory (Protocols & Capabilities)

Core definitions for the Protocol-Driven Inventory System.

This module defines the classifications used to categorize models and transforms in the RheoJAX registry. It provides the type system for: 1. Models (via Protocol) 2. Transforms (via TransformType)

class rheojax.core.inventory.Protocol(*values)[source]

Bases: StrEnum

Rheological experimental protocols supported by models.

A protocol defines a specific type of experiment or measurement that a model is capable of simulating or fitting.

FLOW_CURVE = 'flow_curve'
CREEP = 'creep'
RELAXATION = 'relaxation'
STARTUP = 'startup'
OSCILLATION = 'oscillation'
LAOS = 'laos'
class rheojax.core.inventory.TransformType(*values)[source]

Bases: StrEnum

Categories of data transformation operations.

Transforms are classified by their mathematical operation on the data domain.

SPECTRAL = 'spectral'
SUPERPOSITION = 'superposition'
DECOMPOSITION = 'decomposition'
ANALYSIS = 'analysis'
PROCESSING = 'processing'

Examples

Creating RheoData

import numpy as np
from rheojax.core import RheoData

# Simple time-domain data
time = np.array([0.1, 1.0, 10.0])
stress = np.array([1000, 800, 600])
data = RheoData(
    x=time,
    y=stress,
    x_units="s",
    y_units="Pa",
    domain="time"
)

# Complex frequency-domain data
omega = np.logspace(-2, 2, 50)
Gp = 1000 * omega**0.5
Gpp = 500 * omega**0.3
G_star = Gp + 1j * Gpp

freq_data = RheoData(
    x=omega,
    y=G_star,
    x_units="rad/s",
    y_units="Pa",
    domain="frequency"
)

Working with Parameters

from rheojax.core import Parameter, ParameterSet

# Create parameter set
params = ParameterSet()
params.add(
    name="E",
    value=1000.0,
    bounds=(100, 10000),
    units="Pa",
    description="Elastic modulus"
)
params.add(
    name="tau",
    value=1.0,
    bounds=(0.01, 100),
    units="s",
    description="Relaxation time"
)

# Get/set values
E_value = params.get_value("E")
params.set_value("tau", 2.5)

# Array interface
values = params.get_values()  # [1000.0, 2.5]
params.set_values([2000, 1.5])

Test Mode Detection

from rheojax.core.test_modes import detect_test_mode, TestMode

# Automatic detection
mode = detect_test_mode(data)
print(mode)  # TestMode.RELAXATION

# Check test mode
if data.test_mode == TestMode.RELAXATION:
    print("This is a stress relaxation test")

Using Base Classes

from rheojax.core import BaseModel, ParameterSet
import jax.numpy as jnp

class MaxwellModel(BaseModel):
    def __init__(self, E=1000.0, tau=1.0):
        super().__init__()
        self.parameters.add("E", value=E, bounds=(1, 1e6), units="Pa")
        self.parameters.add("tau", value=tau, bounds=(0.01, 1000), units="s")

    def _fit(self, X, y, **kwargs):
        # Fitting implementation
        return self

    def _predict(self, X):
        E = self.parameters.get_value("E")
        tau = self.parameters.get_value("tau")
        return E * jnp.exp(-X / tau)

# Use model
model = MaxwellModel()
model.fit(time, stress)
predictions = model.predict(time)

Bayesian Inference

from rheojax.models import Maxwell
import numpy as np

# Generate data
t = np.linspace(0.1, 10, 50)
G_data = 1e5 * np.exp(-t / 0.01) + np.random.normal(0, 1e3, size=t.shape)

# 1. NLSQ optimization (fast point estimate)
model = Maxwell()
model.fit(t, G_data)
print(f"NLSQ: G0={model.parameters.get_value('G0'):.3e}")

# 2. Bayesian inference with warm-start
result = model.fit_bayesian(
    t, G_data,
    num_warmup=1000,
    num_samples=2000,
    num_chains=1
)

# 3. Analyze results
print(f"Posterior mean: G0={result.summary['G0']['mean']:.3e} +/- {result.summary['G0']['std']:.3e}")
print(f"Convergence: R-hat={result.diagnostics['r_hat']['G0']:.4f}")

# 4. Get credible intervals
intervals = model.get_credible_intervals(result.posterior_samples, credibility=0.95)
print(f"G0 95% CI: [{intervals['G0'][0]:.3e}, {intervals['G0'][1]:.3e}]")