Modular API Tutorial

The Modular API provides direct access to models and transforms for maximum flexibility and control. Use this API when you need fine-grained parameter manipulation, custom optimization workflows, or complex analysis pipelines.

When to Use the Modular API

Use the Modular API for:

  • Custom parameter initialization and bounds

  • Non-standard optimization algorithms

  • Complex parameter constraints

  • Direct manipulation of model equations

  • Integration with external libraries

  • Research and algorithm development

  • Teaching model fundamentals

Use the Pipeline API for:

  • Standard workflows and rapid prototyping

  • Batch processing

  • Quick exploratory analysis

  • Production code with error handling

The Modular API gives you complete control at the cost of more verbose code.

Core Components

ModelRegistry

The ModelRegistry provides centralized model management:

from rheojax.core.registry import ModelRegistry

# List all available models
available_models = ModelRegistry.list_models()
print(f"Available models: {available_models}")

# Get model information
info = ModelRegistry.get_info('maxwell')
print(f"Description: {info.description}")
print(f"Parameters: {info.metadata.get('parameters')}")

# Create model instance
model = ModelRegistry.create('maxwell')

# Alternative: direct import
from rheojax.models import Maxwell
model = Maxwell()

TransformRegistry

Similarly for transforms:

from rheojax.core.registry import TransformRegistry

# List transforms
transforms = TransformRegistry.list_transforms()

# Create transform
fft = TransformRegistry.create('fft_analysis')

# Alternative: direct import
from rheojax.transforms import FFTAnalysis
fft = FFTAnalysis()

Working with Models

Direct Model Instantiation

Create and configure models directly:

from rheojax.models import Maxwell, Zener, FractionalMaxwellGel
import numpy as np

# Create model instance
maxwell = Maxwell()

# Inspect default parameters
print(maxwell.parameters)
# Output: ParameterSet with G_s and eta_s

# Get parameter details
G_s_param = maxwell.parameters.get_parameter('G_s')
print(f"Name: {G_s_param.name}")
print(f"Units: {G_s_param.units}")
print(f"Bounds: {G_s_param.bounds}")
print(f"Value: {G_s_param.value}")

Setting Initial Parameters

Control parameter initialization:

from rheojax.models import Maxwell

maxwell = Maxwell()

# Set individual parameters
maxwell.parameters.set_value('G_s', 1e5)      # Pa
maxwell.parameters.set_value('eta_s', 1e3)    # Pa·s

# Set multiple parameters
maxwell.parameters.set_values({
    'G_s': 1e5,
    'eta_s': 1e3
})

# Get parameter values
G_s = maxwell.parameters.get_value('G_s')
eta_s = maxwell.parameters.get_value('eta_s')

# Get all parameters as dict
params_dict = maxwell.parameters.to_dict()
print(params_dict)

Setting Parameter Bounds

Control optimization search space:

from rheojax.models import FractionalMaxwellGel

model = FractionalMaxwellGel()

# Set bounds for each parameter
model.parameters.set_bounds('G_s', min_value=1e3, max_value=1e7)
model.parameters.set_bounds('V', min_value=1e2, max_value=1e6)
model.parameters.set_bounds('alpha', min_value=0.1, max_value=0.9)

# Alternative: set during initialization
model.parameters.get_parameter('G_s').bounds = (1e3, 1e7)

# Get bounds
bounds = model.parameters.get_bounds('alpha')
print(f"Alpha bounds: {bounds}")

Parameter Constraints

Add complex constraints:

from rheojax.core.parameters import Parameter, ParameterSet

params = ParameterSet()

# Add parameters with constraints
params.add(Parameter(
    name='G_s',
    value=1e5,
    bounds=(1e3, 1e7),
    constraints=['positive']
))

# Relative constraint (e.g., G_s > G_p)
params.add(Parameter(
    name='G_p',
    value=1e4,
    bounds=(1e2, 1e6),
    constraints=[
        'positive',
        ('relative', 'G_s', 'less_than')  # G_p < G_s
    ]
))

# Validate constraints
is_valid = params.validate()
if not is_valid:
    violations = params.get_constraint_violations()
    print(f"Constraint violations: {violations}")

