JAX Backend

BaseAttentive supports JAX through Keras 3. JAX can run on CPU, GPU (CUDA), and TPU.

Installation

pip install base-attentive[jax]

Or manually:

pip install "jax>=0.4.0" "jaxlib>=0.4.0"

For GPU (CUDA 12):

pip install "jax[cuda12]"

For TPU:

pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Selecting the JAX Backend

import os
os.environ["KERAS_BACKEND"] = "jax"

from base_attentive import BaseAttentive

model = BaseAttentive(
    static_input_dim=4,
    dynamic_input_dim=8,
    future_input_dim=6,
    output_dim=1,
    forecast_horizon=24,
)

Training Example

import os
os.environ["KERAS_BACKEND"] = "jax"

import numpy as np
from base_attentive import BaseAttentive

model = BaseAttentive(
    static_input_dim=4,
    dynamic_input_dim=8,
    future_input_dim=6,
    output_dim=1,
    forecast_horizon=24,
    embed_dim=32,
    num_heads=4,
)

model.compile(optimizer="adam", loss="mse")

x_static  = np.random.randn(32, 4).astype("float32")
x_dynamic = np.random.randn(32, 100, 8).astype("float32")
x_future  = np.random.randn(32, 24, 6).astype("float32")
y         = np.random.randn(32, 24, 1).astype("float32")

model.fit([x_static, x_dynamic, x_future], y, epochs=3)

Device Inspection

import jax
print(jax.devices())           # e.g. [CpuDevice(id=0)]
print(jax.default_backend())   # "cpu", "gpu", or "tpu"

Compatibility Check

from base_attentive.backend import get_backend_version, version_at_least

ver = get_backend_version("jax")
ok = version_at_least(ver, "0.4.0")
print(f"JAX {ver} compatible: {ok}")

Minimum required versions: jax 0.4.0, jaxlib 0.4.0.

Troubleshooting

JAX not found

pip install jax jaxlib

GPU not detected in JAX

Verify the CUDA build:

import jax
print(jax.devices("gpu"))

If the list is empty, install the CUDA-enabled jaxlib:

pip install "jax[cuda12]"

XLA compilation warnings

JAX traces and JIT-compiles operations on first call. Warm-up latency on the first model.predict call is expected.

See Also