Source code for base_attentive.registry.model_registry

"""Model registry for V2 assemblers."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable

from ..backend import normalize_backend_name


@dataclass(frozen=True)
class ModelRegistration:
    """Registered model assembler metadata."""

    key: str
    backend: str
    builder: Callable[..., Any]
    description: str = ""
    experimental: bool = False


[docs] class ModelRegistry: """Registry of backend-specific model assemblers.""" def __init__(self): self._registrations: dict[ str, dict[str, ModelRegistration] ] = {}
[docs] def register( self, key: str, builder: Callable[..., Any], *, backend: str = "generic", description: str = "", experimental: bool = False, replace: bool = False, ) -> ModelRegistration: normalized_backend = ( "generic" if backend == "generic" else normalize_backend_name(backend) ) by_backend = self._registrations.setdefault(key, {}) if normalized_backend in by_backend and not replace: raise KeyError( f"Model {key!r} is already registered for backend " f"{normalized_backend!r}." ) registration = ModelRegistration( key=key, backend=normalized_backend, builder=builder, description=description, experimental=experimental, ) by_backend[normalized_backend] = registration return registration
[docs] def has( self, key: str, *, backend: str | None = None ) -> bool: if key not in self._registrations: return False if backend is None: return True normalized_backend = ( "generic" if backend == "generic" else normalize_backend_name(backend) ) return normalized_backend in self._registrations[key]
[docs] def resolve( self, key: str, *, backend: str, allow_generic: bool = True, ) -> ModelRegistration: normalized_backend = normalize_backend_name(backend) by_backend = self._registrations.get(key) if not by_backend: raise KeyError(f"Unknown model key: {key!r}.") registration = by_backend.get(normalized_backend) if registration is not None: return registration if allow_generic: registration = by_backend.get("generic") if registration is not None: return registration available = ", ".join(sorted(by_backend)) raise KeyError( f"Model {key!r} is not registered for backend " f"{normalized_backend!r}. Available: {available}." )
DEFAULT_MODEL_REGISTRY = ModelRegistry() __all__ = [ "ModelRegistration", "ModelRegistry", "DEFAULT_MODEL_REGISTRY", ]