Fitting Models

Basic Fitting

Fit model to data:

from rheojax.models import Maxwell
from rheojax.io import auto_load
import numpy as np

# Load data
data = auto_load('oscillation_data.txt')
X = data.x  # Frequency (Hz or rad/s)
y = data.y  # Complex modulus |G*|

# Create and fit model
maxwell = Maxwell()
maxwell.fit(X, y)

# Access fitted parameters
G_s = maxwell.parameters.get_value('G_s')
eta_s = maxwell.parameters.get_value('eta_s')
print(f"G_s = {G_s:.2e} Pa")
print(f"eta_s = {eta_s:.2e} Pa·s")

# Make predictions
y_pred = maxwell.predict(X)

# Calculate fit quality
r2 = maxwell.score(X, y)
print(f"R^2 = {r2:.4f}")

Custom Initial Guesses

Provide data-driven initialization:

from rheojax.models import FractionalMaxwellGel
import numpy as np

# Analyze data to inform initial guess
G_min = np.min(np.abs(y))
G_max = np.max(np.abs(y))

model = FractionalMaxwellGel()

# Set initial guess
model.parameters.set_values({
    'G_s': G_min * 0.8,      # Rubbery modulus ~ low-freq plateau
    'V': G_max * 2,          # Fractional viscosity ~ high-freq behavior
    'alpha': 0.5             # Mid-range fractional order
})

# Set bounds
model.parameters.set_bounds('G_s', min_value=G_min*0.1, max_value=G_max*2)
model.parameters.set_bounds('V', min_value=G_min*0.1, max_value=G_max*10)
model.parameters.set_bounds('alpha', min_value=0.1, max_value=0.9)

# Fit with custom initialization
model.fit(X, y)

Multi-Start Optimization

Try multiple initial guesses to avoid local minima:

from rheojax.models import Zener
import numpy as np

# Generate multiple initial guesses
n_starts = 5
best_score = -np.inf
best_model = None

for i in range(n_starts):
    model = Zener()

    # Random initialization within bounds
    G_s_init = np.random.uniform(1e3, 1e6)
    G_p_init = np.random.uniform(1e2, 1e5)
    eta_p_init = np.random.uniform(1e1, 1e4)

    model.parameters.set_values({
        'G_s': G_s_init,
        'G_p': G_p_init,
        'eta_p': eta_p_init
    })

    # Fit
    model.fit(X, y)

    # Check score
    score = model.score(X, y)
    if score > best_score:
        best_score = score
        best_model = model

print(f"Best R^2 = {best_score:.4f}")
print(f"Best parameters: {best_model.parameters.to_dict()}")

Custom Optimization

Use custom optimization algorithms:

from rheojax.models import Maxwell
from rheojax.utils.optimization import nlsq_optimize
import jax.numpy as jnp
import jax

# Create model
maxwell = Maxwell()

# Define custom objective function
@jax.jit
def objective(params_array):
    """Custom objective with weights or constraints."""
    G_s, eta_s = params_array

    # Predictions
    omega = X
    tau = eta_s / G_s
    G_star = G_s / (1 + 1j * omega * tau)
    y_pred = jnp.abs(G_star)

    # Weighted residuals (e.g., emphasize low frequency)
    weights = 1.0 / (1.0 + omega)  # Higher weight at low freq
    residuals = (y - y_pred) * weights

    return jnp.sum(residuals**2)

# Get initial parameters
p0 = jnp.array([
    maxwell.parameters.get_value('G_s'),
    maxwell.parameters.get_value('eta_s')
])

# Optimize
result = nlsq_optimize(objective, maxwell.parameters, use_jax=True)

# Update model with optimized parameters
maxwell.parameters.set_values({
    'G_s': result.x[0],
    'eta_s': result.x[1]
})

Working with Transforms

Direct Transform Usage

Apply transforms directly to RheoData:

from rheojax.transforms import FFTAnalysis, SmoothDerivative
from rheojax.core import RheoData
from rheojax.io import auto_load

# Load time-series data
data = auto_load('time_series.txt')

