Source code for base_attentive._bootstrap

"""
Internal runtime bootstrap for base_attentive.

This module centralizes runtime setup, backend resolution,
and lightweight compatibility helpers used across the package.

It is intentionally private and should not be treated as part of the
stable public API.
"""

from __future__ import annotations

import importlib
import os
import sys
from types import SimpleNamespace
from typing import Any

from ._runtime_requirements import (
    backend_install_command,
    backend_packages,
)

import numpy as np

__all__ = [
    "KERAS_BACKEND",
    "KERAS_DEPS",
    "_ORIGINAL_KERAS_DEPS_GETATTR",
    "dependency_message",
    "get_backend",
    "set_backend",
    "get_available_backends",
    "get_backend_capabilities",
    "enable_eager_runtime_imports",
]


_ALLOW_EAGER_RUNTIME_IMPORTS = (
    os.environ.get("BASE_ATTENTIVE_EAGER_RUNTIME", "0") == "1"
)


def enable_eager_runtime_imports(
    enabled: bool = True,
) -> None:
    """Toggle eager imports of standalone Keras / TensorFlow runtime modules."""
    global _ALLOW_EAGER_RUNTIME_IMPORTS
    _ALLOW_EAGER_RUNTIME_IMPORTS = bool(enabled)
    os.environ["BASE_ATTENTIVE_EAGER_RUNTIME"] = (
        "1" if enabled else "0"
    )


def _runtime_imports_permitted() -> bool:
    """Return whether importing real runtime modules is allowed now.

    During pytest collection we prefer loaded-module lookups and fallback
    symbols instead of importing TensorFlow or Keras eagerly.
    """
    if "pytest" in sys.modules:
        return False
    return _ALLOW_EAGER_RUNTIME_IMPORTS


_BACKEND_ALIASES = {
    "tf": "tensorflow",
    "tensorflow": "tensorflow",
    "jax": "jax",
    "torch": "torch",
    "pytorch": "torch",
}


def _normalize_configured_backend(name: str | None) -> str | None:
    if name is None:
        return None
    normalized = str(name).strip().lower()
    if not normalized:
        return None
    if normalized == "auto":
        return "auto"
    if normalized == "keras":
        configured = os.environ.get("KERAS_BACKEND")
        return _normalize_configured_backend(configured)
    return _BACKEND_ALIASES.get(normalized, normalized)


def _resolve_runtime_backend() -> str | None:
    configured = os.environ.get("BASE_ATTENTIVE_BACKEND")
    normalized = _normalize_configured_backend(configured)
    if normalized is not None:
        return normalized

    configured = os.environ.get("KERAS_BACKEND")
    normalized = _normalize_configured_backend(configured)
    if normalized is not None:
        return normalized

    return None


def _auto_install_enabled() -> bool:
    return os.environ.get("BASE_ATTENTIVE_AUTO_INSTALL", "0").strip().lower() in {
        "1",
        "true",
        "yes",
        "on",
    }


def _configured_backend_display() -> str:
    configured = _resolve_runtime_backend()
    return configured or "<unset>"


def _backend_not_configured_message(module_name: str) -> str:
    return (
        f"BaseAttentive backend is not configured for {module_name}. "
        "Set BASE_ATTENTIVE_BACKEND to one of: tensorflow, torch, jax, or auto. "
        "Example: BASE_ATTENTIVE_BACKEND=torch. "
        "If you want deferred installation when a runtime is missing, set "
        "BASE_ATTENTIVE_AUTO_INSTALL=1."
    )


def _backend_missing_message(module_name: str, backend_name: str) -> str:
    packages = ", ".join(backend_packages(backend_name)) or backend_name
    install_cmd = backend_install_command(backend_name)
    auto_install_note = (
        " Automatic installation is enabled (BASE_ATTENTIVE_AUTO_INSTALL=1), "
        "so BaseAttentive will try to install it when runtime resolution is attempted."
        if _auto_install_enabled()
        else " Set BASE_ATTENTIVE_AUTO_INSTALL=1 to allow deferred installation when needed."
    )
    return (
        f"BaseAttentive backend '{backend_name}' is configured for {module_name}, "
        f"but its runtime is not installed ({packages}). Install it with: `{install_cmd}`."
        f"{auto_install_note}"
    )


KERAS_BACKEND = _resolve_runtime_backend() or ""


