Source code for rheojax.utils.optimization

"""Optimization utilities for parameter fitting using NLSQ.

This module provides GPU-accelerated optimization using the NLSQ package
(https://github.com/imewei/NLSQ). NLSQ provides 5-270x speedup over scipy
through JAX JIT compilation and automatic differentiation.

Critical: This module imports NLSQ, which must be imported before JAX to
enable float64 precision mode. The rheojax package handles this automatically
in __init__.py.

Example:
    >>> from rheojax.core.parameters import ParameterSet
    >>> from rheojax.utils.optimization import nlsq_optimize
    >>>
    >>> # Set up parameters
    >>> params = ParameterSet()
    >>> params.add("x", value=1.0, bounds=(0, 10))
    >>>
    >>> # Define objective function
    >>> def objective(values):
    ...     x = values[0]
    ...     return (x - 5.0) ** 2
    >>>
    >>> # Optimize
    >>> result = nlsq_optimize(objective, params, use_jax=True)
    >>> print(f"Optimal x: {result.x[0]:.4f}")
"""

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any

import nlsq
import numpy as np

from rheojax.core.jax_config import safe_import_jax
from rheojax.core.parameters import ParameterSet
from rheojax.logging import get_logger

logger = get_logger(__name__)

# Safe JAX import (verifies NLSQ was imported first)
jax, jnp = safe_import_jax()


type ArrayLike = np.ndarray | list | float

# OPT-002/OPT-020: Module-level frozenset of RheoJAX-specific kwargs that must
# be filtered before forwarding to NLSQ/SciPy optimizers (prevents TypeError).
_RHEOJAX_RESERVED_KWARGS: frozenset[str] = frozenset(
    {
        "test_mode",
        "deformation_mode",
        "poisson_ratio",
        "seed",
        "method",
        "num_warmup",
        "num_samples",
        "num_chains",
        "gamma_dot",
        "sigma",
        "sigma_applied",
        "gamma_0",
        "omega_laos",
        "return_components",
        "return_full",
        "t_wait",
        "n_cycles",
        # FIKH/FMLIKH-specific protocol kwargs (F-003)
        "strain",
        "sigma_0",
        "T_init",
        "T",
        # Additional protocol kwargs (R3-U-005)
        "omega",
        "lam_init",
        "lam_0",
        "sigma_init",
        "points_per_cycle",
    }
)


