"""Base classes for models and transforms with JAX support.
This module provides abstract base classes that define consistent interfaces
for all models and transforms in the rheojax package, with full JAX support.
"""
from __future__ import annotations
import copy
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any
import numpy as np
from rheojax.core.bayesian import BayesianMixin, BayesianResult
from rheojax.core.deformation_converter import DeformationModeConverter
from rheojax.core.fit_orchestrator import FitOrchestrator
from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import Parameter, ParameterSet
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger
# Module-level logger
logger = get_logger(__name__)
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
# Type alias for arrays (accepts both NumPy and JAX arrays)
# Note: jnp.ndarray is dynamically imported, so we use np.ndarray for type checking
type ArrayLike = np.ndarray
[docs]
class BaseModel(BayesianMixin, ABC):
"""Abstract base class for all rheological models.
This class defines the standard interface that all models must implement,
supporting JAX arrays, scikit-learn style APIs,
and Bayesian inference via NumPyro NUTS.
All models inherit Bayesian capabilities from BayesianMixin, including:
- fit_bayesian(): Bayesian parameter estimation using NUTS
- sample_prior(): Sample from prior distributions
- get_credible_intervals(): Compute highest density intervals
The fit() method uses NLSQ optimization by default for fast point estimation,
which can be used to warm-start Bayesian inference.
"""
[docs]
def __init__(self):
"""Initialize base model."""
logger.debug("Initializing model", model=self.__class__.__name__)
self.parameters = ParameterSet()
self.fitted_ = False
self._nlsq_result = None # Store NLSQ optimization result
self._bayesian_result = None # Store Bayesian inference result
self.X_data = None # Store data for Bayesian inference
self.y_data = None
self._last_fit_kwargs: dict = {} # Protocol state for Bayesian forwarding
self._deformation_mode: DeformationMode | None = None
self._poisson_ratio: float = 0.5
self._closure_cache: OrderedDict = OrderedDict()
@abstractmethod
def _fit(self, X: ArrayLike, y: ArrayLike, **kwargs) -> BaseModel:
"""Internal fit implementation to be overridden by subclasses.
Args:
X: Input features
y: Target values
**kwargs: Additional fitting options
Returns:
self for method chaining
"""
pass
@abstractmethod
def _predict(self, X: ArrayLike, **kwargs) -> ArrayLike:
"""Internal predict implementation to be overridden by subclasses.
Args:
X: Input features
**kwargs: Additional prediction options
Returns:
Predictions
"""
pass
def _standard_nlsq_fit(
self,
X: ArrayLike,
y: ArrayLike,
model_fn,
*,
test_mode=None,
default_test_mode=None,
normalize: bool = True,
**kwargs,
) -> BaseModel:
"""Standard NLSQ fitting pipeline for models with a stateless model_fn.
Handles: RheoData unpacking, test_mode resolution/caching,
objective creation, optimization, result validation, fitted_ flag.
Args:
X: Input array (or RheoData)
y: Target array (or None if X is RheoData)
model_fn: Stateless function(x, params) -> prediction.
Called by the optimizer; must capture test_mode from enclosing scope.
test_mode: Optional test mode override
default_test_mode: Default test mode when not provided by data or kwargs.
If None, defaults to 'relaxation'.
normalize: Whether to normalize the objective (default True)
**kwargs: Passed through to nlsq_optimize (use_jax, method, max_iter, etc.)
Returns:
self for method chaining
"""
from rheojax.core.data import RheoData
from rheojax.logging import log_fit
from rheojax.utils.optimization import (
create_least_squares_objective,
nlsq_optimize,
)
# --- 1. Unpack RheoData vs raw arrays ---
if isinstance(X, RheoData):
rheo_data = X
x_np = np.asarray(rheo_data.x, dtype=float)
y_raw = np.asarray(rheo_data.y)
if np.iscomplexobj(y_raw):
y_np = y_raw.astype(np.complex128)
else:
y_np = y_raw.astype(float)
resolved_test_mode = rheo_data.test_mode
else:
x_np = np.asarray(X, dtype=float)
y_raw = np.asarray(y)
if np.iscomplexobj(y_raw):
y_np = y_raw.astype(np.complex128)
else:
y_np = y_raw.astype(float)
supplied = test_mode if test_mode is not None else kwargs.get("test_mode")
if supplied is not None:
resolved_test_mode = supplied
elif np.iscomplexobj(y_np):
resolved_test_mode = "oscillation"
else:
resolved_test_mode = (
default_test_mode if default_test_mode is not None else "relaxation"
)
# --- 2. Cache test_mode for Bayesian pipeline ---
self._test_mode = resolved_test_mode
# Determine test_mode string for logging
test_mode_str = (
resolved_test_mode.name
if hasattr(resolved_test_mode, "name")
else str(resolved_test_mode)
)
data_shape = (int(x_np.shape[0]),) if hasattr(x_np, "shape") else None
x_data = jnp.array(x_np)
y_data = jnp.array(y_np)
# --- 3. Optimize ---
with log_fit(
logger,
model=self.__class__.__name__,
data_shape=data_shape,
test_mode=test_mode_str,
):
logger.debug(
"Creating least squares objective",
normalize=normalize,
)
objective = create_least_squares_objective(
model_fn, x_data, y_data, normalize=normalize
)
logger.debug(
"Starting NLSQ optimization",
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
try:
result = nlsq_optimize(
objective,
self.parameters,
use_jax=kwargs.get("use_jax", True),
method=kwargs.get("method", "auto"),
max_iter=kwargs.get("max_iter", 1000),
)
except Exception as e:
logger.error(
"NLSQ optimization raised exception",
error_type=type(e).__name__,
error_message=str(e),
exc_info=True,
)
raise
# --- 4. Validate ---
if not result.success:
if not np.isfinite(result.fun) or result.fun > 1e6 * len(x_np):
logger.error(
"Optimization failed",
message=result.message,
iterations=getattr(result, "nit", None),
)
raise RuntimeError(
f"Optimization failed: {result.message}. "
f"Try adjusting initial values, bounds, or max_iter."
)
else:
logger.warning(
"Optimization did not fully converge",
message=result.message,
model=self.__class__.__name__,
)
self._nlsq_result = result
self.fitted_ = True
logger.debug(
"Optimization completed successfully",
iterations=getattr(result, "nit", None),
final_cost=getattr(result, "fun", None),
)
return self
def _detect_optimization_strategy(
self,
X: ArrayLike,
use_log_residuals: bool | None,
use_multi_start: bool | None,
n_starts: int,
) -> tuple[bool, bool]:
"""Auto-detect optimization strategy based on data range.
Args:
X: Input data
use_log_residuals: User-specified setting or None for auto-detect
use_multi_start: User-specified setting or None for auto-detect
n_starts: Number of starts for multi-start optimization
Returns:
Tuple of (use_log_residuals, use_multi_start) with defaults applied
"""
if use_log_residuals is not None and use_multi_start is not None:
return use_log_residuals, use_multi_start
try:
from rheojax.core.data import RheoData
from rheojax.utils.data_quality import detect_data_range_decades
x_array = X.x if isinstance(X, RheoData) else X
decades = detect_data_range_decades(x_array)
if use_log_residuals is None:
if decades > 8.0:
use_log_residuals = True
logger.info(
"Auto-enabling log-residuals for wide range",
model=self.__class__.__name__,
decades=f"{decades:.1f}",
)
else:
use_log_residuals = False
if use_multi_start is None:
if decades > 10.0:
use_multi_start = True
logger.info(
"Auto-enabling multi-start optimization for very wide range",
model=self.__class__.__name__,
decades=f"{decades:.1f}",
n_starts=n_starts,
)
else:
use_multi_start = False
except Exception as e:
logger.debug(
"Range detection failed",
model=self.__class__.__name__,
error=str(e),
)
use_log_residuals = (
use_log_residuals if use_log_residuals is not None else False
)
use_multi_start = use_multi_start if use_multi_start is not None else False
return use_log_residuals, use_multi_start
def _check_compatibility(
self,
X: ArrayLike,
y: ArrayLike,
test_mode: str | None,
) -> dict | None:
"""Check model-data compatibility and return result.
Args:
X: Input data
y: Target data
test_mode: Test mode ('relaxation', 'oscillation', etc.)
Returns:
Compatibility dict if check succeeds, None otherwise
"""
try:
from rheojax.utils.compatibility import check_model_compatibility
return check_model_compatibility(
model=self,
t=X if test_mode == "relaxation" else None,
G_t=y if test_mode == "relaxation" else None,
omega=X if test_mode == "oscillation" else None,
G_star=y if test_mode == "oscillation" else None,
test_mode=test_mode,
)
except Exception as exc:
logger.debug(
"Compatibility check failed",
model=self.__class__.__name__,
error=str(exc),
)
return None
def _make_error_result(self, test_mode: str | None, error: Exception):
"""Build a :class:`FitResult` that records a failed fit attempt."""
from rheojax.core.fit_result import FitResult
return FitResult(
model_name=getattr(self, "_registry_name", self.__class__.__name__),
model_class_name=self.__class__.__name__,
protocol=test_mode,
params={
name: self.parameters.get_value(name) for name in self.parameters.keys()
},
params_units={
name: getattr(self.parameters[name], "units", "") or ""
for name in self.parameters.keys()
},
n_params=len(list(self.parameters.keys())),
optimization_result=None,
metadata={
"error": str(error),
"error_type": type(error).__name__,
},
)
def _enhance_error_with_compatibility(
self,
error: RuntimeError,
X: ArrayLike,
y: ArrayLike,
test_mode: str | None,
) -> RuntimeError:
"""Enhance optimization error with compatibility information.
Args:
error: Original RuntimeError
X: Input data
y: Target data
test_mode: Test mode
Returns:
Enhanced RuntimeError or original if enhancement fails
"""
error_msg = str(error)
if (
"Optimization failed" not in error_msg
and "did not converge" not in error_msg
):
return error
compatibility = self._check_compatibility(X, y, test_mode)
if compatibility is None or compatibility.get("compatible", True):
return error
try:
from rheojax.utils.compatibility import format_compatibility_message
compat_msg = format_compatibility_message(compatibility)
enhanced_msg = (
f"{error_msg}\n\n"
f"Model-data compatibility issue detected:\n"
f"{compat_msg}\n\n"
f"Note: This model may not be appropriate for your data. "
f"In model comparison pipelines, it's normal for some models "
f"to fail when their underlying physics doesn't match the material behavior."
)
return RuntimeError(enhanced_msg)
except Exception as exc:
logger.debug(
"Failed to enhance error with compatibility info",
error=str(exc),
)
return error
[docs]
def fit(
self,
X: ArrayLike,
y: ArrayLike,
method: str = "nlsq",
check_compatibility: bool = False,
use_log_residuals: bool | None = None,
use_multi_start: bool | None = None,
n_starts: int = 5,
perturb_factor: float = 0.3,
deformation_mode: str | DeformationMode | None = None,
poisson_ratio: float = 0.5,
auto_init: bool = False,
return_result: bool = False,
check_physics: bool = False,
uncertainty: str | None = None,
**kwargs,
) -> BaseModel | Any:
"""Fit the model to data using NLSQ optimization.
This method uses NLSQ (GPU-accelerated nonlinear least squares) by default
for fast point estimation. The optimization result is stored for potential
warm-starting of Bayesian inference.
For very wide frequency ranges (>10 decades), multi-start optimization is
automatically enabled to escape local minima.
Args:
X: Input features
y: Target values
method: Optimization method ('nlsq' by default for compatibility)
check_compatibility: Whether to check model-data compatibility before
fitting. If True, warns when model may not be appropriate for data.
Default is False for backward compatibility.
use_log_residuals: Whether to use log-space residuals for fitting.
Recommended for wide frequency ranges (>8 decades) to prevent
optimizer bias. If None (default), automatically detected based
on data range. Explicit True/False overrides auto-detection.
use_multi_start: Whether to use multi-start optimization to escape
local minima. Recommended for very wide ranges (>10 decades).
If None (default), automatically enabled for >10 decades.
n_starts: Number of random starts for multi-start optimization (default: 5)
perturb_factor: Perturbation magnitude for multi-start random starts (default: 0.3).
Parameters are perturbed by ± perturb_factor * (value or range).
Larger values (0.7-0.9) explore wider parameter space.
auto_init: If True, calls ``auto_p0()`` to estimate initial parameters
from data before running the optimizer (default: False).
return_result: If True, returns a ``FitResult`` instead of ``self``.
This intentionally breaks method chaining for workflows that need
structured result objects (default: False).
check_physics: If True, runs post-fit physics validation and emits
``RheoJaxPhysicsWarning`` for any violations (default: False).
uncertainty: Post-fit uncertainty method. ``"hessian"`` for fast
Cramér-Rao bounds, ``"bootstrap"`` for residual bootstrap CIs,
or ``None`` to skip (default: None).
**kwargs: Additional fitting options passed to _fit()
Returns:
``self`` for method chaining (default), or ``FitResult`` if
``return_result=True``.
Example:
>>> model = Maxwell()
>>> model.fit(t, G_data) # Uses NLSQ by default
>>> model.fit(t, G_data, method='nlsq', max_iter=1000)
>>> model.fit(t, G_data, check_compatibility=True) # Check compatibility
>>> model.fit(omega, G_star, use_log_residuals=True) # Force log-residuals
>>> model.fit(mastercurve, None, use_multi_start=True, n_starts=10) # Multi-start
>>> result = model.fit(t, G_data, return_result=True) # Structured result
>>> result = model.fit(t, G_data, auto_init=True, check_physics=True,
... return_result=True) # Full pipeline
"""
return FitOrchestrator().execute(
self,
X,
y,
method=method,
check_compatibility=check_compatibility,
use_log_residuals=use_log_residuals,
use_multi_start=use_multi_start,
n_starts=n_starts,
perturb_factor=perturb_factor,
deformation_mode=deformation_mode,
poisson_ratio=poisson_ratio,
auto_init=auto_init,
return_result=return_result,
check_physics=check_physics,
uncertainty=uncertainty,
**kwargs,
)
[docs]
def precompile(
self,
test_mode: str = "relaxation",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
) -> float:
"""Precompile NLSQ residual functions to eliminate JIT cold-start.
Triggers JIT compilation by running a minimal fit (``max_iter=1``)
with dummy data. The model parameters are reset afterwards so
the model is left in its original state.
This is useful for interactive sessions or benchmarks where the
~870ms first-fit JIT overhead should be excluded.
Args:
test_mode: Test mode to precompile for (default: 'relaxation').
X: Optional input data for shape inference. If None, uses a
10-point logspace array.
y: Optional output data. If None, generates ones matching X.
Returns:
Compilation time in seconds.
Example:
>>> model = Maxwell()
>>> t = model.precompile(test_mode='relaxation')
>>> print(f"Compiled in {t:.2f}s")
>>> model.fit(X, y) # No JIT overhead
"""
import time
logger.info("Starting NLSQ precompilation", model=self.__class__.__name__)
# Save current state (params, fitted, test_mode, fit kwargs)
saved_params = {
name: self.parameters.get_value(name) for name in self.parameters
}
saved_fitted = self.fitted_
_had_test_mode = hasattr(self, "_test_mode") # BASE-003: track if attr existed
saved_test_mode = getattr(self, "_test_mode", None)
_raw = getattr(self, "_last_fit_kwargs", None)
saved_last_fit_kwargs = copy.deepcopy(_raw) if _raw is not None else None
# Generate dummy data if not provided
if X is None:
X = np.logspace(-2, 2, 10, dtype=np.float64)
X_arr = np.asarray(X, dtype=np.float64)
if y is None:
y = np.ones_like(X_arr, dtype=np.float64)
start_time = time.perf_counter()
try:
self._fit(X_arr, y, test_mode=test_mode, max_iter=1)
except Exception as e:
logger.warning(
"NLSQ precompilation fit failed — JIT may still have compiled",
error=str(e),
)
compile_time = time.perf_counter() - start_time
# Restore original state
for name, value in saved_params.items():
if value is not None:
self.parameters.set_value(name, value)
else:
self.parameters._parameters[name].value = None
self.fitted_ = saved_fitted
# BASE-003: only restore _test_mode if the attribute existed before
if _had_test_mode:
self._test_mode = saved_test_mode
elif hasattr(self, "_test_mode"):
del self._test_mode
# Restore original _last_fit_kwargs (may be None or empty dict {})
self._last_fit_kwargs = saved_last_fit_kwargs
logger.info(
"NLSQ precompilation completed",
compile_time_seconds=compile_time,
model=self.__class__.__name__,
)
return compile_time
[docs]
def fit_bayesian( # extends BayesianMixin signature with DMTA params
self,
X: ArrayLike,
y: ArrayLike | None = None,
num_warmup: int = 1000,
num_samples: int = 2000,
num_chains: int = 4,
initial_values: dict[str, float] | None = None,
test_mode: str | None = None,
seed: int | None = None,
deformation_mode: str | DeformationMode | None = None,
poisson_ratio: float = 0.5,
**nuts_kwargs,
) -> BayesianResult:
"""Perform Bayesian inference using NumPyro NUTS sampler.
This method delegates to BayesianMixin.fit_bayesian() to run NUTS sampling
for Bayesian parameter estimation. If initial_values is not provided and
the model has been previously fitted with fit(), the NLSQ point estimates
are automatically used for warm-starting.
Multi-chain sampling is enabled by default (num_chains=4) to provide
reliable convergence diagnostics (R-hat, ESS) and parallel execution
on multi-GPU systems.
Args:
X: Independent variable data (input features) or RheoData object
y: Dependent variable data (observations to fit). If X is RheoData,
y is ignored and extracted from X.
num_warmup: Number of warmup/burn-in iterations (default: 1000)
num_samples: Number of posterior samples per chain (default: 2000)
num_chains: Number of MCMC chains (default: 4). Multiple chains
enable proper R-hat computation and parallel execution.
Chain method is auto-selected: 'parallel' on multi-GPU,
'vectorized' on single GPU/CPU.
initial_values: Optional dict of initial parameter values for
warm-start. If None and model is fitted, uses NLSQ estimates.
test_mode: Explicit test mode (e.g., 'relaxation', 'creep', 'oscillation').
If None, inferred from RheoData.metadata['test_mode'] or defaults
to 'relaxation'. Overrides RheoData metadata if provided.
seed: Random seed for reproducibility. If None, uses seed=0 for
deterministic results. Set to different values for independent runs.
**nuts_kwargs: Additional arguments passed to NUTS sampler
(e.g., target_accept_prob, chain_method)
Returns:
BayesianResult containing posterior samples, summary statistics,
and convergence diagnostics (R-hat, ESS, divergences)
Example:
>>> model = Maxwell()
>>> # Warm-start from NLSQ with explicit mode
>>> model.fit(t, G_data, test_mode='relaxation') # NLSQ optimization
>>> result = model.fit_bayesian(t, G_data, test_mode='relaxation')
>>>
>>> # RheoData with embedded mode (recommended)
>>> rheo_data = RheoData(x=omega, y=G_star, metadata={'test_mode': 'oscillation'})
>>> result = model.fit_bayesian(rheo_data)
>>>
>>> # Or provide explicit initial values
>>> result = model.fit_bayesian(
... t, G_data,
... initial_values={'G0': 1e5, 'eta': 1e3},
... test_mode='creep'
... )
"""
# Get data shape for logging
_shape = getattr(X, "shape", None)
data_shape = (
_shape
if _shape is not None
else (len(X) if hasattr(X, "__len__") else (1,))
)
logger.debug(
"Entering fit_bayesian",
model=self.__class__.__name__,
data_shape=data_shape,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
test_mode=test_mode,
)
# --- RheoData unpacking ---
from rheojax.core.data import RheoData
if isinstance(X, RheoData):
if y is None:
y = X.y
if test_mode is None:
test_mode = X.test_mode
if deformation_mode is None:
deformation_mode = X.metadata.get("deformation_mode", None)
X = jnp.array(X.x)
# --- Deformation mode: fall back to prior fit() if not given ---
if deformation_mode is None:
deformation_mode = getattr(self, "_deformation_mode", None)
if deformation_mode is not None:
poisson_ratio = getattr(self, "_poisson_ratio", poisson_ratio)
logger.warning(
"fit_bayesian() using deformation_mode='%s' from prior "
"fit(). Pass deformation_mode explicitly if this is not "
"intended.",
str(deformation_mode),
)
# --- Convert E* -> G* via shared converter ---
resolved_dm = DeformationModeConverter.resolve_deformation_mode(
deformation_mode
)
if resolved_dm is not None:
self._deformation_mode = resolved_dm
self._poisson_ratio = poisson_ratio
y = DeformationModeConverter.convert_to_shear(
y, resolved_dm, poisson_ratio, self.__class__.__name__
)
# Store data for model_function access
self.X_data = X
self.y_data = y
from rheojax.core.data import RheoData as _RheoData
if isinstance(self.X_data, _RheoData):
self.X_data = self.X_data.x
if isinstance(self.y_data, _RheoData):
self.y_data = self.y_data.y
# Auto warm-start from NLSQ if available and no explicit initial values
if initial_values is None and self.fitted_:
# Extract current parameter values as initial values, filtering out None
initial_values = {
name: v
for name in self.parameters
if (v := self.parameters.get_value(name)) is not None
}
logger.debug(
"Using NLSQ warm-start for Bayesian inference",
model=self.__class__.__name__,
initial_values=initial_values,
)
# Call BayesianMixin implementation with multi-chain parallelization
try:
result = super().fit_bayesian(
X,
y,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
initial_values=initial_values,
test_mode=test_mode,
seed=seed,
**nuts_kwargs,
)
# Store result for later access
self._bayesian_result = result
# Log completion with diagnostics
r_hat = result.diagnostics.get("r_hat") if result.diagnostics else None
ess = result.diagnostics.get("ess") if result.diagnostics else None
logger.info(
"Bayesian fit completed",
model=self.__class__.__name__,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
r_hat=r_hat,
ess=ess,
)
logger.debug(
"Exiting fit_bayesian",
model=self.__class__.__name__,
diagnostics=result.diagnostics,
)
return result
except Exception as e:
logger.error(
"Bayesian fit failed",
model=self.__class__.__name__,
error=str(e),
exc_info=True,
)
raise
[docs]
def predict(
self,
X: ArrayLike,
test_mode: str | None = None,
deformation_mode: str | DeformationMode | None = None,
poisson_ratio: float | None = None,
**kwargs,
) -> ArrayLike:
"""Make predictions.
Args:
X: Input features
test_mode: Optional test mode ('oscillation', 'relaxation', 'creep', 'flow').
If provided, sets model's test_mode before prediction.
Useful for data generation without fitting.
deformation_mode: Optional deformation mode for output conversion.
If None, uses the mode stored from fit(). If tensile, converts
G* predictions to E* space.
poisson_ratio: Poisson's ratio for conversion. If None, uses value
stored from fit() (default 0.5).
**kwargs: Additional arguments passed to the internal _predict method.
Returns:
Model predictions (in E* space if deformation_mode is tensile)
"""
x_shape = getattr(X, "shape", None) or (len(X),)
logger.debug(
"Predict called",
model=self.__class__.__name__,
x_shape=x_shape,
test_mode=test_mode,
kwargs=kwargs,
)
# Check if parameters are set manually (allow predict without fit)
# but do NOT permanently mutate self.fitted_ — predict() must be read-only
_effectively_fitted = self.fitted_
if not _effectively_fitted and len(self.parameters) > 0:
if not any(p.value is None for p in self.parameters._parameters.values()):
_effectively_fitted = True
logger.debug(
"Parameters set manually — proceeding with predict "
"(model not marked as fitted)",
model=self.__class__.__name__,
)
# Set test_mode if provided (for data generation without fitting)
_had_test_mode = hasattr(self, "_test_mode")
_old_test_mode = getattr(self, "_test_mode", None)
if test_mode is not None:
if hasattr(self, "_test_mode"):
self._test_mode = test_mode
# Pass test_mode via kwargs for models that read it from kwargs
kwargs["test_mode"] = test_mode
try:
# ADR-004: All _predict() signatures now accept **kwargs,
# so we can call directly without try/except/retry.
result = self._predict(X, **kwargs)
# Convert G* -> E* if tensile deformation mode
dm = DeformationModeConverter.resolve_deformation_mode(
deformation_mode
if deformation_mode is not None
else self._deformation_mode
)
nu = poisson_ratio if poisson_ratio is not None else self._poisson_ratio
result = DeformationModeConverter.convert_from_shear(
result, dm, nu, self.__class__.__name__
)
logger.debug(
"Predict completed",
model=self.__class__.__name__,
output_shape=getattr(result, "shape", None),
)
return result
except Exception as e:
logger.error(
"Predict failed",
model=self.__class__.__name__,
error=str(e),
exc_info=True,
)
raise
finally:
# R10-BASE-002: Restore original _test_mode to avoid side effects.
# If _test_mode was created as a side effect during _predict(), delete it.
if test_mode is not None:
if _had_test_mode:
self._test_mode = _old_test_mode
elif hasattr(self, "_test_mode"):
del self._test_mode
[docs]
def fit_predict(self, X: ArrayLike, y: ArrayLike, **kwargs) -> ArrayLike:
"""Fit model and return predictions.
Args:
X: Input features
y: Target values
**kwargs: Additional fitting options
Returns:
Model predictions on training data
"""
logger.debug(
"fit_predict called",
model=self.__class__.__name__,
data_shape=getattr(X, "shape", None) or (len(X),),
)
self.fit(X, y, **kwargs)
return self.predict(X)
[docs]
def get_nlsq_result(self):
"""Get stored NLSQ optimization result.
Returns:
OptimizationResult from NLSQ fit, or None if not fitted
Example:
>>> model.fit(t, G_data)
>>> result = model.get_nlsq_result()
>>> if result:
... print(f"Converged: {result.success}")
"""
return self._nlsq_result
@property
def pcov_(self):
"""Parameter covariance matrix from NLSQ fit.
Returns:
ndarray of shape (n_params, n_params), or None if not fitted
"""
return self._nlsq_result.pcov if self._nlsq_result else None
@property
def popt_(self):
"""Optimal parameter values from NLSQ fit.
Returns:
ndarray of shape (n_params,), or None if not fitted
"""
return self._nlsq_result.x if self._nlsq_result else None
[docs]
def get_parameter_uncertainties(self):
"""Get standard errors for fitted parameters from NLSQ covariance.
Returns:
dict of {param_name: std_error}, or None if covariance unavailable
"""
if self._nlsq_result is None or self._nlsq_result.pcov is None:
return None
std_errors = self._nlsq_result.get_parameter_uncertainties()
if std_errors is None:
return None
param_names = list(self.parameters.keys())
return dict(zip(param_names, std_errors, strict=True))
[docs]
def get_bayesian_result(self) -> BayesianResult | None:
"""Get stored Bayesian inference result.
Returns:
BayesianResult from fit_bayesian(), or None if not run
Example:
>>> model.fit_bayesian(t, G_data)
>>> result = model.get_bayesian_result()
>>> print(result.diagnostics['r_hat'])
"""
return self._bayesian_result
[docs]
def get_params(self, deep: bool = True) -> dict[str, Any]:
"""Get model parameters.
Args:
deep: If True, return parameters of sub-objects
Returns:
Dictionary of parameter names and values
"""
if hasattr(self, "parameters") and len(self.parameters) > 0:
return {
name: self.parameters[name].value for name in self.parameters.keys()
}
return {}
[docs]
def set_params(self, **params) -> BaseModel:
"""Set model parameters.
Args:
**params: Parameter names and values
Returns:
self for method chaining
"""
logger.debug(
"set_params called",
model=self.__class__.__name__,
params=params,
)
if hasattr(self, "parameters"):
for name, value in params.items():
if name in self.parameters:
self.parameters.set_value(name, value)
return self
[docs]
def score(self, X: ArrayLike, y: ArrayLike) -> float:
"""Compute model score (R² by default).
Args:
X: Input features
y: True target values
Returns:
Model score (R² coefficient)
"""
predictions = self.predict(X)
# Convert to numpy for scoring
if isinstance(predictions, jnp.ndarray):
predictions = np.array(predictions)
if isinstance(y, jnp.ndarray):
y = np.array(y)
# Compute R² score
# For complex data (e.g., oscillatory shear), use magnitude of residuals
if np.iscomplexobj(y) or np.iscomplexobj(predictions):
ss_res = np.sum(np.abs(y - predictions) ** 2)
ss_tot = np.sum(np.abs(y - np.mean(y)) ** 2)
else:
ss_res = np.sum((y - predictions) ** 2)
ss_tot = np.sum((y - np.mean(y)) ** 2)
# Handle edge cases
if ss_tot == 0:
# All y values are constant — R² is undefined
logger.warning("R² undefined for constant data (ss_tot=0)")
return np.nan
# Handle NaN case (e.g. predictions contain NaN/Inf)
r2 = 1 - (ss_res / ss_tot)
if np.isnan(r2):
logger.warning(
"R² is NaN — predictions may contain NaN/Inf values",
model=self.__class__.__name__,
)
return np.nan
return float(np.real(r2))
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialize model to dictionary.
Returns:
Dictionary representation of model
"""
return {
"class": self.__class__.__name__,
"parameters": (
self.parameters.to_dict()
if hasattr(self, "parameters") and len(self.parameters) > 0
else {}
),
"fitted": self.fitted_,
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> BaseModel:
"""Create model from dictionary.
Args:
data: Dictionary representation
Returns:
Model instance
"""
model = cls()
if "parameters" in data:
model.parameters = ParameterSet.from_dict(data["parameters"])
model.fitted_ = data.get("fitted", False)
logger.debug(
"Model created from dict",
model=cls.__name__,
fitted=model.fitted_,
)
return model
[docs]
def __repr__(self) -> str:
"""String representation of model."""
params = self.get_params()
param_str = ", ".join(f"{k}={v}" for k, v in params.items())
return f"{self.__class__.__name__}({param_str})"
__all__ = [
"BaseModel",
"BaseTransform",
"TransformPipeline",
"Parameter",
"ParameterSet",
]