def _set_runtime_backend(name: str | None) -> str | None:
    global KERAS_BACKEND
    normalized = _normalize_configured_backend(name)
    KERAS_BACKEND = normalized or ""
    if normalized:
        os.environ["BASE_ATTENTIVE_BACKEND"] = normalized
        os.environ["KERAS_BACKEND"] = normalized
    else:
        os.environ.pop("BASE_ATTENTIVE_BACKEND", None)
        os.environ.pop("KERAS_BACKEND", None)
    return normalized


def ensure_runtime_backend(module_name: str) -> str:
    """Ensure a configured backend is available when runtime resolution is needed."""
    configured_backend = _resolve_runtime_backend()
    if configured_backend is None:
        raise ImportError(_backend_not_configured_message(module_name))

    detector = importlib.import_module("base_attentive.backend.detector")
    auto_install = _auto_install_enabled()

    if configured_backend == "auto":
        try:
            chosen = detector.ensure_default_backend(
                auto_install=auto_install,
                install_tensorflow=True,
            )
        except RuntimeError as exc:
            raise ImportError(
                dependency_message(module_name)
            ) from exc
        _set_runtime_backend(chosen)
        return chosen

    available = detector.get_available_backends()
    if configured_backend not in available:
        if auto_install:
            try:
                detector.install_backend_runtime(configured_backend)
            except RuntimeError as exc:
                raise ImportError(
                    _backend_missing_message(module_name, configured_backend)
                ) from exc
        else:
            raise ImportError(
                _backend_missing_message(module_name, configured_backend)
            )
    _set_runtime_backend(configured_backend)
    return configured_backend


def get_backend(name: str | None = None):
    backend = importlib.import_module(
        "base_attentive.backend"
    )
    return backend.get_backend(name)


def set_backend(name: str):
    backend = importlib.import_module(
        "base_attentive.backend"
    )
    return backend.set_backend(name)


def get_available_backends():
    backend = importlib.import_module(
        "base_attentive.backend"
    )
    return backend.get_available_backends()


def get_backend_capabilities(name: str | None = None):
    backend = importlib.import_module(
        "base_attentive.backend"
    )
    return backend.get_backend_capabilities(name)


def _safe_import(module_name: str):
    try:
        return importlib.import_module(module_name)
    except Exception:
        return None


def _resolve_scalar(value: Any) -> Any:
    if isinstance(value, (int, float, bool, str)):
        return value
    if hasattr(value, "item"):
        try:
            return value.item()
        except Exception:
            pass
    return getattr(value, "value", None)


def _get_static_value(value: Any) -> Any:
    scalar = _resolve_scalar(value)
    if scalar is not None:
        return scalar

    tensor = sys.modules.get("tensorflow")
    if tensor is not None:
        get_static_value = getattr(
            tensor, "get_static_value", None
        )
        if callable(get_static_value):
            try:
                result = get_static_value(value)
                if result is not value:
                    return result
            except Exception:
                pass
    return None


def _normalize_dtype(dtype: Any) -> Any:
    """Convert runtime dtypes into a Keras-friendly representation."""
    if dtype is None:
        return None
    if isinstance(dtype, str):
        return dtype

    name = getattr(dtype, "name", None)
    if isinstance(name, str):
        return name

    as_numpy_dtype = getattr(dtype, "as_numpy_dtype", None)
    if as_numpy_dtype is not None:
        try:
            return np.dtype(as_numpy_dtype).name
        except Exception:
            pass

    try:
        return np.dtype(dtype).name
    except Exception:
        return dtype


class _KerasAutographExperimental:
    @staticmethod
    def do_not_convert(func=None, **kwargs):
        if func is None:

            def decorator(inner):
                return inner

            return decorator
        return func


class _KerasAutographNamespace:
    experimental = _KerasAutographExperimental()


class _KerasDebuggingNamespace:
    @staticmethod
    def assert_equal(actual, expected, message="", name=None):
        actual_value = _get_static_value(actual)
        expected_value = _get_static_value(expected)
        if (
            actual_value is not None
            and expected_value is not None
            and actual_value != expected_value
        ):
            raise AssertionError(
                message
                or f"{actual_value} != {expected_value}"
            )
        return None


class _KerasLinalgNamespace:
    @staticmethod
    def band_part(x, num_lower, num_upper):
        tf = _safe_import("tensorflow")
        if tf is not None:
            return tf.linalg.band_part(
                x, num_lower, num_upper
            )
        raise ImportError(
            "linalg.band_part is only available with TensorFlow installed."
        )