# Apply smoothing
smoother = SmoothDerivative(method='savgol', window=11, order=2)
data_smooth = smoother.transform(data)

# Apply FFT
fft = FFTAnalysis(window='hann', detrend=True)
freq_data = fft.transform(data_smooth)

# Access results
G_prime = freq_data.metadata['G_prime']
G_double_prime = freq_data.metadata['G_double_prime']

Transform Composition

Chain transforms manually:

from rheojax.transforms import SmoothDerivative, FFTAnalysis
from rheojax.core.base import TransformPipeline

# Create pipeline
pipeline = TransformPipeline([
    SmoothDerivative(method='savgol', window=11, order=2),
    FFTAnalysis(window='hann', detrend=True)
])

# Apply pipeline
result = pipeline.transform(data)

# Alternative: operator overloading
pipeline = SmoothDerivative(method='savgol', window=11, order=2) + \
           FFTAnalysis(window='hann', detrend=True)

result = pipeline.transform(data)

Inverse Transforms

Some transforms are invertible:

from rheojax.transforms import FFTAnalysis

fft = FFTAnalysis()

# Forward transform
freq_data = fft.transform(time_data)

# Inverse transform
time_data_reconstructed = fft.inverse_transform(freq_data)

# Check reconstruction error
import numpy as np
error = np.mean(np.abs(time_data.y - time_data_reconstructed.y))
print(f"Reconstruction error: {error:.2e}")

Custom Fitting Workflows

Sequential Parameter Estimation

Fit parameters in stages for better convergence:

from rheojax.models import FractionalMaxwellModel
import numpy as np

model = FractionalMaxwellModel()

# Stage 1: Fix alpha, fit G_s and V
model.parameters.get_parameter('alpha').fixed = True
model.parameters.set_value('alpha', 0.5)

model.fit(X, y)

# Stage 2: Fix G_s and V, optimize alpha
model.parameters.get_parameter('G_s').fixed = True
model.parameters.get_parameter('V').fixed = True
model.parameters.get_parameter('alpha').fixed = False

model.fit(X, y)

# Stage 3: Optimize all together
for param in model.parameters.parameters.values():
    param.fixed = False

model.fit(X, y)

print("Final parameters:")
print(model.parameters.to_dict())

Fitting with Analytical Gradients

Leverage JAX automatic differentiation:

from rheojax.models import Maxwell
from rheojax.utils.optimization import nlsq_optimize
import jax
import jax.numpy as jnp

maxwell = Maxwell()

# Define objective with automatic gradient
@jax.jit
def objective(params_array):
    G_s, eta_s = params_array
    tau = eta_s / G_s
    G_star = G_s / (1 + 1j * X * tau)
    y_pred = jnp.abs(G_star)
    return jnp.sum((y - y_pred)**2)

# Compute gradient automatically
grad_fn = jax.grad(objective)

# Check gradient
p0 = jnp.array([1e5, 1e3])
gradient = grad_fn(p0)
print(f"Gradient at p0: {gradient}")

# Optimize using gradient
result = nlsq_optimize(objective, maxwell.parameters,
                        use_jax=True, method='L-BFGS-B')

Cross-Validation

Assess model generalization:

from rheojax.models import Maxwell, Zener
import numpy as np
from sklearn.model_selection import KFold

# K-fold cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)

models = [Maxwell(), Zener()]
cv_scores = {type(m).__name__: [] for m in models}

for model in models:
    model_name = type(model).__name__

    for train_idx, test_idx in kf.split(X):
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        # Fit on training
        model.fit(X_train, y_train)

        # Score on test
        score = model.score(X_test, y_test)
        cv_scores[model_name].append(score)

# Report cross-validation scores
print("Cross-Validation R^2 Scores:")
for model_name, scores in cv_scores.items():
    mean_score = np.mean(scores)
    std_score = np.std(scores)
    print(f"  {model_name}: {mean_score:.4f} +/- {std_score:.4f}")

Model Comparison

Systematically compare models:

from rheojax.models import (Maxwell, Zener, SpringPot,
                         FractionalMaxwellGel, FractionalKelvinVoigt)
import numpy as np
import pandas as pd

