"""Plugin registry system for models and transforms.
This module provides a registry system for discovering, registering, and managing
models and transforms as plugins, enabling extensibility of the rheojax package.
"""
from __future__ import annotations
import importlib
import inspect
import os
import threading
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from rheojax.core.inventory import Protocol, TransformType
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger
logger = get_logger(__name__)
# R11-REG-001: Lock for discover_directory sys.path mutation
_discover_lock = threading.Lock()
[docs]
class PluginType(Enum):
"""Types of plugins that can be registered."""
MODEL = "model"
TRANSFORM = "transform"
[docs]
@dataclass
class PluginInfo:
"""Information about a registered plugin."""
name: str
plugin_class: type
plugin_type: PluginType
metadata: dict[str, Any]
doc: str | None = None
protocols: list[Protocol] = field(default_factory=list)
deformation_modes: list[DeformationMode] = field(default_factory=list)
transform_type: TransformType | None = None
[docs]
def __post_init__(self):
"""Extract documentation from plugin class."""
if self.doc is None and self.plugin_class:
self.doc = inspect.getdoc(self.plugin_class)
[docs]
class Registry:
"""Central registry for models and transforms.
This class manages plugin registration, discovery, and retrieval
for all models and transforms in the rheojax package.
"""
_instance: Registry | None = None
_lock: threading.Lock = threading.Lock()
_models: dict[str, PluginInfo]
_transforms: dict[str, PluginInfo]
[docs]
def __new__(cls):
"""Ensure singleton pattern (thread-safe)."""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._models = {}
cls._instance._transforms = {}
return cls._instance
def _normalize_plugin_type(self, plugin_type: PluginType | str) -> PluginType:
"""Normalize arbitrary plugin type inputs to the PluginType enum."""
if isinstance(plugin_type, str):
try:
return PluginType(plugin_type.lower())
except ValueError as exc:
raise ValueError(
f"Invalid plugin type: {plugin_type}. Must be 'model' or 'transform'"
) from exc
return plugin_type
def _registry_for(
self, plugin_type: PluginType | str
) -> tuple[PluginType, dict[str, PluginInfo]]:
"""Return the normalized plugin type and its backing registry mapping."""
plugin_enum = self._normalize_plugin_type(plugin_type)
if plugin_enum == PluginType.MODEL:
return plugin_enum, self._models
if plugin_enum == PluginType.TRANSFORM:
return plugin_enum, self._transforms
raise ValueError(f"Unsupported plugin type: {plugin_type}")
[docs]
@classmethod
def get_instance(cls) -> Registry:
"""Get the singleton registry instance.
Returns:
The global Registry instance
"""
return cls()
[docs]
def register(
self,
name: str,
plugin_class: type,
plugin_type: PluginType | str,
metadata: dict[str, Any] | None = None,
validate: bool = False,
force: bool = False,
protocols: list[Protocol | str] | None = None,
deformation_modes: list[DeformationMode | str] | None = None,
transform_type: TransformType | str | None = None,
):
"""Register a plugin in the registry.
Args:
name: Unique name for the plugin
plugin_class: The plugin class to register
plugin_type: Type of plugin (MODEL or TRANSFORM)
metadata: Optional metadata dictionary
validate: Whether to validate the plugin interface
force: Whether to overwrite existing registration
protocols: List of supported protocols (for models)
transform_type: Type of transform (for transforms)
Raises:
ValueError: If plugin is already registered (and force=False) or invalid
"""
plugin_enum, registry = self._registry_for(plugin_type)
# Check if already registered
if name in registry and not force:
raise ValueError(
f"Plugin '{name}' is already registered as a {plugin_enum.value}"
)
elif name in registry and force:
logger.warning(
"Overwriting existing registration",
name=name,
plugin_type=plugin_enum.value,
)
# Validate interface if requested
if validate:
self._validate_plugin(plugin_class, plugin_enum)
# Normalize protocols
normalized_protocols = []
if protocols:
for p in protocols:
if isinstance(p, str):
try:
normalized_protocols.append(Protocol(p))
except ValueError:
logger.warning(f"Invalid protocol '{p}' for plugin '{name}'")
elif isinstance(p, Protocol): # type: ignore[unreachable]
normalized_protocols.append(p)
# Normalize deformation_modes
normalized_deformation_modes: list[DeformationMode] = []
if deformation_modes:
for dm in deformation_modes:
if isinstance(dm, str):
try:
normalized_deformation_modes.append(DeformationMode(dm))
except ValueError:
logger.warning(
f"Invalid deformation_mode '{dm}' for plugin '{name}'"
)
elif isinstance(dm, DeformationMode): # type: ignore[unreachable]
normalized_deformation_modes.append(dm)
# Normalize transform_type
normalized_transform_type = None
if transform_type:
if isinstance(transform_type, str):
try:
normalized_transform_type = TransformType(transform_type)
except ValueError:
logger.warning(
f"Invalid transform_type '{transform_type}' for plugin '{name}'"
)
elif isinstance(transform_type, TransformType): # type: ignore[unreachable]
normalized_transform_type = transform_type
# Create plugin info
info = PluginInfo(
name=name,
plugin_class=plugin_class,
plugin_type=plugin_enum,
metadata=metadata or {},
protocols=normalized_protocols,
deformation_modes=normalized_deformation_modes,
transform_type=normalized_transform_type,
)
# Register the plugin
registry[name] = info
def _validate_plugin(self, plugin_class: type, plugin_type: PluginType):
"""Validate that a plugin implements the required interface.
Args:
plugin_class: The plugin class to validate
plugin_type: Expected plugin type
Raises:
ValueError: If plugin doesn't implement required interface
"""
if plugin_type == PluginType.MODEL:
# Check for required model methods
required_methods = ["fit", "predict"]
for method in required_methods:
if not hasattr(plugin_class, method):
raise ValueError(
f"Model plugin does not implement required interface: missing '{method}' method"
)
elif plugin_type == PluginType.TRANSFORM:
# Check for required transform methods
if not hasattr(plugin_class, "transform"):
raise ValueError(
"Transform plugin does not implement required interface: missing 'transform' method"
)
[docs]
def get(
self,
name: str,
plugin_type: PluginType | str,
raise_on_missing: bool = False,
) -> type | None:
"""Retrieve a registered plugin class.
Args:
name: Name of the plugin
plugin_type: Type of plugin
raise_on_missing: Whether to raise error if not found
Returns:
The plugin class, or None if not found
Raises:
KeyError: If plugin not found and raise_on_missing=True
"""
plugin_enum, registry = self._registry_for(plugin_type)
info = registry.get(name)
if info is not None:
return info.plugin_class
if raise_on_missing:
raise KeyError(
f"Plugin '{name}' not found in registry for type {plugin_enum.value}"
)
return None
[docs]
def get_info(self, name: str, plugin_type: PluginType | str) -> PluginInfo | None:
"""Get full information about a registered plugin.
Args:
name: Name of the plugin
plugin_type: Type of plugin
Returns:
PluginInfo object or None if not found
"""
_, registry = self._registry_for(plugin_type)
return registry.get(name)
[docs]
def get_all_models(self) -> list[str]:
"""Get list of all registered model names.
Returns:
List of model names
"""
return list(self._models.keys())
[docs]
def unregister(self, name: str, plugin_type: PluginType | str):
"""Remove a plugin from the registry.
Args:
name: Name of the plugin to remove
plugin_type: Type of plugin
"""
_, registry = self._registry_for(plugin_type)
# Remove if exists
if name in registry:
del registry[name]
[docs]
def get_all(self) -> dict[str, tuple[type, PluginType]]:
"""Get all registered plugins with their types.
Returns:
Dictionary mapping plugin names to (class, type) tuples
"""
result = {}
for name, info in self._models.items():
result[name] = (info.plugin_class, PluginType.MODEL)
for name, info in self._transforms.items():
result[name] = (info.plugin_class, PluginType.TRANSFORM)
return result
[docs]
def clear(self):
"""Clear all registered plugins."""
self._models.clear()
self._transforms.clear()
[docs]
def __len__(self) -> int:
"""Get total number of registered plugins.
Returns:
Total count of models and transforms
"""
return len(self._models) + len(self._transforms)
[docs]
def __contains__(self, name: str) -> bool:
"""Check if a plugin name is registered.
Args:
name: Plugin name to check
Returns:
True if registered as either model or transform
"""
return name in self._models or name in self._transforms
[docs]
def get_stats(self) -> dict[str, int]:
"""Get registration statistics.
Returns:
Dictionary with counts of registered plugins
"""
return {
"total": len(self),
"models": len(self._models),
"transforms": len(self._transforms),
}
[docs]
def discover(self, module_name: str):
"""Discover and register plugins from a module.
Args:
module_name: Name of the module to import and scan
"""
try:
module = importlib.import_module(module_name)
except ImportError:
return
# Scan module for plugins
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj):
# Check if it's a model
if hasattr(obj, "fit") and hasattr(obj, "predict"):
try:
self.register(name, obj, PluginType.MODEL, validate=False)
except ValueError as e:
if "already registered" in str(e):
pass # Expected: duplicate class during discovery
else:
raise
# Check if it's a transform
elif hasattr(obj, "transform"):
try:
self.register(name, obj, PluginType.TRANSFORM, validate=False)
except ValueError as e:
if "already registered" in str(e):
pass # Expected: duplicate class during discovery
else:
raise
[docs]
def discover_directory(self, path: str):
"""Discover plugins in a directory.
Args:
path: Path to directory containing plugin modules
"""
if not os.path.exists(path):
return
# R11-REG-001: Wrap sys.path mutation in a lock to prevent races
# when multiple threads discover directories concurrently.
import sys
with _discover_lock:
sys.path.insert(0, path)
try:
# Scan for Python files
for filename in os.listdir(path):
if filename.endswith(".py") and not filename.startswith("_"):
module_name = filename[:-3]
self.discover(module_name)
finally:
# Remove temporary path (use remove() not pop(0) for thread safety)
try:
sys.path.remove(path)
except ValueError:
pass
[docs]
def create_instance(
self, name: str, plugin_type: PluginType | str, *args, **kwargs
) -> Any:
"""Create an instance of a registered plugin.
Args:
name: Name of the plugin
plugin_type: Type of plugin
*args: Positional arguments for plugin constructor
**kwargs: Keyword arguments for plugin constructor
Returns:
Instance of the plugin class
Raises:
KeyError: If plugin not found
"""
plugin_class = self.get(name, plugin_type, raise_on_missing=True)
if plugin_class is None:
raise RuntimeError(
f"Registry returned None for plugin '{name}' of type {plugin_type}"
)
return plugin_class(*args, **kwargs)
[docs]
def find_compatible(
self,
protocol: Protocol | str | None = None,
deformation_mode: DeformationMode | str | None = None,
transform_type: TransformType | str | None = None,
**criteria,
) -> list[str]:
"""Find plugins matching certain criteria.
Args:
protocol: Filter models by supported protocol
deformation_mode: Filter models by supported deformation mode
transform_type: Filter transforms by type
**criteria: Additional criteria to match against plugin metadata
Returns:
List of plugin names matching all criteria
"""
compatible = []
# Check models
if transform_type is None:
for name, info in self._models.items():
# Protocol filtering
if protocol:
target_proto = (
Protocol(protocol) if isinstance(protocol, str) else protocol
)
if target_proto not in info.protocols:
continue
# Deformation mode filtering
if deformation_mode is not None:
target_dm = (
DeformationMode(deformation_mode)
if isinstance(deformation_mode, str)
else deformation_mode
)
if (
info.deformation_modes
and target_dm not in info.deformation_modes
):
continue
elif not info.deformation_modes:
# Models without explicit deformation_modes are shear-only
if target_dm != DeformationMode.SHEAR:
logger.debug(
"Excluding model without deformation_modes for non-shear query",
model=name,
requested_mode=str(target_dm),
)
continue
if self._matches_criteria(info.metadata, criteria):
compatible.append(name)
# Check transforms
if protocol is None:
for name, info in self._transforms.items():
# Transform type filtering
if transform_type:
target_type = (
TransformType(transform_type)
if isinstance(transform_type, str)
else transform_type
)
if info.transform_type != target_type:
continue
if self._matches_criteria(info.metadata, criteria):
compatible.append(name)
return compatible
def _matches_criteria(
self, metadata: dict[str, Any], criteria: dict[str, Any]
) -> bool:
"""Check if metadata matches all criteria.
Args:
metadata: Plugin metadata
criteria: Criteria to match
Returns:
True if all criteria match
"""
for key, value in criteria.items():
if key not in metadata or metadata[key] != value:
return False
return True
[docs]
def export_state(self) -> dict[str, Any]:
"""Export registry state for serialization.
Returns:
Dictionary representation of registry state
"""
return {
"models": {
name: {
"class_name": info.plugin_class.__name__,
"module": info.plugin_class.__module__,
"metadata": info.metadata,
"protocols": [str(p) for p in info.protocols],
"deformation_modes": [str(dm) for dm in info.deformation_modes],
}
for name, info in self._models.items()
},
"transforms": {
name: {
"class_name": info.plugin_class.__name__,
"module": info.plugin_class.__module__,
"metadata": info.metadata,
"transform_type": (
str(info.transform_type) if info.transform_type else None
),
}
for name, info in self._transforms.items()
},
}
[docs]
def import_state(self, state: dict[str, Any]):
"""Import registry state from serialization.
Args:
state: Dictionary representation of registry state
"""
# Import models
for name, info in state.get("models", {}).items():
try:
module = importlib.import_module(info["module"])
plugin_class = getattr(module, info["class_name"])
protocols = info.get("protocols", [])
self.register(
name,
plugin_class,
PluginType.MODEL,
metadata=info.get("metadata", {}),
force=True,
protocols=protocols,
)
except (ImportError, AttributeError):
continue
# Import transforms
for name, info in state.get("transforms", {}).items():
try:
module = importlib.import_module(info["module"])
plugin_class = getattr(module, info["class_name"])
transform_type = info.get("transform_type")
self.register(
name,
plugin_class,
PluginType.TRANSFORM,
metadata=info.get("metadata", {}),
force=True,
transform_type=transform_type,
)
except (ImportError, AttributeError):
continue
[docs]
def inventory(self) -> dict[str, Any]:
"""Get full inventory of registered plugins.
Returns:
Dictionary with models (by protocol) and transforms (by type)
"""
inventory: dict[str, Any] = {
"models": {p.value: [] for p in Protocol},
"transforms": {t.value: [] for t in TransformType},
"all_models": [],
"all_transforms": [],
}
# Populate models
for name, info in self._models.items():
model_entry = {
"name": name,
"class": info.plugin_class.__name__,
"description": info.doc.split("\n")[0] if info.doc else "",
"protocols": [p.value for p in info.protocols],
"deformation_modes": [dm.value for dm in info.deformation_modes],
}
inventory["all_models"].append(model_entry)
for p in info.protocols:
inventory["models"][p.value].append(name)
# Populate transforms
for name, info in self._transforms.items():
transform_entry = {
"name": name,
"class": info.plugin_class.__name__,
"description": info.doc.split("\n")[0] if info.doc else "",
"type": info.transform_type.value if info.transform_type else None,
}
inventory["all_transforms"].append(transform_entry)
if info.transform_type:
inventory["transforms"][info.transform_type.value].append(name)
return inventory
# Decorator methods for easy registration
[docs]
def model(
self,
name: str | None = None,
protocols: list[Protocol | str] | None = None,
**metadata,
):
"""Decorator for registering a model.
Args:
name: Optional name for the model (uses class name if not provided)
protocols: List of supported protocols
**metadata: Additional metadata for the model
Returns:
Decorator function
"""
def decorator(cls):
model_name = name or cls.__name__
self.register(
model_name,
cls,
PluginType.MODEL,
metadata=metadata,
protocols=protocols,
)
return cls
return decorator
[docs]
class ModelRegistry:
"""Convenient interface for model registration and creation.
This class provides a simplified API specifically for models,
delegating to the main Registry singleton.
Example:
>>> @ModelRegistry.register('maxwell', protocols=['relaxation'])
>>> class Maxwell(BaseModel):
... pass
>>>
>>> model = ModelRegistry.create('maxwell')
>>> models = ModelRegistry.find(protocol='relaxation')
"""
_registry: Registry | None = None
@classmethod
def _get_registry(cls) -> Registry:
"""Get the global registry instance."""
if cls._registry is None:
cls._registry = Registry.get_instance()
return cls._registry
[docs]
@classmethod
def register(
cls,
name: str,
protocols: list[Protocol | str] | None = None,
deformation_modes: list[DeformationMode | str] | None = None,
**metadata,
):
"""Decorator for registering a model.
Args:
name: Name for the model
protocols: List of supported protocols
deformation_modes: List of supported deformation modes (shear, tension,
bending, compression). Models with oscillation protocol that work
in G-space can support all 4 modes via automatic E*<->G* conversion.
**metadata: Additional metadata for the model
Returns:
Decorator function
Example:
>>> @ModelRegistry.register('maxwell',
... protocols=['relaxation', 'oscillation'],
... deformation_modes=['shear', 'tension', 'bending', 'compression'])
>>> class Maxwell(BaseModel):
... pass
"""
def decorator(model_class):
registry = cls._get_registry()
registry.register(
name,
model_class,
PluginType.MODEL,
metadata=metadata,
protocols=protocols,
deformation_modes=deformation_modes,
)
return model_class
return decorator
[docs]
@classmethod
def create(cls, name: str, *args, **kwargs) -> Any:
"""Create a model instance by name (factory method).
Args:
name: Name of the model to create
*args: Positional arguments for model constructor
**kwargs: Keyword arguments for model constructor
Returns:
Instance of the model class
Raises:
KeyError: If model not found
Example:
>>> model = ModelRegistry.create('maxwell')
"""
registry = cls._get_registry()
# If the model isn't registered yet, eagerly import all model modules
# to trigger @ModelRegistry.register decorators (lazy-import fallback)
if registry.get(name, PluginType.MODEL) is None:
from rheojax.models import _ensure_all_registered
_ensure_all_registered()
return registry.create_instance(name, PluginType.MODEL, *args, **kwargs)
[docs]
@classmethod
def list_models(cls) -> list[str]:
"""List all registered model names (discovery).
Returns:
List of registered model names
Example:
>>> models = ModelRegistry.list_models()
>>> print(models)
['maxwell', 'zener', 'springpot', ...]
"""
registry = cls._get_registry()
return registry.get_all_models()
[docs]
@classmethod
def find(
cls,
protocol: Protocol | str | None = None,
deformation_mode: DeformationMode | str | None = None,
**criteria,
) -> list[str]:
"""Find models matching criteria.
Args:
protocol: Filter by supported protocol
deformation_mode: Filter by supported deformation mode
**criteria: Additional metadata criteria
Returns:
List of matching model names
"""
registry = cls._get_registry()
return registry.find_compatible(
protocol=protocol, deformation_mode=deformation_mode, **criteria
)
[docs]
@classmethod
def get_info(cls, name: str) -> PluginInfo | None:
"""Get information about a registered model.
Args:
name: Name of the model
Returns:
PluginInfo object with model details
Example:
>>> info = ModelRegistry.get_info('maxwell')
>>> print(info.protocols)
[<Protocol.RELAXATION: 'relaxation'>, ...]
"""
registry = cls._get_registry()
info = registry.get_info(name, PluginType.MODEL)
if info is None:
from rheojax.models import _ensure_all_registered
_ensure_all_registered()
info = registry.get_info(name, PluginType.MODEL)
return info
[docs]
@classmethod
def for_protocol(cls, protocol: Protocol | str) -> list[PluginInfo]:
"""Get all models supporting a given protocol.
Args:
protocol: Protocol to filter by (e.g. ``"relaxation"``).
Returns:
List of PluginInfo objects for matching models.
"""
registry = cls._get_registry()
names = registry.find_compatible(protocol=protocol)
return [
registry.get_info(n, PluginType.MODEL)
for n in names
if registry.get_info(n, PluginType.MODEL) is not None
]
[docs]
@classmethod
def compatible_models(cls, data: Any) -> list[PluginInfo]:
"""Find models compatible with a RheoData instance.
Reads ``data.test_mode`` (maps to protocol) and
``data.metadata.get("deformation_mode")`` to filter models.
Args:
data: RheoData instance with test_mode metadata.
Returns:
List of PluginInfo objects for compatible models.
"""
registry = cls._get_registry()
# Extract protocol from test_mode
test_mode = getattr(data, "test_mode", None)
if test_mode is None:
metadata = getattr(data, "metadata", {})
test_mode = metadata.get("test_mode")
if test_mode is None:
test_mode = metadata.get("detected_test_mode")
protocol = test_mode if test_mode else None
deformation_mode = None
metadata = getattr(data, "metadata", {})
if isinstance(metadata, dict):
deformation_mode = metadata.get("deformation_mode")
names = registry.find_compatible(
protocol=protocol,
deformation_mode=deformation_mode,
)
return [
registry.get_info(n, PluginType.MODEL)
for n in names
if registry.get_info(n, PluginType.MODEL) is not None
]
[docs]
@classmethod
def model_info(cls, name: str) -> Any:
"""Get aggregated model information including parameter metadata.
Temporarily instantiates the model to read parameter names, bounds,
and units. Returns a ``ModelInfo`` dataclass.
Args:
name: Registry name of the model.
Returns:
ModelInfo instance with full model metadata.
Raises:
KeyError: If model is not registered.
"""
from rheojax.core.fit_result import ModelInfo
return ModelInfo.from_registry(name)
[docs]
@classmethod
def unregister(cls, name: str):
"""Unregister a model.
Args:
name: Name of the model to remove
"""
registry = cls._get_registry()
registry.unregister(name, PluginType.MODEL)