Architecture Overview¶
This document describes the design principles and architectural decisions behind rheojax.
Design Philosophy¶
JAX-First Design¶
All numerical operations in rheojax use JAX for:
- Automatic Differentiation
Exact gradients for optimization without manual derivatives
- JIT Compilation
Performance approaching hand-optimized C code
- GPU/TPU Support
Transparent acceleration on available hardware
- Vectorization
Automatic batching and parallelization
# ALWAYS use safe_import_jax() — never import jax directly
from rheojax.core.jax_config import safe_import_jax
jax, jnp = safe_import_jax() # Ensures float64 is enabled
@jax.jit
def relaxation_modulus(t, E, tau):
"""JIT-compiled for speed."""
return E * jnp.exp(-t / tau)
# Automatic gradient
grad_fn = jax.grad(relaxation_modulus, argnums=(1, 2))
Scikit-learn API Compatibility¶
Models follow scikit-learn conventions:
.fit(X, y)- Fit model to data.predict(X)- Make predictions.score(X, y)- Evaluate performance.get_params()/.set_params()- Parameter management
# Familiar API
model = FractionalMaxwell(n_elements=5)
model.fit(time, stress)
predictions = model.predict(time)
r2 = model.score(time, stress)
Core Architecture¶
Module Structure¶
rheojax/
|--- core/ # Core abstractions
| |--- base.py # BaseModel (BayesianMixin), BaseTransform
| |--- bayesian.py # Bayesian inference engine (NumPyro NUTS)
| |--- data.py # RheoData container (JAX-native)
| |--- parameters.py # ParameterSet with bounds/priors
| |--- test_modes.py # Test mode detection + DeformationMode enum
| |--- registry.py # ModelRegistry + TransformRegistry
| |--- inventory.py # Protocol enum + model capability discovery
| \--- jax_config.py # safe_import_jax() for float64
|--- models/ # 53 rheological models across 22 families
| |--- classical/ # Maxwell, Zener, SpringPot
| |--- fractional/ # FML, FZSS, FMG, Burgers, etc. (11 models)
| |--- flow/ # PowerLaw, Carreau, HB, Bingham, Cross, CY
| |--- multimode/ # GeneralizedMaxwell (Prony series)
| |--- giesekus/ # SingleMode, MultiMode (tensor ODE)
| |--- sgr/ # SGRConventional, SGRGeneric
| |--- fluidity/ # Local, Nonlocal + Saramito variants (4)
| |--- epm/ # LatticeEPM, TensorialEPM
| |--- ikh/ # MIKH, MLIKH
| |--- fikh/ # FIKH, FMLIKH (fractional IKH)
| |--- dmt/ # DMTLocal, DMTNonlocal (thixotropy)
| |--- hl/ # HebraudLequeux
| |--- stz/ # STZConventional
| |--- spp/ # SPPYieldStress (LAOS)
| |--- itt_mct/ # ITTMCTSchematic, Isotropic (MCT)
| |--- tnt/ # SingleMode, Cates, LoopBridge, etc. (5)
| |--- vlb/ # Local, MultiNetwork, Variant, Nonlocal
| |--- hvm/ # HVMLocal (vitrimer, 3 subnetworks)
| \--- hvnm/ # HVNMLocal (nanocomposite, 4 subnetworks)
|--- transforms/ # 7 data transforms
| |--- fft.py # FFT spectral analysis
| |--- mastercurve.py # TTS + auto shift factors
| |--- owchirp.py # OWChirp LAOS analysis
| |--- srfs.py # Strain-Rate Frequency Superposition
| |--- spp.py # SPP decomposition
| |--- mutation_number.py # Mutation number
| \--- smooth_derivative.py # Savitzky-Golay derivatives
|--- utils/ # Utilities
| |--- optimization.py # NLSQ interface (5-270x vs scipy)
| |--- prony.py # Prony series decomposition
| |--- mct_kernels.py # MCT numerical kernels
| |--- modulus_conversion.py # E* ↔ G* conversion (DMTA)
| \--- initialization/ # Smart parameter initialization
|--- pipeline/ # High-level workflows
| |--- base.py # Pipeline (fluent API)
| |--- bayesian.py # BayesianPipeline (ArviZ diagnostics)
| \--- workflows.py # Batch processing
|--- io/ # File I/O
| |--- readers/ # TRIOS, CSV, Excel, Anton Paar, auto-detect
| \--- writers/ # HDF5, Excel writers
|--- visualization/ # Plotting (3 styles)
|--- logging/ # Structured logging (JAX-safe)
\--- gui/ # PyQt/PySide6 interface
Component Relationships¶
+---------------------------------------------+
| User Applications / GUI |
\------------------+---------------------------+
|
+-----------------v---------------------------+
| Pipeline / BayesianPipeline |
| Fluent workflows + batch processing |
\------------------+---------------------------+
|
+------------+------------+
| | |
+----v----+ +---v-----+ +---v-----------+
| Models | |Transforms| | Visualization |
| (53) | | (7) | | (3 styles) |
\----+----+ \----+-----+ \----+----------+
| | |
+----v-----------v------------v--------+
| Core Components |
| RheoData, Parameters, BayesianMixin |
| Registry, DeformationMode, Logging |
\-----+--------------------------+-------+
| |
+----v-----+ +----v-----+
| I/O | | Utils |
| (readers | | (NLSQ, |
| writers)| | Prony) |
\----------+ \----------+
| |
+----v--------------------------v----+
| JAX / NLSQ / NumPyro / ArviZ |
| (Numerical Foundation) |
\--------------------------------------+
Base Class Hierarchy¶
Model Hierarchy¶
BaseModel (ABC) + BayesianMixin
|--- Classical: Maxwell, Zener, SpringPot
|--- Fractional: FML, FZSS, FMG, FKV, Burgers, etc. (11 models)
|--- Flow: PowerLaw, Carreau, CarreauYasuda, Cross, HB, Bingham
|--- GeneralizedMaxwell (multi-mode Prony series)
|--- Giesekus: Single/MultiMode (tensor ODE)
|--- SGR: Conventional, GENERIC (soft glassy rheology)
|--- Fluidity: Local, Nonlocal + Saramito Local/Nonlocal
|--- EPM: Lattice, Tensorial (elasto-plastic)
|--- IKH/FIKH: MIKH, MLIKH, FIKH, FMLIKH (kinematic hardening)
|--- DMT: Local, Nonlocal (thixotropy)
|--- HL, STZ, SPP (single-model families)
|--- ITT-MCT: Schematic, Isotropic (mode-coupling theory)
|--- TNT: SingleMode, Cates, LoopBridge, MultiSpecies, StickyRouse
|--- VLBBase → VLBLocal, MultiNetwork, Variant, Nonlocal
|--- VLBBase → HVMBase → HVMLocal (vitrimer)
\--- VLBBase → HVMBase → HVNMBase → HVNMLocal (nanocomposite)
All 53 models support: .fit(), .predict(), .fit_bayesian(),
.sample_prior(), .get_credible_intervals()
DMTA: 41+ oscillation-capable models accept deformation_mode='tension'
for automatic E* ↔ G* conversion
Transform Hierarchy¶
BaseTransform (ABC)
|--- FFT # Spectral analysis
|--- Mastercurve # TTS (WLF, Arrhenius, auto shift)
|--- OWChirp # LAOS frequency-domain analysis
|--- MutationNumber # Thermorheological simplicity check
|--- SmoothDerivative # Savitzky-Golay smoothing
|--- SRFS # Strain-rate frequency superposition
\--- SPP # Sequence of Physical Processes (LAOS)
Extension Points¶
Adding New Models¶
To add a custom model, inherit from BaseModel:
from rheojax.core import BaseModel, ParameterSet
import jax.numpy as jnp
class CustomModel(BaseModel):
"""Custom rheological model."""
def __init__(self, param1=1.0, param2=1.0):
super().__init__()
# Define parameters
self.parameters = ParameterSet()
self.parameters.add(
"param1",
value=param1,
bounds=(0.1, 10),
units="Pa"
)
self.parameters.add(
"param2",
value=param2,
bounds=(0.01, 100),
units="s"
)
def _fit(self, X, y, **kwargs):
"""Implement fitting logic."""
from rheojax.utils.optimization import nlsq_optimize
def objective(params):
predictions = self._predict(X)
return jnp.sum((predictions - y)**2)
nlsq_optimize(objective, self.parameters, use_jax=True)
return self
def _predict(self, X):
"""Implement prediction logic."""
p1 = self.parameters.get_value("param1")
p2 = self.parameters.get_value("param2")
# Model equation
return p1 * jnp.exp(-X / p2)
Adding New Transforms¶
To add a custom transform, inherit from BaseTransform:
from rheojax.core import BaseTransform
import jax.numpy as jnp
class CustomTransform(BaseTransform):
"""Custom data transform."""
def __init__(self, param=1.0):
super().__init__()
self.param = param
def _transform(self, data):
"""Implement forward transform."""
# Access data
x = data.x
y = data.y
# Transform
y_transformed = y * self.param
# Return new RheoData
from rheojax.core import RheoData
return RheoData(
x=x,
y=y_transformed,
x_units=data.x_units,
y_units=data.y_units,
domain=data.domain,
metadata=data.metadata.copy()
)
def _inverse_transform(self, data):
"""Implement inverse transform."""
y_original = data.y / self.param
from rheojax.core import RheoData
return RheoData(
x=data.x,
y=y_original,
x_units=data.x_units,
y_units=data.y_units,
domain=data.domain,
metadata=data.metadata.copy()
)
Registry Pattern¶
Models and transforms are registered for discovery:
from rheojax.core.registry import ModelRegistry, TransformRegistry
# Register model with protocols and optional deformation modes
@ModelRegistry.register(
name="CustomModel",
protocols=[Protocol.RELAXATION, Protocol.CREEP, Protocol.OSCILLATION],
deformation_modes=[DeformationMode.SHEAR, DeformationMode.TENSION]
)
class CustomModel(BaseModel):
pass
# Register transform
@TransformRegistry.register(name="CustomTransform")
class CustomTransform(BaseTransform):
pass
# Discover registered components
models = ModelRegistry.list_models()
transforms = TransformRegistry.list_transforms()
# Instantiate by name
model = ModelRegistry.get_model("CustomModel")
transform = TransformRegistry.get_transform("CustomTransform")
Data Flow¶
Typical Analysis Workflow¶
1. Load Data
|--- Auto-detect format (auto_read)
|--- Parse file
\--- Create RheoData
2. Preprocess
|--- Detect test mode
|--- Validate data
|--- Apply transforms (smooth, filter)
\--- Convert to appropriate domain
3. Model Fitting
|--- Select model (manual or auto)
|--- Set initial parameters
|--- Optimize parameters (JAX gradients)
\--- Store fitted model
4. Analysis
|--- Make predictions
|--- Compute residuals
|--- Calculate metrics
\--- Cross-validate
5. Visualization
|--- Plot data and fit
|--- Plot residuals
\--- Save figures
6. Export
|--- Save results (HDF5, Excel)
\--- Export parameters
JAX Integration Details¶
Array Handling¶
rheojax supports both NumPy and JAX arrays seamlessly:
import numpy as np
import jax.numpy as jnp
from rheojax.core import RheoData
# NumPy arrays
data_np = RheoData(x=np.array([1, 2, 3]), y=np.array([10, 20, 30]))
# JAX arrays
data_jax = RheoData(x=jnp.array([1, 2, 3]), y=jnp.array([10, 20, 30]))
# Convert between them
data_jax = data_np.to_jax()
data_np = data_jax.to_numpy()
JIT Compilation¶
Functions are JIT-compiled for performance:
import jax
import jax.numpy as jnp
@jax.jit
def model_function(t, params):
"""JIT-compiled model function."""
E, tau = params
return E * jnp.exp(-t / tau)
# First call: compilation + execution
result1 = model_function(t, params) # ~10ms
# Subsequent calls: cached, fast execution
result2 = model_function(t, params) # ~0.1ms
Automatic Differentiation¶
JAX provides automatic gradients:
import jax
def objective(params):
"""Objective function to minimize."""
predictions = model_function(t, params)
return jnp.sum((predictions - y_observed)**2)
# Compute gradient automatically
grad_fn = jax.grad(objective)
gradients = grad_fn(params)
# Use in optimization
from rheojax.utils.optimization import nlsq_optimize
result = nlsq_optimize(objective, parameters, use_jax=True)
Performance Optimization¶
Best Practices¶
Use JAX for heavy computation
# Convert to JAX arrays data_jax = data.to_jax() # JIT compile functions @jax.jit def heavy_computation(x): return jnp.sum(jnp.exp(-x))
Vectorize operations
# Good: vectorized result = jnp.exp(-time / tau) # Bad: loop result = jnp.array([jnp.exp(-t / tau) for t in time])
Avoid Python loops in hot paths
# Good: use vmap batch_fn = jax.vmap(single_fn) results = batch_fn(inputs) # Bad: Python loop results = [single_fn(inp) for inp in inputs]
Profile before optimizing
import time start = time.time() result = compute_function(data) elapsed = time.time() - start print(f"Time: {elapsed:.3f}s")
Memory Management¶
JAX uses device arrays that may reside on GPU:
import jax.numpy as jnp
# Create array (may be on GPU)
x = jnp.array([1, 2, 3])
# Transfer to CPU if needed
x_cpu = np.array(x)
# Free memory explicitly if needed
del x
Testing Strategy¶
Unit Tests¶
Each module has comprehensive unit tests:
# tests/core/test_data.py
def test_rheodata_creation():
"""Test RheoData initialization."""
data = RheoData(x=np.array([1, 2, 3]), y=np.array([10, 20, 30]))
assert len(data.x) == 3
assert data.shape == (3,)
Integration Tests¶
Test complete workflows:
# tests/test_workflows.py
def test_complete_analysis():
"""Test full analysis workflow."""
# Load data
data = auto_read("test_data.txt")
# Fit model
model = Maxwell()
model.fit(data.x, data.y)
# Predict
predictions = model.predict(data.x)
# Verify
assert model.score(data.x, data.y) > 0.9
Test Coverage¶
Aim for >90% test coverage:
# Run tests with coverage
pytest --cov=rheojax --cov-report=html
# View coverage report
open htmlcov/index.html
Documentation Standards¶
Docstring Format¶
Use NumPy-style docstrings:
def function_name(param1, param2):
"""Short description.
Longer description with more details about what the function does.
Parameters
----------
param1 : type
Description of param1
param2 : type
Description of param2
Returns
-------
return_type
Description of return value
Raises
------
ValueError
When parameter is invalid
Examples
--------
>>> result = function_name(1, 2)
>>> print(result)
3
Notes
-----
Additional implementation notes.
References
----------
.. [1] Author, "Title", Journal, Year
"""
pass
Type Hints¶
Use type hints for clarity:
from typing import Optional, Union, List
import numpy as np
import jax.numpy as jnp
ArrayLike = Union[np.ndarray, jnp.ndarray, List]
def process_data(
x: ArrayLike,
y: ArrayLike,
method: str = "default"
) -> RheoData:
"""Process rheological data."""
pass
Future Extensions¶
Phase 2: Models and Transforms¶
20+ rheological models
Master curve generation
FFT analysis
OWChirp signal processing
Mutation number calculation
Phase 3: Advanced Features¶
Bayesian parameter estimation
Uncertainty quantification
Multi-objective optimization
Parallel batch processing
GPU-accelerated model fitting
Phase 4: Machine Learning¶
Neural network surrogate models
Active learning for parameter estimation
Automated model selection
Transfer learning for similar materials
See Also¶
Contributing to RheoJAX - Contribution guidelines
../user_guide/core_concepts - Core concepts
API Reference - API documentation
JAX documentation - JAX details