Backend Guide

BaseAttentive uses Keras 3 as its neural network runtime. The full model path is available with TensorFlow; JAX and Torch are experimental.

Support status

Backend

Status

Notes

TensorFlow

Stable

Full BaseAttentive model path; recommended for training and deployment.

JAX

Experimental

Runtime abstraction available; full model path under evaluation.

Torch (PyTorch)

Experimental

Runtime abstraction available; TorchDeviceManager for GPU/MPS management. Full model path under evaluation.

Selecting a backend

Set the backend before importing base_attentive or keras:

export KERAS_BACKEND=tensorflow
export BASE_ATTENTIVE_BACKEND=tensorflow

Or in Python:

import os
os.environ["KERAS_BACKEND"] = "tensorflow"
from base_attentive import BaseAttentive

The package also exposes helper functions for runtime inspection and control:

from base_attentive import (
    get_backend,
    get_available_backends,
    get_backend_capabilities,
    set_backend,
    normalize_backend_name,
)

set_backend("tensorflow")
backend = get_backend()
print(backend.name)
print(get_available_backends())
print(get_backend_capabilities("jax"))
print(normalize_backend_name("tf"))   # -> "tensorflow"

How backend resolution works

When you call get_backend() without an explicit name, BaseAttentive checks in this order:

  1. BASE_ATTENTIVE_BACKEND environment variable

  2. KERAS_BACKEND environment variable

  3. The backend previously set in the current Python process

  4. tensorflow (default)

V2 Backend Detection and Selection

New in v1.0.0 — richer detection and selection utilities:

from base_attentive import (
    detect_available_backends,
    select_best_backend,
    ensure_default_backend,
)

# Inspect all installed backends
info = detect_available_backends()
# Returns: {'tensorflow': {'available': True, 'version': '2.14.0'}, ...}
for name, details in info.items():
    print(f"{name}: available={details.get('available')}, "
          f"version={details.get('version')}")

# Select the best available backend automatically
best = select_best_backend(require_supported=True)
print(f"Best backend: {best}")

# Ensure a default backend is ready (optionally attempt auto-install)
name = ensure_default_backend(auto_install=False)

Version Compatibility Checks

from base_attentive.backend import (
    check_tensorflow_compatibility,
    check_torch_compatibility,
    get_backend_version,
    version_at_least,
)

# Check if installed TF version is compatible
ok, msg = check_tensorflow_compatibility()
print(msg)   # e.g. "TensorFlow 2.14.0 is compatible"

# Check PyTorch compatibility (requires >= 2.0.0)
ok, msg = check_torch_compatibility()

# Get installed version string
ver = get_backend_version("tensorflow")  # e.g. "2.14.0"

# Version comparison utility
ok = version_at_least("2.13.0", "2.12.0")  # True

Current recommendations

  • Use TensorFlow for training, testing, and model serialization.

  • Use JAX or Torch for backend experiments and research; treat the full BaseAttentive execution path as exploratory on those runtimes.

Accelerated Inference on TensorFlow

Wrap the model with make_fast_predict_fn for repeated inference:

from base_attentive import BaseAttentive, make_fast_predict_fn
import numpy as np

model = BaseAttentive(
    static_input_dim=4, dynamic_input_dim=8, future_input_dim=6,
    output_dim=1, forecast_horizon=24,
)

x_static  = np.random.randn(32, 4).astype('float32')
x_dynamic = np.random.randn(32, 100, 8).astype('float32')
x_future  = np.random.randn(32, 24, 6).astype('float32')

fast_predict = make_fast_predict_fn(
    model,
    warmup_inputs=[x_static, x_dynamic, x_future],
)
predictions = fast_predict([x_static, x_dynamic, x_future])

This helper wraps inference with tf.function and uses training=False. For best results, keep batch and sequence shapes relatively stable. For training, try model.compile(..., jit_compile="auto").

See Also