Source code for rheojax.core.registry

"""Plugin registry system for models and transforms.

This module provides a registry system for discovering, registering, and managing
models and transforms as plugins, enabling extensibility of the rheojax package.
"""

from __future__ import annotations

import importlib
import inspect
import os
import threading
from dataclasses import dataclass, field
from enum import Enum
from typing import Any

from rheojax.core.inventory import Protocol, TransformType
from rheojax.core.test_modes import DeformationMode
from rheojax.logging import get_logger

logger = get_logger(__name__)

# R11-REG-001: Lock for discover_directory sys.path mutation
_discover_lock = threading.Lock()


[docs] class PluginType(Enum): """Types of plugins that can be registered.""" MODEL = "model" TRANSFORM = "transform"
[docs] @dataclass class PluginInfo: """Information about a registered plugin.""" name: str plugin_class: type plugin_type: PluginType metadata: dict[str, Any] doc: str | None = None protocols: list[Protocol] = field(default_factory=list) deformation_modes: list[DeformationMode] = field(default_factory=list) transform_type: TransformType | None = None
[docs] def __post_init__(self): """Extract documentation from plugin class.""" if self.doc is None and self.plugin_class: self.doc = inspect.getdoc(self.plugin_class)
[docs] class Registry: """Central registry for models and transforms. This class manages plugin registration, discovery, and retrieval for all models and transforms in the rheojax package. """ _instance: Registry | None = None _lock: threading.Lock = threading.Lock() _models: dict[str, PluginInfo] _transforms: dict[str, PluginInfo]
[docs] def __new__(cls): """Ensure singleton pattern (thread-safe).""" if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._models = {} cls._instance._transforms = {} return cls._instance
def _normalize_plugin_type(self, plugin_type: PluginType | str) -> PluginType: """Normalize arbitrary plugin type inputs to the PluginType enum.""" if isinstance(plugin_type, str): try: return PluginType(plugin_type.lower()) except ValueError as exc: raise ValueError( f"Invalid plugin type: {plugin_type}. Must be 'model' or 'transform'" ) from exc return plugin_type def _registry_for( self, plugin_type: PluginType | str ) -> tuple[PluginType, dict[str, PluginInfo]]: """Return the normalized plugin type and its backing registry mapping.""" plugin_enum = self._normalize_plugin_type(plugin_type) if plugin_enum == PluginType.MODEL: return plugin_enum, self._models if plugin_enum == PluginType.TRANSFORM: return plugin_enum, self._transforms raise ValueError(f"Unsupported plugin type: {plugin_type}")
[docs] @classmethod def get_instance(cls) -> Registry: """Get the singleton registry instance. Returns: The global Registry instance """ return cls()
[docs] def register( self, name: str, plugin_class: type, plugin_type: PluginType | str, metadata: dict[str, Any] | None = None, validate: bool = False, force: bool = False, protocols: list[Protocol | str] | None = None, deformation_modes: list[DeformationMode | str] | None = None, transform_type: TransformType | str | None = None, ): """Register a plugin in the registry. Args: name: Unique name for the plugin plugin_class: The plugin class to register plugin_type: Type of plugin (MODEL or TRANSFORM) metadata: Optional metadata dictionary validate: Whether to validate the plugin interface force: Whether to overwrite existing registration protocols: List of supported protocols (for models) transform_type: Type of transform (for transforms) Raises: ValueError: If plugin is already registered (and force=False) or invalid """ plugin_enum, registry = self._registry_for(plugin_type) # Check if already registered if name in registry and not force: raise ValueError( f"Plugin '{name}' is already registered as a {plugin_enum.value}" ) elif name in registry and force: logger.warning( "Overwriting existing registration", name=name, plugin_type=plugin_enum.value, ) # Validate interface if requested if validate: self._validate_plugin(plugin_class, plugin_enum) # Normalize protocols normalized_protocols = [] if protocols: for p in protocols: if isinstance(p, str): try: normalized_protocols.append(Protocol(p)) except ValueError: logger.warning(f"Invalid protocol '{p}' for plugin '{name}'") elif isinstance(p, Protocol): # type: ignore[unreachable] normalized_protocols.append(p) # Normalize deformation_modes normalized_deformation_modes: list[DeformationMode] = [] if deformation_modes: for dm in deformation_modes: if isinstance(dm, str): try: normalized_deformation_modes.append(DeformationMode(dm)) except ValueError: logger.warning( f"Invalid deformation_mode '{dm}' for plugin '{name}'" ) elif isinstance(dm, DeformationMode): # type: ignore[unreachable] normalized_deformation_modes.append(dm) # Normalize transform_type normalized_transform_type = None if transform_type: if isinstance(transform_type, str): try: normalized_transform_type = TransformType(transform_type) except ValueError: logger.warning( f"Invalid transform_type '{transform_type}' for plugin '{name}'" ) elif isinstance(transform_type, TransformType): # type: ignore[unreachable] normalized_transform_type = transform_type # Create plugin info info = PluginInfo( name=name, plugin_class=plugin_class, plugin_type=plugin_enum, metadata=metadata or {}, protocols=normalized_protocols, deformation_modes=normalized_deformation_modes, transform_type=normalized_transform_type, ) # Register the plugin registry[name] = info
def _validate_plugin(self, plugin_class: type, plugin_type: PluginType): """Validate that a plugin implements the required interface. Args: plugin_class: The plugin class to validate plugin_type: Expected plugin type Raises: ValueError: If plugin doesn't implement required interface """ if plugin_type == PluginType.MODEL: # Check for required model methods required_methods = ["fit", "predict"] for method in required_methods: if not hasattr(plugin_class, method): raise ValueError( f"Model plugin does not implement required interface: missing '{method}' method" ) elif plugin_type == PluginType.TRANSFORM: # Check for required transform methods if not hasattr(plugin_class, "transform"): raise ValueError( "Transform plugin does not implement required interface: missing 'transform' method" )
[docs] def get( self, name: str, plugin_type: PluginType | str, raise_on_missing: bool = False, ) -> type | None: """Retrieve a registered plugin class. Args: name: Name of the plugin plugin_type: Type of plugin raise_on_missing: Whether to raise error if not found Returns: The plugin class, or None if not found Raises: KeyError: If plugin not found and raise_on_missing=True """ plugin_enum, registry = self._registry_for(plugin_type) info = registry.get(name) if info is not None: return info.plugin_class if raise_on_missing: raise KeyError( f"Plugin '{name}' not found in registry for type {plugin_enum.value}" ) return None
[docs] def get_info(self, name: str, plugin_type: PluginType | str) -> PluginInfo | None: """Get full information about a registered plugin. Args: name: Name of the plugin plugin_type: Type of plugin Returns: PluginInfo object or None if not found """ _, registry = self._registry_for(plugin_type) return registry.get(name)
[docs] def get_all_models(self) -> list[str]: """Get list of all registered model names. Returns: List of model names """ return list(self._models.keys())
[docs] def get_all_transforms(self) -> list[str]: """Get list of all registered transform names. Returns: List of transform names """ return list(self._transforms.keys())
[docs] def unregister(self, name: str, plugin_type: PluginType | str): """Remove a plugin from the registry. Args: name: Name of the plugin to remove plugin_type: Type of plugin """ _, registry = self._registry_for(plugin_type) # Remove if exists if name in registry: del registry[name]
[docs] def get_all(self) -> dict[str, tuple[type, PluginType]]: """Get all registered plugins with their types. Returns: Dictionary mapping plugin names to (class, type) tuples """ result = {} for name, info in self._models.items(): result[name] = (info.plugin_class, PluginType.MODEL) for name, info in self._transforms.items(): result[name] = (info.plugin_class, PluginType.TRANSFORM) return result
[docs] def clear(self): """Clear all registered plugins.""" self._models.clear() self._transforms.clear()
[docs] def __len__(self) -> int: """Get total number of registered plugins. Returns: Total count of models and transforms """ return len(self._models) + len(self._transforms)
[docs] def __contains__(self, name: str) -> bool: """Check if a plugin name is registered. Args: name: Plugin name to check Returns: True if registered as either model or transform """ return name in self._models or name in self._transforms
[docs] def get_stats(self) -> dict[str, int]: """Get registration statistics. Returns: Dictionary with counts of registered plugins """ return { "total": len(self), "models": len(self._models), "transforms": len(self._transforms), }
[docs] def discover(self, module_name: str): """Discover and register plugins from a module. Args: module_name: Name of the module to import and scan """ try: module = importlib.import_module(module_name) except ImportError: return # Scan module for plugins for name, obj in inspect.getmembers(module): if inspect.isclass(obj): # Check if it's a model if hasattr(obj, "fit") and hasattr(obj, "predict"): try: self.register(name, obj, PluginType.MODEL, validate=False) except ValueError as e: if "already registered" in str(e): pass # Expected: duplicate class during discovery else: raise # Check if it's a transform elif hasattr(obj, "transform"): try: self.register(name, obj, PluginType.TRANSFORM, validate=False) except ValueError as e: if "already registered" in str(e): pass # Expected: duplicate class during discovery else: raise
[docs] def discover_directory(self, path: str): """Discover plugins in a directory. Args: path: Path to directory containing plugin modules """ if not os.path.exists(path): return # R11-REG-001: Wrap sys.path mutation in a lock to prevent races # when multiple threads discover directories concurrently. import sys with _discover_lock: sys.path.insert(0, path) try: # Scan for Python files for filename in os.listdir(path): if filename.endswith(".py") and not filename.startswith("_"): module_name = filename[:-3] self.discover(module_name) finally: # Remove temporary path (use remove() not pop(0) for thread safety) try: sys.path.remove(path) except ValueError: pass
[docs] def create_instance( self, name: str, plugin_type: PluginType | str, *args, **kwargs ) -> Any: """Create an instance of a registered plugin. Args: name: Name of the plugin plugin_type: Type of plugin *args: Positional arguments for plugin constructor **kwargs: Keyword arguments for plugin constructor Returns: Instance of the plugin class Raises: KeyError: If plugin not found """ plugin_class = self.get(name, plugin_type, raise_on_missing=True) if plugin_class is None: raise RuntimeError( f"Registry returned None for plugin '{name}' of type {plugin_type}" ) return plugin_class(*args, **kwargs)
[docs] def find_compatible( self, protocol: Protocol | str | None = None, deformation_mode: DeformationMode | str | None = None, transform_type: TransformType | str | None = None, **criteria, ) -> list[str]: """Find plugins matching certain criteria. Args: protocol: Filter models by supported protocol deformation_mode: Filter models by supported deformation mode transform_type: Filter transforms by type **criteria: Additional criteria to match against plugin metadata Returns: List of plugin names matching all criteria """ compatible = [] # Check models if transform_type is None: for name, info in self._models.items(): # Protocol filtering if protocol: target_proto = ( Protocol(protocol) if isinstance(protocol, str) else protocol ) if target_proto not in info.protocols: continue # Deformation mode filtering if deformation_mode is not None: target_dm = ( DeformationMode(deformation_mode) if isinstance(deformation_mode, str) else deformation_mode ) if ( info.deformation_modes and target_dm not in info.deformation_modes ): continue elif not info.deformation_modes: # Models without explicit deformation_modes are shear-only if target_dm != DeformationMode.SHEAR: logger.debug( "Excluding model without deformation_modes for non-shear query", model=name, requested_mode=str(target_dm), ) continue if self._matches_criteria(info.metadata, criteria): compatible.append(name) # Check transforms if protocol is None: for name, info in self._transforms.items(): # Transform type filtering if transform_type: target_type = ( TransformType(transform_type) if isinstance(transform_type, str) else transform_type ) if info.transform_type != target_type: continue if self._matches_criteria(info.metadata, criteria): compatible.append(name) return compatible
def _matches_criteria( self, metadata: dict[str, Any], criteria: dict[str, Any] ) -> bool: """Check if metadata matches all criteria. Args: metadata: Plugin metadata criteria: Criteria to match Returns: True if all criteria match """ for key, value in criteria.items(): if key not in metadata or metadata[key] != value: return False return True
[docs] def export_state(self) -> dict[str, Any]: """Export registry state for serialization. Returns: Dictionary representation of registry state """ return { "models": { name: { "class_name": info.plugin_class.__name__, "module": info.plugin_class.__module__, "metadata": info.metadata, "protocols": [str(p) for p in info.protocols], "deformation_modes": [str(dm) for dm in info.deformation_modes], } for name, info in self._models.items() }, "transforms": { name: { "class_name": info.plugin_class.__name__, "module": info.plugin_class.__module__, "metadata": info.metadata, "transform_type": ( str(info.transform_type) if info.transform_type else None ), } for name, info in self._transforms.items() }, }
[docs] def import_state(self, state: dict[str, Any]): """Import registry state from serialization. Args: state: Dictionary representation of registry state """ # Import models for name, info in state.get("models", {}).items(): try: module = importlib.import_module(info["module"]) plugin_class = getattr(module, info["class_name"]) protocols = info.get("protocols", []) self.register( name, plugin_class, PluginType.MODEL, metadata=info.get("metadata", {}), force=True, protocols=protocols, ) except (ImportError, AttributeError): continue # Import transforms for name, info in state.get("transforms", {}).items(): try: module = importlib.import_module(info["module"]) plugin_class = getattr(module, info["class_name"]) transform_type = info.get("transform_type") self.register( name, plugin_class, PluginType.TRANSFORM, metadata=info.get("metadata", {}), force=True, transform_type=transform_type, ) except (ImportError, AttributeError): continue
[docs] def inventory(self) -> dict[str, Any]: """Get full inventory of registered plugins. Returns: Dictionary with models (by protocol) and transforms (by type) """ inventory: dict[str, Any] = { "models": {p.value: [] for p in Protocol}, "transforms": {t.value: [] for t in TransformType}, "all_models": [], "all_transforms": [], } # Populate models for name, info in self._models.items(): model_entry = { "name": name, "class": info.plugin_class.__name__, "description": info.doc.split("\n")[0] if info.doc else "", "protocols": [p.value for p in info.protocols], "deformation_modes": [dm.value for dm in info.deformation_modes], } inventory["all_models"].append(model_entry) for p in info.protocols: inventory["models"][p.value].append(name) # Populate transforms for name, info in self._transforms.items(): transform_entry = { "name": name, "class": info.plugin_class.__name__, "description": info.doc.split("\n")[0] if info.doc else "", "type": info.transform_type.value if info.transform_type else None, } inventory["all_transforms"].append(transform_entry) if info.transform_type: inventory["transforms"][info.transform_type.value].append(name) return inventory
# Decorator methods for easy registration
[docs] def model( self, name: str | None = None, protocols: list[Protocol | str] | None = None, **metadata, ): """Decorator for registering a model. Args: name: Optional name for the model (uses class name if not provided) protocols: List of supported protocols **metadata: Additional metadata for the model Returns: Decorator function """ def decorator(cls): model_name = name or cls.__name__ self.register( model_name, cls, PluginType.MODEL, metadata=metadata, protocols=protocols, ) return cls return decorator
[docs] def transform( self, name: str | None = None, transform_type: TransformType | str | None = None, **metadata, ): """Decorator for registering a transform. Args: name: Optional name for the transform (uses class name if not provided) transform_type: Type of transform **metadata: Additional metadata for the transform Returns: Decorator function """ def decorator(cls): transform_name = name or cls.__name__ self.register( transform_name, cls, PluginType.TRANSFORM, metadata=metadata, transform_type=transform_type, ) return cls return decorator
[docs] class ModelRegistry: """Convenient interface for model registration and creation. This class provides a simplified API specifically for models, delegating to the main Registry singleton. Example: >>> @ModelRegistry.register('maxwell', protocols=['relaxation']) >>> class Maxwell(BaseModel): ... pass >>> >>> model = ModelRegistry.create('maxwell') >>> models = ModelRegistry.find(protocol='relaxation') """ _registry: Registry | None = None @classmethod def _get_registry(cls) -> Registry: """Get the global registry instance.""" if cls._registry is None: cls._registry = Registry.get_instance() return cls._registry
[docs] @classmethod def register( cls, name: str, protocols: list[Protocol | str] | None = None, deformation_modes: list[DeformationMode | str] | None = None, **metadata, ): """Decorator for registering a model. Args: name: Name for the model protocols: List of supported protocols deformation_modes: List of supported deformation modes (shear, tension, bending, compression). Models with oscillation protocol that work in G-space can support all 4 modes via automatic E*<->G* conversion. **metadata: Additional metadata for the model Returns: Decorator function Example: >>> @ModelRegistry.register('maxwell', ... protocols=['relaxation', 'oscillation'], ... deformation_modes=['shear', 'tension', 'bending', 'compression']) >>> class Maxwell(BaseModel): ... pass """ def decorator(model_class): registry = cls._get_registry() registry.register( name, model_class, PluginType.MODEL, metadata=metadata, protocols=protocols, deformation_modes=deformation_modes, ) return model_class return decorator
[docs] @classmethod def create(cls, name: str, *args, **kwargs) -> Any: """Create a model instance by name (factory method). Args: name: Name of the model to create *args: Positional arguments for model constructor **kwargs: Keyword arguments for model constructor Returns: Instance of the model class Raises: KeyError: If model not found Example: >>> model = ModelRegistry.create('maxwell') """ registry = cls._get_registry() # If the model isn't registered yet, eagerly import all model modules # to trigger @ModelRegistry.register decorators (lazy-import fallback) if registry.get(name, PluginType.MODEL) is None: from rheojax.models import _ensure_all_registered _ensure_all_registered() return registry.create_instance(name, PluginType.MODEL, *args, **kwargs)
[docs] @classmethod def list_models(cls) -> list[str]: """List all registered model names (discovery). Returns: List of registered model names Example: >>> models = ModelRegistry.list_models() >>> print(models) ['maxwell', 'zener', 'springpot', ...] """ registry = cls._get_registry() return registry.get_all_models()
[docs] @classmethod def find( cls, protocol: Protocol | str | None = None, deformation_mode: DeformationMode | str | None = None, **criteria, ) -> list[str]: """Find models matching criteria. Args: protocol: Filter by supported protocol deformation_mode: Filter by supported deformation mode **criteria: Additional metadata criteria Returns: List of matching model names """ registry = cls._get_registry() return registry.find_compatible( protocol=protocol, deformation_mode=deformation_mode, **criteria )
[docs] @classmethod def get_info(cls, name: str) -> PluginInfo | None: """Get information about a registered model. Args: name: Name of the model Returns: PluginInfo object with model details Example: >>> info = ModelRegistry.get_info('maxwell') >>> print(info.protocols) [<Protocol.RELAXATION: 'relaxation'>, ...] """ registry = cls._get_registry() info = registry.get_info(name, PluginType.MODEL) if info is None: from rheojax.models import _ensure_all_registered _ensure_all_registered() info = registry.get_info(name, PluginType.MODEL) return info
[docs] @classmethod def for_protocol(cls, protocol: Protocol | str) -> list[PluginInfo]: """Get all models supporting a given protocol. Args: protocol: Protocol to filter by (e.g. ``"relaxation"``). Returns: List of PluginInfo objects for matching models. """ registry = cls._get_registry() names = registry.find_compatible(protocol=protocol) return [ registry.get_info(n, PluginType.MODEL) for n in names if registry.get_info(n, PluginType.MODEL) is not None ]
[docs] @classmethod def compatible_models(cls, data: Any) -> list[PluginInfo]: """Find models compatible with a RheoData instance. Reads ``data.test_mode`` (maps to protocol) and ``data.metadata.get("deformation_mode")`` to filter models. Args: data: RheoData instance with test_mode metadata. Returns: List of PluginInfo objects for compatible models. """ registry = cls._get_registry() # Extract protocol from test_mode test_mode = getattr(data, "test_mode", None) if test_mode is None: metadata = getattr(data, "metadata", {}) test_mode = metadata.get("test_mode") if test_mode is None: test_mode = metadata.get("detected_test_mode") protocol = test_mode if test_mode else None deformation_mode = None metadata = getattr(data, "metadata", {}) if isinstance(metadata, dict): deformation_mode = metadata.get("deformation_mode") names = registry.find_compatible( protocol=protocol, deformation_mode=deformation_mode, ) return [ registry.get_info(n, PluginType.MODEL) for n in names if registry.get_info(n, PluginType.MODEL) is not None ]
[docs] @classmethod def model_info(cls, name: str) -> Any: """Get aggregated model information including parameter metadata. Temporarily instantiates the model to read parameter names, bounds, and units. Returns a ``ModelInfo`` dataclass. Args: name: Registry name of the model. Returns: ModelInfo instance with full model metadata. Raises: KeyError: If model is not registered. """ from rheojax.core.fit_result import ModelInfo return ModelInfo.from_registry(name)
[docs] @classmethod def unregister(cls, name: str): """Unregister a model. Args: name: Name of the model to remove """ registry = cls._get_registry() registry.unregister(name, PluginType.MODEL)
[docs] class TransformRegistry: """Convenient interface for transform registration and creation. This class provides a simplified API specifically for transforms, delegating to the main Registry singleton. Example: >>> @TransformRegistry.register('fft_analysis', type='spectral') >>> class RheoAnalysis(BaseTransform): ... pass >>> >>> transform = TransformRegistry.create('fft_analysis') >>> transforms = TransformRegistry.find(type='spectral') """ _registry: Registry | None = None @classmethod def _get_registry(cls) -> Registry: """Get the global registry instance.""" if cls._registry is None: cls._registry = Registry.get_instance() return cls._registry
[docs] @classmethod def register(cls, name: str, type: TransformType | str | None = None, **metadata): """Decorator for registering a transform. Args: name: Name for the transform type: Type of transform (TransformType) **metadata: Additional metadata for the transform Returns: Decorator function Example: >>> @TransformRegistry.register('fft_analysis', type='spectral') >>> class RheoAnalysis(BaseTransform): ... pass """ def decorator(transform_class): registry = cls._get_registry() registry.register( name, transform_class, PluginType.TRANSFORM, metadata=metadata, transform_type=type, ) return transform_class return decorator
[docs] @classmethod def create(cls, name: str, *args, **kwargs) -> Any: """Create a transform instance by name (factory method). Args: name: Name of the transform to create *args: Positional arguments for transform constructor **kwargs: Keyword arguments for transform constructor Returns: Instance of the transform class Raises: KeyError: If transform not found Example: >>> transform = TransformRegistry.create('fft_analysis') """ registry = cls._get_registry() return registry.create_instance(name, PluginType.TRANSFORM, *args, **kwargs)
[docs] @classmethod def list_transforms(cls) -> list[str]: """List all registered transform names (discovery). Returns: List of registered transform names Example: >>> transforms = TransformRegistry.list_transforms() >>> print(transforms) ['fft_analysis', 'mastercurve', 'owchirp', ...] """ registry = cls._get_registry() return registry.get_all_transforms()
[docs] @classmethod def find(cls, type: TransformType | str | None = None, **criteria) -> list[str]: """Find transforms matching criteria. Args: type: Filter by transform type **criteria: Additional metadata criteria Returns: List of matching transform names """ registry = cls._get_registry() return registry.find_compatible(transform_type=type, **criteria)
[docs] @classmethod def get_info(cls, name: str) -> PluginInfo | None: """Get information about a registered transform. Args: name: Name of the transform Returns: PluginInfo object with transform details Example: >>> info = TransformRegistry.get_info('fft_analysis') >>> print(info.transform_type) TransformType.SPECTRAL """ registry = cls._get_registry() return registry.get_info(name, PluginType.TRANSFORM)
[docs] @classmethod def unregister(cls, name: str): """Unregister a transform. Args: name: Name of the transform to remove """ registry = cls._get_registry() registry.unregister(name, PluginType.TRANSFORM)