Core Module (rheojax.core)¶
The core module provides fundamental data structures and abstractions for rheological analysis.
Data Container¶
RheoData¶
- class rheojax.core.data.RheoData(x=None, y=None, x_units=None, y_units=None, domain='time', initial_test_mode=None, metadata=<factory>, validate=True)[source]¶
Bases:
objectJAX-native container for rheological data with NumPy/JAX array support.
This class provides a unified interface for rheological data that supports both NumPy and JAX arrays with additional features needed for rheological analysis including automatic test mode detection, data validation, and domain-specific operations.
- x¶
Independent variable data (e.g., time, frequency)
- y¶
Dependent variable data (e.g., stress, strain, modulus)
- x_units¶
Units for x-axis data
- y_units¶
Units for y-axis data
- domain¶
Data domain (‘time’ or ‘frequency’)
- metadata¶
Dictionary of additional metadata
- validate¶
Whether to validate data on creation
- to_jax()[source]¶
Convert arrays to JAX arrays.
Returns cached result on subsequent calls — invalidated if x or y are reassigned.
- Return type:
- Returns:
New RheoData with JAX arrays
- to_numpy()[source]¶
Convert arrays to NumPy arrays.
Uses np.asarray() for zero-copy conversion when possible, providing 10-30% memory savings for large arrays (>100k points).
- Return type:
- Returns:
New RheoData with NumPy arrays
- property dtype¶
Data type of y data.
- property y_real: ndarray | Array¶
Get real component of y data.
For complex modulus data (G* = G’ + i·G’’), this returns the storage modulus (G’). For real data, returns y unchanged.
- Returns:
Real component of y data (G’ for complex modulus)
Example
>>> data = read_trios('frequency_sweep.txt') # Returns complex G* >>> G_prime = data[0].y_real # Storage modulus (G') >>> plt.loglog(data[0].x, G_prime, label="G'")
- property y_imag: ndarray | Array¶
Get imaginary component of y data.
For complex modulus data (G* = G’ + i·G’’), this returns the loss modulus (G’’). For real data, returns zeros.
- Returns:
Imaginary component of y data (G’’ for complex modulus)
Example
>>> data = read_trios('frequency_sweep.txt') # Returns complex G* >>> G_double_prime = data[0].y_imag # Loss modulus (G'') >>> plt.loglog(data[0].x, G_double_prime, label='G"')
- property storage_modulus: ndarray | None¶
Get storage modulus (G’) from complex modulus data.
Alias for y_real that makes rheological intent explicit.
- Returns:
Storage modulus (G’) if data is complex, None otherwise
Example
>>> data = read_trios('frequency_sweep.txt') >>> G_prime = data[0].storage_modulus
- property loss_modulus: ndarray | None¶
Get loss modulus (G’’) from complex modulus data.
Alias for y_imag that makes rheological intent explicit.
- Returns:
Loss modulus (G’’) if data is complex, None otherwise
Example
>>> data = read_trios('frequency_sweep.txt') >>> G_double_prime = data[0].loss_modulus
- property tan_delta: ndarray | None¶
Get loss tangent (tan δ = G’’/G’) from complex modulus data.
The loss tangent quantifies the ratio of viscous to elastic response: - tan δ < 1: Elastic-dominant (solid-like) - tan δ > 1: Viscous-dominant (liquid-like) - tan δ = 1: Equal elastic and viscous contributions
- Returns:
Loss tangent (dimensionless) if data is complex, None otherwise
Example
>>> data = read_trios('frequency_sweep.txt') >>> tan_d = data[0].tan_delta >>> print(f"Material type: {'solid-like' if tan_d.mean() < 1 else 'liquid-like'}")
- property test_mode: str¶
Automatically detect or retrieve test mode.
The test mode is detected based on data characteristics and cached in a private field. If already detected, returns the cached value. If explicitly set in metadata[‘test_mode’], returns that value.
- Returns:
Test mode string (relaxation, creep, oscillation, rotation, unknown)
- property deformation_mode: str¶
Get deformation mode from metadata.
Returns ‘shear’ if not explicitly set. Possible values: ‘shear’, ‘tension’, ‘bending’, ‘compression’.
- property storage_modulus_label: str¶
Get appropriate storage modulus label based on deformation mode.
Returns “E’” for tensile/bending/compression, “G’” for shear.
- property loss_modulus_label: str¶
Get appropriate loss modulus label based on deformation mode.
Returns ‘E”’ for tensile/bending/compression, ‘G”’ for shear.
- interpolate(new_x)[source]¶
Interpolate data to new x values.
- Parameters:
new_x (numpy.typing.ArrayLike) – New x values for interpolation
- Return type:
- Returns:
Interpolated RheoData
- derivative()[source]¶
Compute numerical derivative.
- Return type:
- Returns:
RheoData with derivative values
- integral()[source]¶
Compute numerical integral.
- Return type:
- Returns:
RheoData with integrated values
- to_frequency_domain()[source]¶
Convert time domain data to frequency domain.
- Return type:
- Returns:
Frequency domain RheoData
- to_time_domain()[source]¶
Convert frequency domain data to time domain.
- Return type:
- Returns:
Time domain RheoData
- __init__(x=None, y=None, x_units=None, y_units=None, domain='time', initial_test_mode=None, metadata=<factory>, validate=True)¶
Base Classes¶
BaseModel¶
- class rheojax.core.base.BaseModel[source]
Bases:
BayesianMixin,ABCAbstract base class for all rheological models.
This class defines the standard interface that all models must implement, supporting JAX arrays, scikit-learn style APIs, and Bayesian inference via NumPyro NUTS.
All models inherit Bayesian capabilities from BayesianMixin, including: - fit_bayesian(): Bayesian parameter estimation using NUTS - sample_prior(): Sample from prior distributions - get_credible_intervals(): Compute highest density intervals
The fit() method uses NLSQ optimization by default for fast point estimation, which can be used to warm-start Bayesian inference.
Abstract base class for all rheological models. Provides a consistent interface with support for scikit-learn style API and JAX arrays.
- __init__()[source]
Initialize base model.
- fit(X, y=None, method='nlsq', check_compatibility=False, use_log_residuals=None, use_multi_start=None, n_starts=5, perturb_factor=0.3, deformation_mode=None, poisson_ratio=0.5, auto_init=False, return_result=False, check_physics=False, uncertainty=None, **kwargs)[source]
Fit the model to data using NLSQ optimization.
This method uses NLSQ (GPU-accelerated nonlinear least squares) by default for fast point estimation. The optimization result is stored for potential warm-starting of Bayesian inference.
For very wide frequency ranges (>10 decades), multi-start optimization is automatically enabled to escape local minima.
- Parameters:
X (numpy.typing.ArrayLike) – Input features
y (
Optional[numpy.typing.ArrayLike]) – Target valuesmethod (
str) – Optimization method (‘nlsq’ by default for compatibility)check_compatibility (
bool) – Whether to check model-data compatibility before fitting. If True, warns when model may not be appropriate for data. Default is False for backward compatibility.use_log_residuals (
bool|None) – Whether to use log-space residuals for fitting. Recommended for wide frequency ranges (>8 decades) to prevent optimizer bias. If None (default), automatically detected based on data range. Explicit True/False overrides auto-detection.use_multi_start (
bool|None) – Whether to use multi-start optimization to escape local minima. Recommended for very wide ranges (>10 decades). If None (default), automatically enabled for >10 decades.n_starts (
int) – Number of random starts for multi-start optimization (default: 5)perturb_factor (
float) – Perturbation magnitude for multi-start random starts (default: 0.3). Parameters are perturbed by ± perturb_factor * (value or range). Larger values (0.7-0.9) explore wider parameter space.auto_init (
bool) – If True, callsauto_p0()to estimate initial parameters from data before running the optimizer (default: False).return_result (
bool) – If True, returns aFitResultinstead ofself. This intentionally breaks method chaining for workflows that need structured result objects (default: False).check_physics (
bool) – If True, runs post-fit physics validation and emitsRheoJaxPhysicsWarningfor any violations (default: False).uncertainty (
str|None) – Post-fit uncertainty method."hessian"for fast Cramér-Rao bounds,"bootstrap"for residual bootstrap CIs, orNoneto skip (default: None).**kwargs – Additional fitting options passed to _fit()
- Return type:
BaseModel|Any- Returns:
selffor method chaining (default), orFitResultifreturn_result=True.
Example
>>> model = Maxwell() >>> model.fit(t, G_data) # Uses NLSQ by default >>> model.fit(t, G_data, method='nlsq', max_iter=1000) >>> model.fit(t, G_data, check_compatibility=True) # Check compatibility >>> model.fit(omega, G_star, use_log_residuals=True) # Force log-residuals >>> model.fit(mastercurve, None, use_multi_start=True, n_starts=10) # Multi-start >>> result = model.fit(t, G_data, return_result=True) # Structured result >>> result = model.fit(t, G_data, auto_init=True, check_physics=True, ... return_result=True) # Full pipeline
- precompile(test_mode='relaxation', X=None, y=None)[source]
Precompile NLSQ residual functions to eliminate JIT cold-start.
Triggers JIT compilation by running a minimal fit (
max_iter=1) with dummy data. The model parameters are reset afterwards so the model is left in its original state.This is useful for interactive sessions or benchmarks where the ~870ms first-fit JIT overhead should be excluded.
- Parameters:
- Return type:
- Returns:
Compilation time in seconds.
Example
>>> model = Maxwell() >>> t = model.precompile(test_mode='relaxation') >>> print(f"Compiled in {t:.2f}s") >>> model.fit(X, y) # No JIT overhead
- fit_bayesian(X, y=None, num_warmup=1000, num_samples=2000, num_chains=4, initial_values=None, test_mode=None, seed=None, deformation_mode=None, poisson_ratio=0.5, **nuts_kwargs)[source]
Perform Bayesian inference using NumPyro NUTS sampler.
This method delegates to BayesianMixin.fit_bayesian() to run NUTS sampling for Bayesian parameter estimation. If initial_values is not provided and the model has been previously fitted with fit(), the NLSQ point estimates are automatically used for warm-starting.
Multi-chain sampling is enabled by default (num_chains=4) to provide reliable convergence diagnostics (R-hat, ESS) and parallel execution on multi-GPU systems.
- Parameters:
X (numpy.typing.ArrayLike) – Independent variable data (input features) or RheoData object
y (
Optional[numpy.typing.ArrayLike]) – Dependent variable data (observations to fit). If X is RheoData, y is ignored and extracted from X.num_warmup (
int) – Number of warmup/burn-in iterations (default: 1000)num_samples (
int) – Number of posterior samples per chain (default: 2000)num_chains (
int) – Number of MCMC chains (default: 4). Multiple chains enable proper R-hat computation and parallel execution. Chain method is auto-selected: ‘parallel’ on multi-GPU, ‘vectorized’ on single GPU/CPU.initial_values (
dict[str,float] |None) – Optional dict of initial parameter values for warm-start. If None and model is fitted, uses NLSQ estimates.test_mode (
str|None) – Explicit test mode (e.g., ‘relaxation’, ‘creep’, ‘oscillation’). If None, inferred from RheoData.metadata[‘test_mode’] or defaults to ‘relaxation’. Overrides RheoData metadata if provided.seed (
int|None) – Random seed for reproducibility. If None, uses seed=0 for deterministic results. Set to different values for independent runs.**nuts_kwargs – Additional arguments passed to NUTS sampler (e.g., target_accept_prob, chain_method)
- Return type:
BayesianResult- Returns:
BayesianResult containing posterior samples, summary statistics, and convergence diagnostics (R-hat, ESS, divergences)
Example
>>> model = Maxwell() >>> # Warm-start from NLSQ with explicit mode >>> model.fit(t, G_data, test_mode='relaxation') # NLSQ optimization >>> result = model.fit_bayesian(t, G_data, test_mode='relaxation') >>> >>> # RheoData with embedded mode (recommended) >>> rheo_data = RheoData(x=omega, y=G_star, metadata={'test_mode': 'oscillation'}) >>> result = model.fit_bayesian(rheo_data) >>> >>> # Or provide explicit initial values >>> result = model.fit_bayesian( ... t, G_data, ... initial_values={'G0': 1e5, 'eta': 1e3}, ... test_mode='creep' ... )
- predict(X, test_mode=None, deformation_mode=None, poisson_ratio=None, **kwargs)[source]
Make predictions.
- Parameters:
X (numpy.typing.ArrayLike) – Input features
test_mode (
str|None) – Optional test mode (‘oscillation’, ‘relaxation’, ‘creep’, ‘flow’). If provided, sets model’s test_mode before prediction. Useful for data generation without fitting.deformation_mode (
str|DeformationMode|None) – Optional deformation mode for output conversion. If None, uses the mode stored from fit(). If tensile, converts G* predictions to E* space.poisson_ratio (
float|None) – Poisson’s ratio for conversion. If None, uses value stored from fit() (default 0.5).**kwargs – Additional arguments passed to the internal _predict method.
- Return type:
numpy.typing.ArrayLike
- Returns:
Model predictions (in E* space if deformation_mode is tensile)
- fit_predict(X, y, **kwargs)[source]
Fit model and return predictions.
- Parameters:
X (numpy.typing.ArrayLike) – Input features
y (numpy.typing.ArrayLike) – Target values
**kwargs – Additional fitting options
- Return type:
numpy.typing.ArrayLike
- Returns:
Model predictions on training data
- get_nlsq_result()[source]
Get stored NLSQ optimization result.
- Returns:
OptimizationResult from NLSQ fit, or None if not fitted
Example
>>> model.fit(t, G_data) >>> result = model.get_nlsq_result() >>> if result: ... print(f"Converged: {result.success}")
- property pcov_
Parameter covariance matrix from NLSQ fit.
- Returns:
ndarray of shape (n_params, n_params), or None if not fitted
- property popt_
Optimal parameter values from NLSQ fit.
- Returns:
ndarray of shape (n_params,), or None if not fitted
- get_parameter_uncertainties()[source]
Get standard errors for fitted parameters from NLSQ covariance.
- Returns:
std_error}, or None if covariance unavailable
- Return type:
dict of {param_name
- get_bayesian_result()[source]
Get stored Bayesian inference result.
- Return type:
BayesianResult|None- Returns:
BayesianResult from fit_bayesian(), or None if not run
Example
>>> model.fit_bayesian(t, G_data) >>> result = model.get_bayesian_result() >>> print(result.diagnostics['r_hat'])
- get_params(deep=True)[source]
Get model parameters.
- set_params(**params)[source]
Set model parameters.
- Parameters:
**params – Parameter names and values
- Return type:
BaseModel- Returns:
self for method chaining
- score(X, y)[source]
Compute model score (R² by default).
- Parameters:
X (numpy.typing.ArrayLike) – Input features
y (numpy.typing.ArrayLike) – True target values
- Return type:
- Returns:
Model score (R² coefficient)
- to_dict()[source]
Serialize model to dictionary.
- classmethod from_dict(data)[source]
Create model from dictionary.
BaseTransform¶
- class rheojax.core.base.BaseTransform[source]
Bases:
ABCAbstract base class for all data transforms.
This class defines the standard interface that all transforms must implement, supporting JAX arrays and composable transformations.
Abstract base class for data transforms. Supports fit, transform, and inverse_transform operations with pipeline composition.
- __init__()[source]
Initialize base transform.
- transform(data)[source]
Transform the data.
- Parameters:
data – Input data (RheoData or list[RheoData])
- Returns:
Transformed data (RheoData or tuple[RheoData, dict])
- inverse_transform(data)[source]
Apply inverse transformation.
- Parameters:
data – Transformed data (RheoData)
- Returns:
Original data (RheoData)
- fit(data)[source]
Fit the transform to data (learn parameters if needed).
- Parameters:
data – Training data (RheoData or list[RheoData])
- Return type:
BaseTransform- Returns:
self for method chaining
- fit_transform(data)[source]
Fit to data and transform it.
- Parameters:
data – Input data (RheoData or list[RheoData])
- Returns:
Transformed data (RheoData or tuple[RheoData, dict])
- __add__(other)[source]
Compose transforms using + operator.
- Parameters:
other (
BaseTransform) – Another transform to compose- Return type:
TransformPipeline- Returns:
Pipeline of transforms
- batch_transform(datasets)[source]
Transform multiple datasets sequentially.
Applies the transform to each dataset in order. Sequential execution is required because JAX JIT compilation is not thread-safe.
TransformPipeline¶
- class rheojax.core.base.TransformPipeline(transforms)[source]
Bases:
BaseTransformPipeline of multiple transforms applied sequentially.
Pipeline for composing multiple transforms that are applied sequentially.
- __init__(transforms)[source]
Initialize transform pipeline.
- Parameters:
transforms (
list[BaseTransform]) – List of transforms to apply in order
- fit(data)[source]
Fit all transforms in the pipeline.
- Parameters:
data (numpy.typing.ArrayLike) – Training data
- Return type:
TransformPipeline- Returns:
self for method chaining
Parameters¶
Parameter¶
- class rheojax.core.parameters.Parameter(name, value=None, bounds=None, units=None, description=None, constraints=None)[source]
Bases:
objectSingle parameter with value, bounds, and metadata.
A Parameter represents a model parameter with support for bounds validation, units tracking, and constraint enforcement. Parameters can be used in both NLSQ optimization and Bayesian inference workflows.
- name
Parameter identifier used for lookup and serialization.
- value
Current parameter value (may be None if unset).
- bounds
Lower and upper bounds as tuple (min, max).
- units
Physical units string for display (e.g., “Pa”, “s”).
- description
Human-readable description.
- constraints
List of ParameterConstraint objects for validation.
Example
>>> param = Parameter("G0", value=1e5, bounds=(1e3, 1e9), units="Pa") >>> param.value = 2e5 # Validated against bounds >>> param.validate() True
Single parameter with value, bounds, units, and constraints.
Attributes:
name (str) – Parameter name
value (float | None) – Current value
bounds (tuple[float, float] | None) – (min, max) bounds
units (str | None) – Physical units
description (str | None) – Parameter description
constraints (list[ParameterConstraint]) – List of constraints
- __init__(name, value=None, bounds=None, units=None, description=None, constraints=None)[source]
- name
- units
- description
- constraints
- property was_clamped: bool
Return True if the last assignment clamped the value.
- validate(value, context=None)[source]
Validate value against all constraints.
- __hash__()[source]
Make Parameter hashable for use as dict keys.
- Return type:
- Returns:
Hash based on immutable identity attributes only
- __eq__(other)[source]
Check equality with another Parameter.
Matches __hash__: identity-based on (name, bounds, units). Value is excluded because it changes during fitting while the parameter identity remains the same.
- classmethod from_dict(data)[source]
Create from dictionary representation.
- Return type:
Parameter
ParameterSet¶
- class rheojax.core.parameters.ParameterSet[source]
Bases:
objectCollection of parameters for a model or transform.
A ParameterSet manages multiple Parameter objects with dict-like access, batch operations, and serialization support. It is the primary interface for working with model parameters in RheoJAX.
- Key Features:
Dict-like access:
params["G0"]orparams.get("G0")Batch operations:
get_values(),set_values(),get_bounds()Unpack helper:
G0, eta = params.unpack("G0", "eta")Serialization:
to_dict()/from_dict()for JSON/HDF5
Example
>>> params = ParameterSet() >>> params.add("G0", value=1e5, bounds=(1e3, 1e9), units="Pa") >>> params.add("eta", value=1e3, bounds=(1e-3, 1e9), units="Pa*s") >>> G0, eta = params.unpack("G0", "eta") >>> print(f"G0={G0:.2e}, eta={eta:.2e}") G0=1.00e+05, eta=1.00e+03
See also
Parameter: Individual parameter class. SharedParameterSet: For multi-model parameter sharing.
Collection of parameters for a model or transform.
- __init__()[source]
Initialize empty parameter set.
- add(name, value=None, bounds=None, units=None, description=None, constraints=None, overwrite=False)[source]
Add a parameter to the set.
- Parameters:
- Return type:
Parameter- Returns:
The created Parameter object
- get(name)[source]
Get a parameter by name.
- set_value(name, value)[source]
Set parameter value.
- Parameters:
- Raises:
KeyError – If parameter not found
ValueError – If value violates constraints
- set_bounds(name, bounds)[source]
Set bounds for a parameter.
- get_values()[source]
Get all parameter values as array.
- Return type:
- Returns:
Array of parameter values in order
- set_values(values)[source]
Set parameter values from array or dictionary.
- Parameters:
values (
Union[numpy.typing.ArrayLike,dict[str,float]]) – Array of values in order, or dict mapping names to values- Raises:
ValueError – If wrong number of values (array) or unknown parameter (dict)
- update(values, *, strict=True)[source]
Apply a batch of name→value updates with optional failure tolerance.
Replacement for the
for k, v in d.items(): try: set_value(k, v) except: logger.warning(...)pattern found in notebooks that mix parameters from different model schemas. Two failure modes are reported separately so the caller (or a schema-migration review) can tell “this key does not exist on this model” from “this value violates the constraints” without scanning ERROR-level logs.- Parameters:
values (
dict[str,float]) – Mapping of parameter name → new value.strict (
bool) – WhenTrue(default), re-raises the firstKeyError(unknown name) orValueError(bad value) so calling code cannot silently drift out of the current schema. WhenFalse, collects every failure into the returned dict without logging at ERROR level — useful during migration to draft a single summary warning.
- Return type:
- Returns:
Dict of
{name: reason}for entries that failed. Empty when all succeeded (including whenvaluesis empty).- Raises:
KeyError – (strict=True) if any name is unknown.
ValueError – (strict=True) if any value violates constraints.
- get_bounds()[source]
Get bounds for all parameters.
- get_value(name)[source]
Get value of a specific parameter.
- unpack(*names)[source]
Extract multiple parameter values in a single call.
This method provides a concise way to extract several parameter values at once, reducing boilerplate in model implementations.
- Parameters:
*names (
str) – Parameter names to extract- Return type:
- Returns:
Tuple of parameter values in the same order as requested. Returns None for parameters with None values.
- Raises:
KeyError – If any parameter name is not found. The error message includes the missing name and lists available parameters.
Examples
Basic usage - extract multiple parameters in one line:
>>> params = ParameterSet() >>> _ = params.add('x', value=1.5) >>> _ = params.add('G0', value=100.0) >>> _ = params.add('tau0', value=0.01) >>> x, G0, tau0 = params.unpack('x', 'G0', 'tau0') >>> x 1.5 >>> G0 100.0
Before (verbose):
x = params.get_value('x') G0 = params.get_value('G0') tau0 = params.get_value('tau0')
After (concise):
x, G0, tau0 = params.unpack('x', 'G0', 'tau0')
- __iter__()[source]
Iterate over parameter names.
- keys()[source]
Return an iterator over parameter names (dict-like interface).
- Returns:
Iterator over parameter names in order
Examples
>>> params = ParameterSet() >>> params.add('alpha', value=0.5) >>> params.add('beta', value=1.0) >>> list(params.keys()) ['alpha', 'beta']
- values()[source]
Return an iterator over Parameter objects (dict-like interface).
- Returns:
Iterator over Parameter objects in order
Examples
>>> params = ParameterSet() >>> params.add('alpha', value=0.5, units='') >>> for param in params.values(): ... print(f"{param.name}: {param.value}") alpha: 0.5
- items()[source]
Return an iterator over (name, Parameter) tuples (dict-like interface).
- Returns:
Iterator over (name, Parameter) tuples in order
Examples
>>> params = ParameterSet() >>> params.add('alpha', value=0.5) >>> for name, param in params.items(): ... print(f"{name}: {param.value}") alpha: 0.5
- __getitem__(key)[source]
Get parameter by name using subscript notation.
- Parameters:
key (
str) – Parameter name- Return type:
Parameter- Returns:
Parameter object
- Raises:
KeyError – If parameter not found
Examples
>>> params = ParameterSet() >>> params.add('alpha', value=0.5) >>> param = params['alpha'] # Get parameter object >>> value = params['alpha'].value # Get value
- __setitem__(key, value)[source]
Set parameter value using subscript notation.
- Parameters:
- Raises:
KeyError – If parameter not found and value is float
ValueError – If value violates constraints
Examples
>>> params = ParameterSet() >>> params.add('alpha', value=0.5, bounds=(0, 1)) >>> params['alpha'] = 0.7 # Set value >>> # Or replace entire parameter: >>> params['alpha'] = Parameter('alpha', value=0.8, bounds=(0, 1))
- classmethod from_dict(data)[source]
Create from dictionary representation.
Uses Parameter.from_dict() to preserve constraints (not just bounds).
- Return type:
ParameterSet
ParameterConstraint¶
- class rheojax.core.parameters.ParameterConstraint(type, min_value=None, max_value=None, value=None, relation=None, other_param=None, validator=None)[source]
Bases:
objectConstraint on a parameter value.
Constraint on a parameter value.
Types:
"bounds": Min/max value bounds"positive": Must be > 0"integer": Must be an integer"fixed": Fixed to specific value"relative": Relative to another parameter"custom": Custom validator function
- type: str
- validate(value, context=None)[source]
Check if value satisfies the constraint.
- __init__(type, min_value=None, max_value=None, value=None, relation=None, other_param=None, validator=None)
ParameterOptimizer¶
- class rheojax.core.parameters.ParameterOptimizer(parameters, use_jax=False, track_history=False)[source]
Bases:
objectOptimizer for parameter fitting.
Optimizer for parameter fitting with JAX gradient support.
- __init__(parameters, use_jax=False, track_history=False)[source]
Initialize parameter optimizer.
- property n_parameters: int
Number of parameters.
- set_objective(objective)[source]
Set objective function to minimize.
- Parameters:
objective (
Callable) – Function that takes parameter values and returns scalar
- evaluate(values)[source]
Evaluate objective at given values.
- Parameters:
values (numpy.typing.ArrayLike) – Parameter values
- Return type:
- Returns:
Objective function value
- compute_gradient(values)[source]
Compute gradient of objective.
- Parameters:
values (numpy.typing.ArrayLike) – Parameter values
- Return type:
- Returns:
Gradient vector
- add_constraint(constraint)[source]
Add optimization constraint.
- Parameters:
constraint (
Callable) – Function that returns >= 0 for valid values
- validate_constraints(values)[source]
Check if constraints are satisfied.
- Parameters:
values (numpy.typing.ArrayLike) – Parameter values
- Return type:
- Returns:
True if all constraints satisfied
- set_callback(callback)[source]
Set optimization callback.
- Parameters:
callback (
Callable) – Function called after each iteration
- step(values, iteration=None)[source]
Perform one optimization step.
Test Modes¶
Test mode detection for rheological data.
This module provides automatic detection of rheological test modes based on data characteristics, units, and metadata.
- class rheojax.core.test_modes.DeformationMode(*values)[source]¶
Bases:
StrEnumDeformation geometry mode for rheological measurements.
This enum classifies the type of mechanical deformation applied during a rheological measurement. Shear-based instruments (rotational rheometers) measure G*(w), while tensile/bending/compression instruments (DMTA/DMA) measure E*(w). The relationship is:
E*(w) = 2(1 + v) * G*(w)
where v is Poisson’s ratio of the material.
- SHEAR = 'shear'¶
- TENSION = 'tension'¶
- BENDING = 'bending'¶
- COMPRESSION = 'compression'¶
- class rheojax.core.test_modes.TestModeEnum(*values)[source]¶
Bases:
StrEnumEnumeration of rheological test modes.
Note: Named TestModeEnum (not TestMode) to avoid pytest collection warnings. Pytest treats classes starting with ‘Test’ and ending without ‘Enum’ as test classes.
- Note on EPM/Flow protocols:
FLOW_CURVE: Steady-state stress vs shear rate (same physics as ROTATION)
STARTUP: Transient stress evolution at constant shear rate
ROTATION: Generic rotational/steady shear mode (legacy)
- RELAXATION = 'relaxation'¶
- CREEP = 'creep'¶
- OSCILLATION = 'oscillation'¶
- LAOS = 'laos'¶
- ROTATION = 'rotation'¶
- FLOW_CURVE = 'flow_curve'¶
- STARTUP = 'startup'¶
- UNKNOWN = 'unknown'¶
- rheojax.core.test_modes.RheoTestMode¶
alias of
TestModeEnum
- rheojax.core.test_modes.TestMode¶
alias of
TestModeEnum
- rheojax.core.test_modes.is_monotonic_increasing(data, strict=False, tolerance=1e-10, allow_fraction=0.1)[source]¶
Check if data is mostly monotonically increasing.
- Parameters:
- Return type:
- Returns:
True if data is mostly monotonically increasing
- rheojax.core.test_modes.is_monotonic_decreasing(data, strict=False, tolerance=1e-10, allow_fraction=0.1)[source]¶
Check if data is mostly monotonically decreasing.
- Parameters:
- Return type:
- Returns:
True if data is mostly monotonically decreasing
- rheojax.core.test_modes.detect_test_mode(rheo_data)[source]¶
Detect rheological test mode from data characteristics.
The detection algorithm uses the following heuristics:
Check metadata[‘test_mode’] if explicitly provided
Check domain and units:
frequency domain with rad/s or Hz → OSCILLATION
time domain with 1/s or s^-1 x-units → ROTATION
Check monotonicity for time-domain data:
monotonic decreasing → RELAXATION
monotonic increasing → CREEP
Fall back to UNKNOWN if ambiguous
- Parameters:
rheo_data (
RheoData) – RheoData object to analyze- Return type:
- Returns:
Detected TestMode
- Raises:
ValueError – If rheo_data is invalid
- rheojax.core.test_modes.validate_test_mode(test_mode)[source]¶
Validate and convert test mode to TestMode enum.
- Parameters:
test_mode (
str|TestModeEnum) – Test mode as string or TestMode enum- Return type:
- Returns:
TestMode enum
- Raises:
ValueError – If test_mode is invalid
- rheojax.core.test_modes.get_compatible_test_modes(model_name)[source]¶
Get compatible test modes for a given model.
Queries the ModelRegistry to determine which test modes are supported by the specified model, using the Protocol-Driven Inventory System.
- Parameters:
model_name (
str) – Name of the rheological model- Return type:
- Returns:
List of compatible TestMode values
- rheojax.core.test_modes.suggest_models_for_test_mode(test_mode)[source]¶
Suggest appropriate models for a given test mode.
Queries the ModelRegistry to find models compatible with the specified test mode using the Protocol-Driven Inventory System.
- Parameters:
test_mode (
TestModeEnum) – Detected test mode- Return type:
- Returns:
List of recommended model names
TestMode¶
- rheojax.core.test_modes.TestMode
Enumeration of rheological test modes.
Values:
RELAXATION: Stress relaxation testCREEP: Creep compliance testOSCILLATION: Oscillatory (SAOS/LAOS) testROTATION: Steady shear (flow curve) testFLOW_CURVE: Steady-state stress vs shear rateSTARTUP: Transient stress at constant shear rateLAOS: Large Amplitude Oscillatory ShearUNKNOWN: Unknown or ambiguous test type
alias of
TestModeEnum
DeformationMode¶
- class rheojax.core.test_modes.DeformationMode(*values)[source]
Bases:
StrEnumDeformation geometry mode for rheological measurements.
This enum classifies the type of mechanical deformation applied during a rheological measurement. Shear-based instruments (rotational rheometers) measure G*(w), while tensile/bending/compression instruments (DMTA/DMA) measure E*(w). The relationship is:
E*(w) = 2(1 + v) * G*(w)
where v is Poisson’s ratio of the material.
Deformation geometry for rheological measurements. Controls whether models work with shear modulus G* or Young’s modulus E*.
Values:
SHEAR: Rotational rheometer geometry (measures G*)TENSION: DMTA/DMA tensile geometry (measures E*)BENDING: DMTA/DMA bending geometry (measures E*)COMPRESSION: DMTA/DMA compression geometry (measures E*)
Conversion:
\[E^*(\omega) = 2(1 + \nu) \, G^*(\omega)\]where \(\nu\) is the Poisson’s ratio of the material.
Usage with models:
from rheojax.models import Maxwell model = Maxwell() model.fit( omega, E_star, test_mode='oscillation', deformation_mode='tension', poisson_ratio=0.5, # rubber )
See
rheojax.utils.modulus_conversionfor array-level conversion utilities.- SHEAR = 'shear'
- TENSION = 'tension'
- BENDING = 'bending'
- COMPRESSION = 'compression'
Functions¶
- rheojax.core.test_modes.detect_test_mode(rheo_data)[source]
Detect rheological test mode from data characteristics.
The detection algorithm uses the following heuristics:
Check metadata[‘test_mode’] if explicitly provided
Check domain and units:
frequency domain with rad/s or Hz → OSCILLATION
time domain with 1/s or s^-1 x-units → ROTATION
Check monotonicity for time-domain data:
monotonic decreasing → RELAXATION
monotonic increasing → CREEP
Fall back to UNKNOWN if ambiguous
- Parameters:
rheo_data (
RheoData) – RheoData object to analyze- Return type:
- Returns:
Detected TestMode
- Raises:
ValueError – If rheo_data is invalid
- rheojax.core.test_modes.validate_test_mode(test_mode)[source]
Validate and convert test mode to TestMode enum.
- Parameters:
test_mode (
str|TestModeEnum) – Test mode as string or TestMode enum- Return type:
- Returns:
TestMode enum
- Raises:
ValueError – If test_mode is invalid
- rheojax.core.test_modes.is_monotonic_increasing(data, strict=False, tolerance=1e-10, allow_fraction=0.1)[source]
Check if data is mostly monotonically increasing.
- Parameters:
- Return type:
- Returns:
True if data is mostly monotonically increasing
- rheojax.core.test_modes.is_monotonic_decreasing(data, strict=False, tolerance=1e-10, allow_fraction=0.1)[source]
Check if data is mostly monotonically decreasing.
- Parameters:
- Return type:
- Returns:
True if data is mostly monotonically decreasing
- rheojax.core.test_modes.get_compatible_test_modes(model_name)[source]
Get compatible test modes for a given model.
Queries the ModelRegistry to determine which test modes are supported by the specified model, using the Protocol-Driven Inventory System.
- Parameters:
model_name (
str) – Name of the rheological model- Return type:
- Returns:
List of compatible TestMode values
- rheojax.core.test_modes.suggest_models_for_test_mode(test_mode)[source]
Suggest appropriate models for a given test mode.
Queries the ModelRegistry to find models compatible with the specified test mode using the Protocol-Driven Inventory System.
- Parameters:
test_mode (
TestModeEnum) – Detected test mode- Return type:
- Returns:
List of recommended model names
Bayesian Inference¶
The Bayesian inference module provides NumPyro NUTS sampling capabilities with NLSQ warm-start for all rheological models through the BayesianMixin class.
BayesianMixin¶
- class rheojax.core.bayesian.BayesianMixin[source]
Bases:
objectMixin class providing Bayesian inference capabilities via NumPyro NUTS.
This mixin adds methods for Bayesian parameter estimation to any class that has a parameters attribute (ParameterSet). It implements: - NUTS sampling for posterior inference - Prior sampling for prior predictive checks - Credible interval computation (highest density intervals) - Convergence diagnostics (R-hat, ESS)
The mixin is designed to be composed with model classes, typically through BaseModel. All 20+ rheological models automatically inherit these capabilities when BaseModel is extended with BayesianMixin.
- Requirements:
Class must have parameters attribute (ParameterSet)
Class must define model_function(X, params) method for predictions
Class must have X_data and y_data attributes when fitting
Example
>>> class MyModel(BayesianMixin): ... def __init__(self): ... self.parameters = ParameterSet() ... self.parameters.add("a", bounds=(0, 10)) ... self.X_data = None ... self.y_data = None ... ... def model_function(self, X, params): ... return params[0] * X ... >>> model = MyModel() >>> model.X_data = X >>> model.y_data = y >>> result = model.fit_bayesian(X, y)
Mixin class that adds Bayesian inference capabilities to models. Provides:
NLSQ -> NUTS warm-start workflow (2-5x faster convergence)
Automatic prior specification from parameter bounds
Credible interval calculation
Model function for NumPyro NUTS sampling
- parameters: ParameterSet
- sample_prior(num_samples=1000, seed=None)[source]
Sample from prior distributions over parameter bounds.
Samples from uniform prior distributions defined by parameter bounds. This is useful for prior predictive checks and understanding the prior’s influence on the posterior.
- Parameters:
- Return type:
- Returns:
Dictionary mapping parameter names to arrays of prior samples. Each array has shape (num_samples,) and dtype float64.
- Raises:
AttributeError – If class doesn’t have parameters attribute
ValueError – If any parameter lacks bounds
Example
>>> model = MyModel() >>> prior_samples = model.sample_prior(num_samples=500, seed=42) >>> print(prior_samples["a"].shape) # (500,)
- get_credible_intervals(posterior_samples, credibility=0.95)[source]
Compute highest density intervals (HDI) for posterior samples.
Computes the highest posterior density interval for each parameter, which is the shortest interval containing the specified probability mass. This is preferred over equal-tailed intervals for skewed distributions.
- Parameters:
- Return type:
- Returns:
Dictionary mapping parameter names to (lower, upper) credible interval tuples. All values are float64.
Example
>>> result = model.fit_bayesian(X, y) >>> intervals_95 = model.get_credible_intervals( ... result.posterior_samples, credibility=0.95 ... ) >>> print(f"95% CI for a: {intervals_95['a']}")
- precompile_bayesian(X=None, y=None, test_mode=None, num_chains=4)[source]
Precompile NUTS kernel to eliminate JIT overhead in subsequent calls.
Triggers JIT compilation of the NumPyro model by running a minimal sampling (1 warmup, 1 sample). This caches the compiled kernel so that subsequent fit_bayesian() calls are 2-5x faster.
- Parameters:
X (
ndarray|RheoData|None) – Sample input data for determining array shapes. If None, uses a default 10-point linspace [0.01, 100].y (
ndarray|None) – Sample output data. If None, generates dummy data.test_mode (
str|TestModeEnum|None) – Test mode to precompile for. If None, defaults to ‘relaxation’.
- Returns:
Compilation time in seconds.
- Return type:
Example
>>> model = Maxwell() >>> compile_time = model.precompile_bayesian(test_mode='relaxation') >>> print(f"Compiled in {compile_time:.1f}s") >>> # Now fit_bayesian() will be faster >>> result = model.fit_bayesian(X, y) # No compilation overhead
- fit_bayesian(X, y=None, num_warmup=1000, num_samples=2000, num_chains=4, initial_values=None, test_mode=None, seed=None, **nuts_kwargs)[source]
Perform Bayesian inference using NumPyro NUTS sampler.
Runs NUTS (No-U-Turn Sampler) to obtain posterior samples for model parameters. Supports warm-starting from NLSQ point estimates for faster convergence. Uses uniform priors over parameter bounds.
Multi-chain sampling is enabled by default (num_chains=4) to provide reliable convergence diagnostics (R-hat, ESS) and parallel execution on multi-GPU systems.
CRITICAL: test_mode is captured in model_function closure to ensure correct posteriors for all test modes (relaxation, creep, oscillation).
- Parameters:
X (
ndarray|RheoData) – Independent variable data (input features) or RheoData objecty (
ndarray|None) – Dependent variable data (observations to fit). If X is RheoData, y is ignored and extracted from X.num_warmup (
int) – Number of warmup/burn-in iterations (default: 1000)num_samples (
int) – Number of posterior samples per chain (default: 2000)num_chains (
int) – Number of MCMC chains (default: 4). Multiple chains enable proper R-hat computation and parallel execution.initial_values (
dict[str,float] |None) – Optional dict of initial parameter values for warm-start (e.g., from NLSQ). Keys are parameter names.test_mode (
str|TestModeEnum|None) – Explicit test mode (e.g., ‘relaxation’, ‘creep’, ‘oscillation’). If None, inferred from RheoData.metadata[‘test_mode’] or defaults to ‘relaxation’. Overrides RheoData metadata if provided.seed (
int|None) – Random seed for reproducibility. If None, uses seed=0 for deterministic results. Set to different values for independent runs.**nuts_kwargs – Additional arguments passed to NUTS sampler (e.g., target_accept_prob, chain_method)
- Return type:
BayesianResult- Returns:
BayesianResult containing posterior_samples, summary, diagnostics.
Example
>>> result = model.fit_bayesian(X, y, test_mode='oscillation') >>> print(result.diagnostics["r_hat"]) # Should be < 1.01 >>> >>> # For production: use num_chains=4 (default) >>> result = model.fit_bayesian(X, y, num_chains=4)
BayesianResult¶
- class rheojax.core.bayesian.BayesianResult(posterior_samples, summary, diagnostics, num_samples, num_chains, mcmc=None, model_comparison=<factory>, _inference_data=None, _inference_data_ll=None)[source]
Bases:
objectResults from Bayesian inference with NUTS sampling.
This dataclass stores the complete results of NumPyro NUTS sampling, including posterior samples, summary statistics, convergence diagnostics, and placeholders for future model comparison metrics.
- posterior_samples
Dictionary mapping parameter names to arrays of posterior samples (shape: [num_samples * num_chains, ]). All arrays are float64.
- summary
Dictionary with summary statistics for each parameter. Contains nested dicts with ‘mean’, ‘std’, and quantiles.
- diagnostics
Dictionary with convergence diagnostics including: - r_hat: Gelman-Rubin statistic for each parameter (dict) - ess: Effective sample size for each parameter (dict) - divergences: Number of divergent transitions (int)
- num_samples
Number of posterior samples per chain (after warmup).
- num_chains
Number of MCMC chains used in sampling.
- mcmc
NumPyro MCMC object containing full sampling information including NUTS-specific diagnostics (energy, divergences, tree depth). Required for ArviZ visualization with full diagnostics.
- model_comparison
Dictionary for model comparison metrics (WAIC, LOO). Currently a placeholder for future implementation.
- _inference_data
Cached ArviZ InferenceData object. Automatically created on first access via to_inference_data(). Do not set manually.
Example
>>> result = model.fit_bayesian(X, y) >>> print(result.summary["a"]["mean"]) >>> print(result.diagnostics["r_hat"]["a"]) >>> # Convert to ArviZ InferenceData for advanced plotting >>> idata = result.to_inference_data()
Dataclass storing complete Bayesian inference results:
Attributes:
posterior_samples: Dict mapping parameter names to posterior samples (float64 arrays)summary: Dict with summary statistics (mean, std, quantiles) for each parameterdiagnostics: Convergence diagnostics including R-hat, ESS, divergenceswaic: WAIC model comparison metric (if computed)loo: LOO cross-validation metric (if computed)inference_data: ArviZ InferenceData object for advanced diagnostics
- diagnostics: DiagnosticsDict
- num_samples: int
- num_chains: int
- __post_init__()[source]
Validate result after initialization.
- to_inference_data(log_likelihood=False)[source]
Convert to ArviZ InferenceData format for advanced visualization.
Converts the NumPyro MCMC result to ArviZ InferenceData format, which enables access to ArviZ’s comprehensive plotting and diagnostic tools. The conversion preserves all NUTS-specific diagnostics including energy, divergences, and tree depth information.
The InferenceData object is cached after first conversion to avoid repeated conversion overhead. The
log_likelihood=Falseandlog_likelihood=Truevariants are cached independently.- Parameters:
log_likelihood (
bool) – If True, compute pointwise log-likelihood for WAIC/LOO model comparison (az.waic(), az.loo()). This re-evaluates the model for all samples (~600-800ms slower). Default False for faster conversion when only plotting.- Returns:
posterior: Posterior samples for all parameters
sample_stats: NUTS diagnostics (energy, divergences, etc.)
log_likelihood: Only when
log_likelihood=TrueAdditional groups as available from NumPyro
- Return type:
- Raises:
ImportError – If arviz is not installed
ValueError – If MCMC object was not stored (older results)
Example
>>> result = model.fit_bayesian(X, y) >>> idata = result.to_inference_data() # Fast: no log-lik >>> az.plot_trace(idata) >>> >>> # For model comparison (slower): >>> idata_ll = result.to_inference_data(log_likelihood=True) >>> az.waic(idata_ll)
Note
Requires arviz package: pip install arviz The MCMC object must be present (automatically stored by fit_bayesian).
- __init__(posterior_samples, summary, diagnostics, num_samples, num_chains, mcmc=None, model_comparison=<factory>, _inference_data=None, _inference_data_ll=None)
JAX Configuration¶
JAX configuration and safe import mechanism for float64 precision.
This module provides utilities to ensure JAX operates in float64 mode by enforcing that NLSQ is imported before JAX and explicitly enabling float64.
NLSQ v0.2.1+ uses float32 by default with automatic precision fallback. RheoJAX explicitly enables float64 for numerical stability in rheological calculations.
- Critical Configuration Steps:
Import nlsq (required for GPU-accelerated optimization)
Import JAX
Enable float64: jax.config.update(“jax_enable_x64”, True)
- Usage:
`python from rheojax.core.jax_config import safe_import_jax jax, jnp = safe_import_jax() `
This replaces direct JAX imports throughout the RheoJAX codebase.
- rheojax.core.jax_config.suppress_glyph_warnings()[source]¶
Suppress matplotlib font glyph warnings.
These warnings are purely cosmetic — plots render correctly, the glyph is just displayed as a box or skipped. Common when using Unicode subscripts (e.g., σ₀, τ₀) with fonts that lack those glyphs. This is harmless for headless batch runs and provides no actionable information to users.
Call explicitly rather than relying on module-level side effects.
safe_import_jax()calls this automatically.- Return type:
- rheojax.core.jax_config.verify_float64()[source]¶
Verify that JAX is operating in float64 mode.
This function checks that JAX’s default dtype is float64. It should be called after JAX has been imported to validate the configuration.
- Raises:
RuntimeError – If JAX is not in float64 mode.
- Return type:
Example
>>> import nlsq >>> import jax >>> jax.config.update("jax_enable_x64", True) >>> verify_float64() # Validates float64 mode
- rheojax.core.jax_config.safe_import_jax()[source]¶
Safely import JAX with float64 precision enforcement.
This function ensures that NLSQ has been imported before JAX and explicitly enables float64 precision. NLSQ v0.2.1+ uses float32 by default, so RheoJAX must explicitly configure JAX for float64.
It uses a thread-safe singleton pattern to cache validation results and avoid repeated checks.
- Returns:
A tuple of (jax, jax.numpy) modules for use.
- Return type:
- Raises:
ImportError – If NLSQ has not been imported before calling this function.
RuntimeError – If float64 mode cannot be enabled.
Example
>>> # Correct usage (NLSQ imported first at package level) >>> import nlsq >>> from rheojax.core.jax_config import safe_import_jax >>> jax, jnp = safe_import_jax() >>> arr = jnp.array([1.0, 2.0, 3.0]) # Operates in float64
Note
The rheojax package automatically imports NLSQ and configures JAX in __init__.py, so users don’t need to worry about configuration. This function is for internal use by RheoJAX modules.
- rheojax.core.jax_config.lazy_import(module_name)[source]¶
Return a lazy proxy for module_name.
The real
importis deferred until the first attribute access on the returned object. This is safe for modules that are only used inside method bodies (not at module-level scope for decorators, base classes, etc.).Example:
diffrax = lazy_import("diffrax") # ... later, inside a method ... solver = diffrax.Tsit5() # triggers ``import diffrax`` on first use
- Return type:
_LazyModule
- rheojax.core.jax_config.reset_validation()[source]¶
Reset validation state (for testing purposes only).
This function is intended for use in tests that need to simulate different import scenarios. It should not be used in production code.
Warning
This is not thread-safe and should only be used in single-threaded test environments.
- Return type:
The JAX configuration module ensures float64 precision throughout the JAX stack by enforcing proper import order (NLSQ must be imported before JAX).
- rheojax.core.jax_config.safe_import_jax()[source]
Safely import JAX with float64 precision enforcement.
This function ensures that NLSQ has been imported before JAX and explicitly enables float64 precision. NLSQ v0.2.1+ uses float32 by default, so RheoJAX must explicitly configure JAX for float64.
It uses a thread-safe singleton pattern to cache validation results and avoid repeated checks.
- Returns:
A tuple of (jax, jax.numpy) modules for use.
- Return type:
- Raises:
ImportError – If NLSQ has not been imported before calling this function.
RuntimeError – If float64 mode cannot be enabled.
Example
>>> # Correct usage (NLSQ imported first at package level) >>> import nlsq >>> from rheojax.core.jax_config import safe_import_jax >>> jax, jnp = safe_import_jax() >>> arr = jnp.array([1.0, 2.0, 3.0]) # Operates in float64
Note
The rheojax package automatically imports NLSQ and configures JAX in __init__.py, so users don’t need to worry about configuration. This function is for internal use by RheoJAX modules.
Safe JAX import that verifies NLSQ was imported first for float64 precision.
Usage:
# CORRECT - Always use in RheoJAX modules from rheojax.core.jax_config import safe_import_jax jax, jnp = safe_import_jax() # INCORRECT - Never import JAX directly import jax # Will raise ImportError if NLSQ not imported first
- rheojax.core.jax_config.verify_float64()[source]
Verify that JAX is operating in float64 mode.
This function checks that JAX’s default dtype is float64. It should be called after JAX has been imported to validate the configuration.
- Raises:
RuntimeError – If JAX is not in float64 mode.
- Return type:
Example
>>> import nlsq >>> import jax >>> jax.config.update("jax_enable_x64", True) >>> verify_float64() # Validates float64 mode
Verify JAX is operating in float64 mode. Raises exception if not.
Registry¶
Plugin registry system for models and transforms.
This module provides a registry system for discovering, registering, and managing models and transforms as plugins, enabling extensibility of the rheojax package.
- class rheojax.core.registry.PluginType(*values)[source]¶
Bases:
EnumTypes of plugins that can be registered.
- MODEL = 'model'¶
- TRANSFORM = 'transform'¶
- class rheojax.core.registry.PluginInfo(name, plugin_class, plugin_type, metadata, doc=None, protocols=<factory>, deformation_modes=<factory>, transform_type=None)[source]¶
Bases:
objectInformation about a registered plugin.
- plugin_type: PluginType¶
- deformation_modes: list[DeformationMode]¶
- transform_type: TransformType | None = None¶
- __init__(name, plugin_class, plugin_type, metadata, doc=None, protocols=<factory>, deformation_modes=<factory>, transform_type=None)¶
- class rheojax.core.registry.Registry[source]¶
Bases:
objectCentral registry for models and transforms.
This class manages plugin registration, discovery, and retrieval for all models and transforms in the rheojax package.
- classmethod get_instance()[source]¶
Get the singleton registry instance.
- Return type:
- Returns:
The global Registry instance
- register(name, plugin_class, plugin_type, metadata=None, validate=False, force=False, protocols=None, deformation_modes=None, transform_type=None)[source]¶
Register a plugin in the registry.
- Parameters:
name (
str) – Unique name for the pluginplugin_class (
type) – The plugin class to registerplugin_type (
PluginType|str) – Type of plugin (MODEL or TRANSFORM)metadata (
dict[str,Any] |None) – Optional metadata dictionaryvalidate (
bool) – Whether to validate the plugin interfaceforce (
bool) – Whether to overwrite existing registrationprotocols (
list[Protocol|str] |None) – List of supported protocols (for models)transform_type (
TransformType|str|None) – Type of transform (for transforms)
- Raises:
ValueError – If plugin is already registered (and force=False) or invalid
- get_info(name, plugin_type)[source]¶
Get full information about a registered plugin.
- Parameters:
name (
str) – Name of the pluginplugin_type (
PluginType|str) – Type of plugin
- Return type:
- Returns:
PluginInfo object or None if not found
- unregister(name, plugin_type)[source]¶
Remove a plugin from the registry.
- Parameters:
name (
str) – Name of the plugin to removeplugin_type (
PluginType|str) – Type of plugin
- get_all()[source]¶
Get all registered plugins with their types.
- Return type:
dict[str,tuple[type,PluginType]]- Returns:
Dictionary mapping plugin names to (class, type) tuples
- __len__()[source]¶
Get total number of registered plugins.
- Return type:
- Returns:
Total count of models and transforms
- discover(module_name)[source]¶
Discover and register plugins from a module.
- Parameters:
module_name (
str) – Name of the module to import and scan
- discover_directory(path)[source]¶
Discover plugins in a directory.
- Parameters:
path (
str) – Path to directory containing plugin modules
- create_instance(name, plugin_type, *args, **kwargs)[source]¶
Create an instance of a registered plugin.
- Parameters:
name (
str) – Name of the pluginplugin_type (
PluginType|str) – Type of plugin*args – Positional arguments for plugin constructor
**kwargs – Keyword arguments for plugin constructor
- Return type:
- Returns:
Instance of the plugin class
- Raises:
KeyError – If plugin not found
- find_compatible(protocol=None, deformation_mode=None, transform_type=None, **criteria)[source]¶
Find plugins matching certain criteria.
- Parameters:
protocol (
Protocol|str|None) – Filter models by supported protocoldeformation_mode (
DeformationMode|str|None) – Filter models by supported deformation modetransform_type (
TransformType|str|None) – Filter transforms by type**criteria – Additional criteria to match against plugin metadata
- Return type:
- Returns:
List of plugin names matching all criteria
- transform(name=None, transform_type=None, **metadata)[source]¶
Decorator for registering a transform.
- Parameters:
name (
str|None) – Optional name for the transform (uses class name if not provided)transform_type (
TransformType|str|None) – Type of transform**metadata – Additional metadata for the transform
- Returns:
Decorator function
- class rheojax.core.registry.ModelRegistry[source]¶
Bases:
objectConvenient interface for model registration and creation.
This class provides a simplified API specifically for models, delegating to the main Registry singleton.
Example
>>> @ModelRegistry.register('maxwell', protocols=['relaxation']) >>> class Maxwell(BaseModel): ... pass >>> >>> model = ModelRegistry.create('maxwell') >>> models = ModelRegistry.find(protocol='relaxation')
- classmethod register(name, protocols=None, deformation_modes=None, **metadata)[source]¶
Decorator for registering a model.
- Parameters:
name (
str) – Name for the modelprotocols (
list[Protocol|str] |None) – List of supported protocolsdeformation_modes (
list[DeformationMode|str] |None) – List of supported deformation modes (shear, tension, bending, compression). Models with oscillation protocol that work in G-space can support all 4 modes via automatic E*<->G* conversion.**metadata – Additional metadata for the model
- Returns:
Decorator function
Example
>>> @ModelRegistry.register('maxwell', ... protocols=['relaxation', 'oscillation'], ... deformation_modes=['shear', 'tension', 'bending', 'compression']) >>> class Maxwell(BaseModel): ... pass
- classmethod create(name, *args, **kwargs)[source]¶
Create a model instance by name (factory method).
- Parameters:
name (
str) – Name of the model to create*args – Positional arguments for model constructor
**kwargs – Keyword arguments for model constructor
- Return type:
- Returns:
Instance of the model class
- Raises:
KeyError – If model not found
Example
>>> model = ModelRegistry.create('maxwell')
- classmethod list_models()[source]¶
List all registered model names (discovery).
Example
>>> models = ModelRegistry.list_models() >>> print(models) ['maxwell', 'zener', 'springpot', ...]
- classmethod find(protocol=None, deformation_mode=None, **criteria)[source]¶
Find models matching criteria.
- classmethod get_info(name)[source]¶
Get information about a registered model.
- Parameters:
name (
str) – Name of the model- Return type:
- Returns:
PluginInfo object with model details
Example
>>> info = ModelRegistry.get_info('maxwell') >>> print(info.protocols) [<Protocol.RELAXATION: 'relaxation'>, ...]
- classmethod for_protocol(protocol)[source]¶
Get all models supporting a given protocol.
- classmethod compatible_models(data)[source]¶
Find models compatible with a RheoData instance.
Reads
data.test_mode(maps to protocol) anddata.metadata.get("deformation_mode")to filter models.- Parameters:
data (
Any) – RheoData instance with test_mode metadata.- Return type:
- Returns:
List of PluginInfo objects for compatible models.
- class rheojax.core.registry.TransformRegistry[source]¶
Bases:
objectConvenient interface for transform registration and creation.
This class provides a simplified API specifically for transforms, delegating to the main Registry singleton.
Example
>>> @TransformRegistry.register('fft_analysis', type='spectral') >>> class RheoAnalysis(BaseTransform): ... pass >>> >>> transform = TransformRegistry.create('fft_analysis') >>> transforms = TransformRegistry.find(type='spectral')
- classmethod register(name, type=None, **metadata)[source]¶
Decorator for registering a transform.
- Parameters:
name (
str) – Name for the transformtype (
TransformType|str|None) – Type of transform (TransformType)**metadata – Additional metadata for the transform
- Returns:
Decorator function
Example
>>> @TransformRegistry.register('fft_analysis', type='spectral') >>> class RheoAnalysis(BaseTransform): ... pass
- classmethod create(name, *args, **kwargs)[source]¶
Create a transform instance by name (factory method).
- Parameters:
name (
str) – Name of the transform to create*args – Positional arguments for transform constructor
**kwargs – Keyword arguments for transform constructor
- Return type:
- Returns:
Instance of the transform class
- Raises:
KeyError – If transform not found
Example
>>> transform = TransformRegistry.create('fft_analysis')
- classmethod list_transforms()[source]¶
List all registered transform names (discovery).
Example
>>> transforms = TransformRegistry.list_transforms() >>> print(transforms) ['fft_analysis', 'mastercurve', 'owchirp', ...]
- classmethod find(type=None, **criteria)[source]¶
Find transforms matching criteria.
- Parameters:
type (
TransformType|str|None) – Filter by transform type**criteria – Additional metadata criteria
- Return type:
- Returns:
List of matching transform names
- classmethod get_info(name)[source]¶
Get information about a registered transform.
- Parameters:
name (
str) – Name of the transform- Return type:
- Returns:
PluginInfo object with transform details
Example
>>> info = TransformRegistry.get_info('fft_analysis') >>> print(info.transform_type) TransformType.SPECTRAL
The registry system provides a centralized way to discover and instantiate models and transforms by name, protocol, or deformation mode.
Inventory (Protocols & Capabilities)¶
Core definitions for the Protocol-Driven Inventory System.
This module defines the classifications used to categorize models and transforms in the RheoJAX registry. It provides the type system for: 1. Models (via Protocol) 2. Transforms (via TransformType)
- class rheojax.core.inventory.Protocol(*values)[source]¶
Bases:
StrEnumRheological experimental protocols supported by models.
A protocol defines a specific type of experiment or measurement that a model is capable of simulating or fitting.
- FLOW_CURVE = 'flow_curve'¶
- CREEP = 'creep'¶
- RELAXATION = 'relaxation'¶
- STARTUP = 'startup'¶
- OSCILLATION = 'oscillation'¶
- LAOS = 'laos'¶
- class rheojax.core.inventory.TransformType(*values)[source]¶
Bases:
StrEnumCategories of data transformation operations.
Transforms are classified by their mathematical operation on the data domain.
- SPECTRAL = 'spectral'¶
- SUPERPOSITION = 'superposition'¶
- DECOMPOSITION = 'decomposition'¶
- ANALYSIS = 'analysis'¶
- PROCESSING = 'processing'¶
Examples¶
Creating RheoData¶
import numpy as np
from rheojax.core import RheoData
# Simple time-domain data
time = np.array([0.1, 1.0, 10.0])
stress = np.array([1000, 800, 600])
data = RheoData(
x=time,
y=stress,
x_units="s",
y_units="Pa",
domain="time"
)
# Complex frequency-domain data
omega = np.logspace(-2, 2, 50)
Gp = 1000 * omega**0.5
Gpp = 500 * omega**0.3
G_star = Gp + 1j * Gpp
freq_data = RheoData(
x=omega,
y=G_star,
x_units="rad/s",
y_units="Pa",
domain="frequency"
)
Working with Parameters¶
from rheojax.core import Parameter, ParameterSet
# Create parameter set
params = ParameterSet()
params.add(
name="E",
value=1000.0,
bounds=(100, 10000),
units="Pa",
description="Elastic modulus"
)
params.add(
name="tau",
value=1.0,
bounds=(0.01, 100),
units="s",
description="Relaxation time"
)
# Get/set values
E_value = params.get_value("E")
params.set_value("tau", 2.5)
# Array interface
values = params.get_values() # [1000.0, 2.5]
params.set_values([2000, 1.5])
Test Mode Detection¶
from rheojax.core.test_modes import detect_test_mode, TestMode
# Automatic detection
mode = detect_test_mode(data)
print(mode) # TestMode.RELAXATION
# Check test mode
if data.test_mode == TestMode.RELAXATION:
print("This is a stress relaxation test")
Using Base Classes¶
from rheojax.core import BaseModel, ParameterSet
import jax.numpy as jnp
class MaxwellModel(BaseModel):
def __init__(self, E=1000.0, tau=1.0):
super().__init__()
self.parameters.add("E", value=E, bounds=(1, 1e6), units="Pa")
self.parameters.add("tau", value=tau, bounds=(0.01, 1000), units="s")
def _fit(self, X, y, **kwargs):
# Fitting implementation
return self
def _predict(self, X):
E = self.parameters.get_value("E")
tau = self.parameters.get_value("tau")
return E * jnp.exp(-X / tau)
# Use model
model = MaxwellModel()
model.fit(time, stress)
predictions = model.predict(time)
Bayesian Inference¶
from rheojax.models import Maxwell
import numpy as np
# Generate data
t = np.linspace(0.1, 10, 50)
G_data = 1e5 * np.exp(-t / 0.01) + np.random.normal(0, 1e3, size=t.shape)
# 1. NLSQ optimization (fast point estimate)
model = Maxwell()
model.fit(t, G_data)
print(f"NLSQ: G0={model.parameters.get_value('G0'):.3e}")
# 2. Bayesian inference with warm-start
result = model.fit_bayesian(
t, G_data,
num_warmup=1000,
num_samples=2000,
num_chains=1
)
# 3. Analyze results
print(f"Posterior mean: G0={result.summary['G0']['mean']:.3e} +/- {result.summary['G0']['std']:.3e}")
print(f"Convergence: R-hat={result.diagnostics['r_hat']['G0']:.4f}")
# 4. Get credible intervals
intervals = model.get_credible_intervals(result.posterior_samples, credibility=0.95)
print(f"G0 95% CI: [{intervals['G0'][0]:.3e}, {intervals['G0'][1]:.3e}]")