[docs] def make_fd_differentiable( fn: Callable, eps: float = 1e-7, ) -> Callable: """Wrap a function with a finite-difference custom JVP. This enables ``jax.jacfwd`` (forward-mode AD) for functions that cannot be traced by JAX's autodiff — e.g. diffrax ODE solvers which use ``custom_vjp`` and are therefore incompatible with ``jacfwd``. The wrapper computes JVPs via central differences: ``(f(x+εv) - f(x-εv)) / 2ε``. When combined with ``jax.jacfwd``, this effectively computes the full Jacobian via ``vmap``'d perturbations in a **single batched XLA call** — much faster than scipy's sequential finite differences. Args: fn: Function ``(x_data, params) -> predictions``. Only the ``params`` argument (index 1) is differentiated; ``x_data`` passes through. eps: Perturbation size for central differences. Returns: A function with identical signature but a custom JVP rule for ``params``. Example:: # Before: NLSQ fails with TypeError on ODE models objective = create_least_squares_objective(model_fn, x, y) # After: finite-difference JVP makes it NLSQ-compatible objective = create_least_squares_objective( make_fd_differentiable(model_fn), x, y ) """ @jax.custom_jvp def wrapped(x_data, params): return fn(x_data, params) @wrapped.defjvp def wrapped_jvp(primals, tangents): x_data, params = primals _, params_dot = tangents # Primal output y = wrapped(x_data, params) # Relative perturbation: scale eps by each parameter's magnitude to # avoid catastrophic cancellation for large-magnitude params (e.g. # eta_p ~ 100 Pa·s). Without this, abs perturbation 1e-7 on a value # of 100 gives relative perturbation 1e-9, at float64 precision limit. scale = jnp.maximum(jnp.abs(params), 1.0) * eps # Central-difference JVP: J·v ≈ (f(x + h) - f(x - h)) / (2 * ||h||) # where h = scale ⊙ v (element-wise product). # When jacfwd sends v = e_i (unit basis), h = scale_i * e_i, so # ||h|| = scale_i and this gives column i of J correctly. h = scale * params_dot y_plus = fn(x_data, params + h) y_minus = fn(x_data, params - h) h_norm = jnp.sqrt(jnp.sum(h**2) + 1e-300) y_dot = (y_plus - y_minus) / (2.0 * h_norm) return y, y_dot return wrapped
def _validate_optimization_result( result: OptimizationResult, residuals: np.ndarray, y_data: np.ndarray | None = None, mse_threshold: float = 1e18, ) -> None: """Validate optimization result against pathological outcomes. Checks for non-finite or astronomically large residuals that indicate the optimizer "succeeded" numerically but produced meaningless parameters. Args: result: OptimizationResult to validate (uses result.fun for RSS). residuals: Residual vector at the optimal point. y_data: Original y data array. When provided, uses len(y_data) as the observation count so that complex data (where residuals has length 2N) is not penalised with an inflated denominator. mse_threshold: Maximum allowed mean squared error (default: 1e18). The default accommodates GPa-scale moduli data where raw residuals can be O(1e9 Pa), yielding MSS/N ~ O(1e18 Pa²). For normalized residuals the threshold is effectively unbounded (RSS/N ≈ O(1) << 1e18). Only truly diverged fits (non-finite or beyond 1e18) are rejected. Raises: RuntimeError: If MSE exceeds threshold or is non-finite. """ if residuals.size == 0: raise RuntimeError( "Optimization produced empty residual vector. " "Check that the objective function returns a non-empty array." ) residual_count = len(y_data) if y_data is not None else residuals.size # R10-OPT-001: auto-scale threshold by data magnitude so that the MSE # check works for both dimensionless log-residuals and raw Pa-scale data. # Fallback to 1e18 when y_data is unavailable or all-zero. if y_data is not None and len(y_data) > 0: y_scale = float(np.max(np.abs(np.asarray(y_data)))) if y_scale > 0: mse_threshold = max(mse_threshold, 1e6 * y_scale**2) mse = result.fun / residual_count if not np.isfinite(mse) or mse > mse_threshold: logger.error( "Optimization failed: residual norm extremely large", mean_squared_error=float(mse) if np.isfinite(mse) else "inf", residual_count=residual_count, rss=float(result.fun), ) raise RuntimeError( "Optimization failed: residual norm remains extremely large. " "Try providing better initial values, looser bounds, or scaling the data." ) def _extract_bounds( parameters: ParameterSet, ) -> tuple[np.ndarray, tuple[np.ndarray, np.ndarray]]: """Extract initial values and bounds arrays from a ParameterSet. # TODO: Planned enhancement — log-transform bounded parameters before # passing to the optimizer and inverse-transform on the way out. This # would improve conditioning for parameters that span many orders of # magnitude (e.g. moduli 1e0–1e9, time constants 1e-6–1e3). Requires # careful integration with the NLSQ library's bound handling and the # complex-residual split path in create_least_squares_objective. Converts ParameterSet bounds (which may contain None) into the (lower_array, upper_array) format expected by SciPy and NLSQ. Args: parameters: ParameterSet with initial values and bounds. Returns: (x0, (lower_bounds, upper_bounds)) where x0 is the initial values array and bounds are float64 arrays with -inf/+inf for missing bounds. """ x0 = np.asarray(parameters.get_values(), dtype=np.float64) bounds_list = parameters.get_bounds() lower_list: list[float] = [] upper_list: list[float] = [] for bound_pair in bounds_list: if bound_pair is None or (bound_pair[0] is None and bound_pair[1] is None): lower_list.append(-np.inf) upper_list.append(np.inf) else: lower_list.append(bound_pair[0] if bound_pair[0] is not None else -np.inf) upper_list.append(bound_pair[1] if bound_pair[1] is not None else np.inf) lower = np.asarray(lower_list, dtype=np.float64) upper = np.asarray(upper_list, dtype=np.float64) return x0, (lower, upper) def _run_scipy_least_squares( objective: Callable[[np.ndarray], float | np.ndarray], x0: np.ndarray, bounds: tuple[np.ndarray, np.ndarray], ftol: float, xtol: float, gtol: float, max_iter: int, compute_covariance: bool = True, ) -> OptimizationResult: """Run SciPy's TRF least squares and return an OptimizationResult. Shared implementation for both the explicit method='scipy' path and the NLSQ failure fallback path. Computes covariance from the Jacobian when available. Args: objective: Residual function (values -> residual vector or scalar). x0: Initial parameter values. bounds: (lower_bounds, upper_bounds) arrays. ftol: Function tolerance. xtol: Parameter tolerance. gtol: Gradient tolerance. max_iter: Maximum iterations (max_nfev = max_iter * 10). compute_covariance: Whether to compute the covariance matrix from the Jacobian (default: True). Set False to skip SVD and save time when covariance is not required. Returns: OptimizationResult with optimal parameters and covariance. """ from scipy.optimize import least_squares as scipy_least_squares def residual_fn(values: np.ndarray) -> np.ndarray: res = objective(values) res = np.asarray(res) if np.iscomplexobj(res): res = np.concatenate([np.real(res), np.imag(res)]) res = res.astype(np.float64) # Guard against NaN/Inf from ODE solvers — replace with large finite # penalty so scipy can still attempt to optimize (gradient-guided away # from the bad region). Without this, scipy raises ValueError at init. # UTILS-001: Apply nan_to_num first so NaN gets 1e10 (not 0.0 from sign(NaN)) if not np.all(np.isfinite(res)): res = np.nan_to_num(res, nan=1e10, posinf=1e10, neginf=-1e10) return res scipy_result = scipy_least_squares( residual_fn, x0, bounds=bounds, ftol=ftol, xtol=xtol, gtol=gtol, max_nfev=max_iter * 10, method="trf", ) cost_value = getattr(scipy_result, "cost", None) jac = None pcov = None # OPT-05: Evaluate residuals once and reuse for both covariance and final # stats — avoids a second (potentially expensive) objective call. # P1-5: Include residuals, y_data, n_data, and _is_complex_split so that # downstream statistics (R², AIC, adj-R²) can be computed correctly. final_residuals = residual_fn(scipy_result.x) rss = float(np.sum(final_residuals**2)) if scipy_result.jac is not None and compute_covariance: jac = np.asarray(scipy_result.jac, dtype=np.float64) pcov = compute_covariance_from_jacobian(jac, final_residuals) return OptimizationResult( x=np.asarray(scipy_result.x, dtype=np.float64), fun=(float(2.0 * scipy_result.cost) if hasattr(scipy_result, "cost") else rss), jac=jac, pcov=pcov, success=bool(scipy_result.success), message=str(scipy_result.message), nit=int(getattr(scipy_result, "nit", scipy_result.nfev)), nfev=int(scipy_result.nfev), njev=int(getattr(scipy_result, "njev", 0)), optimality=( float(getattr(scipy_result, "optimality", np.nan)) if getattr(scipy_result, "optimality", None) is not None else None ), active_mask=( np.asarray(scipy_result.active_mask) if getattr(scipy_result, "active_mask", None) is not None else None ), cost=float(cost_value) if cost_value is not None else None, residuals=final_residuals, ) def _run_differential_evolution( objective: Callable[[np.ndarray], float | np.ndarray], x0: np.ndarray, bounds: tuple[np.ndarray, np.ndarray], max_iter: int, ) -> OptimizationResult: """Run SciPy's differential_evolution as a last-resort global optimizer. Used as the final fallback in the ``workflow="auto_global"`` chain after both NLSQ and SciPy TRF least-squares have failed to converge. Differential evolution performs a population-based global search that is robust to non-convex and multimodal objective landscapes. The objective is wrapped to return a scalar RSS so that ``differential_evolution`` (which minimises a scalar) can consume the same residual functions used elsewhere. Args: objective: Residual function (params -> residual vector or scalar). x0: Initial parameter values — seeded into the initial population so that the global search starts from (or near) the best local solution found so far. bounds: (lower_bounds, upper_bounds) arrays. Any infinite bound is clamped to ±1e10 so that ``differential_evolution`` can draw a finite initial population. max_iter: Maximum number of generations (``maxiter`` in SciPy). Returns: OptimizationResult compatible with the rest of the optimisation chain. """ from scipy.optimize import differential_evolution lower, upper = bounds # differential_evolution requires finite bounds; clamp ±inf to ±1e10. finite_lower = np.where(np.isfinite(lower), lower, -1e10) finite_upper = np.where(np.isfinite(upper), upper, 1e10) de_bounds = list(zip(finite_lower.tolist(), finite_upper.tolist(), strict=True)) # Seed x0 into the initial population so DE starts near the best local # solution. We reshape x0 into a (1, n_params) array that # differential_evolution accepts via ``init``. x0_pop = x0.reshape(1, -1) def scalar_objective(values: np.ndarray) -> float: res = np.asarray(objective(values)) if np.iscomplexobj(res): res = np.concatenate([np.real(res), np.imag(res)]) res = res.astype(np.float64) if not np.all(np.isfinite(res)): res = np.nan_to_num(res, nan=1e10, posinf=1e10, neginf=-1e10) return float(np.sum(res**2)) logger.info( "Running differential_evolution global optimizer", n_params=len(x0), maxiter=max_iter, ) de_result = differential_evolution( scalar_objective, de_bounds, maxiter=max_iter, tol=1e-6, seed=0, polish=True, # final local refinement with L-BFGS-B x0=x0_pop, # seed best local solution into initial population ) x_opt = np.asarray(de_result.x, dtype=np.float64) rss = float(de_result.fun) return OptimizationResult( x=x_opt, fun=rss, jac=None, pcov=None, success=bool(de_result.success), message=str(de_result.message), nit=int(getattr(de_result, "nit", 0)), nfev=int(getattr(de_result, "nfev", 0)), njev=0, optimality=None, active_mask=None, cost=rss, ) def compute_covariance_from_jacobian( jac: np.ndarray, residuals: np.ndarray | None = None, n_data: int | None = None, ) -> np.ndarray | None: """Compute parameter covariance matrix from Jacobian via SVD. Uses SVD-based Moore-Penrose pseudo-inverse for numerical stability: pcov = VT.T @ diag(1/s²) @ VT Scaled by residual variance when residuals provided: pcov *= RSS / (n_data - n_params) Args: jac: Jacobian matrix (m x n), where m = data points, n = parameters residuals: Optional residual vector for scaling n_data: Number of data points (default: inferred from jac.shape[0]) Returns: Covariance matrix (n x n), or None if computation fails """ if jac is None or jac.size == 0: logger.debug("Jacobian is None or empty, cannot compute covariance") return None try: jac = np.asarray(jac, dtype=np.float64) m, n = jac.shape # m = data points, n = parameters logger.debug( "Computing covariance from Jacobian", jacobian_shape=(m, n), n_data_points=m, n_params=n, ) # SVD of Jacobian: J = U @ S @ VT U, s, VT = np.linalg.svd(jac, full_matrices=False) logger.debug( "SVD computed", singular_values_range=(float(s.min()), float(s.max())), condition_number=float(s.max() / s.min()) if s.min() > 0 else float("inf"), ) # Filter near-zero singular values threshold = np.finfo(np.float64).eps * max(m, n) * s[0] # Use safe division to avoid RuntimeWarning: divide by zero s_safe = np.where(s > threshold, s, np.inf) s_inv_sq = np.where(s > threshold, 1.0 / (s_safe**2), 0.0) n_filtered = np.sum(s <= threshold) if n_filtered > 0: logger.debug( "Filtered near-zero singular values", n_filtered=int(n_filtered), threshold=float(threshold), ) # Compute covariance: (J.T @ J)^-1 = VT.T @ diag(1/s²) @ VT # OPT-COV-001: Broadcasting (VT.T * s_inv_sq) avoids allocating an N×N # diagonal matrix. Equivalent to VT.T @ diag(s_inv_sq) @ VT but ~1.5× # faster and uses O(N) instead of O(N²) scratch memory. pcov = (VT.T * s_inv_sq) @ VT # Scale by residual variance if available if residuals is not None: residuals = np.asarray(residuals, dtype=np.float64).ravel() rss = np.sum(residuals**2) n_data_actual = n_data if n_data is not None else m dof = n_data_actual - n # degrees of freedom if dof > 0: pcov = pcov * (rss / dof) logger.debug( "Scaled covariance by residual variance", rss=float(rss), degrees_of_freedom=dof, scale_factor=float(rss / dof), ) # Validate result if not np.all(np.isfinite(pcov)): logger.warning( "Covariance matrix contains inf/nan, returning None", has_inf=bool(np.any(np.isinf(pcov))), has_nan=bool(np.any(np.isnan(pcov))), ) return None logger.debug( "Covariance computation completed", pcov_shape=pcov.shape, pcov_diagonal_range=( float(np.diag(pcov).min()), float(np.diag(pcov).max()), ), ) return pcov except Exception as e: logger.error( "Failed to compute covariance from Jacobian", error=str(e), exc_info=True, ) return None
[docs] @dataclass class OptimizationResult: """Result from optimization with NLSQ 0.6.6 CurveFitResult-compatible properties. This dataclass stores the results of NLSQ optimization, including optimal parameter values, objective function value, convergence information, and statistical metrics compatible with NLSQ 0.6.6's CurveFitResult. Attributes: x: Optimal parameter values (float64 array) fun: Objective function value at optimum (RSS = sum of squared residuals) jac: Jacobian (gradient) at optimum pcov: Parameter covariance matrix (n_params x n_params) success: Whether optimization converged successfully message: Status message from optimizer nit: Number of iterations nfev: Number of function evaluations njev: Number of Jacobian evaluations optimality: Optimality metric (gradient norm) active_mask: Active bound constraints at solution cost: Final cost value grad: Final gradient nlsq_result: Full NLSQ result dictionary (for advanced diagnostics) residuals: Residual vector (y_data - y_pred) for statistical metrics y_data: Original dependent variable data (for R² computation) n_data: Number of data points (for AIC/BIC computation) diagnostics: Model health diagnostics (NLSQ 0.6.6, when compute_diagnostics=True) Statistical Properties (NLSQ 0.6.6 CurveFitResult compatible): r_squared: Coefficient of determination (R²) adj_r_squared: Adjusted R² accounting for number of parameters rmse: Root mean squared error mae: Mean absolute error aic: Akaike Information Criterion bic: Bayesian Information Criterion Methods: confidence_intervals(alpha): Compute parameter confidence intervals prediction_interval(x_new, alpha): Compute prediction intervals (NLSQ 0.6.6) get_parameter_uncertainties(): Get standard errors from covariance diagonal """ x: np.ndarray fun: float jac: np.ndarray | None = None pcov: np.ndarray | None = None success: bool = False message: str = "" nit: int = 0 nfev: int = 0 njev: int = 0 optimality: float | None = None active_mask: np.ndarray | None = None cost: float | None = None grad: np.ndarray | None = None nlsq_result: dict[str, Any] | None = field(default=None, repr=False) # Fields for statistical metrics (NLSQ 0.6.6 compatibility) residuals: np.ndarray | None = field(default=None, repr=False) y_data: np.ndarray | None = field(default=None, repr=False) n_data: int | None = None # NLSQ 0.6.6 fields for native delegation diagnostics: dict[str, Any] | None = field(default=None, repr=False) _curve_fit_result: Any | None = field(default=None, repr=False) _model_fn: Callable | None = field(default=None, repr=False) _x_data: np.ndarray | None = field(default=None, repr=False) # OPT-AIC-BIC-001: True when residuals are concatenated [real, imag] from # complex data. Used by _resolve_n_data() to halve the residual vector # length to recover the true observation count N. _is_complex_split: bool = field(default=False, repr=False) # P1-6: When residuals are normalized (divided by weights), store the # normalization weights so that R²/AIC/BIC can un-normalize for correct # statistics. Shape matches residuals. None when residuals are raw. _normalization_weights: np.ndarray | None = field(default=None, repr=False) def _resolve_n_data(self) -> int: """Resolve the true observation count N. Priority: n_data > len(y_data) > residual-length (halved when ``_is_complex_split`` is True, since the residual vector is ``[real, imag]`` with length 2N). """ if self.n_data is not None: return self.n_data if self.y_data is not None: return len(self.y_data) if self.residuals is not None: raw_len = len(self.residuals) return raw_len // 2 if self._is_complex_split else raw_len return 0 # ========================================================================= # Statistical Properties (NLSQ 0.6.0 CurveFitResult compatible) # ========================================================================= @property def r_squared(self) -> float | None: """Coefficient of determination (R²). Measures goodness of fit. Range: (-∞, 1], where 1 is perfect fit. R² = 1 - SS_res / SS_tot where SS_res = sum((y - y_pred)²) and SS_tot = sum((y - y_mean)²) Returns: R² value, or None if residuals/y_data not available """ if self.residuals is None or self.y_data is None: return None # Handle complex data by using magnitude y_data = np.asarray(self.y_data) residuals = np.asarray(self.residuals) if np.iscomplexobj(residuals): residuals = np.abs(residuals) # P1-6: Un-normalize residuals when normalization weights are stored, # so that ss_res is in the same units as ss_tot (raw data units). if self._normalization_weights is not None: residuals = residuals * np.asarray(self._normalization_weights) if len(residuals) == 2 * len(y_data): half = len(y_data) ss_res = np.sum(residuals[:half] ** 2) + np.sum(residuals[half:] ** 2) if np.iscomplexobj(y_data): ss_tot = np.sum((y_data.real - np.mean(y_data.real)) ** 2) + np.sum( (y_data.imag - np.mean(y_data.imag)) ** 2 ) else: ss_tot = np.sum((y_data - np.mean(y_data)) ** 2) else: ss_res = np.sum(residuals**2) if np.iscomplexobj(y_data): y_data = np.abs(y_data) ss_tot = np.sum((y_data - np.mean(y_data)) ** 2) if ss_tot == 0: logger.warning( "Total sum of squares is zero (constant data). R² undefined." ) return np.nan return float(1 - (ss_res / ss_tot)) @property def adj_r_squared(self) -> float | None: """Adjusted R² accounting for number of parameters. Adj R² = 1 - (1 - R²) * (n - 1) / (n - p - 1) where n is number of data points and p is number of parameters. Returns: Adjusted R² value, or None if cannot be computed """ r2 = self.r_squared if r2 is None: return None n = ( self.n_data if self.n_data is not None else (len(self.y_data) if self.y_data is not None else None) ) if n is None or n == 0: return None p = self.x.size if n - p - 1 <= 0: logger.warning("Not enough degrees of freedom for adjusted R².") return np.nan return float(1 - (1 - r2) * (n - 1) / (n - p - 1)) @property def rmse(self) -> float | None: """Root mean squared error. RMSE = sqrt(mean(residuals²)) Returns: RMSE value, or None if residuals not available """ if self.residuals is None: return None residuals = np.asarray(self.residuals) if np.iscomplexobj(residuals): residuals = np.abs(residuals) return float(np.sqrt(np.mean(residuals**2))) @property def mae(self) -> float | None: """Mean absolute error. MAE = mean(abs(residuals)) More robust to outliers than RMSE. Returns: MAE value, or None if residuals not available """ if self.residuals is None: return None residuals = np.asarray(self.residuals) return float(np.mean(np.abs(residuals))) @property def aic(self) -> float | None: """Akaike Information Criterion. AIC = 2k + n*ln(RSS/n) where k is number of parameters, n is number of data points, and RSS is residual sum of squares. Lower is better. Used for model selection. Returns: AIC value, or None if cannot be computed """ if self.residuals is None: return None residuals = np.asarray(self.residuals) if np.iscomplexobj(residuals): residuals = np.abs(residuals) # P1-6/P1-7: Un-normalize residuals for correct AIC computation. if self._normalization_weights is not None: residuals = residuals * np.asarray(self._normalization_weights) n = self._resolve_n_data() if n == 0: return None k = self.x.size rss = np.sum(residuals**2) if rss <= 0: logger.warning("RSS ≤ 0, AIC undefined.") return np.nan return float(2 * k + n * np.log(rss / n)) @property def bic(self) -> float | None: """Bayesian Information Criterion. BIC = k*ln(n) + n*ln(RSS/n) where k is number of parameters, n is number of data points, and RSS is residual sum of squares. Lower is better. Penalizes model complexity more than AIC. Returns: BIC value, or None if cannot be computed """ if self.residuals is None: return None residuals = np.asarray(self.residuals) if np.iscomplexobj(residuals): residuals = np.abs(residuals) # P1-6/P1-7: Un-normalize residuals for correct BIC computation. if self._normalization_weights is not None: residuals = residuals * np.asarray(self._normalization_weights) n = self._resolve_n_data() if n == 0: return None k = self.x.size rss = np.sum(residuals**2) if rss <= 0: logger.warning("RSS ≤ 0, BIC undefined.") return np.nan return float(k * np.log(n) + n * np.log(rss / n)) # ========================================================================= # Statistical Methods (NLSQ 0.6.0 CurveFitResult compatible) # =========================================================================
[docs] def confidence_intervals(self, alpha: float = 0.95) -> np.ndarray | None: """Compute parameter confidence intervals. Parameters ---------- alpha : float, optional Confidence level (default: 0.95 for 95% CI). Returns ------- intervals : ndarray or None Array of shape (n_params, 2) with [lower, upper] bounds for each parameter, or None if covariance not available. Examples -------- >>> result = nlsq_optimize(objective, params) >>> ci = result.confidence_intervals(alpha=0.95) >>> if ci is not None: ... for i, (lower, upper) in enumerate(ci): ... print(f"Parameter {i}: [{lower:.3f}, {upper:.3f}]") """ if self.pcov is None: return None from scipy import stats n = self._resolve_n_data() p = self.x.size # Degrees of freedom dof = max(n - p, 1) # t-value for confidence level t_val = stats.t.ppf((1 + alpha) / 2, dof) # Standard errors from covariance diagonal # OPT-006: Guard against negative covariance diagonals from # near-singular Jacobians perr = np.sqrt(np.maximum(np.diag(self.pcov), 0.0)) # Confidence intervals intervals = np.zeros((p, 2)) intervals[:, 0] = self.x - t_val * perr # Lower bound intervals[:, 1] = self.x + t_val * perr # Upper bound return intervals
[docs] def get_parameter_uncertainties(self) -> np.ndarray | None: """Get standard errors for parameters from covariance diagonal. Returns ------- uncertainties : ndarray or None Standard errors for each parameter, or None if covariance not available. Examples -------- >>> result = nlsq_optimize(objective, params) >>> std_errs = result.get_parameter_uncertainties() >>> if std_errs is not None: ... for i, se in enumerate(std_errs): ... print(f"Parameter {i}: {result.x[i]:.4f} ± {se:.4f}") """ if self.pcov is None: return None return np.sqrt(np.maximum(np.diag(self.pcov), 0.0))
[docs] def prediction_interval( self, x_new: np.ndarray | None = None, alpha: float = 0.95, ) -> np.ndarray | None: """Compute prediction intervals for new x values. Prediction intervals account for both parameter uncertainty and observation noise, providing bounds where future observations are expected to fall with the specified probability. Parameters ---------- x_new : ndarray or None, optional New x values for prediction. If None, uses original x_data. alpha : float, optional Confidence level for intervals (default: 0.95 for 95% PI). Returns ------- intervals : ndarray or None Array of shape (n_points, 2) with [lower, upper] bounds for each point, or None if prediction intervals cannot be computed. Notes ----- When a native NLSQ CurveFitResult is available (from nlsq_curve_fit), this method delegates to NLSQ's prediction_interval for accuracy. Otherwise, it falls back to a manual computation using covariance propagation. Examples -------- >>> result = nlsq_curve_fit(model, x_data, y_data, params) >>> pi = result.prediction_interval(x_new, alpha=0.95) >>> if pi is not None: ... for i, (lower, upper) in enumerate(pi): ... print(f"x={x_new[i]:.2f}: [{lower:.3f}, {upper:.3f}]") """ # Delegate to native NLSQ CurveFitResult when available if self._curve_fit_result is not None: try: return self._curve_fit_result.prediction_interval(x_new, alpha) except Exception as e: logger.debug( "Native prediction_interval failed, using fallback", error=str(e), ) # Fallback: manual computation requires model function and data if self._model_fn is None or self.pcov is None: logger.debug("Cannot compute prediction interval: missing model_fn or pcov") return None x_eval = x_new if x_new is not None else self._x_data if x_eval is None: logger.debug("Cannot compute prediction interval: no x data") return None from scipy import stats x_eval = np.asarray(x_eval, dtype=np.float64) n = self._resolve_n_data() p = self.x.size dof = max(n - p, 1) # t-value for prediction interval t_val = stats.t.ppf((1 + alpha) / 2, dof) # Compute predictions and standard errors via numerical differentiation try: y_pred = np.asarray(self._model_fn(x_eval, self.x)) # Estimate prediction variance using residual variance + parameter uncertainty if self.residuals is not None: residuals = np.asarray(self.residuals) if np.iscomplexobj(residuals): residuals = np.abs(residuals) mse = np.sum(residuals**2) / dof else: mse = (self.fun / n) if n > 0 else 0.0 # P2-Fit-3: Use leverage-weighted prediction intervals when the # Jacobian is available: h_i = diag(J @ inv(J^T J) @ J^T), then # pred_std_i = sqrt(mse * (1 + h_i)). Falls back to constant MSE # when Jacobian is unavailable (approximate but conservative). if self.jac is not None: try: J = np.asarray(self.jac, dtype=np.float64) # Compute J at x_new if different from training data # For now, only use training-point leverages when x_new matches JtJ = J.T @ J JtJ_inv = np.linalg.inv(JtJ + 1e-12 * np.eye(JtJ.shape[0])) hat_matrix_diag = np.sum((J @ JtJ_inv) * J, axis=1) # hat_matrix_diag has shape (n_residuals,); map to x_eval length if len(hat_matrix_diag) == len(x_eval): pred_std = np.sqrt(mse * (1.0 + hat_matrix_diag)) else: # Residual length != x_eval length (e.g. complex-split) pred_std = np.sqrt(mse) * np.ones_like(y_pred) except Exception: pred_std = np.sqrt(mse) * np.ones_like(y_pred) else: pred_std = np.sqrt(mse) * np.ones_like(y_pred) intervals = np.zeros((len(x_eval), 2)) intervals[:, 0] = y_pred - t_val * pred_std intervals[:, 1] = y_pred + t_val * pred_std return intervals except Exception as e: logger.warning( "Failed to compute prediction interval", error=str(e), ) return None
[docs] @classmethod def from_curve_fit_result( cls, curve_fit_result: Any, y_data: np.ndarray | None = None, model_fn: Callable | None = None, x_data: np.ndarray | None = None, ) -> OptimizationResult: """Create OptimizationResult from NLSQ 0.6.6 CurveFitResult. This factory method preserves the native CurveFitResult for property delegation, enabling access to NLSQ 0.6.6's statistical methods like prediction_interval() without reimplementation. Parameters ---------- curve_fit_result : CurveFitResult Result from nlsq.curve_fit() call. y_data : ndarray, optional Original dependent variable data (for complex data handling). model_fn : callable, optional Model function f(x, params) for prediction intervals. x_data : ndarray, optional Original independent variable data for prediction intervals. Returns ------- result : OptimizationResult Result with native delegation to CurveFitResult properties. Examples -------- >>> curve_fit_result = nlsq.curve_fit(model_fn, x, y, p0=p0) >>> result = OptimizationResult.from_curve_fit_result( ... curve_fit_result, y_data=y, model_fn=model_fn, x_data=x ... ) >>> print(result.r_squared) # Delegates to native >>> pi = result.prediction_interval(x_new) # Delegates to native """ # Extract standard fields popt = np.asarray(curve_fit_result.popt, dtype=np.float64) pcov = ( np.asarray(curve_fit_result.pcov, dtype=np.float64) if curve_fit_result.pcov is not None else None ) success = getattr(curve_fit_result, "success", True) message = getattr(curve_fit_result, "message", "Converged") nfev = getattr(curve_fit_result, "nfev", 0) njev = getattr(curve_fit_result, "njev", 0) cost = getattr(curve_fit_result, "cost", None) # Get residuals from native result residuals = None if hasattr(curve_fit_result, "residuals"): residuals = np.asarray(curve_fit_result.residuals) # Get diagnostics if available diagnostics = getattr(curve_fit_result, "diagnostics", None) # Compute RSS from residuals or cost if residuals is not None: rss = float(np.sum(residuals**2)) elif cost is not None: # scipy/NLSQ convention: cost = 0.5 * sum(residuals²) rss = float(2.0 * cost) else: rss = 0.0 # Handle y_data y_data_np = np.asarray(y_data) if y_data is not None else None n_data = len(y_data_np) if y_data_np is not None else None # OPT-AIC-BIC-001: detect complex-split residuals (length 2N from # concatenated [real, imag]). Only flag when y_data is absent and # residuals are real-typed with even length — callers who pass y_data # get the authoritative n_data from len(y_data) instead. _complex_split = False if n_data is None and residuals is not None: _complex_split = ( not np.iscomplexobj(residuals) and len(residuals) % 2 == 0 and len(residuals) > 2 and y_data_np is not None and np.iscomplexobj(y_data_np) ) # If y_data is complex but residuals are real, residuals were split if y_data_np is not None and np.iscomplexobj(y_data_np): n_data = len(y_data_np) # Otherwise use residual length as-is (no unsafe halving) result = cls( x=popt, fun=rss, jac=None, pcov=pcov, success=bool(success), message=str(message), nit=int(nfev), nfev=int(nfev), njev=int(njev), optimality=None, active_mask=None, cost=float(cost) if cost is not None else rss, grad=None, nlsq_result=None, residuals=residuals, y_data=y_data_np, n_data=n_data, diagnostics=diagnostics, _curve_fit_result=curve_fit_result, _model_fn=model_fn, _x_data=np.asarray(x_data) if x_data is not None else None, _is_complex_split=_complex_split, ) return result
[docs] @classmethod def from_nlsq( cls, nlsq_result: dict[str, Any], residuals: np.ndarray | None = None, y_data: np.ndarray | None = None, compute_covariance: bool = True, ) -> OptimizationResult: """Create OptimizationResult from NLSQ result dictionary. Args: nlsq_result: Result dictionary from nlsq.LeastSquares.least_squares residuals: Optional residual vector for covariance scaling and metrics y_data: Optional original y data for R² computation compute_covariance: Whether to compute the covariance matrix via SVD (default: True). Set False to skip SVD when CIs are not needed. Returns: OptimizationResult instance with fields extracted from NLSQ result """ # Extract common fields x = np.asarray(nlsq_result.get("x", []), dtype=np.float64) # NLSQ v0.6.10: 'fun' is now the residual vector (array), not a scalar. # Always prefer 'cost' (scalar); fall back to sum-of-squares of 'fun'. _cost_raw = nlsq_result.get("cost") if _cost_raw is not None: fun = float(_cost_raw) else: _fun_raw = nlsq_result.get("fun", 0.0) _fun_arr = np.asarray(_fun_raw) fun = float(_fun_arr) if _fun_arr.ndim == 0 else float(np.sum(_fun_arr**2)) success = bool(nlsq_result.get("success", False)) message = str(nlsq_result.get("message", "")) nfev = int(nlsq_result.get("nfev", 0)) njev = int(nlsq_result.get("njev", 0)) # Extract NLSQ-specific fields jac = nlsq_result.get("jac") if jac is not None: jac = np.asarray(jac, dtype=np.float64) grad = nlsq_result.get("grad") if grad is not None: grad = np.asarray(grad, dtype=np.float64) optimality = nlsq_result.get("optimality") if optimality is not None: optimality = float(optimality) active_mask = nlsq_result.get("active_mask") if active_mask is not None: active_mask = np.asarray(active_mask) # Note: NLSQ uses 'nfev' for iterations in some contexts nit = int(nlsq_result.get("nit", nlsq_result.get("nfev", 0))) # OPT-06: Compute covariance from Jacobian only when requested. # Skipping the SVD is a significant speedup for fast fits where # confidence intervals are not needed. pcov = None if jac is not None and compute_covariance: pcov = compute_covariance_from_jacobian(jac, residuals) # Store residuals as numpy array for statistical properties residuals_np = None if residuals is not None: residuals_np = np.asarray(residuals, dtype=np.float64) # Handle scalar residuals (0-d arrays) if residuals_np.ndim == 0: residuals_np = residuals_np.reshape(1) # Store y_data for R² computation y_data_np = None n_data = None _complex_split = False if y_data is not None: y_data_np = np.asarray(y_data) n_data = len(y_data_np) # Detect complex-split residuals: y_data is complex but residuals are real if ( np.iscomplexobj(y_data_np) and residuals_np is not None and not np.iscomplexobj(residuals_np) ): _complex_split = True return cls( x=x, fun=fun, jac=jac, pcov=pcov, success=success, message=message, nit=nit, nfev=nfev, njev=njev, optimality=optimality, active_mask=active_mask, cost=fun, # NLSQ uses 'cost' terminology grad=grad, nlsq_result=nlsq_result, residuals=residuals_np, y_data=y_data_np, n_data=n_data, _is_complex_split=_complex_split, )
[docs] def nlsq_optimize( objective: Callable[[np.ndarray], float | np.ndarray], parameters: ParameterSet, method: str = "auto", use_jax: bool = True, max_iter: int = 1000, ftol: float = 1e-6, xtol: float = 1e-6, gtol: float = 1e-6, # NLSQ 0.6.6 parameters workflow: str = "auto", auto_bounds: bool = False, stability: str | bool = False, fallback: bool = False, compute_diagnostics: bool = False, # OPT-06: make covariance computation lazy to avoid SVD on every fit compute_covariance: bool = True, **kwargs, ) -> OptimizationResult: """Optimize objective function using NLSQ (GPU-accelerated). This function provides GPU-accelerated nonlinear least squares optimization using the NLSQ package. It achieves 5-270x speedup over scipy through JAX JIT compilation and automatic differentiation. The objective function should accept parameter values as a 1D array and return a scalar value (minimization) or vector of residuals (least squares). Args: objective: Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. Should use jax.numpy for operations to enable GPU acceleration and automatic differentiation. parameters: ParameterSet with initial values and bounds method: Optimization method. Options: - "auto": Automatically select based on bounds (default) - "trf": Trust Region Reflective (supports bounds) - "lm": Levenberg-Marquardt (no bounds) - "scipy": Use SciPy's least_squares directly (bypasses NLSQ). Use this for models that use Diffrax ODE solvers which are incompatible with NLSQ's forward-mode autodiff. NLSQ internally selects the best algorithm regardless of this parameter. use_jax: Whether to use JAX for gradient computation (default: True). Should always be True for GPU acceleration and float64 precision. max_iter: Maximum number of iterations (default: 1000) ftol: Function tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ's mixed precision management. xtol: Parameter tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ's mixed precision management. gtol: Gradient tolerance for convergence (default: 1e-6). Relaxed from 1e-8 due to NLSQ's mixed precision management. workflow: NLSQ 0.6.6 workflow selection (default: "auto"): - "auto": Memory-aware local optimization (default) - "auto_global": Global optimization with bounds exploration - "hpc": HPC mode with checkpointing support auto_bounds: Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics. stability: Numerical stability checks (default: False): - 'auto': Check and automatically fix stability issues - 'check': Check and warn but don't fix - False: Skip stability checks fallback: Enable NLSQ's native fallback strategies (default: False). When True, tries alternative approaches if optimization fails. Note: RheoJAX also has its own SciPy fallback independent of this. compute_diagnostics: Compute model health diagnostics (default: False). When True, result.diagnostics includes identifiability analysis, gradient health, and other diagnostic information. compute_covariance: Whether to compute the parameter covariance matrix (default: True). The covariance is derived from an SVD of the Jacobian at the solution. Set False to skip this step when confidence intervals and parameter uncertainties are not needed, avoiding one full SVD per fit. **kwargs: Additional arguments passed to nlsq.LeastSquares.least_squares Returns: OptimizationResult with optimal parameters, convergence info, and optional diagnostics (when compute_diagnostics=True). Raises: ValueError: If objective is not callable or parameters is not ParameterSet Example: >>> from rheojax.core.parameters import ParameterSet >>> params = ParameterSet() >>> params.add("a", value=1.0, bounds=(0, 10)) >>> params.add("b", value=1.0, bounds=(0, 10)) >>> >>> def objective(values): ... a, b = values ... return (a - 5.0) ** 2 + (b - 3.0) ** 2 >>> >>> result = nlsq_optimize(objective, params) >>> print(result.x) # Should be close to [5.0, 3.0] >>> >>> # With NLSQ 0.6.6 features >>> result = nlsq_optimize( ... objective, params, ... workflow="auto_global", # Global optimization ... stability="auto", # Auto-fix stability issues ... compute_diagnostics=True # Get diagnostics ... ) >>> print(result.diagnostics) Notes: - This function automatically handles float64 precision through NLSQ - JAX JIT compilation provides 5-270x speedup over scipy - Automatic differentiation eliminates need for manual Jacobian - Bounds are automatically extracted from ParameterSet - Parameters are updated in-place with optimal values """ # Validate inputs if not callable(objective): raise ValueError("objective must be callable") if not isinstance(parameters, ParameterSet): raise ValueError("parameters must be ParameterSet") # Get initial values and bounds from ParameterSet x0, nlsq_bounds = _extract_bounds(parameters) original_values = x0.copy() # If method='scipy', use SciPy directly (bypasses NLSQ autodiff issues with Diffrax) if method == "scipy": logger.info( "Using SciPy least_squares directly (method='scipy')", n_params=len(x0), ) result = _run_scipy_least_squares( objective, x0, nlsq_bounds, ftol, xtol, gtol, max_iter, compute_covariance=compute_covariance, ) parameters.set_values(result.x) return result logger.info( "Starting NLSQ optimization", n_params=len(x0), method=method, max_iter=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, workflow=workflow, auto_bounds=auto_bounds, stability=stability, fallback=fallback, compute_diagnostics=compute_diagnostics, ) logger.debug( "Initial parameter values", x0=x0.tolist() if hasattr(x0, "tolist") else list(x0), lower_bounds=nlsq_bounds[0].tolist(), upper_bounds=nlsq_bounds[1].tolist(), ) # NLSQ expects a residual function that returns a vector of residuals # The objective function from create_least_squares_objective() now returns # a proper residual vector, so we use it directly # NLSQ will minimize sum(residuals²) internally # Set up NLSQ optimization parameters nlsq_kwargs: dict[str, Any] = { "fun": objective, # Now a residual function returning vector "x0": x0, "bounds": nlsq_bounds, "method": "trf", # Trust Region Reflective (supports bounds) "ftol": ftol, "xtol": xtol, "gtol": gtol, "max_nfev": max_iter * 10, # NLSQ uses max_nfev for iteration limit "verbose": 0, } # Add NLSQ 0.6.6 parameters (feature detection for backward compatibility) # Note: LeastSquares.least_squares may not support all curve_fit params if workflow != "auto": nlsq_kwargs["workflow"] = workflow if auto_bounds: nlsq_kwargs["auto_bounds"] = auto_bounds if stability: nlsq_kwargs["stability"] = stability if fallback: nlsq_kwargs["fallback"] = fallback if compute_diagnostics: nlsq_kwargs["compute_diagnostics"] = compute_diagnostics # Merge with user-provided kwargs, filtering out rheojax-specific ones # that are not valid NLSQ optimizer parameters (prevents TypeError in NLSQ) clean_kwargs = { k: v for k, v in kwargs.items() if k not in _RHEOJAX_RESERVED_KWARGS } nlsq_kwargs.update(clean_kwargs) def _scipy_fallback(initial_guess: np.ndarray) -> OptimizationResult: """Fallback chain when NLSQ fails. For workflow="auto_global": SciPy TRF first, then differential_evolution as the final global-search fallback. For all other workflows: SciPy TRF only. """ logger.info("Using SciPy least_squares fallback") scipy_result = _run_scipy_least_squares( objective, initial_guess, nlsq_bounds, ftol, xtol, gtol, max_iter, compute_covariance=compute_covariance, ) if workflow == "auto_global" and not scipy_result.success: logger.info( "SciPy TRF fallback did not converge; trying differential_evolution" ) de_result = _run_differential_evolution( objective, scipy_result.x, nlsq_bounds, max_iter ) de_result.message = ( f"[DE fallback] {de_result.message} " f"(SciPy TRF: {scipy_result.message})" ) return de_result return scipy_result # Create NLSQ optimizer instance and run optimization try: logger.debug("Creating NLSQ optimizer instance") optimizer = nlsq.LeastSquares() nlsq_result = optimizer.least_squares(**nlsq_kwargs) logger.debug( "NLSQ optimization completed", success=nlsq_result.get("success", False), nfev=nlsq_result.get("nfev", 0), cost=float(nlsq_result.get("cost", 0.0)), ) except ( RuntimeError, TypeError, ValueError, FloatingPointError, OverflowError, ) as e: logger.warning( "NLSQ optimization raised exception, falling back to SciPy", error=str(e), exc_info=True, ) result = _scipy_fallback(x0) result.message = f"[SciPy fallback] {result.message} (NLSQ failed: {e})" # Compute residuals for validation (OPT-001) _residuals_fb_raw = np.asarray(objective(result.x)) if np.iscomplexobj(_residuals_fb_raw): residuals_fb = np.concatenate( [np.real(_residuals_fb_raw), np.imag(_residuals_fb_raw)] ).astype(np.float64) else: residuals_fb = _residuals_fb_raw.astype(np.float64) result.fun = float(np.sum(residuals_fb**2)) # P1-6: Propagate normalization weights for correct R²/AIC/BIC _nw = getattr(objective, "_normalization_weights", None) if _nw is not None: result._normalization_weights = _nw _validate_optimization_result(result, residuals_fb) # Write back optimal params so model state reflects the fit parameters.set_values(result.x) return result # Compute residuals at optimal point for covariance scaling. # P2-Fit-2: Reuse residuals from NLSQ result when available (avoids an # extra objective evaluation that can be expensive for ODE models). x_opt = np.asarray(nlsq_result.get("x", x0), dtype=np.float64) # NLSQ v0.6.10: 'fun' is the residual vector (array); 'fun_vector' and # 'residuals' keys were removed. Prefer the cached residual from the result # dict to avoid an expensive extra objective evaluation (especially costly # for ODE-based models). _cached_fun = nlsq_result.get("fun_vector", nlsq_result.get("residuals")) if _cached_fun is None: _fun_field = nlsq_result.get("fun") if _fun_field is not None and np.asarray(_fun_field).ndim >= 1: _cached_fun = _fun_field if _cached_fun is not None: residuals_raw = np.asarray(_cached_fun) else: residuals_raw = np.asarray(objective(x_opt)) if np.iscomplexobj(residuals_raw): residuals_np = np.concatenate( [np.real(residuals_raw), np.imag(residuals_raw)] ).astype(np.float64) else: residuals_np = residuals_raw.astype(np.float64) # Convert NLSQ result to OptimizationResult (with residuals for covariance) # OPT-06: forward compute_covariance so SVD is skipped when not requested. result = OptimizationResult.from_nlsq( nlsq_result, residuals=residuals_np, compute_covariance=compute_covariance, ) # P1-6: Propagate normalization weights from the objective closure so that # R²/AIC/BIC can un-normalize residuals for correct statistics. _nw = getattr(objective, "_normalization_weights", None) if _nw is not None: result._normalization_weights = _nw # Store diagnostics if available (NLSQ 0.6.6+) if hasattr(nlsq_result, "diagnostics") or "diagnostics" in nlsq_result: result.diagnostics = nlsq_result.get( "diagnostics", getattr(nlsq_result, "diagnostics", None) ) # P2-Fit-9: Trigger SciPy fallback on ANY NLSQ failure when fallback is # enabled, not just the "inner optimization loop exceeded" message. if not result.success and fallback: logger.warning( "NLSQ did not converge; retrying with SciPy least_squares for stability.", nlsq_message=result.message, ) # OPT-003: warm-start from NLSQ's best result, not stale x0 fallback_result = _scipy_fallback(x_opt) fallback_result.message = ( f"[SciPy fallback] {fallback_result.message} (NLSQ inner loop limit)" ) # OPT-004: validate the fallback result _residuals_fb_raw = np.asarray(objective(fallback_result.x)) if np.iscomplexobj(_residuals_fb_raw): residuals_fb = np.concatenate( [np.real(_residuals_fb_raw), np.imag(_residuals_fb_raw)] ).astype(np.float64) else: residuals_fb = _residuals_fb_raw.astype(np.float64) fallback_result.fun = float(np.sum(residuals_fb**2)) _validate_optimization_result(fallback_result, residuals_fb) parameters.set_values(fallback_result.x) return fallback_result # Ensure x is float64 result.x = np.asarray(result.x, dtype=np.float64) # Compute RSS = sum(residuals²) # OPT-008: Use np.sum instead of jnp.sum on numpy array to avoid host-device round-trip result.fun = float(np.sum(residuals_np**2)) # Guard against false "success" with astronomically large residuals try: _validate_optimization_result(result, residuals_np) except RuntimeError: parameters.set_values(original_values) raise # Update ParameterSet with optimal values parameters.set_values(result.x) logger.info( "Optimization completed successfully", success=result.success, rss=float(result.fun), nfev=result.nfev, nit=result.nit, r_squared=result.r_squared, ) logger.debug( "Final parameter values", x_opt=result.x.tolist(), message=result.message, ) return result
[docs] def nlsq_multistart_optimize( objective: Callable[[np.ndarray], float | np.ndarray], parameters: ParameterSet, n_starts: int = 5, perturb_factor: float = 0.3, method: str = "auto", use_jax: bool = True, max_iter: int = 1000, ftol: float = 1e-6, xtol: float = 1e-6, gtol: float = 1e-6, verbose: bool = False, parallel: bool = True, n_workers: int | None = None, y_data: np.ndarray | None = None, **kwargs, ) -> OptimizationResult: """Multi-start optimization to escape local minima. For complex objective functions (e.g., mastercurves with 10+ decades), single optimization runs may converge to poor local minima even from good initial guesses. This function performs multiple optimization runs from different starting points and returns the best result. Strategy: 1. First attempt: Use current parameter values (from smart initialization) 2. Additional attempts: Random perturbations around initial values (parallel) 3. Return result with lowest final cost (best fit) Performance: With parallel=True (default), achieves 2-4x speedup for 5-10 starts by running optimizations concurrently. JAX releases the GIL during computation, enabling effective thread-based parallelism. Args: objective: Objective function to minimize parameters: ParameterSet with initial values and bounds n_starts: Number of random starts (default: 5) perturb_factor: Perturbation factor for random starts (default: 0.3) Parameters are perturbed by ± perturb_factor * (value or range) method: Optimization method (default: "auto") use_jax: Whether to use JAX (default: True) max_iter: Max iterations per start (default: 1000) ftol: Function tolerance (default: 1e-6) xtol: Parameter tolerance (default: 1e-6) gtol: Gradient tolerance (default: 1e-6) verbose: Print progress messages (default: False) parallel: Run additional starts in parallel (default: True) n_workers: Number of parallel workers (default: min(n_starts-1, 4)) **kwargs: Additional arguments for nlsq_optimize Returns: OptimizationResult with best parameters from all starts Example: >>> # For mastercurve data (12+ decades) >>> result = nlsq_multistart_optimize( ... objective, parameters, n_starts=5, verbose=True ... ) >>> print(f"Best cost: {result.fun:.3e}") """ import os from concurrent.futures import ThreadPoolExecutor, as_completed # Store original parameter values original_values = parameters.get_values() bounds_list = parameters.get_bounds() param_names = list(parameters.keys()) logger.info( "Starting multi-start optimization", n_starts=n_starts, perturb_factor=perturb_factor, n_params=len(original_values), parallel=parallel, ) # First attempt: Use smart initialization values (sequential) if verbose: logger.info("Multi-start optimization: Attempt 1 (smart initialization)") best_result = nlsq_optimize( objective, parameters, method=method, use_jax=use_jax, max_iter=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, **kwargs, ) best_cost = best_result.fun logger.debug( "First attempt completed", cost=float(best_cost), success=best_result.success, ) if verbose: logger.info(f" Cost: {best_cost:.3e}, Success: {best_result.success}") # If only 1 start requested, return early if n_starts <= 1: return best_result # Generate all perturbed starting points def generate_perturbed_values(rng: np.random.Generator) -> list[float]: """Generate perturbed initial values using the provided RNG.""" perturbed = [] for orig_val, bounds in zip(original_values, bounds_list, strict=True): if bounds is None or (bounds[0] is None and bounds[1] is None): perturbation = rng.uniform(-perturb_factor, perturb_factor) if abs(orig_val) < 1e-30: # Additive perturbation for zero-valued parameters new_val = perturbation else: new_val = orig_val * (1.0 + perturbation) else: lower = bounds[0] if bounds[0] is not None else orig_val - abs(orig_val) upper = bounds[1] if bounds[1] is not None else orig_val + abs(orig_val) range_size = upper - lower perturbation = rng.uniform( -perturb_factor * range_size, perturb_factor * range_size ) new_val = np.clip(orig_val + perturbation, lower, upper) perturbed.append(new_val) return perturbed def run_single_optimization( start_idx: int, initial_values: list[float] ) -> tuple[int, OptimizationResult | None]: """Run a single optimization from given starting point.""" # Create a fresh ParameterSet copy for thread-safe operation. # Use initial_values[i] as the starting value for each parameter instead # of a hardcoded 0.0 — a zero initializer triggers a spurious clamping # RuntimeWarning for parameters whose lower bound is > 0 (e.g. (1e-3, 1e6)). # The perturbed values are always within bounds (generated by np.clip) so # no clamping occurs and the separate set_values() call is unnecessary. params_copy = ParameterSet() for name, bounds, init_val in zip( param_names, bounds_list, initial_values, strict=True ): # Handle None values in bounds bounds_tuple: tuple[float, float] | None = None if bounds is not None: if bounds[0] is not None and bounds[1] is not None: bounds_tuple = (float(bounds[0]), float(bounds[1])) params_copy.add(name=name, value=float(init_val), bounds=bounds_tuple) try: result = nlsq_optimize( objective, params_copy, method=method, use_jax=use_jax, max_iter=max_iter, ftol=ftol, xtol=xtol, gtol=gtol, **kwargs, ) return start_idx, result except Exception as e: logger.warning( "Multi-start attempt failed", attempt=start_idx + 1, error=str(e), ) return start_idx, None # Prepare all starting points root_rng = np.random.default_rng(seed=42) all_starts = [ generate_perturbed_values(np.random.default_rng(root_rng.integers(2**31))) for _ in range(1, n_starts) ] if parallel and n_starts > 2: # Parallel execution for additional starts if n_workers is None: n_workers = min(n_starts - 1, min(4, os.cpu_count() or 1)) logger.debug( "Running parallel multi-start", n_workers=n_workers, n_additional_starts=n_starts - 1, ) with ThreadPoolExecutor(max_workers=n_workers) as executor: futures = { executor.submit(run_single_optimization, i, starts): i for i, starts in enumerate(all_starts, start=1) } for future in as_completed(futures): start_idx, result = future.result() if result is not None: logger.debug( "Multi-start attempt completed", attempt=start_idx + 1, cost=float(result.fun), success=result.success, ) if verbose: logger.info( f" Attempt {start_idx + 1}: Cost: {result.fun:.3e}, " f"Success: {result.success}" ) if result.fun < best_cost: best_result = result best_cost = result.fun logger.debug( "New best result found", attempt=start_idx + 1, best_cost=float(best_cost), ) if verbose: logger.info(f" -> New best! Cost: {best_cost:.3e}") else: if verbose: logger.warning(f" Attempt {start_idx + 1} failed") else: # Sequential execution (original behavior) for i, perturbed_values in enumerate(all_starts, start=1): logger.debug("Starting multi-start attempt", attempt=i + 1, total=n_starts) if verbose: logger.info( f"Multi-start optimization: Attempt {i + 1} (random perturbation)" ) start_idx, result = run_single_optimization(i, perturbed_values) if result is not None: logger.debug( "Multi-start attempt completed", attempt=i + 1, cost=float(result.fun), success=result.success, ) if verbose: logger.info(f" Cost: {result.fun:.3e}, Success: {result.success}") if result.fun < best_cost: best_result = result best_cost = result.fun logger.debug( "New best result found", attempt=i + 1, best_cost=float(best_cost), ) if verbose: logger.info(f" -> New best! Cost: {best_cost:.3e}") else: if verbose: logger.warning(f" Attempt {i + 1} failed") # Restore best parameters parameters.set_values(best_result.x) if y_data is not None and best_result.y_data is None: best_result.y_data = np.asarray(y_data) # R8-OPT-002: also set n_data for AIC/BIC/adj_R² computation if y_data is not None and getattr(best_result, "n_data", None) is None: best_result.n_data = len(np.asarray(y_data)) logger.info( "Multi-start optimization completed", best_cost=float(best_cost), n_starts=n_starts, final_success=best_result.success, parallel=parallel, ) if verbose: logger.info( f"\nMulti-start completed: Best cost = {best_cost:.3e} " f"({n_starts} starts, parallel={parallel})" ) return best_result
def nlsq_optimize_global( objective: Callable[[np.ndarray], float | np.ndarray], parameters: ParameterSet, **kwargs, ) -> OptimizationResult: """Global optimization using NLSQ 0.6.6 workflow='auto_global'. Convenience function for global optimization that explores parameter space more thoroughly using the NLSQ 0.6.6 global optimization workflow. Args: objective: Objective function to minimize. Takes parameter values as array and returns scalar or residual vector. parameters: ParameterSet with initial values and bounds **kwargs: Additional arguments passed to nlsq_optimize Returns: OptimizationResult with optimal parameters from global search Example: >>> from rheojax.core.parameters import ParameterSet >>> params = ParameterSet() >>> params.add("a", value=1.0, bounds=(0, 10)) >>> params.add("b", value=1.0, bounds=(0, 10)) >>> >>> result = nlsq_optimize_global(objective, params) >>> print(f"Global optimum: {result.x}") Notes: - Uses workflow='auto_global' for bounds-aware global exploration - More thorough but slower than standard local optimization - Useful for multi-modal objective functions """ return nlsq_optimize( objective, parameters, workflow="auto_global", **kwargs, )
[docs] def nlsq_curve_fit( model_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], x_data: np.ndarray, y_data: np.ndarray, parameters: ParameterSet, auto_bounds: bool = False, stability: str | bool = False, fallback: bool = False, compute_diagnostics: bool = False, multistart: bool = False, n_starts: int = 10, workflow: str = "auto", **kwargs, ) -> OptimizationResult: """Curve fitting using NLSQ 0.6.6 curve_fit() API with advanced features. This function provides access to NLSQ 0.6.6's enhanced curve_fit() features including auto-bounds, stability checks, fallback strategies, model diagnostics, and workflow selection. It returns an OptimizationResult with CurveFitResult-compatible statistical properties (r_squared, rmse, aic, bic, prediction_interval, etc.). Args: model_fn: Model function f(x, params_array) -> y_pred. Takes x_data and parameter array, returns predictions. x_data: Independent variable data y_data: Dependent variable data (observations) parameters: ParameterSet with initial values and bounds auto_bounds: Enable automatic parameter bounds inference (default: False). When True, reasonable bounds are inferred based on data characteristics. stability: Numerical stability checks (default: False). - 'auto': Check and automatically fix stability issues - 'check': Check and warn but don't fix - False: Skip stability checks fallback: Enable automatic fallback strategies (default: False). When True, tries alternative approaches if optimization fails. compute_diagnostics: Compute model health diagnostics (default: False). When True, result includes identifiability analysis, gradient health, etc. multistart: Enable multi-start optimization (default: False). When True, explores multiple starting points to find global optimum. n_starts: Number of starting points for multi-start (default: 10). workflow: NLSQ 0.6.6 workflow selection (default: "auto"): - "auto": Memory-aware local optimization (default) - "auto_global": Global optimization with bounds exploration - "hpc": HPC mode with checkpointing support **kwargs: Additional arguments passed to nlsq.curve_fit() Returns: OptimizationResult with CurveFitResult-compatible statistical properties: - r_squared, adj_r_squared, rmse, mae, aic, bic - confidence_intervals(alpha) method - prediction_interval(x_new, alpha) method (NLSQ 0.6.6 native) - get_parameter_uncertainties() method Example: >>> from rheojax.core.parameters import ParameterSet >>> from rheojax.utils.optimization import nlsq_curve_fit >>> >>> def model(x, params): ... a, b = params ... return a * np.exp(-b * x) >>> >>> params = ParameterSet() >>> params.add("a", value=1.0, bounds=(0, 10)) >>> params.add("b", value=0.5, bounds=(0, 5)) >>> >>> result = nlsq_curve_fit( ... model, x_data, y_data, params, ... auto_bounds=True, ... stability='auto', ... fallback=True, ... compute_diagnostics=True, ... ) >>> print(f"R² = {result.r_squared:.4f}") >>> print(f"RMSE = {result.rmse:.4f}") >>> ci = result.confidence_intervals(alpha=0.95) >>> >>> # Prediction intervals (NLSQ 0.6.6) >>> pi = result.prediction_interval(x_new, alpha=0.95) >>> print(f"95% PI: [{pi[0, 0]:.3f}, {pi[0, 1]:.3f}]") Notes: - This function uses nlsq.curve_fit() directly (not LeastSquares.least_squares()) - The model function signature is ``f(x, params_array)`` not ``f(x, *params)`` - Results delegate to native CurveFitResult for prediction_interval() calls - Results include all CurveFitResult properties for model comparison """ import nlsq as nlsq_module logger.info( "Starting curve fit", n_params=len(parameters), n_data=len(x_data), auto_bounds=auto_bounds, stability=stability, multistart=multistart, workflow=workflow, ) # Extract p0 and bounds from ParameterSet p0, (lower, upper) = _extract_bounds(parameters) # Convert x_data and y_data to numpy arrays x_data_np = np.asarray(x_data, dtype=np.float64) y_data_np = np.asarray(y_data) # Preserve complex type if present # Create wrapper function f(x, *params) -> y for nlsq.curve_fit # NLSQ curve_fit expects f(x, p0, p1, ...) not f(x, params_array) def f_wrapper(x, *params_tuple): params_array = jnp.asarray(params_tuple, dtype=jnp.float64) return model_fn(x, params_array) # Build kwargs for nlsq.curve_fit curve_fit_kwargs: dict[str, Any] = { "p0": p0, "bounds": (lower, upper), "auto_bounds": auto_bounds, "stability": stability, "fallback": fallback, "compute_diagnostics": compute_diagnostics, "multistart": multistart, "n_starts": n_starts, } # Add workflow parameter (NLSQ 0.6.6) if workflow != "auto": curve_fit_kwargs["workflow"] = workflow # OPT-002: Filter RheoJAX-specific kwargs before forwarding to nlsq.curve_fit clean_kwargs = { k: v for k, v in kwargs.items() if k not in _RHEOJAX_RESERVED_KWARGS } curve_fit_kwargs.update(clean_kwargs) try: # Call nlsq.curve_fit() - returns CurveFitResult (tuple unpacking compatible) curve_fit_result = nlsq_module.curve_fit( f_wrapper, x_data_np, y_data_np, **curve_fit_kwargs ) # CurveFitResult supports both tuple unpacking and attribute access # We need to check if it's a tuple (popt, pcov) or CurveFitResult object is_curve_fit_result = not isinstance(curve_fit_result, tuple) if isinstance(curve_fit_result, tuple): popt, pcov = curve_fit_result else: # It's a CurveFitResult object popt = np.asarray(curve_fit_result.popt) pcov = curve_fit_result.pcov # Compute residuals and y_pred at optimal point (for complex data handling) y_pred = model_fn(x_data_np, popt) y_pred_np = np.asarray(y_pred) # Compute residuals (handle complex data) if np.iscomplexobj(y_data_np): if np.iscomplexobj(y_pred_np): # Both complex: residuals for real and imaginary parts residuals_real = np.real(y_data_np) - np.real(y_pred_np) residuals_imag = np.imag(y_data_np) - np.imag(y_pred_np) residuals = np.concatenate([residuals_real, residuals_imag]) else: # Complex data, real pred: use magnitude residuals = np.abs(y_data_np) - y_pred_np else: if np.iscomplexobj(y_pred_np): # Real data, complex pred: use magnitude of pred residuals = y_data_np - np.abs(y_pred_np) else: # Both real residuals = y_data_np - y_pred_np # Create OptimizationResult - use factory for native CurveFitResult delegation if is_curve_fit_result and hasattr(curve_fit_result, "prediction_interval"): # NLSQ 0.6.6+ CurveFitResult with native property delegation result = OptimizationResult.from_curve_fit_result( curve_fit_result, y_data=y_data_np, model_fn=model_fn, x_data=x_data_np, ) # Override residuals for complex data handling result.residuals = residuals result.fun = float(np.sum(residuals**2)) result.cost = result.fun else: # Legacy tuple result or no native delegation success = ( True if isinstance(curve_fit_result, tuple) else getattr(curve_fit_result, "success", True) ) message = ( "Optimization converged successfully" if isinstance(curve_fit_result, tuple) else getattr(curve_fit_result, "message", "Converged") ) nfev = ( 0 if isinstance(curve_fit_result, tuple) else getattr(curve_fit_result, "nfev", 0) ) njev = ( 0 if isinstance(curve_fit_result, tuple) else getattr(curve_fit_result, "njev", 0) ) diagnostics = ( None if isinstance(curve_fit_result, tuple) else getattr(curve_fit_result, "diagnostics", None) ) result = OptimizationResult( x=np.asarray(popt, dtype=np.float64), fun=float(np.sum(residuals**2)), jac=None, # curve_fit doesn't return Jacobian directly pcov=np.asarray(pcov, dtype=np.float64) if pcov is not None else None, success=bool(success), message=str(message), nit=int(nfev), nfev=int(nfev), njev=int(njev), optimality=None, active_mask=None, cost=float(np.sum(residuals**2)), grad=None, nlsq_result=None, residuals=residuals, y_data=y_data_np, # UTIL-011: Use len(y_data_np) not len(residuals) — for complex data, # residuals is 2N (real + imag parts concatenated) while actual # observations are N. AIC/BIC must use the observation count N. n_data=len(y_data_np), diagnostics=diagnostics, ) # Update ParameterSet with optimal values parameters.set_values(result.x) logger.info( "Curve fit completed successfully", r_squared=result.r_squared, rmse=result.rmse, success=result.success, ) logger.debug( "Curve fit results", popt=result.x.tolist(), rss=float(result.fun), aic=result.aic, bic=result.bic, ) return result except (RuntimeError, ValueError, FloatingPointError, OverflowError) as e: logger.warning( "nlsq.curve_fit() failed, falling back to nlsq_optimize", error=str(e), exc_info=True, ) # Fallback to nlsq_optimize with residual-based objective objective = create_least_squares_objective(model_fn, x_data_np, y_data_np) if multistart: result = nlsq_multistart_optimize( objective, parameters, n_starts=n_starts, **kwargs ) else: result = nlsq_optimize(objective, parameters, **kwargs) # Annotate that this result came from a curve_fit→nlsq_optimize fallback result.message = ( f"[curve_fit→nlsq_optimize fallback] {getattr(result, 'message', '')}" ) # Preserve y_data for R² calculation (not set by nlsq_optimize fallback) result.y_data = y_data_np result._model_fn = model_fn result._x_data = x_data_np return result
[docs] def optimize_with_bounds( objective: Callable[[np.ndarray], float | np.ndarray], x0: np.ndarray, bounds: list[tuple[float | None, float | None]], use_jax: bool = True, **kwargs, ) -> OptimizationResult: """Optimize objective function with parameter bounds. Lower-level optimization function that works with arrays instead of ParameterSet. Useful for custom optimization workflows. Args: objective: Objective function to minimize x0: Initial parameter values bounds: List of (min, max) tuples for each parameter use_jax: Whether to use JAX for gradients (default: True) **kwargs: Additional arguments passed to nlsq_optimize Returns: OptimizationResult with optimal parameters Example: >>> def objective(x): ... return x[0]**2 + x[1]**2 >>> result = optimize_with_bounds( ... objective, ... x0=np.array([1.0, 1.0]), ... bounds=[(0, 5), (0, 5)] ... ) """ # Create temporary ParameterSet for interface consistency params = ParameterSet() for i, (val, bound) in enumerate(zip(x0, bounds, strict=True)): # Handle None values in bounds bounds_tuple: tuple[float, float] | None = None if bound is not None: if bound[0] is not None and bound[1] is not None: bounds_tuple = (float(bound[0]), float(bound[1])) params.add(name=f"p{i}", value=val, bounds=bounds_tuple) # Use main optimization function return nlsq_optimize(objective, params, use_jax=use_jax, **kwargs)
def fit_with_nlsq( residual_fn: Callable[[np.ndarray], np.ndarray], x0: np.ndarray, bounds: tuple[np.ndarray, np.ndarray] | None = None, **kwargs, ) -> OptimizationResult: """Fit using nonlinear least squares with residual function. Convenience function for fitting models using a residual function that takes parameter array and returns residual vector. Args: residual_fn: Function that takes parameter array and returns residuals x0: Initial parameter values as 1D array bounds: Tuple of (lower, upper) bound arrays, or None for unbounded **kwargs: Additional arguments passed to optimize_with_bounds Returns: OptimizationResult with optimal parameters in .x attribute """ # Convert bounds format: (lower_array, upper_array) -> list of tuples if bounds is not None: lower, upper = bounds bounds_list: list[tuple[float | None, float | None]] = [ (float(lo), float(hi)) for lo, hi in zip(lower, upper, strict=True) ] else: bounds_list = [(None, None)] * len(x0) return optimize_with_bounds(residual_fn, x0, bounds_list, **kwargs)
[docs] def residual_sum_of_squares( y_true: ArrayLike, y_pred: ArrayLike, normalize: bool = True ) -> float: """Compute residual sum of squares (RSS). Handles both real and complex data correctly. For complex data (e.g., oscillatory shear with G' + iG"), computes RSS for both real and imaginary parts separately and returns the sum. Args: y_true: True values (real or complex) y_pred: Predicted values (real or complex) normalize: Whether to normalize by y_true (relative error) Returns: RSS value (scalar, maintains float64 precision) Example: >>> y_true = np.array([1.0, 2.0, 3.0]) >>> y_pred = np.array([1.1, 2.1, 2.9]) >>> rss = residual_sum_of_squares(y_true, y_pred) """ # Use JAX operations if inputs are JAX arrays for gradient support # OPT-013: Use hasattr check instead of isinstance(x, jnp.ndarray) if hasattr(y_pred, "devices") or hasattr(y_true, "devices"): # Convert to JAX arrays (preserving complex type) y_true_jax = jnp.asarray(y_true) y_pred_jax = jnp.asarray(y_pred) # Check if data is complex y_true_is_complex = jnp.iscomplexobj(y_true_jax) y_pred_is_complex = jnp.iscomplexobj(y_pred_jax) if y_pred_is_complex: if y_true_is_complex: # Both complex: fit real and imaginary parts separately residuals_real = jnp.real(y_pred_jax) - jnp.real(y_true_jax) residuals_imag = jnp.imag(y_pred_jax) - jnp.imag(y_true_jax) if normalize: residuals_real = residuals_real / jnp.maximum( jnp.abs(jnp.real(y_true_jax)), 1e-10 ) residuals_imag = residuals_imag / jnp.maximum( jnp.abs(jnp.imag(y_true_jax)), 1e-10 ) rss = jnp.sum(residuals_real**2) + jnp.sum(residuals_imag**2) else: # Complex predictions, real data: fit to magnitude y_pred_magnitude = jnp.abs(y_pred_jax) y_true_jax = jnp.asarray(y_true_jax, dtype=jnp.float64) residuals = y_pred_magnitude - y_true_jax if normalize: residuals = residuals / jnp.maximum(jnp.abs(y_true_jax), 1e-10) rss = jnp.sum(residuals**2) else: # Real predictions if y_true_is_complex: # Real predictions, complex data: fit to magnitude of data y_true_magnitude = jnp.abs(y_true_jax) y_pred_jax = jnp.asarray(y_pred_jax, dtype=jnp.float64) residuals = y_pred_jax - y_true_magnitude if normalize: residuals = residuals / jnp.maximum(y_true_magnitude, 1e-10) rss = jnp.sum(residuals**2) else: # Both real: standard case y_true_jax = jnp.asarray(y_true_jax, dtype=jnp.float64) y_pred_jax = jnp.asarray(y_pred_jax, dtype=jnp.float64) residuals = y_pred_jax - y_true_jax if normalize: residuals = residuals / jnp.maximum(jnp.abs(y_true_jax), 1e-10) rss = jnp.sum(residuals**2) # Return scalar JAX array, don't convert to Python float (breaks gradients) return rss else: # NumPy path y_true_np = np.asarray(y_true) y_pred_np = np.asarray(y_pred) # Check if data is complex y_true_is_complex = np.iscomplexobj(y_true_np) y_pred_is_complex = np.iscomplexobj(y_pred_np) if y_pred_is_complex: if y_true_is_complex: # Both complex: fit real and imaginary parts separately residuals_real = np.real(y_pred_np) - np.real(y_true_np) residuals_imag = np.imag(y_pred_np) - np.imag(y_true_np) if normalize: with np.errstate(divide="ignore", invalid="ignore"): residuals_real = residuals_real / np.maximum( np.abs(np.real(y_true_np)), 1e-10 ) residuals_imag = residuals_imag / np.maximum( np.abs(np.imag(y_true_np)), 1e-10 ) rss = float(np.sum(residuals_real**2) + np.sum(residuals_imag**2)) else: # Complex predictions, real data: fit to magnitude y_pred_magnitude = np.abs(y_pred_np) residuals = y_pred_magnitude - y_true_np if normalize: with np.errstate(divide="ignore", invalid="ignore"): residuals = residuals / np.maximum(np.abs(y_true_np), 1e-10) rss = float(np.sum(residuals**2)) else: # Real predictions if y_true_is_complex: # Real predictions, complex data: fit to magnitude of data y_true_magnitude = np.abs(y_true_np) residuals = y_pred_np - y_true_magnitude if normalize: with np.errstate(divide="ignore", invalid="ignore"): residuals = residuals / np.maximum(y_true_magnitude, 1e-10) rss = float(np.sum(residuals**2)) else: # Both real: standard case y_true_np = np.asarray(y_true_np, dtype=np.float64) y_pred_np = np.asarray(y_pred_np, dtype=np.float64) residuals = y_pred_np - y_true_np if normalize: with np.errstate(divide="ignore", invalid="ignore"): residuals = residuals / np.maximum(np.abs(y_true_np), 1e-10) rss = float(np.sum(residuals**2)) return rss
[docs] class ResidualFunction: """Callable wrapper for residual functions that carries normalization metadata. This replaces the fragile pattern of attaching ``_normalization_weights`` as a function attribute (which breaks if the function is wrapped by decorators, ``functools.wraps``, ``jax.jit``, etc.). The class is fully transparent to callers — it behaves like a plain function but safely exposes the weights. """ __slots__ = ("_fn", "_normalization_weights")
[docs] def __init__( self, fn: Callable[[np.ndarray], np.ndarray], normalization_weights: np.ndarray | None = None, ) -> None: self._fn = fn self._normalization_weights = normalization_weights
def __call__(self, params: np.ndarray) -> np.ndarray: return self._fn(params)
[docs] def create_least_squares_objective( model_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], x_data: np.ndarray, y_data: np.ndarray, normalize: bool = True, use_log_residuals: bool = False, ) -> ResidualFunction: """Create residual function for NLSQ least-squares fitting. IMPORTANT: This now returns a RESIDUAL FUNCTION (vector output), not a scalar objective. NLSQ minimizes sum(residuals²), so this provides per-point residuals to the optimizer, which enables proper gradient computation and weighting. For complex data (e.g., G* = G' + iG"), returns stacked real and imaginary residuals: [real₁, ..., real_n, imag₁, ..., imag_n] with shape (2N,). For real data, returns residuals with shape (N,). **Log-space residuals (NEW)**: For rheological data spanning many decades (e.g., mastercurves with 8+ decades), use `use_log_residuals=True` to compute residuals in log10 space. This gives equal weight to all frequency ranges and prevents optimizer bias toward high-modulus regions. Args: model_fn: Model function that takes (x_data, parameters) and returns predictions x_data: Independent variable data y_data: Dependent variable data (observations, may be complex) normalize: Whether to use relative error (default: True) use_log_residuals: Whether to compute residuals in log10 space (default: False). Recommended for data spanning >8 decades. Formula: ``residual = log10(abs(y_pred)) - log10(abs(y_data))`` Returns: Residual function that takes parameters and returns residual vector Example: >>> def linear_model(x, params): ... a, b = params ... return a * x + b >>> x = np.array([1, 2, 3, 4, 5]) >>> y = np.array([2.1, 4.0, 5.9, 8.1, 10.0]) >>> residual_fn = create_least_squares_objective(linear_model, x, y) >>> # Now use with nlsq_optimize - it receives proper residual vector >>> >>> # For mastercurve data (wide frequency range): >>> residual_fn_log = create_least_squares_objective( ... model_fn, omega, G_star, use_log_residuals=True ... ) """ # Convert to JAX arrays and detect if complex x_data_jax = jnp.asarray(x_data, dtype=jnp.float64) # Preserve complex type for y_data y_data_is_complex = jnp.iscomplexobj(y_data) or np.iscomplexobj(y_data) if y_data_is_complex: y_data_jax = jnp.asarray(y_data, dtype=jnp.complex128) else: y_data_jax = jnp.asarray(y_data, dtype=jnp.float64) # P2-Fit-1: Compute a relative normalization floor from the overall data # magnitude. At the elastic plateau G'' can be orders of magnitude below # G', so an absolute floor of 1e-10 Pa biases the optimizer. Using a # relative floor (1e-10 * max|y|) keeps the weighting proportional. _norm_floor = jnp.float64(1e-10) if normalize and not use_log_residuals: _max_modulus = jnp.max(jnp.abs(y_data_jax)) _norm_floor = jnp.maximum(jnp.float64(1e-10), jnp.float64(1e-10) * _max_modulus) # OPT-03: Determine the model's output format once at construction time # using jax.eval_shape (zero-cost abstract evaluation, no actual compute). # The three static flags below replace the per-call jnp.iscomplexobj / # y_pred.ndim checks that were executed on every optimizer iteration. # Fallback to dynamic dispatch if eval_shape fails (e.g. model has # side-effects or requires concrete values). _static_dispatch: str | None = None # "2d", "complex", "real", or None (dynamic) _static_y_data_is_2d: bool = y_data_jax.ndim == 2 and y_data_jax.shape[-1] == 2 try: # Use a single-element x probe and a single-element params probe. # eval_shape traces abstractly — values are irrelevant, only dtypes # and shapes propagate. The 1-element params probe may not match the # real parameter count, so only the output *dtype* and *ndim* are # trusted from this probe (not the full output shape). This is # sufficient for the dispatch decision (complex vs real vs 2D). _x_probe = ( x_data_jax[:1] if x_data_jax.ndim >= 1 and x_data_jax.shape[0] > 0 else x_data_jax ) _p_probe = jnp.zeros(1, dtype=jnp.float64) _out_shape = jax.eval_shape(model_fn, _x_probe, _p_probe) _out_is_complex = jnp.issubdtype(_out_shape.dtype, jnp.complexfloating) _out_is_2d = _out_shape.ndim == 2 and _out_shape.shape[-1] == 2 if _out_is_2d: _static_dispatch = "2d" elif _out_is_complex: _static_dispatch = "complex" else: _static_dispatch = "real" except Exception: # Dynamic dispatch fallback: eval_shape failed (model has side-effects, # requires concrete values, or uses Python control flow on shape). _static_dispatch = None def residuals(params: np.ndarray) -> np.ndarray: """Compute residual vector for all data points.""" # OPT-02: Avoid unconditional host→device transfer on every iteration. # NLSQ passes NumPy arrays; JAX's own gradient passes jax.Array. # isinstance check is a Python-level no-op for JAX arrays. if not isinstance(params, jax.Array): params_jax = jnp.asarray(params, dtype=jnp.float64) else: params_jax = params # Get model predictions y_pred = model_fn(x_data_jax, params_jax) # OPT-03: Use statically-determined dispatch path baked at closure # construction time. Falls back to the original dynamic check when # eval_shape could not determine the output format. if _static_dispatch is not None: y_pred_is_2d = _static_dispatch == "2d" y_pred_is_complex = _static_dispatch == "complex" else: # Dynamic fallback (only reached when eval_shape failed) y_pred_is_complex = jnp.iscomplexobj(y_pred) y_pred_is_2d = y_pred.ndim == 2 and y_pred.shape[-1] == 2 if y_pred_is_2d: # Case 1: 2D [G', G"] format (e.g., FractionalZenerSolidSolid) if y_data_is_complex: # Fit to real and imaginary parts separately if use_log_residuals: # Log-space residuals: abs() intentionally strips sign because # G' and G'' are physically positive; normalize is implicit resid_real = jnp.log10( jnp.maximum(jnp.abs(y_pred[:, 0]), 1e-20) ) - jnp.log10(jnp.maximum(jnp.abs(jnp.real(y_data_jax)), 1e-20)) resid_imag = jnp.log10( jnp.maximum(jnp.abs(y_pred[:, 1]), 1e-20) ) - jnp.log10(jnp.maximum(jnp.abs(jnp.imag(y_data_jax)), 1e-20)) else: resid_real = y_pred[:, 0] - jnp.real(y_data_jax) resid_imag = y_pred[:, 1] - jnp.imag(y_data_jax) if normalize: resid_real = resid_real / jnp.maximum( jnp.abs(jnp.real(y_data_jax)), _norm_floor ) resid_imag = resid_imag / jnp.maximum( jnp.abs(jnp.imag(y_data_jax)), _norm_floor ) return jnp.concatenate([resid_real, resid_imag]) else: if _static_y_data_is_2d: # Both (N, 2): fit both columns independently (stacked residuals) if use_log_residuals: # Log-space residuals: abs() intentionally strips sign because # G' and G'' are physically positive; normalize is implicit resid_col0 = jnp.log10( jnp.maximum(jnp.abs(y_pred[:, 0]), 1e-20) ) - jnp.log10(jnp.maximum(jnp.abs(y_data_jax[:, 0]), 1e-20)) resid_col1 = jnp.log10( jnp.maximum(jnp.abs(y_pred[:, 1]), 1e-20) ) - jnp.log10(jnp.maximum(jnp.abs(y_data_jax[:, 1]), 1e-20)) else: resid_col0 = y_pred[:, 0] - y_data_jax[:, 0] resid_col1 = y_pred[:, 1] - y_data_jax[:, 1] if normalize: resid_col0 = resid_col0 / jnp.maximum( jnp.abs(y_data_jax[:, 0]), _norm_floor ) resid_col1 = resid_col1 / jnp.maximum( jnp.abs(y_data_jax[:, 1]), _norm_floor ) return jnp.concatenate([resid_col0, resid_col1]) else: # (N, 2) pred, (N,) data: fit to magnitude |G*| y_pred_magnitude = jnp.sqrt( y_pred[:, 0] ** 2 + y_pred[:, 1] ** 2 + 1e-30 ) _resid = y_pred_magnitude - y_data_jax if normalize: _resid = _resid / jnp.maximum(jnp.abs(y_data_jax), _norm_floor) return _resid elif y_pred_is_complex: # Case 2: Complex predictions (G' + iG") if y_data_is_complex: # Both complex: fit real and imaginary parts separately if use_log_residuals: # Log-space residuals for rheological data (mastercurves) # Use magnitudes to avoid log of negative numbers resid_real = jnp.log10( jnp.maximum(jnp.abs(jnp.real(y_pred)), 1e-20) ) - jnp.log10(jnp.maximum(jnp.abs(jnp.real(y_data_jax)), 1e-20)) resid_imag = jnp.log10( jnp.maximum(jnp.abs(jnp.imag(y_pred)), 1e-20) ) - jnp.log10(jnp.maximum(jnp.abs(jnp.imag(y_data_jax)), 1e-20)) # Note: normalize has no effect in log space (already relative) else: # Linear residuals (default) resid_real = jnp.real(y_pred) - jnp.real(y_data_jax) resid_imag = jnp.imag(y_pred) - jnp.imag(y_data_jax) if normalize: resid_real = resid_real / jnp.maximum( jnp.abs(jnp.real(y_data_jax)), _norm_floor ) resid_imag = resid_imag / jnp.maximum( jnp.abs(jnp.imag(y_data_jax)), _norm_floor ) return jnp.concatenate([resid_real, resid_imag]) else: # Complex predictions, real data: fit to magnitude |G*| # This is the common case for oscillation mode fitting y_pred_magnitude = jnp.abs(y_pred) _resid = y_pred_magnitude - y_data_jax if normalize: _resid = _resid / jnp.maximum(jnp.abs(y_data_jax), _norm_floor) return _resid else: # Case 3: Real predictions if y_data_is_complex: # Real predictions, complex data: this is unusual but handle it # Fit to magnitude of data y_data_magnitude = jnp.abs(y_data_jax) if use_log_residuals: # Log-space residuals _resid = jnp.log10(jnp.maximum(jnp.abs(y_pred), 1e-20)) - jnp.log10( jnp.maximum(y_data_magnitude, 1e-20) ) else: _resid = y_pred - y_data_magnitude if normalize: _resid = _resid / jnp.maximum(y_data_magnitude, _norm_floor) return _resid else: # Both real: standard case if use_log_residuals: # Log-space residuals for rheological data # Handle both positive and negative values by using absolute value _resid = jnp.log10(jnp.maximum(jnp.abs(y_pred), 1e-20)) - jnp.log10( jnp.maximum(jnp.abs(y_data_jax), 1e-20) ) else: _resid = y_pred - y_data_jax if normalize: _resid = _resid / jnp.maximum(jnp.abs(y_data_jax), _norm_floor) return _resid # P1-6: Compute normalization weights so that downstream # OptimizationResult can un-normalize for correct R²/AIC/BIC. # OPT-WGT-001: Compute weights in pure NumPy from the original y_data to # avoid unnecessary JAX dispatch + host-device transfer. The _norm_floor # value is already available as a Python float (converted from the JAX # scalar via float()). This saves ~10-25 µs per create_least_squares_objective # call (measured: 10.7 µs → 0.75 µs for real data; 26.6 µs → 2.5 µs for # complex data). weights: np.ndarray | None = None if normalize and not use_log_residuals: # Use float(_norm_floor) to get the Python scalar; avoids a JAX roundtrip. _norm_floor_f = float(_norm_floor) # Work from the original y_data (possibly still numpy) to avoid a device # transfer. np.asarray on a numpy array is zero-copy. y_data_np = np.asarray(y_data) if y_data_is_complex: w_real = np.maximum(np.abs(np.real(y_data_np)), _norm_floor_f) w_imag = np.maximum(np.abs(np.imag(y_data_np)), _norm_floor_f) weights = np.concatenate([w_real, w_imag]) elif y_data_np.ndim == 2 and y_data_np.shape[-1] == 2: w0 = np.maximum(np.abs(y_data_np[:, 0]), _norm_floor_f) w1 = np.maximum(np.abs(y_data_np[:, 1]), _norm_floor_f) weights = np.concatenate([w0, w1]) else: weights = np.maximum(np.abs(y_data_np), _norm_floor_f) return ResidualFunction(residuals, normalization_weights=weights)
# Convenience aliases for compatibility with different naming conventions optimize = nlsq_optimize # Generic name fit_parameters = nlsq_optimize # More descriptive for model fitting __all__ = [ "OptimizationResult", "ResidualFunction", "make_fd_differentiable", "nlsq_optimize", "nlsq_multistart_optimize", "nlsq_curve_fit", "optimize_with_bounds", "residual_sum_of_squares", "create_least_squares_objective", "optimize", "fit_parameters", ]