JAX Backend
BaseAttentive supports JAX through Keras 3. JAX can run on CPU, GPU (CUDA), and TPU.
Installation
pip install base-attentive[jax]
Or manually:
pip install "jax>=0.4.0" "jaxlib>=0.4.0"
For GPU (CUDA 12):
pip install "jax[cuda12]"
For TPU:
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Selecting the JAX Backend
import os
os.environ["KERAS_BACKEND"] = "jax"
from base_attentive import BaseAttentive
model = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
)
Training Example
import os
os.environ["KERAS_BACKEND"] = "jax"
import numpy as np
from base_attentive import BaseAttentive
model = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
embed_dim=32,
num_heads=4,
)
model.compile(optimizer="adam", loss="mse")
x_static = np.random.randn(32, 4).astype("float32")
x_dynamic = np.random.randn(32, 100, 8).astype("float32")
x_future = np.random.randn(32, 24, 6).astype("float32")
y = np.random.randn(32, 24, 1).astype("float32")
model.fit([x_static, x_dynamic, x_future], y, epochs=3)
Device Inspection
import jax
print(jax.devices()) # e.g. [CpuDevice(id=0)]
print(jax.default_backend()) # "cpu", "gpu", or "tpu"
Compatibility Check
from base_attentive.backend import get_backend_version, version_at_least
ver = get_backend_version("jax")
ok = version_at_least(ver, "0.4.0")
print(f"JAX {ver} compatible: {ok}")
Minimum required versions: jax 0.4.0, jaxlib 0.4.0.
Troubleshooting
JAX not found
pip install jax jaxlib
GPU not detected in JAX
Verify the CUDA build:
import jax
print(jax.devices("gpu"))
If the list is empty, install the CUDA-enabled jaxlib:
pip install "jax[cuda12]"
XLA compilation warnings
JAX traces and JIT-compiles operations on first call. Warm-up
latency on the first model.predict call is expected.
See Also
Backend Guide — Backend overview and selection
Installation — Full installation instructions