Source code for rheojax.core.parameters

"""Parameter management system for models and transforms.

This module provides classes for managing parameters, constraints,
and optimization support for rheological models.
"""

from __future__ import annotations

import math
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

import numpy as np

from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger

# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
HAS_JAX = True

# Module-level logger
logger = get_logger(__name__)


if TYPE_CHECKING:  # pragma: no cover - typing helper only
    import jax.numpy as jnp_typing
else:
    jnp_typing = np


type ArrayLike = np.ndarray | jnp_typing.ndarray | list | tuple


def _coerce_array(values: ArrayLike) -> np.ndarray:
    """Convert array-like inputs to NumPy arrays without altering callers."""
    if isinstance(values, np.ndarray):
        return values
    if HAS_JAX and isinstance(values, jnp.ndarray):
        return np.asarray(values)
    return np.asarray(values)


[docs] @dataclass class ParameterConstraint: """Constraint on a parameter value.""" type: str # 'bounds', 'positive', 'integer', 'fixed', 'relative', 'custom' min_value: float | None = None max_value: float | None = None value: float | None = None # For fixed constraints relation: str | None = None # For relative constraints other_param: str | None = None # For relative constraints validator: Callable[[float], bool] | None = None # For custom constraints
[docs] def to_dict(self) -> dict[str, Any]: """Serialize constraint to a dictionary.""" d: dict[str, Any] = {"type": self.type} if self.type == "bounds": d["min_value"] = self.min_value d["max_value"] = self.max_value if self.relation is not None: d["relation"] = self.relation if self.other_param is not None: d["other_param"] = self.other_param if self.value is not None: d["value"] = self.value return d
[docs] def validate(self, value: float, context: dict[str, float] | None = None) -> bool: """Check if value satisfies the constraint. Args: value: Value to validate context: Context with other parameter values (for relative constraints) Returns: True if constraint is satisfied """ # NaN/Inf bypass IEEE 754 comparisons — reject unconditionally if not np.isfinite(value): return False if self.type == "bounds": if self.min_value is not None and value < self.min_value: logger.debug( "Bound check failed: value below minimum", constraint_type=self.type, value=value, min_value=self.min_value, ) return False if self.max_value is not None and value > self.max_value: logger.debug( "Bound check failed: value above maximum", constraint_type=self.type, value=value, max_value=self.max_value, ) return False return True elif self.type == "positive": return value > 0 elif self.type == "integer": return float(value).is_integer() elif self.type == "fixed": return value == self.value elif self.type == "relative" and context: if self.other_param not in context: return True # Can't validate without context other_value = context[self.other_param] if self.relation == "less_than": return value < other_value elif self.relation == "greater_than": return value > other_value elif self.relation == "equal": return value == other_value elif self.type == "custom" and self.validator: return self.validator(value) elif self.type not in { "bounds", "positive", "integer", "fixed", "relative", "custom", }: raise ValueError(f"Unknown constraint type: {self.type!r}") return True
[docs] class Parameter: """Single parameter with value, bounds, and metadata. A Parameter represents a model parameter with support for bounds validation, units tracking, and constraint enforcement. Parameters can be used in both NLSQ optimization and Bayesian inference workflows. Attributes: name: Parameter identifier used for lookup and serialization. value: Current parameter value (may be None if unset). bounds: Lower and upper bounds as tuple (min, max). units: Physical units string for display (e.g., "Pa", "s"). description: Human-readable description. constraints: List of ParameterConstraint objects for validation. Example: >>> param = Parameter("G0", value=1e5, bounds=(1e3, 1e9), units="Pa") >>> param.value = 2e5 # Validated against bounds >>> param.validate() True """ __slots__ = ( "name", "_bounds", "units", "description", "constraints", "_value", "_clamp_on_set", "_was_clamped", "prior", )
[docs] def __init__( self, name: str, value: float | None = None, bounds: tuple[float, float] | None = None, units: str | None = None, description: str | None = None, constraints: list[ParameterConstraint] | None = None, ) -> None: self.name = name self._bounds: tuple[float, float] | None = bounds self.units = units self.description = description self.constraints = list(constraints) if constraints else [] self._value: float | None = None self._clamp_on_set = False self._was_clamped = False self.prior: dict[str, Any] | None = None if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Creating parameter", parameter=name, bounds=bounds, units=units, ) self._initialize(value)
@property def bounds(self) -> tuple[float, float] | None: """Get parameter bounds.""" return self._bounds @bounds.setter def bounds(self, new_bounds: tuple[float, float] | None) -> None: """Set parameter bounds and sync any bounds constraint.""" self._bounds = new_bounds # Sync ALL bounds constraints — iterate the full list so that every # "bounds" constraint is updated, not only the first one. if hasattr(self, "constraints"): for c in self.constraints: if hasattr(c, "type") and c.type == "bounds": if new_bounds is not None: c.min_value = new_bounds[0] c.max_value = new_bounds[1] else: c.min_value = None c.max_value = None def _initialize(self, value: float | None) -> None: """Validate parameter after initialization.""" if self.bounds is not None: lower, upper = self.bounds lower = float(lower) upper = float(upper) if lower > upper: logger.error( "Invalid bounds: lower > upper", parameter=self.name, bounds=(lower, upper), ) raise ValueError( f"Invalid bounds for parameter '{self.name}': {(lower, upper)}" ) self.bounds = (lower, upper) # Add bounds as constraint if specified and not already present. # R12-B-015: use `is not None` instead of truthiness check so that a # bounds tuple of (0.0, 0.0) — while degenerate — is still handled. if self.bounds is not None: has_bounds_constraint = any(c.type == "bounds" for c in self.constraints) if not has_bounds_constraint: self.constraints.insert( 0, ParameterConstraint( type="bounds", min_value=self.bounds[0], max_value=self.bounds[1], ), ) if value is not None: self._clamp_on_set = True self.value = value self._clamp_on_set = False @property def value(self) -> float | None: """Get parameter value.""" return self._value @value.setter def value(self, val: float | None) -> None: """Set parameter value with validation.""" if val is None: self._value = None self._was_clamped = False return # R5-JAX-002: Use isinstance check instead of hasattr duck-typing. # The old guard (hasattr "aval") also rejected concrete jax.Array # scalars which have .aval but are NOT tracers. try: from rheojax.core.jax_config import safe_import_jax _jax, _ = safe_import_jax() if isinstance(val, _jax.core.Tracer): raise TypeError( "Cannot set parameter value to a JAX traced value. " "Call set_value() outside of @jax.jit." ) except ImportError: pass try: numeric_val = float(val) except (TypeError, ValueError) as exc: logger.error( "Failed to convert value to numeric", parameter=self.name, value=val, exc_info=True, ) raise ValueError( f"Parameter '{self.name}' requires a numeric value" ) from exc if not np.isfinite(numeric_val): logger.error( "Non-finite value received", parameter=self.name, value=numeric_val, ) raise ValueError(f"Parameter '{self.name}' received non-finite value") clamped_during_init = False if self.bounds is not None: lower, upper = self.bounds _debug = logger.isEnabledFor(10) # logging.DEBUG == 10 if _debug: logger.debug( "Bound check", parameter=self.name, value=numeric_val, bounds=self.bounds, ) if self._clamp_on_set: if numeric_val < lower: warnings.warn( f"Parameter '{self.name}' initialized below bounds; clamped to {lower}", RuntimeWarning, stacklevel=2, ) if _debug: logger.debug( "Value clamped to lower bound", parameter=self.name, original_value=numeric_val, clamped_value=lower, ) numeric_val = lower clamped_during_init = True elif numeric_val > upper: warnings.warn( f"Parameter '{self.name}' initialized above bounds; clamped to {upper}", RuntimeWarning, stacklevel=2, ) if _debug: logger.debug( "Value clamped to upper bound", parameter=self.name, original_value=numeric_val, clamped_value=upper, ) numeric_val = upper clamped_during_init = True elif numeric_val < lower or numeric_val > upper: logger.error( "Value out of bounds", parameter=self.name, value=numeric_val, bounds=self.bounds, ) raise ValueError(f"Value {numeric_val} out of bounds {self.bounds}") if self._clamp_on_set: self._was_clamped = clamped_during_init else: # R12-B-017: _was_clamped is only meaningful at initialization time # (when _clamp_on_set=True). For all subsequent set_value() calls it # is reset to False because clamping is not applied — out-of-bounds # values raise ValueError instead. self._was_clamped = False self._value = numeric_val @property def was_clamped(self) -> bool: """Return True if the last assignment clamped the value.""" return self._was_clamped
[docs] def validate(self, value: float, context: dict[str, float] | None = None) -> bool: """Validate value against all constraints. Args: value: Value to validate context: Context with other parameter values Returns: True if all constraints are satisfied """ for constraint in self.constraints: if not constraint.validate(value, context): logger.debug( "Constraint validation failed", parameter=self.name, value=value, constraint_type=constraint.type, ) return False return True
[docs] def __hash__(self) -> int: """Make Parameter hashable for use as dict keys. Returns: Hash based on immutable identity attributes only """ return hash((self.name, self.bounds, self.units))
[docs] def __eq__(self, other: object) -> bool: """Check equality with another Parameter. Matches __hash__: identity-based on (name, bounds, units). Value is excluded because it changes during fitting while the parameter identity remains the same. Args: other: Object to compare with Returns: True if parameters have the same identity """ if not isinstance(other, Parameter): return NotImplemented return ( self.name == other.name and self.bounds == other.bounds and self.units == other.units )
[docs] def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" d = { "name": self.name, "value": self.value, "bounds": self.bounds, "units": self.units, "description": self.description, } if self.prior is not None: d["prior"] = self.prior if self.constraints: d["constraints"] = [c.to_dict() for c in self.constraints] return d
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> Parameter: """Create from dictionary representation.""" param = cls( name=data["name"], value=data.get("value"), bounds=tuple(data["bounds"]) if data.get("bounds") else None, units=data.get("units"), description=data.get("description"), ) if "constraints" in data: for c_data in data["constraints"]: # Skip bounds constraints — __init__ already creates one from # the bounds= parameter, so appending another would duplicate if c_data["type"] == "bounds": continue param.constraints.append( ParameterConstraint( type=c_data["type"], min_value=c_data.get("min_value"), max_value=c_data.get("max_value"), relation=c_data.get("relation"), other_param=c_data.get("other_param"), value=c_data.get("value"), ) ) if data.get("prior") is not None: param.prior = data["prior"] return param
[docs] class ParameterSet: """Collection of parameters for a model or transform. A ParameterSet manages multiple Parameter objects with dict-like access, batch operations, and serialization support. It is the primary interface for working with model parameters in RheoJAX. Key Features: - Dict-like access: ``params["G0"]`` or ``params.get("G0")`` - Batch operations: ``get_values()``, ``set_values()``, ``get_bounds()`` - Unpack helper: ``G0, eta = params.unpack("G0", "eta")`` - Serialization: ``to_dict()`` / ``from_dict()`` for JSON/HDF5 Example: >>> params = ParameterSet() >>> params.add("G0", value=1e5, bounds=(1e3, 1e9), units="Pa") >>> params.add("eta", value=1e3, bounds=(1e-3, 1e9), units="Pa*s") >>> G0, eta = params.unpack("G0", "eta") >>> print(f"G0={G0:.2e}, eta={eta:.2e}") G0=1.00e+05, eta=1.00e+03 See Also: Parameter: Individual parameter class. SharedParameterSet: For multi-model parameter sharing. """ __slots__ = ("_parameters", "_order", "_has_relative_constraints")
[docs] def __init__(self): """Initialize empty parameter set.""" self._parameters: dict[str, Parameter] = {} self._order: list[str] = [] self._has_relative_constraints: bool = False if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug("ParameterSet created")
[docs] def add( self, name: str, value: float | None = None, bounds: tuple[float, float] | None = None, units: str | None = None, description: str | None = None, constraints: list[ParameterConstraint] | None = None, overwrite: bool = False, ) -> Parameter: """Add a parameter to the set. Args: name: Parameter name value: Initial value bounds: Value bounds (min, max) units: Parameter units description: Parameter description constraints: List of constraints overwrite: If True, silently overwrite an existing parameter without emitting a warning. Default is False (warns on overwrite). Returns: The created Parameter object """ _debug = logger.isEnabledFor(10) # logging.DEBUG == 10 if _debug: logger.debug( "Adding parameter to set", operation="add", parameter=name, value=value, bounds=bounds, ) # R8-PARAMS-002: warn on silent overwrite (unless overwrite=True is explicit) # R10-PARAMS-001: expose overwrite flag so callers can suppress the warning if name in self._parameters and not overwrite: warnings.warn( f"Parameter '{name}' already exists and will be overwritten", stacklevel=2, ) param = Parameter( name=name, value=value, bounds=bounds, units=units, description=description, constraints=constraints or [], ) self._parameters[name] = param if name not in self._order: self._order.append(name) # Track whether any relative constraints exist (for set_value optimization) if constraints: if any(c.type == "relative" for c in constraints): self._has_relative_constraints = True if _debug: logger.debug( "Parameter added", operation="add", params=list(self._parameters.keys()), ) return param
[docs] def get(self, name: str) -> Parameter | None: """Get a parameter by name. Args: name: Parameter name Returns: Parameter object or None if not found """ if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Getting parameter", operation="get", parameter=name, ) return self._parameters.get(name)
[docs] def set_value(self, name: str, value: float): """Set parameter value. Args: name: Parameter name value: New value Raises: KeyError: If parameter not found ValueError: If value violates constraints """ if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Setting parameter value", operation="set_value", parameter=name, value=value, ) if name not in self._parameters: logger.error( "Parameter not found", parameter=name, available_params=list(self._parameters.keys()), ) raise KeyError(f"Parameter '{name}' not found") param = self._parameters[name] # Always build context dict so that relative constraints added after # initial parameter creation are not bypassed. The cost is negligible # (one dict comprehension) compared to the validation itself. context = { p.name: p.value for p in self._parameters.values() if p.value is not None } if not param.validate(value, context): logger.error( "Value violates constraints", parameter=name, value=value, ) raise ValueError( f"Value {value} violates constraints for parameter '{name}'" ) param.value = value
[docs] def set_bounds(self, name: str, bounds: tuple[float, float]): """Set bounds for a parameter. Args: name: Parameter name bounds: Tuple of (min, max) values Raises: KeyError: If parameter not found ValueError: If bounds are invalid """ if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Setting parameter bounds", operation="set_bounds", parameter=name, bounds=bounds, ) if name not in self._parameters: logger.error( "Parameter not found", parameter=name, available_params=list(self._parameters.keys()), ) raise KeyError(f"Parameter '{name}' not found") min_val, max_val = bounds if min_val > max_val: logger.error( "Invalid bounds: min > max", parameter=name, min_val=min_val, max_val=max_val, ) raise ValueError( f"Invalid bounds: min ({min_val}) must be <= max ({max_val})" ) param = self._parameters[name] # bounds.setter auto-syncs the associated bounds constraint param.bounds = bounds
[docs] def get_values(self) -> np.ndarray: """Get all parameter values as array. Returns: Array of parameter values in order """ values = [] for name in self._order: param = self._parameters[name] if param.value is not None: values.append(param.value) else: # Use geometric mean of positive bounds for scale-invariant default; # fall back to arithmetic midpoint for bounds that include zero or # are negative. 0.0 is outside bounds for most parameters and # causes NLSQ to start from an infeasible point. lo, hi = param.bounds if param.bounds else (0.0, 1.0) lo = lo if lo is not None else 0.0 hi = hi if hi is not None else 1.0 if lo > 0 and hi > 0: default = math.sqrt(lo * hi) else: default = (lo + hi) / 2.0 logger.warning( "Parameter has no value set, using bounds midpoint", parameter=name, bounds=param.bounds, default=default, ) values.append(default) if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Getting all parameter values", operation="get_values", params=list(self._parameters.keys()), num_params=len(values), ) return np.array(values, dtype=np.float64)
[docs] def set_values(self, values: ArrayLike | dict[str, float]): """Set parameter values from array or dictionary. Args: values: Array of values in order, or dict mapping names to values Raises: ValueError: If wrong number of values (array) or unknown parameter (dict) """ if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Setting multiple parameter values", operation="set_values", params=list(self._parameters.keys()), ) if isinstance(values, dict): for name, value in values.items(): if name not in self._parameters: logger.error( "Unknown parameter in dict", parameter=name, available_params=list(self._parameters.keys()), ) raise ValueError(f"Unknown parameter: {name}") self.set_value(name, float(value)) else: values = np.atleast_1d(values) if len(values) != len(self._order): logger.error( "Wrong number of values", expected=len(self._order), got=len(values), ) raise ValueError( f"Expected {len(self._order)} values, got {len(values)}" ) for name, value in zip(self._order, values, strict=False): self.set_value(name, float(value))
[docs] def update( self, values: dict[str, float], *, strict: bool = True, ) -> dict[str, str]: """Apply a batch of name→value updates with optional failure tolerance. Replacement for the ``for k, v in d.items(): try: set_value(k, v) except: logger.warning(...)`` pattern found in notebooks that mix parameters from different model schemas. Two failure modes are reported separately so the caller (or a schema-migration review) can tell "this key does not exist on this model" from "this value violates the constraints" without scanning ERROR-level logs. Args: values: Mapping of parameter name → new value. strict: When ``True`` (default), re-raises the first ``KeyError`` (unknown name) or ``ValueError`` (bad value) so calling code cannot silently drift out of the current schema. When ``False``, collects every failure into the returned dict without logging at ERROR level — useful during migration to draft a single summary warning. Returns: Dict of ``{name: reason}`` for entries that failed. Empty when all succeeded (including when ``values`` is empty). Raises: KeyError: (strict=True) if any name is unknown. ValueError: (strict=True) if any value violates constraints. """ failures: dict[str, str] = {} for name, value in values.items(): if name not in self._parameters: if strict: raise KeyError(f"Parameter '{name}' not found") failures[name] = ( f"unknown parameter (available: " f"{sorted(self._parameters.keys())})" ) continue param = self._parameters[name] context = { p.name: p.value for p in self._parameters.values() if p.value is not None } if not param.validate(float(value), context): if strict: raise ValueError( f"Value {value} violates constraints for parameter '{name}'" ) failures[name] = f"value {value} violates constraints" continue param.value = float(value) return failures
[docs] def get_bounds(self) -> list[tuple[float | None, float | None]]: """Get bounds for all parameters. Returns: List of (min, max) tuples """ bounds: list[tuple[float | None, float | None]] = [] for name in self._order: param = self._parameters[name] if param.bounds: bounds.append(param.bounds) else: bounds.append((None, None)) if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Getting all parameter bounds", operation="get_bounds", params=list(self._parameters.keys()), num_params=len(bounds), ) return bounds
[docs] def get_value(self, name: str) -> float | None: """Get value of a specific parameter. Args: name: Parameter name Returns: Parameter value or None """ param = self.get(name) return param.value if param else None
[docs] def unpack(self, *names: str) -> tuple[float | None, ...]: """Extract multiple parameter values in a single call. This method provides a concise way to extract several parameter values at once, reducing boilerplate in model implementations. Args: *names: Parameter names to extract Returns: Tuple of parameter values in the same order as requested. Returns None for parameters with None values. Raises: KeyError: If any parameter name is not found. The error message includes the missing name and lists available parameters. Examples: Basic usage - extract multiple parameters in one line: >>> params = ParameterSet() >>> _ = params.add('x', value=1.5) >>> _ = params.add('G0', value=100.0) >>> _ = params.add('tau0', value=0.01) >>> x, G0, tau0 = params.unpack('x', 'G0', 'tau0') >>> x 1.5 >>> G0 100.0 Before (verbose):: x = params.get_value('x') G0 = params.get_value('G0') tau0 = params.get_value('tau0') After (concise):: x, G0, tau0 = params.unpack('x', 'G0', 'tau0') """ values = [] for name in names: if name not in self._parameters: available = list(self._parameters.keys()) raise KeyError( f"Parameter '{name}' not found. " f"Available parameters: {available}" ) values.append(self.get_value(name)) return tuple(values)
[docs] def __len__(self) -> int: """Number of parameters.""" return len(self._parameters)
[docs] def __contains__(self, name: str) -> bool: """Check if parameter exists.""" return name in self._parameters
[docs] def __iter__(self): """Iterate over parameter names.""" return iter(self._order)
[docs] def keys(self): """Return an iterator over parameter names (dict-like interface). Returns: Iterator over parameter names in order Examples: >>> params = ParameterSet() >>> params.add('alpha', value=0.5) >>> params.add('beta', value=1.0) >>> list(params.keys()) ['alpha', 'beta'] """ return iter(self._order)
[docs] def values(self): """Return an iterator over Parameter objects (dict-like interface). Returns: Iterator over Parameter objects in order Examples: >>> params = ParameterSet() >>> params.add('alpha', value=0.5, units='') >>> for param in params.values(): ... print(f"{param.name}: {param.value}") alpha: 0.5 """ for name in self._order: yield self._parameters[name]
[docs] def items(self): """Return an iterator over (name, Parameter) tuples (dict-like interface). Returns: Iterator over (name, Parameter) tuples in order Examples: >>> params = ParameterSet() >>> params.add('alpha', value=0.5) >>> for name, param in params.items(): ... print(f"{name}: {param.value}") alpha: 0.5 """ for name in self._order: yield name, self._parameters[name]
[docs] def __getitem__(self, key: str) -> Parameter: """Get parameter by name using subscript notation. Args: key: Parameter name Returns: Parameter object Raises: KeyError: If parameter not found Examples: >>> params = ParameterSet() >>> params.add('alpha', value=0.5) >>> param = params['alpha'] # Get parameter object >>> value = params['alpha'].value # Get value """ if key not in self._parameters: logger.error( "Parameter not found in subscript access", parameter=key, available_params=list(self._parameters.keys()), ) raise KeyError(f"Parameter '{key}' not found in ParameterSet") return self._parameters[key]
[docs] def __setitem__(self, key: str, value: float | Parameter): """Set parameter value using subscript notation. Args: key: Parameter name value: New value (float) or Parameter object Raises: KeyError: If parameter not found and value is float ValueError: If value violates constraints Examples: >>> params = ParameterSet() >>> params.add('alpha', value=0.5, bounds=(0, 1)) >>> params['alpha'] = 0.7 # Set value >>> # Or replace entire parameter: >>> params['alpha'] = Parameter('alpha', value=0.8, bounds=(0, 1)) """ if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Setting parameter via subscript", operation="__setitem__", parameter=key, ) if isinstance(value, Parameter): # Replace entire parameter self._parameters[key] = value if key not in self._order: self._order.append(key) if not self._has_relative_constraints and value.constraints: if any(c.type == "relative" for c in value.constraints): self._has_relative_constraints = True else: # Set value only if key not in self._parameters: logger.error( "Parameter not found for subscript assignment", parameter=key, available_params=list(self._parameters.keys()), ) raise KeyError( f"Parameter '{key}' not found. Use add() to create new parameters." ) self.set_value(key, float(value))
[docs] def to_dict(self) -> dict[str, dict[str, Any]]: """Convert to dictionary representation.""" if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Converting ParameterSet to dict", operation="to_dict", params=list(self._parameters.keys()), ) return {name: self._parameters[name].to_dict() for name in self._order}
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> ParameterSet: """Create from dictionary representation. Uses Parameter.from_dict() to preserve constraints (not just bounds). """ # R8-PARAMS-003: NOTE — _was_clamped and _clamp_on_set flags are not # preserved across serialization. Clamping behavior may differ after round-trip. if logger.isEnabledFor(10): # logging.DEBUG == 10 logger.debug( "Creating ParameterSet from dict", operation="from_dict", params=list(data.keys()), ) params = cls() for name, param_data in data.items(): if isinstance(param_data, dict): # Ensure name is in param_data for Parameter.from_dict param_data_with_name = {**param_data, "name": name} param = Parameter.from_dict(param_data_with_name) params._parameters[name] = param params._order.append(name) for _param in params._parameters.values(): if not params._has_relative_constraints and _param.constraints: if any(c.type == "relative" for c in _param.constraints): params._has_relative_constraints = True break return params
[docs] class SharedParameterSet: """Manages parameters shared across multiple models."""
[docs] def __init__(self): """Initialize shared parameter set.""" self._shared: dict[str, Parameter] = {} self._links: dict[str, list[Any]] = {} # Parameter -> list of linked objects self._groups: dict[str, list[str]] = {} # Group name -> parameter names logger.debug("SharedParameterSet created")
[docs] def add_shared( self, name: str, value: float | None = None, bounds: tuple[float, float] | None = None, units: str | None = None, constraints: list[ParameterConstraint] | None = None, group: str | None = None, ) -> Parameter: """Add a shared parameter. Args: name: Parameter name value: Initial value bounds: Value bounds units: Parameter units constraints: Parameter constraints group: Optional group name Returns: The created Parameter """ logger.debug( "Adding shared parameter", operation="add_shared", parameter=name, value=value, bounds=bounds, group=group, ) param = Parameter( name=name, value=value, bounds=bounds, units=units, constraints=constraints or [], ) self._shared[name] = param self._links[name] = [] if group: if group not in self._groups: self._groups[group] = [] self._groups[group].append(name) logger.debug( "Shared parameter added", operation="add_shared", params=list(self._shared.keys()), ) return param
[docs] def set_value(self, name: str, value: float): """Set shared parameter value. Args: name: Parameter name value: New value Raises: ValueError: If value violates constraints """ logger.debug( "Setting shared parameter value", operation="set_value", parameter=name, value=value, ) if name not in self._shared: logger.error( "Shared parameter not found", parameter=name, available_params=list(self._shared.keys()), ) raise KeyError(f"Shared parameter '{name}' not found") param = self._shared[name] # Validate if not param.validate(value): logger.error( "Value violates constraints for shared parameter", parameter=name, value=value, ) raise ValueError( f"Value {value} violates constraints for parameter '{name}'" ) param.value = value # Update linked models/parameter sets for linked in self._links.get(name, []): if ( hasattr(linked, "set_value") and hasattr(linked, "__contains__") and name in linked ): # This is a ParameterSet with the parameter linked.set_value(name, value) elif hasattr(linked, "parameters") and name in linked.parameters: # This is a model with parameters linked.parameters.set_value(name, value)
[docs] def get_value(self, name: str) -> float | None: """Get shared parameter value. Args: name: Parameter name Returns: Parameter value or None """ param = self._shared.get(name) return param.value if param else None
[docs] def get_linked_models(self, param_name: str) -> list[Any]: """Get models linked to a parameter. Args: param_name: Parameter name Returns: List of linked models """ return self._links.get(param_name, [])
[docs] def create_group(self, group_name: str, param_names: list[str]): """Create a parameter group. Args: group_name: Name for the group param_names: Parameter names to include """ logger.debug( "Creating parameter group", operation="create_group", group=group_name, params=param_names, ) self._groups[group_name] = param_names
[docs] def get_group(self, group_name: str) -> list[str]: """Get parameters in a group. Args: group_name: Group name Returns: List of parameter names in group """ return self._groups.get(group_name, [])
[docs] def __contains__(self, name: str) -> bool: """Check if shared parameter exists.""" return name in self._shared
[docs] class ParameterOptimizer: """Optimizer for parameter fitting."""
[docs] def __init__( self, parameters: ParameterSet, use_jax: bool = False, track_history: bool = False, ): """Initialize parameter optimizer. Args: parameters: ParameterSet to optimize use_jax: Whether to use JAX for optimization track_history: Whether to track optimization history """ self.parameters = parameters self.use_jax = use_jax and HAS_JAX self.track_history = track_history self.history: list[dict[str, Any]] = [] self.objective: Callable | None = None self.constraints: list[Callable] = [] self.callback: Callable | None = None logger.debug( "ParameterOptimizer created", num_params=len(parameters), use_jax=self.use_jax, track_history=track_history, )
@property def n_parameters(self) -> int: """Number of parameters.""" return len(self.parameters)
[docs] def get_values(self) -> np.ndarray: """Get current parameter values.""" return self.parameters.get_values()
[docs] def get_bounds(self) -> list[tuple[float | None, float | None]]: """Get parameter bounds.""" return self.parameters.get_bounds()
[docs] def set_objective(self, objective: Callable): """Set objective function to minimize. Args: objective: Function that takes parameter values and returns scalar """ logger.debug( "Setting objective function", operation="set_objective", ) self.objective = objective
[docs] def evaluate(self, values: ArrayLike) -> float: """Evaluate objective at given values. Args: values: Parameter values Returns: Objective function value """ if self.objective is None: logger.error( "No objective function set", ) raise ValueError("No objective function set") result = self.objective(values) # Convert to float if needed if isinstance(result, (np.ndarray, jnp.ndarray)): result = float(result) return result
[docs] def compute_gradient(self, values: ArrayLike) -> np.ndarray: """Compute gradient of objective. Args: values: Parameter values Returns: Gradient vector """ logger.debug( "Computing gradient", operation="compute_gradient", use_jax=self.use_jax, ) if not self.use_jax or not HAS_JAX: # Numerical gradient eps = 1e-8 values_array = _coerce_array(values) n = len(values_array) grad = np.zeros(n) for i in range(n): values_plus = values_array.copy() values_plus[i] += eps f_plus = self.evaluate(values_plus) f = self.evaluate(values_array) grad[i] = (f_plus - f) / eps return grad else: # JAX automatic differentiation grad_fn = jax.grad(self.objective) return np.array(grad_fn(jnp.array(values)))
[docs] def add_constraint(self, constraint: Callable): """Add optimization constraint. Args: constraint: Function that returns >= 0 for valid values """ logger.debug( "Adding optimization constraint", operation="add_constraint", num_constraints=len(self.constraints) + 1, ) self.constraints.append(constraint)
[docs] def validate_constraints(self, values: ArrayLike) -> bool: """Check if constraints are satisfied. Args: values: Parameter values Returns: True if all constraints satisfied """ values_array = _coerce_array(values) for constraint in self.constraints: if constraint(values_array) < 0: logger.debug( "Constraint validation failed", operation="validate_constraints", ) return False return True
[docs] def set_callback(self, callback: Callable): """Set optimization callback. Args: callback: Function called after each iteration """ logger.debug( "Setting optimization callback", operation="set_callback", ) self.callback = callback
[docs] def step(self, values: ArrayLike, iteration: int | None = None): """Perform one optimization step. Args: values: Current parameter values iteration: Current iteration number """ # Update parameters coerced_values = _coerce_array(values) self.parameters.set_values(coerced_values) # Evaluate objective obj_value = self.evaluate(coerced_values) # Track history if self.track_history: # Use `is not None` guard so iteration=0 (first step) is stored # correctly. The `or` sentinel coerces 0 → len(history), recording # the wrong iteration number when history is already non-empty. effective_iter = iteration if iteration is not None else len(self.history) self.history.append( { "iteration": effective_iter, "values": coerced_values.copy(), "objective": obj_value, } ) # Call callback if self.callback: # Same guard: iteration=0 must reach the callback as 0. self.callback( iteration if iteration is not None else 0, coerced_values, obj_value )
[docs] def get_history(self) -> list[dict[str, Any]]: """Get optimization history. Returns: List of history dictionaries """ return self.history