Quick Start Guide
Installation
Install with the TensorFlow backend:
pip install "base-attentive[tensorflow]"
Or from source for development:
git clone https://github.com/earthai-tech/base-attentive.git
cd base-attentive
pip install -e ".[dev,tensorflow]"
Your First Model
This minimal example covers model creation, fitting, and prediction:
import numpy as np
from base_attentive import BaseAttentive
# 1. Create model instance
model = BaseAttentive(
static_input_dim=4, # 4 static features
dynamic_input_dim=8, # 8 dynamic features
future_input_dim=6, # 6 future features
output_dim=2, # 2 output variables
forecast_horizon=24, # 24 time steps ahead
)
# 2. Prepare data
batch_size = 32
lookback = 100
static_features = np.random.randn(batch_size, 4).astype('float32')
dynamic_features = np.random.randn(batch_size, lookback, 8).astype('float32')
future_features = np.random.randn(batch_size, 24, 6).astype('float32')
targets = np.random.randn(batch_size, 24, 2).astype('float32')
# 3. Compile and train
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
model.fit(
[static_features, dynamic_features, future_features],
targets,
epochs=10,
batch_size=32,
verbose=1,
)
# 4. Make predictions
preds = model.predict([static_features, dynamic_features, future_features])
print(f"Shape: {preds.shape}") # (32, 24, 2)
Understanding Inputs
Static Features (batch_size, static_dim)
Time-invariant properties:
static = np.array([
[40.7128, -74.0060, 10, 2020], # NYC: lat, lon, elev, year
[34.0522, -118.243, 285, 2019], # LA
], dtype='float32')
# Shape: (2, 4)
Dynamic Features (batch_size, time_steps, dynamic_dim)
Historical time series:
dynamic = np.random.randn(2, 100, 8).astype('float32')
# Shape: (2, 100, 8)
Future Features (batch_size, forecast_horizon, future_dim)
Known future values:
future = np.random.randn(2, 24, 6).astype('float32')
# Shape: (2, 24, 6)
Output Formats
Point Forecast
By default, returns single point predictions:
predictions = model([static, dynamic, future])
print(predictions.shape) # (32, 24, 2) — (batch, horizon, output_dim)
Probabilistic Forecasts with Quantiles
Include quantiles for uncertainty estimates:
model_q = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=2,
forecast_horizon=24,
quantiles=[0.1, 0.5, 0.9],
)
preds = model_q([static, dynamic, future])
print(preds.shape) # (32, 24, 3, 2) — (batch, horizon, quantiles, output_dim)
lower = preds[:, :, 0, :] # 10th percentile
median = preds[:, :, 1, :] # 50th percentile
upper = preds[:, :, 2, :] # 90th percentile
Encoder Objective
Use objective to choose the encoder design:
Hybrid (Default)
Multi-scale LSTM with attention — suitable for longer sequences:
model = BaseAttentive(..., objective="hybrid")
Transformer
Pure self-attention — better parallelism on shorter sequences:
model = BaseAttentive(..., objective="transformer")
Operational Mode Shortcuts
The mode parameter applies a pre-configured combination of settings:
# TFT-like (Temporal Fusion Transformer style)
model = BaseAttentive(
static_input_dim=4, dynamic_input_dim=8, future_input_dim=6,
output_dim=1, forecast_horizon=24,
mode="tft",
)
# PIHALNet-like (Physics-Informed HAL style)
model = BaseAttentive(
static_input_dim=4, dynamic_input_dim=8, future_input_dim=6,
output_dim=1, forecast_horizon=24,
mode="pihal",
)
Valid values: "tft", "tft_like", "pihal", "pihal_like",
or None (default — manual configuration).
Attention Levels
Use attention_levels to declare which decoder attention mechanisms to enable:
# All three (default when attention_levels=None)
model = BaseAttentive(..., attention_levels=None)
# Cross-attention only
model = BaseAttentive(..., attention_levels="cross")
# Cross + hierarchical
model = BaseAttentive(..., attention_levels=["cross", "hierarchical"])
# Integer shortcuts: 1=cross, 2=hierarchical, 3=memory
model = BaseAttentive(..., attention_levels=1)
Multi-Scale Aggregation
Control how temporal features are aggregated across LSTM scales:
model = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
scales=[1, 2, 4], # 3 temporal resolutions
multi_scale_agg="last", # 'last', 'average', 'flatten', 'concat'
final_agg="last", # final sequence step aggregation
)
Serialization
Save and Load Models
model.save('my_model.keras')
from keras import models
loaded = models.load_model('my_model.keras')
preds = loaded([static, dynamic, future])
Get / Restore Configuration
config = model.get_config()
new_model = BaseAttentive.from_config(config)
# Create a variant without mutating the original
bigger = model.reconfigure({"encoder_type": "transformer"})
Common Patterns
Training with Early Stopping
import keras
model.compile(optimizer='adam', loss='mse', metrics=['mae'])
history = model.fit(
[static_features, dynamic_features, future_features],
targets,
validation_split=0.2,
epochs=50,
batch_size=32,
callbacks=[
keras.callbacks.EarlyStopping(
monitor='val_loss', patience=5, restore_best_weights=True,
)
],
)
Confidence Interval Visualization
import numpy as np
import matplotlib.pyplot as plt
model_ci = BaseAttentive(
static_input_dim=4, dynamic_input_dim=8, future_input_dim=6,
output_dim=2, forecast_horizon=24,
quantiles=[0.025, 0.5, 0.975],
)
preds = model_ci([static, dynamic, future])
lower_ci = preds[:, :, 0, :]
point = preds[:, :, 1, :]
upper_ci = preds[:, :, 2, :]
t = np.arange(24)
plt.fill_between(t, lower_ci[0, :, 0], upper_ci[0, :, 0],
alpha=0.3, label='95% CI')
plt.plot(t, point[0, :, 0], 'r-', label='Median')
plt.legend()
plt.show()
Using BaseAttentive as a Kernel
Wrap BaseAttentive inside a larger model to add domain-specific logic:
import keras
from base_attentive import BaseAttentive
class RobustForecastModel(keras.Model):
def __init__(self):
super().__init__()
self.kernel = BaseAttentive(
static_input_dim=4, dynamic_input_dim=8, future_input_dim=6,
output_dim=2, forecast_horizon=24,
)
self.residual_head = keras.layers.Dense(2)
def call(self, inputs, training=False):
_, dynamic_x, _ = inputs
base_forecast = self.kernel(inputs, training=training)
residual = self.residual_head(keras.ops.mean(dynamic_x, axis=1))
return base_forecast + keras.ops.expand_dims(residual, axis=1)
For a fuller guide see Usage. For ensemble and physics-guided patterns see Applications and Use Cases.
Backend Selection
export KERAS_BACKEND=tensorflow
python your_script.py
import os
os.environ['KERAS_BACKEND'] = 'tensorflow'
from base_attentive import BaseAttentive
Supported backends: tensorflow (stable), jax (experimental),
torch (experimental).
Faster TensorFlow Inference
from base_attentive import make_fast_predict_fn
fast_predict = make_fast_predict_fn(
model,
warmup_inputs=[static_features, dynamic_features, future_features],
)
predictions = fast_predict([static_features, dynamic_features, future_features])
This is TensorFlow-specific and wraps inference with tf.function.
Next Steps
Explore configurations: See Configuration Guide
Understand architecture: See Architecture Guide
API reference: See API Reference
Extended usage and kernel patterns: See Usage
Application blueprints: See Applications and Use Cases