# Models to compare
models = [
    Maxwell(),
    Zener(),
    SpringPot(),
    FractionalMaxwellGel(),
    FractionalKelvinVoigt()
]

# Fit all models and collect metrics
results = []

for model in models:
    model_name = type(model).__name__

    # Fit
    model.fit(X, y)

    # Metrics
    y_pred = model.predict(X)
    residuals = y - y_pred
    r2 = model.score(X, y)
    rmse = np.sqrt(np.mean(residuals**2))
    n_params = len(model.parameters)

    # Information criteria
    n = len(y)
    rss = np.sum(residuals**2)
    aic = n * np.log(rss/n) + 2 * n_params
    bic = n * np.log(rss/n) + n_params * np.log(n)

    results.append({
        'Model': model_name,
        'N_params': n_params,
        'R^2': r2,
        'RMSE': rmse,
        'AIC': aic,
        'BIC': bic
    })

# Create comparison table
df = pd.DataFrame(results)
df = df.sort_values('AIC')  # Sort by AIC (lower is better)

print("\nModel Comparison:")
print(df.to_string(index=False))

# Best model by AIC
best_model_name = df.iloc[0]['Model']
print(f"\nBest model (AIC): {best_model_name}")

Advanced Parameter Management

Parameter Sensitivity Analysis

Analyze how sensitive predictions are to parameters:

from rheojax.models import Maxwell
import numpy as np
import matplotlib.pyplot as plt

maxwell = Maxwell()
maxwell.fit(X, y)

# Baseline parameters
G_s_base = maxwell.parameters.get_value('G_s')
eta_s_base = maxwell.parameters.get_value('eta_s')

# Vary G_s
G_s_range = np.linspace(G_s_base*0.5, G_s_base*1.5, 10)
predictions = []

for G_s_test in G_s_range:
    maxwell.parameters.set_value('G_s', G_s_test)
    y_pred = maxwell.predict(X)
    predictions.append(y_pred)

# Plot sensitivity
fig, ax = plt.subplots(figsize=(10, 6))
for i, G_s_test in enumerate(G_s_range):
    alpha = 0.3 + 0.7 * (i / len(G_s_range))
    ax.loglog(X, predictions[i], alpha=alpha,
              label=f'G_s = {G_s_test:.2e}')

ax.loglog(X, y, 'ko', markersize=8, label='Data')
ax.set_xlabel('Frequency (rad/s)')
ax.set_ylabel('|G*| (Pa)')
ax.legend()
ax.set_title('Sensitivity to G_s')
plt.show()

Confidence Intervals

Estimate parameter uncertainty:

from rheojax.models import Maxwell
from rheojax.utils.optimization import calculate_confidence_intervals
import numpy as np

maxwell = Maxwell()
maxwell.fit(X, y)

# Calculate 95% confidence intervals
ci = calculate_confidence_intervals(maxwell, X, y, alpha=0.05)

print("95% Confidence Intervals:")
for param_name, (lower, upper) in ci.items():
    value = maxwell.parameters.get_value(param_name)
    rel_error = (upper - lower) / (2 * value) * 100
    print(f"  {param_name}: {value:.2e} [{lower:.2e}, {upper:.2e}] "
          f"(+/-{rel_error:.1f}%)")

Parameter Correlation

Check for parameter correlation:

from rheojax.models import Zener
import numpy as np

zener = Zener()
zener.fit(X, y)

# Bootstrap to estimate correlation
n_bootstrap = 100
param_samples = {name: [] for name in zener.parameters.parameter_names}

for i in range(n_bootstrap):
    # Resample data
    indices = np.random.choice(len(X), size=len(X), replace=True)
    X_boot = X[indices]
    y_boot = y[indices]

    # Fit
    model_boot = Zener()
    model_boot.fit(X_boot, y_boot)

    # Store parameters
    for name in param_samples.keys():
        param_samples[name].append(model_boot.parameters.get_value(name))

# Calculate correlation matrix
import pandas as pd

df = pd.DataFrame(param_samples)
corr = df.corr()

print("Parameter Correlation Matrix:")
print(corr)