class _KerasDeps:
    """Resolve Keras symbols across Keras 3 and TensorFlow namespaces."""

    _OP_ALIASES = {
        "concat": "concatenate",
        "floordiv": "floor_divide",
        "reduce_mean": "mean",
        "reduce_sum": "sum",
        "reduce_max": "max",
        "range": "arange",
    }

    _SEARCH_PATHS = (
        None,
        "layers",
        "losses",
        "activations",
        "initializers",
        "models",
        "ops",
        "random",
        "saving",
        "utils",
    )

    def __init__(self):
        self._cache: dict[str, Any] = {}
        self._fallback_runtime = None
        self._cache_state: tuple[Any, ...] | None = None

    def _current_state(self) -> tuple[Any, ...]:
        return (
            KERAS_BACKEND,
            id(sys.modules.get("keras")),
            id(sys.modules.get("tensorflow")),
        )

    def _maybe_reset_cache(self) -> None:
        state = self._current_state()
        if self._cache_state != state:
            self._cache.clear()
            self._cache_state = state

    def _load_fallback_runtime(self):
        if self._fallback_runtime is None:
            self._fallback_runtime = _safe_import(
                "base_attentive._keras_fallback"
            )
        return self._fallback_runtime

    def _load_keras_root(self):
        """Return standalone Keras only when already loaded or explicitly enabled."""
        loaded = sys.modules.get("keras")
        if loaded is not None:
            return loaded
        if _runtime_imports_permitted():
            return _safe_import("keras")
        return None

    def _load_namespace(self, root: Any, name: str | None):
        if root is None:
            return None
        if name is None:
            return root

        namespace = getattr(root, name, None)
        if namespace is not None:
            return namespace

        module_name = getattr(root, "__name__", None)
        loaded_root = (
            sys.modules.get(module_name)
            if module_name
            else None
        )
        # Only walk into submodules when ``root`` is the actual loaded module
        # object, not an arbitrary stand-in that merely advertises a ``__name__``.
        if (
            module_name
            and loaded_root is root
            and isinstance(root, type(sys))
        ):
            return _safe_import(f"{module_name}.{name}")
        return None

    def _load_tensorflow(self):
        if KERAS_BACKEND != "tensorflow":
            return None
        loaded = sys.modules.get("tensorflow")
        if loaded is not None:
            return loaded
        if _runtime_imports_permitted():
            return _safe_import("tensorflow")
        return None

    def _resolve_special(self, name: str) -> Any:
        fallback = self._load_fallback_runtime()
        if name == "autograph":
            return _KerasAutographNamespace()
        if name == "debugging":
            tf = self._load_tensorflow()
            if tf is not None and hasattr(tf, "debugging"):
                return tf.debugging
            return _KerasDebuggingNamespace()
        if name == "newaxis":
            return None
        if name == "bool":
            return np.bool_
        if name == "float32":
            return np.float32
        if name == "int32":
            return np.int32
        if name == "Assert":
            tf = self._load_tensorflow()
            if tf is not None and hasattr(tf, "Assert"):
                return tf.Assert
            return getattr(
                fallback,
                "Assert",
                lambda condition,
                data=None,
                summarize=None,
                name=None: (condition),
            )
        if name == "Tensor":
            tf = self._load_tensorflow()
            if tf:
                return getattr(tf, "Tensor", object)
            return getattr(fallback, "Tensor", object)
        if name == "TensorShape":
            tf = self._load_tensorflow()
            if tf:
                return getattr(tf, "TensorShape", tuple)
            return getattr(fallback, "TensorShape", tuple)
        if name == "Reduction":
            keras = self._load_keras_root()
            losses = self._load_namespace(keras, "losses")
            reduction = getattr(losses, "Reduction", None)
            if reduction is not None:
                return reduction
            return getattr(
                fallback,
                "Reduction",
                SimpleNamespace(
                    AUTO="auto", SUM="sum", NONE="none"
                ),
            )
        if name == "get_static_value":
            return _get_static_value
        if name == "linalg":
            tf = self._load_tensorflow()
            if tf is not None and hasattr(tf, "linalg"):
                return tf.linalg
            return getattr(
                fallback,
                "linalg",
                _KerasLinalgNamespace(),
            )
        return None

    def _resolve_from_keras(self, name: str) -> Any:
        keras = self._load_keras_root()
        if keras is None:
            return None

        if name == "register_keras_serializable":
            for namespace_name in ("saving", "utils"):
                namespace = self._load_namespace(
                    keras,
                    namespace_name,
                )
                value = getattr(
                    namespace,
                    "register_keras_serializable",
                    None,
                )
                if value is not None:
                    return value

        if name == "get":
            losses = self._load_namespace(keras, "losses")
            value = getattr(losses, "get", None)
            if value is not None:
                return value

        if name == "activations":
            return self._load_namespace(keras, "activations")
        if name == "random":
            return self._load_namespace(keras, "random")

        target_name = self._OP_ALIASES.get(name, name)
        ops = self._load_namespace(keras, "ops")

        if name == "constant":
            convert_to_tensor = getattr(
                ops,
                "convert_to_tensor",
                None,
            )
            if callable(convert_to_tensor):

                def _constant(value, dtype=None):
                    normalized = _normalize_dtype(dtype)
                    if normalized is None:
                        return convert_to_tensor(value)
                    return convert_to_tensor(
                        value, dtype=normalized
                    )

                return _constant

        if name == "cast":
            cast = getattr(ops, "cast", None)
            if callable(cast):
                return lambda value, dtype, **kwargs: cast(
                    value,
                    _normalize_dtype(dtype),
                )

        if ops is not None:
            value = getattr(ops, target_name, None)
            if value is not None:
                return value

        for namespace_name in self._SEARCH_PATHS:
            namespace = self._load_namespace(
                keras, namespace_name
            )
            if namespace is None:
                continue
            value = getattr(namespace, target_name, None)
            if value is not None:
                return value
        return None

    def _resolve_from_tensorflow(self, name: str) -> Any:
        tf = self._load_tensorflow()
        if tf is None:
            return None

        target_name = self._OP_ALIASES.get(name, name)
        if hasattr(tf, target_name):
            return getattr(tf, target_name)

        # Avoid walking through ``tf.keras`` here. When TensorFlow exposes Keras
        # via its internal lazy loader, attribute access can recurse while it is
        # deciding whether to bridge to Keras 3. For model/layer namespaces we
        # rely on standalone ``keras`` above; TensorFlow is only used here for
        # root-level TF ops and dtypes.
        return None

    def _resolve_from_fallback(self, name: str) -> Any:
        fallback = self._load_fallback_runtime()
        if fallback is None:
            return None

        namespace_map = {
            "activations": "activations",
            "random": "random",
            "register_keras_serializable": "register_keras_serializable",
            "get": "get",
        }
        target_name = namespace_map.get(
            name,
            self._OP_ALIASES.get(name, name),
        )
        return getattr(fallback, target_name, None)

    def __getattr__(self, name: str) -> Any:
        self._maybe_reset_cache()
        if name in self._cache:
            return self._cache[name]
        if name == "newaxis":
            self._cache[name] = None
            return None

        value = self._resolve_special(name)
        if value is None:
            value = self._resolve_from_keras(name)
        if value is None:
            value = self._resolve_from_tensorflow(name)
        if value is None:
            value = self._resolve_from_fallback(name)
        if value is None:
            configured_backend = _resolve_runtime_backend()
            if configured_backend is None:
                raise ImportError(
                    _backend_not_configured_message("base_attentive runtime")
                )
            if configured_backend == "auto":
                raise ImportError(
                    "BaseAttentive backend is set to 'auto', but no suitable "
                    "runtime has been resolved yet. Install one of: "
                    "tensorflow keras, torch keras, or jax jaxlib keras; "
                    "or set BASE_ATTENTIVE_AUTO_INSTALL=1."
                )
            raise ImportError(
                _backend_missing_message(
                    "base_attentive runtime", configured_backend
                )
            )

        self._cache[name] = value
        return value


KERAS_DEPS = _KerasDeps()
_ORIGINAL_KERAS_DEPS_GETATTR = _KerasDeps.__getattr__


[docs] def dependency_message(module_name: str) -> str: """Return a dependency hint for missing runtime packages.""" configured_backend = _resolve_runtime_backend() if configured_backend is None: return _backend_not_configured_message(module_name) if configured_backend == "auto": return ( f"BaseAttentive backend is set to 'auto' for {module_name}. " "A runtime will be chosen on first use from the installed backends. " "If none is installed, install one of: `tensorflow keras`, `torch keras`, " "or `jax jaxlib keras`; or set BASE_ATTENTIVE_AUTO_INSTALL=1." ) return _backend_missing_message(module_name, configured_backend)