"""Parameter management system for models and transforms.
This module provides classes for managing parameters, constraints,
and optimization support for rheological models.
"""
from __future__ import annotations
import math
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import numpy as np
from rheojax.core.jax_config import safe_import_jax
from rheojax.logging import get_logger
# Safe JAX import (enforces float64)
jax, jnp = safe_import_jax()
HAS_JAX = True
# Module-level logger
logger = get_logger(__name__)
if TYPE_CHECKING: # pragma: no cover - typing helper only
import jax.numpy as jnp_typing
else:
jnp_typing = np
type ArrayLike = np.ndarray | jnp_typing.ndarray | list | tuple
def _coerce_array(values: ArrayLike) -> np.ndarray:
"""Convert array-like inputs to NumPy arrays without altering callers."""
if isinstance(values, np.ndarray):
return values
if HAS_JAX and isinstance(values, jnp.ndarray):
return np.asarray(values)
return np.asarray(values)
[docs]
@dataclass
class ParameterConstraint:
"""Constraint on a parameter value."""
type: str # 'bounds', 'positive', 'integer', 'fixed', 'relative', 'custom'
min_value: float | None = None
max_value: float | None = None
value: float | None = None # For fixed constraints
relation: str | None = None # For relative constraints
other_param: str | None = None # For relative constraints
validator: Callable[[float], bool] | None = None # For custom constraints
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialize constraint to a dictionary."""
d: dict[str, Any] = {"type": self.type}
if self.type == "bounds":
d["min_value"] = self.min_value
d["max_value"] = self.max_value
if self.relation is not None:
d["relation"] = self.relation
if self.other_param is not None:
d["other_param"] = self.other_param
if self.value is not None:
d["value"] = self.value
return d
[docs]
def validate(self, value: float, context: dict[str, float] | None = None) -> bool:
"""Check if value satisfies the constraint.
Args:
value: Value to validate
context: Context with other parameter values (for relative constraints)
Returns:
True if constraint is satisfied
"""
# NaN/Inf bypass IEEE 754 comparisons — reject unconditionally
if not np.isfinite(value):
return False
if self.type == "bounds":
if self.min_value is not None and value < self.min_value:
logger.debug(
"Bound check failed: value below minimum",
constraint_type=self.type,
value=value,
min_value=self.min_value,
)
return False
if self.max_value is not None and value > self.max_value:
logger.debug(
"Bound check failed: value above maximum",
constraint_type=self.type,
value=value,
max_value=self.max_value,
)
return False
return True
elif self.type == "positive":
return value > 0
elif self.type == "integer":
return float(value).is_integer()
elif self.type == "fixed":
return value == self.value
elif self.type == "relative" and context:
if self.other_param not in context:
return True # Can't validate without context
other_value = context[self.other_param]
if self.relation == "less_than":
return value < other_value
elif self.relation == "greater_than":
return value > other_value
elif self.relation == "equal":
return value == other_value
elif self.type == "custom" and self.validator:
return self.validator(value)
elif self.type not in {
"bounds",
"positive",
"integer",
"fixed",
"relative",
"custom",
}:
raise ValueError(f"Unknown constraint type: {self.type!r}")
return True
[docs]
class Parameter:
"""Single parameter with value, bounds, and metadata.
A Parameter represents a model parameter with support for bounds validation,
units tracking, and constraint enforcement. Parameters can be used in both
NLSQ optimization and Bayesian inference workflows.
Attributes:
name: Parameter identifier used for lookup and serialization.
value: Current parameter value (may be None if unset).
bounds: Lower and upper bounds as tuple (min, max).
units: Physical units string for display (e.g., "Pa", "s").
description: Human-readable description.
constraints: List of ParameterConstraint objects for validation.
Example:
>>> param = Parameter("G0", value=1e5, bounds=(1e3, 1e9), units="Pa")
>>> param.value = 2e5 # Validated against bounds
>>> param.validate()
True
"""
__slots__ = (
"name",
"_bounds",
"units",
"description",
"constraints",
"_value",
"_clamp_on_set",
"_was_clamped",
"prior",
)
[docs]
def __init__(
self,
name: str,
value: float | None = None,
bounds: tuple[float, float] | None = None,
units: str | None = None,
description: str | None = None,
constraints: list[ParameterConstraint] | None = None,
) -> None:
self.name = name
self._bounds: tuple[float, float] | None = bounds
self.units = units
self.description = description
self.constraints = list(constraints) if constraints else []
self._value: float | None = None
self._clamp_on_set = False
self._was_clamped = False
self.prior: dict[str, Any] | None = None
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Creating parameter",
parameter=name,
bounds=bounds,
units=units,
)
self._initialize(value)
@property
def bounds(self) -> tuple[float, float] | None:
"""Get parameter bounds."""
return self._bounds
@bounds.setter
def bounds(self, new_bounds: tuple[float, float] | None) -> None:
"""Set parameter bounds and sync any bounds constraint."""
self._bounds = new_bounds
# Sync ALL bounds constraints — iterate the full list so that every
# "bounds" constraint is updated, not only the first one.
if hasattr(self, "constraints"):
for c in self.constraints:
if hasattr(c, "type") and c.type == "bounds":
if new_bounds is not None:
c.min_value = new_bounds[0]
c.max_value = new_bounds[1]
else:
c.min_value = None
c.max_value = None
def _initialize(self, value: float | None) -> None:
"""Validate parameter after initialization."""
if self.bounds is not None:
lower, upper = self.bounds
lower = float(lower)
upper = float(upper)
if lower > upper:
logger.error(
"Invalid bounds: lower > upper",
parameter=self.name,
bounds=(lower, upper),
)
raise ValueError(
f"Invalid bounds for parameter '{self.name}': {(lower, upper)}"
)
self.bounds = (lower, upper)
# Add bounds as constraint if specified and not already present.
# R12-B-015: use `is not None` instead of truthiness check so that a
# bounds tuple of (0.0, 0.0) — while degenerate — is still handled.
if self.bounds is not None:
has_bounds_constraint = any(c.type == "bounds" for c in self.constraints)
if not has_bounds_constraint:
self.constraints.insert(
0,
ParameterConstraint(
type="bounds",
min_value=self.bounds[0],
max_value=self.bounds[1],
),
)
if value is not None:
self._clamp_on_set = True
self.value = value
self._clamp_on_set = False
@property
def value(self) -> float | None:
"""Get parameter value."""
return self._value
@value.setter
def value(self, val: float | None) -> None:
"""Set parameter value with validation."""
if val is None:
self._value = None
self._was_clamped = False
return
# R5-JAX-002: Use isinstance check instead of hasattr duck-typing.
# The old guard (hasattr "aval") also rejected concrete jax.Array
# scalars which have .aval but are NOT tracers.
try:
from rheojax.core.jax_config import safe_import_jax
_jax, _ = safe_import_jax()
if isinstance(val, _jax.core.Tracer):
raise TypeError(
"Cannot set parameter value to a JAX traced value. "
"Call set_value() outside of @jax.jit."
)
except ImportError:
pass
try:
numeric_val = float(val)
except (TypeError, ValueError) as exc:
logger.error(
"Failed to convert value to numeric",
parameter=self.name,
value=val,
exc_info=True,
)
raise ValueError(
f"Parameter '{self.name}' requires a numeric value"
) from exc
if not np.isfinite(numeric_val):
logger.error(
"Non-finite value received",
parameter=self.name,
value=numeric_val,
)
raise ValueError(f"Parameter '{self.name}' received non-finite value")
clamped_during_init = False
if self.bounds is not None:
lower, upper = self.bounds
_debug = logger.isEnabledFor(10) # logging.DEBUG == 10
if _debug:
logger.debug(
"Bound check",
parameter=self.name,
value=numeric_val,
bounds=self.bounds,
)
if self._clamp_on_set:
if numeric_val < lower:
warnings.warn(
f"Parameter '{self.name}' initialized below bounds; clamped to {lower}",
RuntimeWarning,
stacklevel=2,
)
if _debug:
logger.debug(
"Value clamped to lower bound",
parameter=self.name,
original_value=numeric_val,
clamped_value=lower,
)
numeric_val = lower
clamped_during_init = True
elif numeric_val > upper:
warnings.warn(
f"Parameter '{self.name}' initialized above bounds; clamped to {upper}",
RuntimeWarning,
stacklevel=2,
)
if _debug:
logger.debug(
"Value clamped to upper bound",
parameter=self.name,
original_value=numeric_val,
clamped_value=upper,
)
numeric_val = upper
clamped_during_init = True
elif numeric_val < lower or numeric_val > upper:
logger.error(
"Value out of bounds",
parameter=self.name,
value=numeric_val,
bounds=self.bounds,
)
raise ValueError(f"Value {numeric_val} out of bounds {self.bounds}")
if self._clamp_on_set:
self._was_clamped = clamped_during_init
else:
# R12-B-017: _was_clamped is only meaningful at initialization time
# (when _clamp_on_set=True). For all subsequent set_value() calls it
# is reset to False because clamping is not applied — out-of-bounds
# values raise ValueError instead.
self._was_clamped = False
self._value = numeric_val
@property
def was_clamped(self) -> bool:
"""Return True if the last assignment clamped the value."""
return self._was_clamped
[docs]
def validate(self, value: float, context: dict[str, float] | None = None) -> bool:
"""Validate value against all constraints.
Args:
value: Value to validate
context: Context with other parameter values
Returns:
True if all constraints are satisfied
"""
for constraint in self.constraints:
if not constraint.validate(value, context):
logger.debug(
"Constraint validation failed",
parameter=self.name,
value=value,
constraint_type=constraint.type,
)
return False
return True
[docs]
def __hash__(self) -> int:
"""Make Parameter hashable for use as dict keys.
Returns:
Hash based on immutable identity attributes only
"""
return hash((self.name, self.bounds, self.units))
[docs]
def __eq__(self, other: object) -> bool:
"""Check equality with another Parameter.
Matches __hash__: identity-based on (name, bounds, units).
Value is excluded because it changes during fitting while the
parameter identity remains the same.
Args:
other: Object to compare with
Returns:
True if parameters have the same identity
"""
if not isinstance(other, Parameter):
return NotImplemented
return (
self.name == other.name
and self.bounds == other.bounds
and self.units == other.units
)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary representation."""
d = {
"name": self.name,
"value": self.value,
"bounds": self.bounds,
"units": self.units,
"description": self.description,
}
if self.prior is not None:
d["prior"] = self.prior
if self.constraints:
d["constraints"] = [c.to_dict() for c in self.constraints]
return d
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Parameter:
"""Create from dictionary representation."""
param = cls(
name=data["name"],
value=data.get("value"),
bounds=tuple(data["bounds"]) if data.get("bounds") else None,
units=data.get("units"),
description=data.get("description"),
)
if "constraints" in data:
for c_data in data["constraints"]:
# Skip bounds constraints — __init__ already creates one from
# the bounds= parameter, so appending another would duplicate
if c_data["type"] == "bounds":
continue
param.constraints.append(
ParameterConstraint(
type=c_data["type"],
min_value=c_data.get("min_value"),
max_value=c_data.get("max_value"),
relation=c_data.get("relation"),
other_param=c_data.get("other_param"),
value=c_data.get("value"),
)
)
if data.get("prior") is not None:
param.prior = data["prior"]
return param
[docs]
class ParameterSet:
"""Collection of parameters for a model or transform.
A ParameterSet manages multiple Parameter objects with dict-like access,
batch operations, and serialization support. It is the primary interface
for working with model parameters in RheoJAX.
Key Features:
- Dict-like access: ``params["G0"]`` or ``params.get("G0")``
- Batch operations: ``get_values()``, ``set_values()``, ``get_bounds()``
- Unpack helper: ``G0, eta = params.unpack("G0", "eta")``
- Serialization: ``to_dict()`` / ``from_dict()`` for JSON/HDF5
Example:
>>> params = ParameterSet()
>>> params.add("G0", value=1e5, bounds=(1e3, 1e9), units="Pa")
>>> params.add("eta", value=1e3, bounds=(1e-3, 1e9), units="Pa*s")
>>> G0, eta = params.unpack("G0", "eta")
>>> print(f"G0={G0:.2e}, eta={eta:.2e}")
G0=1.00e+05, eta=1.00e+03
See Also:
Parameter: Individual parameter class.
SharedParameterSet: For multi-model parameter sharing.
"""
__slots__ = ("_parameters", "_order", "_has_relative_constraints")
[docs]
def __init__(self):
"""Initialize empty parameter set."""
self._parameters: dict[str, Parameter] = {}
self._order: list[str] = []
self._has_relative_constraints: bool = False
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug("ParameterSet created")
[docs]
def add(
self,
name: str,
value: float | None = None,
bounds: tuple[float, float] | None = None,
units: str | None = None,
description: str | None = None,
constraints: list[ParameterConstraint] | None = None,
overwrite: bool = False,
) -> Parameter:
"""Add a parameter to the set.
Args:
name: Parameter name
value: Initial value
bounds: Value bounds (min, max)
units: Parameter units
description: Parameter description
constraints: List of constraints
overwrite: If True, silently overwrite an existing parameter without
emitting a warning. Default is False (warns on overwrite).
Returns:
The created Parameter object
"""
_debug = logger.isEnabledFor(10) # logging.DEBUG == 10
if _debug:
logger.debug(
"Adding parameter to set",
operation="add",
parameter=name,
value=value,
bounds=bounds,
)
# R8-PARAMS-002: warn on silent overwrite (unless overwrite=True is explicit)
# R10-PARAMS-001: expose overwrite flag so callers can suppress the warning
if name in self._parameters and not overwrite:
warnings.warn(
f"Parameter '{name}' already exists and will be overwritten",
stacklevel=2,
)
param = Parameter(
name=name,
value=value,
bounds=bounds,
units=units,
description=description,
constraints=constraints or [],
)
self._parameters[name] = param
if name not in self._order:
self._order.append(name)
# Track whether any relative constraints exist (for set_value optimization)
if constraints:
if any(c.type == "relative" for c in constraints):
self._has_relative_constraints = True
if _debug:
logger.debug(
"Parameter added",
operation="add",
params=list(self._parameters.keys()),
)
return param
[docs]
def get(self, name: str) -> Parameter | None:
"""Get a parameter by name.
Args:
name: Parameter name
Returns:
Parameter object or None if not found
"""
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Getting parameter",
operation="get",
parameter=name,
)
return self._parameters.get(name)
[docs]
def set_value(self, name: str, value: float):
"""Set parameter value.
Args:
name: Parameter name
value: New value
Raises:
KeyError: If parameter not found
ValueError: If value violates constraints
"""
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Setting parameter value",
operation="set_value",
parameter=name,
value=value,
)
if name not in self._parameters:
logger.error(
"Parameter not found",
parameter=name,
available_params=list(self._parameters.keys()),
)
raise KeyError(f"Parameter '{name}' not found")
param = self._parameters[name]
# Always build context dict so that relative constraints added after
# initial parameter creation are not bypassed. The cost is negligible
# (one dict comprehension) compared to the validation itself.
context = {
p.name: p.value for p in self._parameters.values() if p.value is not None
}
if not param.validate(value, context):
logger.error(
"Value violates constraints",
parameter=name,
value=value,
)
raise ValueError(
f"Value {value} violates constraints for parameter '{name}'"
)
param.value = value
[docs]
def set_bounds(self, name: str, bounds: tuple[float, float]):
"""Set bounds for a parameter.
Args:
name: Parameter name
bounds: Tuple of (min, max) values
Raises:
KeyError: If parameter not found
ValueError: If bounds are invalid
"""
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Setting parameter bounds",
operation="set_bounds",
parameter=name,
bounds=bounds,
)
if name not in self._parameters:
logger.error(
"Parameter not found",
parameter=name,
available_params=list(self._parameters.keys()),
)
raise KeyError(f"Parameter '{name}' not found")
min_val, max_val = bounds
if min_val > max_val:
logger.error(
"Invalid bounds: min > max",
parameter=name,
min_val=min_val,
max_val=max_val,
)
raise ValueError(
f"Invalid bounds: min ({min_val}) must be <= max ({max_val})"
)
param = self._parameters[name]
# bounds.setter auto-syncs the associated bounds constraint
param.bounds = bounds
[docs]
def get_values(self) -> np.ndarray:
"""Get all parameter values as array.
Returns:
Array of parameter values in order
"""
values = []
for name in self._order:
param = self._parameters[name]
if param.value is not None:
values.append(param.value)
else:
# Use geometric mean of positive bounds for scale-invariant default;
# fall back to arithmetic midpoint for bounds that include zero or
# are negative. 0.0 is outside bounds for most parameters and
# causes NLSQ to start from an infeasible point.
lo, hi = param.bounds if param.bounds else (0.0, 1.0)
lo = lo if lo is not None else 0.0
hi = hi if hi is not None else 1.0
if lo > 0 and hi > 0:
default = math.sqrt(lo * hi)
else:
default = (lo + hi) / 2.0
logger.warning(
"Parameter has no value set, using bounds midpoint",
parameter=name,
bounds=param.bounds,
default=default,
)
values.append(default)
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Getting all parameter values",
operation="get_values",
params=list(self._parameters.keys()),
num_params=len(values),
)
return np.array(values, dtype=np.float64)
[docs]
def set_values(self, values: ArrayLike | dict[str, float]):
"""Set parameter values from array or dictionary.
Args:
values: Array of values in order, or dict mapping names to values
Raises:
ValueError: If wrong number of values (array) or unknown parameter (dict)
"""
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Setting multiple parameter values",
operation="set_values",
params=list(self._parameters.keys()),
)
if isinstance(values, dict):
for name, value in values.items():
if name not in self._parameters:
logger.error(
"Unknown parameter in dict",
parameter=name,
available_params=list(self._parameters.keys()),
)
raise ValueError(f"Unknown parameter: {name}")
self.set_value(name, float(value))
else:
values = np.atleast_1d(values)
if len(values) != len(self._order):
logger.error(
"Wrong number of values",
expected=len(self._order),
got=len(values),
)
raise ValueError(
f"Expected {len(self._order)} values, got {len(values)}"
)
for name, value in zip(self._order, values, strict=False):
self.set_value(name, float(value))
[docs]
def get_bounds(self) -> list[tuple[float | None, float | None]]:
"""Get bounds for all parameters.
Returns:
List of (min, max) tuples
"""
bounds: list[tuple[float | None, float | None]] = []
for name in self._order:
param = self._parameters[name]
if param.bounds:
bounds.append(param.bounds)
else:
bounds.append((None, None))
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Getting all parameter bounds",
operation="get_bounds",
params=list(self._parameters.keys()),
num_params=len(bounds),
)
return bounds
[docs]
def get_value(self, name: str) -> float | None:
"""Get value of a specific parameter.
Args:
name: Parameter name
Returns:
Parameter value or None
"""
param = self.get(name)
return param.value if param else None
[docs]
def unpack(self, *names: str) -> tuple[float | None, ...]:
"""Extract multiple parameter values in a single call.
This method provides a concise way to extract several parameter values
at once, reducing boilerplate in model implementations.
Args:
*names: Parameter names to extract
Returns:
Tuple of parameter values in the same order as requested.
Returns None for parameters with None values.
Raises:
KeyError: If any parameter name is not found. The error message
includes the missing name and lists available parameters.
Examples:
Basic usage - extract multiple parameters in one line:
>>> params = ParameterSet()
>>> _ = params.add('x', value=1.5)
>>> _ = params.add('G0', value=100.0)
>>> _ = params.add('tau0', value=0.01)
>>> x, G0, tau0 = params.unpack('x', 'G0', 'tau0')
>>> x
1.5
>>> G0
100.0
Before (verbose)::
x = params.get_value('x')
G0 = params.get_value('G0')
tau0 = params.get_value('tau0')
After (concise)::
x, G0, tau0 = params.unpack('x', 'G0', 'tau0')
"""
values = []
for name in names:
if name not in self._parameters:
available = list(self._parameters.keys())
raise KeyError(
f"Parameter '{name}' not found. "
f"Available parameters: {available}"
)
values.append(self.get_value(name))
return tuple(values)
[docs]
def __len__(self) -> int:
"""Number of parameters."""
return len(self._parameters)
[docs]
def __contains__(self, name: str) -> bool:
"""Check if parameter exists."""
return name in self._parameters
[docs]
def __iter__(self):
"""Iterate over parameter names."""
return iter(self._order)
[docs]
def keys(self):
"""Return an iterator over parameter names (dict-like interface).
Returns:
Iterator over parameter names in order
Examples:
>>> params = ParameterSet()
>>> params.add('alpha', value=0.5)
>>> params.add('beta', value=1.0)
>>> list(params.keys())
['alpha', 'beta']
"""
return iter(self._order)
[docs]
def values(self):
"""Return an iterator over Parameter objects (dict-like interface).
Returns:
Iterator over Parameter objects in order
Examples:
>>> params = ParameterSet()
>>> params.add('alpha', value=0.5, units='')
>>> for param in params.values():
... print(f"{param.name}: {param.value}")
alpha: 0.5
"""
for name in self._order:
yield self._parameters[name]
[docs]
def items(self):
"""Return an iterator over (name, Parameter) tuples (dict-like interface).
Returns:
Iterator over (name, Parameter) tuples in order
Examples:
>>> params = ParameterSet()
>>> params.add('alpha', value=0.5)
>>> for name, param in params.items():
... print(f"{name}: {param.value}")
alpha: 0.5
"""
for name in self._order:
yield name, self._parameters[name]
[docs]
def __getitem__(self, key: str) -> Parameter:
"""Get parameter by name using subscript notation.
Args:
key: Parameter name
Returns:
Parameter object
Raises:
KeyError: If parameter not found
Examples:
>>> params = ParameterSet()
>>> params.add('alpha', value=0.5)
>>> param = params['alpha'] # Get parameter object
>>> value = params['alpha'].value # Get value
"""
if key not in self._parameters:
logger.error(
"Parameter not found in subscript access",
parameter=key,
available_params=list(self._parameters.keys()),
)
raise KeyError(f"Parameter '{key}' not found in ParameterSet")
return self._parameters[key]
[docs]
def __setitem__(self, key: str, value: float | Parameter):
"""Set parameter value using subscript notation.
Args:
key: Parameter name
value: New value (float) or Parameter object
Raises:
KeyError: If parameter not found and value is float
ValueError: If value violates constraints
Examples:
>>> params = ParameterSet()
>>> params.add('alpha', value=0.5, bounds=(0, 1))
>>> params['alpha'] = 0.7 # Set value
>>> # Or replace entire parameter:
>>> params['alpha'] = Parameter('alpha', value=0.8, bounds=(0, 1))
"""
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Setting parameter via subscript",
operation="__setitem__",
parameter=key,
)
if isinstance(value, Parameter):
# Replace entire parameter
self._parameters[key] = value
if key not in self._order:
self._order.append(key)
if not self._has_relative_constraints and value.constraints:
if any(c.type == "relative" for c in value.constraints):
self._has_relative_constraints = True
else:
# Set value only
if key not in self._parameters:
logger.error(
"Parameter not found for subscript assignment",
parameter=key,
available_params=list(self._parameters.keys()),
)
raise KeyError(
f"Parameter '{key}' not found. Use add() to create new parameters."
)
self.set_value(key, float(value))
[docs]
def to_dict(self) -> dict[str, dict[str, Any]]:
"""Convert to dictionary representation."""
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Converting ParameterSet to dict",
operation="to_dict",
params=list(self._parameters.keys()),
)
return {name: self._parameters[name].to_dict() for name in self._order}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> ParameterSet:
"""Create from dictionary representation.
Uses Parameter.from_dict() to preserve constraints (not just bounds).
"""
# R8-PARAMS-003: NOTE — _was_clamped and _clamp_on_set flags are not
# preserved across serialization. Clamping behavior may differ after round-trip.
if logger.isEnabledFor(10): # logging.DEBUG == 10
logger.debug(
"Creating ParameterSet from dict",
operation="from_dict",
params=list(data.keys()),
)
params = cls()
for name, param_data in data.items():
if isinstance(param_data, dict):
# Ensure name is in param_data for Parameter.from_dict
param_data_with_name = {**param_data, "name": name}
param = Parameter.from_dict(param_data_with_name)
params._parameters[name] = param
params._order.append(name)
for _param in params._parameters.values():
if not params._has_relative_constraints and _param.constraints:
if any(c.type == "relative" for c in _param.constraints):
params._has_relative_constraints = True
break
return params
[docs]
class SharedParameterSet:
"""Manages parameters shared across multiple models."""
[docs]
def __init__(self):
"""Initialize shared parameter set."""
self._shared: dict[str, Parameter] = {}
self._links: dict[str, list[Any]] = {} # Parameter -> list of linked objects
self._groups: dict[str, list[str]] = {} # Group name -> parameter names
logger.debug("SharedParameterSet created")
[docs]
def add_shared(
self,
name: str,
value: float | None = None,
bounds: tuple[float, float] | None = None,
units: str | None = None,
constraints: list[ParameterConstraint] | None = None,
group: str | None = None,
) -> Parameter:
"""Add a shared parameter.
Args:
name: Parameter name
value: Initial value
bounds: Value bounds
units: Parameter units
constraints: Parameter constraints
group: Optional group name
Returns:
The created Parameter
"""
logger.debug(
"Adding shared parameter",
operation="add_shared",
parameter=name,
value=value,
bounds=bounds,
group=group,
)
param = Parameter(
name=name,
value=value,
bounds=bounds,
units=units,
constraints=constraints or [],
)
self._shared[name] = param
self._links[name] = []
if group:
if group not in self._groups:
self._groups[group] = []
self._groups[group].append(name)
logger.debug(
"Shared parameter added",
operation="add_shared",
params=list(self._shared.keys()),
)
return param
[docs]
def link_model(self, model: Any, param_name: str):
"""Link a model to a shared parameter.
Args:
model: Model to link
param_name: Name of shared parameter
"""
logger.debug(
"Linking model to shared parameter",
operation="link_model",
parameter=param_name,
)
if param_name not in self._shared:
logger.error(
"Shared parameter not found for linking",
parameter=param_name,
available_params=list(self._shared.keys()),
)
raise KeyError(f"Shared parameter '{param_name}' not found")
if model not in self._links[param_name]:
self._links[param_name].append(model)
[docs]
def link_parameter_set(self, param_set: ParameterSet, param_name: str):
"""Link a parameter set to a shared parameter.
Args:
param_set: ParameterSet to link
param_name: Name of shared parameter
"""
logger.debug(
"Linking ParameterSet to shared parameter",
operation="link_parameter_set",
parameter=param_name,
)
if param_name not in self._shared:
logger.error(
"Shared parameter not found for linking",
parameter=param_name,
available_params=list(self._shared.keys()),
)
raise KeyError(f"Shared parameter '{param_name}' not found")
if param_set not in self._links[param_name]:
self._links[param_name].append(param_set)
[docs]
def set_value(self, name: str, value: float):
"""Set shared parameter value.
Args:
name: Parameter name
value: New value
Raises:
ValueError: If value violates constraints
"""
logger.debug(
"Setting shared parameter value",
operation="set_value",
parameter=name,
value=value,
)
if name not in self._shared:
logger.error(
"Shared parameter not found",
parameter=name,
available_params=list(self._shared.keys()),
)
raise KeyError(f"Shared parameter '{name}' not found")
param = self._shared[name]
# Validate
if not param.validate(value):
logger.error(
"Value violates constraints for shared parameter",
parameter=name,
value=value,
)
raise ValueError(
f"Value {value} violates constraints for parameter '{name}'"
)
param.value = value
# Update linked models/parameter sets
for linked in self._links.get(name, []):
if (
hasattr(linked, "set_value")
and hasattr(linked, "__contains__")
and name in linked
):
# This is a ParameterSet with the parameter
linked.set_value(name, value)
elif hasattr(linked, "parameters") and name in linked.parameters:
# This is a model with parameters
linked.parameters.set_value(name, value)
[docs]
def get_value(self, name: str) -> float | None:
"""Get shared parameter value.
Args:
name: Parameter name
Returns:
Parameter value or None
"""
param = self._shared.get(name)
return param.value if param else None
[docs]
def get_linked_models(self, param_name: str) -> list[Any]:
"""Get models linked to a parameter.
Args:
param_name: Parameter name
Returns:
List of linked models
"""
return self._links.get(param_name, [])
[docs]
def create_group(self, group_name: str, param_names: list[str]):
"""Create a parameter group.
Args:
group_name: Name for the group
param_names: Parameter names to include
"""
logger.debug(
"Creating parameter group",
operation="create_group",
group=group_name,
params=param_names,
)
self._groups[group_name] = param_names
[docs]
def get_group(self, group_name: str) -> list[str]:
"""Get parameters in a group.
Args:
group_name: Group name
Returns:
List of parameter names in group
"""
return self._groups.get(group_name, [])
[docs]
def __contains__(self, name: str) -> bool:
"""Check if shared parameter exists."""
return name in self._shared
[docs]
class ParameterOptimizer:
"""Optimizer for parameter fitting."""
[docs]
def __init__(
self,
parameters: ParameterSet,
use_jax: bool = False,
track_history: bool = False,
):
"""Initialize parameter optimizer.
Args:
parameters: ParameterSet to optimize
use_jax: Whether to use JAX for optimization
track_history: Whether to track optimization history
"""
self.parameters = parameters
self.use_jax = use_jax and HAS_JAX
self.track_history = track_history
self.history: list[dict[str, Any]] = []
self.objective: Callable | None = None
self.constraints: list[Callable] = []
self.callback: Callable | None = None
logger.debug(
"ParameterOptimizer created",
num_params=len(parameters),
use_jax=self.use_jax,
track_history=track_history,
)
@property
def n_parameters(self) -> int:
"""Number of parameters."""
return len(self.parameters)
[docs]
def get_values(self) -> np.ndarray:
"""Get current parameter values."""
return self.parameters.get_values()
[docs]
def get_bounds(self) -> list[tuple[float | None, float | None]]:
"""Get parameter bounds."""
return self.parameters.get_bounds()
[docs]
def set_objective(self, objective: Callable):
"""Set objective function to minimize.
Args:
objective: Function that takes parameter values and returns scalar
"""
logger.debug(
"Setting objective function",
operation="set_objective",
)
self.objective = objective
[docs]
def evaluate(self, values: ArrayLike) -> float:
"""Evaluate objective at given values.
Args:
values: Parameter values
Returns:
Objective function value
"""
if self.objective is None:
logger.error(
"No objective function set",
)
raise ValueError("No objective function set")
result = self.objective(values)
# Convert to float if needed
if isinstance(result, (np.ndarray, jnp.ndarray)):
result = float(result)
return result
[docs]
def compute_gradient(self, values: ArrayLike) -> np.ndarray:
"""Compute gradient of objective.
Args:
values: Parameter values
Returns:
Gradient vector
"""
logger.debug(
"Computing gradient",
operation="compute_gradient",
use_jax=self.use_jax,
)
if not self.use_jax or not HAS_JAX:
# Numerical gradient
eps = 1e-8
values_array = _coerce_array(values)
n = len(values_array)
grad = np.zeros(n)
for i in range(n):
values_plus = values_array.copy()
values_plus[i] += eps
f_plus = self.evaluate(values_plus)
f = self.evaluate(values_array)
grad[i] = (f_plus - f) / eps
return grad
else:
# JAX automatic differentiation
grad_fn = jax.grad(self.objective)
return np.array(grad_fn(jnp.array(values)))
[docs]
def add_constraint(self, constraint: Callable):
"""Add optimization constraint.
Args:
constraint: Function that returns >= 0 for valid values
"""
logger.debug(
"Adding optimization constraint",
operation="add_constraint",
num_constraints=len(self.constraints) + 1,
)
self.constraints.append(constraint)
[docs]
def validate_constraints(self, values: ArrayLike) -> bool:
"""Check if constraints are satisfied.
Args:
values: Parameter values
Returns:
True if all constraints satisfied
"""
values_array = _coerce_array(values)
for constraint in self.constraints:
if constraint(values_array) < 0:
logger.debug(
"Constraint validation failed",
operation="validate_constraints",
)
return False
return True
[docs]
def set_callback(self, callback: Callable):
"""Set optimization callback.
Args:
callback: Function called after each iteration
"""
logger.debug(
"Setting optimization callback",
operation="set_callback",
)
self.callback = callback
[docs]
def step(self, values: ArrayLike, iteration: int | None = None):
"""Perform one optimization step.
Args:
values: Current parameter values
iteration: Current iteration number
"""
# Update parameters
coerced_values = _coerce_array(values)
self.parameters.set_values(coerced_values)
# Evaluate objective
obj_value = self.evaluate(coerced_values)
# Track history
if self.track_history:
# Use `is not None` guard so iteration=0 (first step) is stored
# correctly. The `or` sentinel coerces 0 → len(history), recording
# the wrong iteration number when history is already non-empty.
effective_iter = iteration if iteration is not None else len(self.history)
self.history.append(
{
"iteration": effective_iter,
"values": coerced_values.copy(),
"objective": obj_value,
}
)
# Call callback
if self.callback:
# Same guard: iteration=0 must reach the callback as 0.
self.callback(
iteration if iteration is not None else 0, coerced_values, obj_value
)
[docs]
def get_history(self) -> list[dict[str, Any]]:
"""Get optimization history.
Returns:
List of history dictionaries
"""
return self.history