# High correlation (>0.9) indicates parameter redundancy

Serialization and Persistence

Saving Models

Save fitted models for later use:

from rheojax.models import FractionalMaxwellGel
import pickle

# Fit model
model = FractionalMaxwellGel()
model.fit(X, y)

# Save to file
with open('fitted_model.pkl', 'wb') as f:
    pickle.dump(model, f)

# Load model
with open('fitted_model.pkl', 'rb') as f:
    loaded_model = pickle.load(f)

# Use loaded model
y_pred = loaded_model.predict(X)

Model Export/Import

Export model parameters as JSON:

import json

# Fit model
model = FractionalMaxwellGel()
model.fit(X, y)

# Export parameters
model_dict = {
    'model_type': type(model).__name__,
    'parameters': model.parameters.to_dict(),
    'metadata': {
        'fit_date': '2025-10-24',
        'r2': model.score(X, y),
        'data_source': 'experiment_01.txt'
    }
}

with open('model_params.json', 'w') as f:
    json.dump(model_dict, f, indent=2)

# Import parameters
with open('model_params.json', 'r') as f:
    loaded_dict = json.load(f)

# Reconstruct model
from rheojax.core.registry import ModelRegistry

model_reconstructed = ModelRegistry.create(loaded_dict['model_type'])
model_reconstructed.parameters.set_values(loaded_dict['parameters'])

Integration with External Libraries

scikit-learn Compatibility

rheojax models follow scikit-learn API:

from rheojax.models import Maxwell
from sklearn.model_selection import GridSearchCV
from sklearn.base import BaseEstimator

# Wrap rheojax model for scikit-learn
class RheoEstimator(BaseEstimator):
    def __init__(self, G_s=1e5, eta_s=1e3):
        self.G_s = G_s
        self.eta_s = eta_s

    def fit(self, X, y):
        self.model_ = Maxwell()
        self.model_.parameters.set_values({
            'G_s': self.G_s,
            'eta_s': self.eta_s
        })
        self.model_.fit(X, y)
        return self

    def predict(self, X):
        return self.model_.predict(X)

    def score(self, X, y):
        return self.model_.score(X, y)

# Grid search over parameters
param_grid = {
    'G_s': [1e4, 1e5, 1e6],
    'eta_s': [1e2, 1e3, 1e4]
}

grid_search = GridSearchCV(RheoEstimator(), param_grid, cv=3)
grid_search.fit(X, y)

print(f"Best parameters: {grid_search.best_params_}")
print(f"Best score: {grid_search.best_score_:.4f}")

JAX Integration

Direct use of JAX arrays and operations:

from rheojax.models import Maxwell
import jax.numpy as jnp
import jax

# Create JAX arrays
X_jax = jnp.array(X)
y_jax = jnp.array(y)

maxwell = Maxwell()
maxwell.fit(X_jax, y_jax)  # Works with JAX arrays

# JIT compile predictions
@jax.jit
def predict_jit(freq, G_s, eta_s):
    tau = eta_s / G_s
    G_star = G_s / (1 + 1j * freq * tau)
    return jnp.abs(G_star)

# Vectorize over parameters
G_s_array = jnp.array([1e4, 1e5, 1e6])
eta_s_array = jnp.array([1e2, 1e3, 1e4])

predictions = jax.vmap(lambda g, e: predict_jit(X_jax, g, e))(
    G_s_array, eta_s_array
)

Best Practices

Parameter Initialization

Always provide reasonable initial guesses:

# Good: data-driven initialization
G_typical = np.median(np.abs(y))
model.parameters.set_value('G_s', G_typical * 0.5)

# Bad: no initialization (uses arbitrary defaults)
# model.fit(X, y)  # May fail or converge slowly

Bounds Setting

Set physical bounds to constrain optimization:

# Good: physical bounds
model.parameters.set_bounds('G_s', min_value=1e2, max_value=1e8)
model.parameters.set_bounds('eta_s', min_value=1e0, max_value=1e6)

# Bad: unbounded (may give non-physical results)
# model.fit(X, y)

Validation

Always validate fitted models:

# Check parameter values
params = model.parameters.to_dict()
for name, value in params.items():
    if value <= 0:
        print(f"Warning: {name} = {value} is non-physical!")

# Check fit quality
r2 = model.score(X, y)
if r2 < 0.9:
    print(f"Warning: Poor fit (R^2 = {r2:.3f})")

# Visual inspection
import matplotlib.pyplot as plt
plt.loglog(X, y, 'o', label='Data')
plt.loglog(X, model.predict(X), '-', label='Model')
plt.legend()
plt.show()

Documentation

Document custom workflows:

def fit_with_validation(model, X, y, n_starts=5):
    """Fit model with multi-start optimization and validation.

    Parameters
    ----------
    model : BaseModel
        Model to fit
    X : array
        Independent variable
    y : array
        Dependent variable
    n_starts : int
        Number of random starts

    Returns
    -------
    model : BaseModel
        Best fitted model
    metrics : dict
        Fit quality metrics
    """
    best_score = -np.inf
    best_model = None

    for i in range(n_starts):
        # Random initialization
        for param in model.parameters.parameters.values():
            if param.bounds is not None:
                low, high = param.bounds
                param.value = np.random.uniform(low, high)

        # Fit
        model.fit(X, y)

        # Validate
        score = model.score(X, y)
        if score > best_score:
            best_score = score
            best_model = model

    # Calculate metrics
    y_pred = best_model.predict(X)
    metrics = {
        'r2': best_score,
        'rmse': np.sqrt(np.mean((y - y_pred)**2)),
        'parameters': best_model.parameters.to_dict()
    }

    return best_model, metrics

Common Patterns

Pattern 1: Custom Weighted Fitting

from rheojax.models import Maxwell
import jax.numpy as jnp

@jax.jit
def weighted_objective(params_array, X, y, weights):
    G_s, eta_s = params_array
    tau = eta_s / G_s
    G_star = G_s / (1 + 1j * X * tau)
    y_pred = jnp.abs(G_star)
    residuals = (y - y_pred) * weights
    return jnp.sum(residuals**2)

# Emphasize low frequency
weights = 1.0 / (1.0 + X)

from rheojax.utils.optimization import nlsq_optimize
maxwell = Maxwell()
result = nlsq_optimize(
    lambda p: weighted_objective(p, X, y, weights),
    maxwell.parameters,
    use_jax=True
)

Pattern 2: Hierarchical Model Selection

from rheojax.models import Maxwell, Zener, FractionalMaxwellGel

# Start simple, increase complexity if needed
models_hierarchy = [Maxwell(), Zener(), FractionalMaxwellGel()]

for model in models_hierarchy:
    model.fit(X, y)
    r2 = model.score(X, y)

    if r2 > 0.95:  # Satisfactory fit
        print(f"Selected model: {type(model).__name__} (R^2 = {r2:.4f})")
        break
else:
    print("Warning: No satisfactory fit found")

Pattern 3: Ensemble Prediction

from rheojax.models import Maxwell, Zener, SpringPot
import numpy as np

# Fit multiple models
models = [Maxwell(), Zener(), SpringPot()]
for model in models:
    model.fit(X, y)

# Ensemble prediction (average)
predictions = np.array([m.predict(X) for m in models])
ensemble_pred = np.mean(predictions, axis=0)

# Weighted ensemble (by R^2)
weights = np.array([m.score(X, y) for m in models])
weights /= np.sum(weights)
weighted_ensemble = np.average(predictions, axis=0, weights=weights)

Summary

The Modular API provides complete control over:

  1. Model instantiation and parameter management

  2. Custom optimization algorithms and objectives

  3. Transform composition and data preprocessing

  4. Advanced fitting workflows (multi-start, sequential, hierarchical)

  5. Integration with external libraries (scikit-learn, JAX)

For standard workflows, use the /user_guide/pipeline_api.

Next Steps

  • /user_guide/pipeline_api - High-level workflow API

  • /user_guide/multi_technique_fitting - Multi-technique fitting with shared parameters

  • Models API - Complete model API reference

  • Core Module (rheojax.core) - Core classes (ParameterSet, RheoData, etc.)