# SPDX-License-Identifier: Apache-2.0
# Author: LKouadio <etanoyau@gmail.com>
"""Lazy backend runtime abstraction for Base-Attentive.
This package exposes backend selection, capability inspection, and helper
utilities without importing all backend implementations eagerly.
"""
from __future__ import annotations
import importlib
import os
import warnings
from typing import Any
__all__ = [
"Backend",
"TensorFlowBackend",
"JaxBackend",
"TorchBackend",
"PyTorchBackend",
"get_backend",
"set_backend",
"get_available_backends",
"get_backend_capabilities",
"normalize_backend_name",
"detect_available_backends",
"select_best_backend",
"ensure_default_backend",
"get_backend_version",
"check_tensorflow_compatibility",
"check_torch_compatibility",
"parse_version",
"version_at_least",
"get_torch_device",
"get_torch_version",
"torch_is_available",
"TorchDeviceManager",
"_has_module",
"_import_module",
]
_CURRENT_BACKEND = None
# Eagerly bind torch_utils symbols so they are always real objects in
# this module's namespace regardless of test-ordering or sys.modules state.
try:
from base_attentive.backend.torch_utils import ( # noqa: E402
TorchDeviceManager,
check_torch_compatibility,
get_torch_device,
get_torch_version,
torch_is_available,
)
except Exception:
pass
def _module(name: str):
return importlib.import_module(name)
def _backend_classes() -> dict[str, type]:
implementations = _module(
"base_attentive.backend.implementations"
)
return {
"tensorflow": implementations.TensorFlowBackend,
"jax": implementations.JaxBackend,
"torch": implementations.TorchBackend,
"pytorch": implementations.PyTorchBackend,
}
[docs]
def normalize_backend_name(name: str | None) -> str:
detector = _module("base_attentive.backend.detector")
return detector.normalize_backend_name(name)
# detector wrappers -------------------------------------------------------
[docs]
def detect_available_backends():
detector = _module("base_attentive.backend.detector")
return detector.detect_available_backends()
[docs]
def select_best_backend(
prefer: str | None = None,
require_supported: bool = True,
):
detector = _module("base_attentive.backend.detector")
return detector.select_best_backend(
prefer=prefer,
require_supported=require_supported,
)
[docs]
def ensure_default_backend(
auto_install: bool = False,
install_tensorflow: bool = True,
) -> str:
detector = _module("base_attentive.backend.detector")
return detector.ensure_default_backend(
auto_install=auto_install,
install_tensorflow=install_tensorflow,
)
[docs]
def get_available_backends():
detector = _module("base_attentive.backend.detector")
return detector.get_available_backends()
# version wrappers --------------------------------------------------------
[docs]
def get_backend_version(name: str):
version_check = _module(
"base_attentive.backend.version_check"
)
return version_check.get_backend_version(name)
[docs]
def check_tensorflow_compatibility():
version_check = _module(
"base_attentive.backend.version_check"
)
return version_check.check_tensorflow_compatibility()
[docs]
def check_torch_compatibility():
version_check = _module(
"base_attentive.backend.version_check"
)
return version_check.check_torch_compatibility()
[docs]
def parse_version(version: str):
version_check = _module(
"base_attentive.backend.version_check"
)
return version_check.parse_version(version)
[docs]
def version_at_least(version: str, minimum: str):
version_check = _module(
"base_attentive.backend.version_check"
)
return version_check.version_at_least(version, minimum)
# core API ----------------------------------------------------------------
[docs]
def get_backend(name: str | None = None):
global _CURRENT_BACKEND
requested_name = name
if name is None:
env_name = os.environ.get("BASE_ATTENTIVE_BACKEND")
if env_name is None:
env_name = os.environ.get("KERAS_BACKEND")
if env_name is None and _CURRENT_BACKEND is not None:
return _CURRENT_BACKEND
if env_name is None or not str(env_name).strip():
raise RuntimeError(
"BaseAttentive backend is not configured. Set "
"BASE_ATTENTIVE_BACKEND to one of: tensorflow, torch, jax, or auto."
)
name = env_name
normalized = normalize_backend_name(name)
backends = _backend_classes()
detector = _module("base_attentive.backend.detector")
auto_install = os.environ.get(
"BASE_ATTENTIVE_AUTO_INSTALL", "0"
).strip().lower() in {"1", "true", "yes", "on"}
if normalized == "auto":
normalized = detector.ensure_default_backend(
auto_install=auto_install,
install_tensorflow=True,
)
if normalized not in backends:
raise ValueError(
f"Unknown backend: {name}. Available: {list(backends.keys())}"
)
backend_cls = backends[normalized]
try:
backend = backend_cls()
except ImportError as exc:
if auto_install:
detector.install_backend_runtime(normalized)
backend = backend_cls()
else:
available = detector.get_available_backends()
install_cmd = detector.backend_install_command(normalized)
raise ValueError(
f"Backend '{normalized}' is not available. "
f"Available backends: {available}. "
f"Install it with: {install_cmd}. "
"Or set BASE_ATTENTIVE_AUTO_INSTALL=1."
) from exc
if requested_name is None:
_CURRENT_BACKEND = backend
return backend
[docs]
def get_backend_capabilities(
name: str | None = None,
) -> dict[str, Any]:
backends = _backend_classes()
if name is None:
try:
backend = get_backend()
caps = backend.get_capabilities()
caps.setdefault(
"name", getattr(backend, "name", "unknown")
)
caps.setdefault(
"framework",
getattr(
backend,
"framework",
getattr(backend, "name", "unknown"),
),
)
caps.setdefault(
"available",
backend.is_available()
if hasattr(backend, "is_available")
else True,
)
caps.setdefault(
"uses_keras_runtime",
getattr(backend, "uses_keras_runtime", False),
)
caps.setdefault(
"experimental",
getattr(backend, "experimental", False),
)
caps.setdefault(
"supports_base_attentive",
getattr(
backend, "supports_base_attentive", False
),
)
caps.setdefault(
"supports_base_attentive_v2",
getattr(
backend,
"supports_base_attentive_v2",
False,
),
)
caps.setdefault(
"blockers",
list(getattr(backend, "blockers", ())),
)
caps.setdefault(
"v2_blockers",
list(getattr(backend, "v2_blockers", ())),
)
caps.setdefault(
"version",
get_backend_version(
getattr(backend, "name", "tensorflow")
),
)
return caps
except Exception:
name = os.environ.get(
"BASE_ATTENTIVE_BACKEND"
) or os.environ.get(
"KERAS_BACKEND",
"tensorflow",
)
normalized = normalize_backend_name(name)
if normalized not in backends:
raise ValueError(
f"Unknown backend: {name}. Available: {list(backends.keys())}"
)
backend_cls = backends[normalized]
try:
backend = backend_cls(load_runtime=False)
caps = backend.get_capabilities()
caps.setdefault(
"name", getattr(backend, "name", normalized)
)
caps.setdefault(
"framework",
getattr(backend_cls, "framework", normalized),
)
caps.setdefault(
"available",
backend.is_available()
if hasattr(backend, "is_available")
else True,
)
caps.setdefault(
"uses_keras_runtime",
getattr(backend, "uses_keras_runtime", False),
)
caps.setdefault(
"experimental",
getattr(backend, "experimental", False),
)
caps.setdefault(
"supports_base_attentive",
getattr(
backend, "supports_base_attentive", False
),
)
caps.setdefault(
"supports_base_attentive_v2",
getattr(
backend, "supports_base_attentive_v2", False
),
)
caps.setdefault(
"blockers", list(getattr(backend, "blockers", ()))
)
caps.setdefault(
"v2_blockers",
list(getattr(backend, "v2_blockers", ())),
)
caps["version"] = get_backend_version(normalized)
return caps
except Exception as exc:
return {
"name": normalized,
"framework": getattr(
backend_cls, "framework", normalized
),
"available": False,
"uses_keras_runtime": getattr(
backend_cls,
"uses_keras_runtime",
False,
),
"experimental": getattr(
backend_cls, "experimental", False
),
"supports_base_attentive": getattr(
backend_cls,
"supports_base_attentive",
False,
),
"supports_base_attentive_v2": getattr(
backend_cls,
"supports_base_attentive_v2",
False,
),
"blockers": list(
getattr(backend_cls, "blockers", ())
),
"v2_blockers": list(
getattr(backend_cls, "v2_blockers", ())
),
"version": get_backend_version(normalized),
"error": str(exc),
}
[docs]
def set_backend(name: str):
global _CURRENT_BACKEND
normalized = normalize_backend_name(name)
if normalized == "tensorflow":
is_compatible, msg = check_tensorflow_compatibility()
if not is_compatible:
warnings.warn(msg, RuntimeWarning, stacklevel=2)
base = _module("base_attentive.backend.base")
loaded_backend = base._read_loaded_keras_backend()
if loaded_backend and loaded_backend != normalized:
warnings.warn(
"Keras is already loaded with backend "
f"'{loaded_backend}'. Restart Python after switching to "
f"'{normalized}' for the change to take full effect.",
RuntimeWarning,
stacklevel=2,
)
_CURRENT_BACKEND = get_backend(normalized)
os.environ["BASE_ATTENTIVE_BACKEND"] = normalized
os.environ["KERAS_BACKEND"] = normalized
return _CURRENT_BACKEND
def _auto_initialize():
env_name = os.environ.get("BASE_ATTENTIVE_BACKEND")
if env_name is None:
env_name = os.environ.get("KERAS_BACKEND")
if env_name is None or not str(env_name).strip():
raise RuntimeError(
"BaseAttentive backend is not configured. Set BASE_ATTENTIVE_BACKEND first."
)
if normalize_backend_name(env_name) == "auto":
chosen = ensure_default_backend(
auto_install=os.environ.get("BASE_ATTENTIVE_AUTO_INSTALL", "0").strip().lower() in {"1", "true", "yes", "on"},
install_tensorflow=True,
)
return set_backend(chosen)
return set_backend(env_name)
# lazy attribute surface --------------------------------------------------
_LAZY_ATTRS = {
"_BACKENDS": ("base_attentive.backend.detector", "_BACKENDS"),
"Backend": ("base_attentive.backend.base", "Backend"),
"TensorFlowBackend": (
"base_attentive.backend.implementations",
"TensorFlowBackend",
),
"JaxBackend": (
"base_attentive.backend.implementations",
"JaxBackend",
),
"TorchBackend": (
"base_attentive.backend.implementations",
"TorchBackend",
),
"PyTorchBackend": (
"base_attentive.backend.implementations",
"PyTorchBackend",
),
"TorchDeviceManager": (
"base_attentive.backend.torch_utils",
"TorchDeviceManager",
),
"get_torch_device": (
"base_attentive.backend.torch_utils",
"get_torch_device",
),
"get_torch_version": (
"base_attentive.backend.torch_utils",
"get_torch_version",
),
"torch_is_available": (
"base_attentive.backend.torch_utils",
"torch_is_available",
),
"_has_module": (
"base_attentive.backend.detector",
"_has_module",
),
"_import_module": (
"base_attentive.backend.detector",
"_import_module",
),
}
def __getattr__(name: str):
target = _LAZY_ATTRS.get(name)
if target is None:
raise AttributeError(
f"module {__name__!r} has no attribute {name!r}"
)
module_name, attr_name = target
value = getattr(_module(module_name), attr_name)
globals()[name] = value
return value
def __dir__() -> list[str]:
return sorted(set(globals()) | set(__all__))