Usage
This page explains how to configure, train, and extend BaseAttentive in v2. It covers every configuration path in detail — from the one-liner quick-start through to the full spec-based assembly API — with worked examples at each level.
Quick Start
The shortest possible working example:
import numpy as np
from base_attentive import BaseAttentive
model = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
)
model.compile(optimizer="adam", loss="mse")
x_static = np.random.randn(32, 4).astype("float32")
x_dynamic = np.random.randn(32, 100, 8).astype("float32")
x_future = np.random.randn(32, 24, 6).astype("float32")
y = np.random.randn(32, 24, 1).astype("float32")
model.fit([x_static, x_dynamic, x_future], y, epochs=5)
preds = model.predict([x_static, x_dynamic, x_future])
# shape: (32, 24, 1)
Input Contract
BaseAttentive expects three inputs in fixed order:
Slot |
Shape |
Content |
|---|---|---|
|
|
Time-invariant properties (e.g. coordinates, site type) |
|
|
Historical time series (any T ≥ 1) |
|
|
Known future exogenous variables |
Pass them as a list:
predictions = model([x_static, x_dynamic, x_future])
Any slot can be omitted by setting its *_input_dim to 0; in that case
pass a zero-column array or None for the missing slot.
Output Shapes
# Point forecast
model = BaseAttentive(..., output_dim=2, forecast_horizon=24)
# → (batch, 24, 2)
# Quantile forecast — quantile axis is second-to-last
model = BaseAttentive(..., output_dim=2, quantiles=[0.1, 0.5, 0.9])
# → (batch, 24, 3, 2) ← (batch, horizon, Q, output_dim)
V2 Configuration System
v2.0.0 introduces three ways to configure a model. They can be combined: higher-precedence levels override lower ones.
Precedence (low → high)
──────────────────────
1. Built-in defaults
2. Keyword arguments (embed_dim, objective, mode, …)
3. architecture_config dict
──────────────────────
(optional) BaseAttentiveSpec + BaseAttentiveComponentSpec ← full declarative path
Path 1: Keyword arguments
The simplest and most common path. Pass parameters directly to
BaseAttentive:
model = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
# ── capacity ────────────────────
embed_dim=64,
hidden_units=128,
num_heads=8,
num_encoder_layers=3,
# ── architecture ────────────────
objective="hybrid",
scales=[1, 2, 4],
multi_scale_agg="average",
attention_levels=["cross", "hierarchical"],
# ── regularisation ──────────────
dropout_rate=0.1,
use_vsn=True,
# ── output ──────────────────────
quantiles=[0.1, 0.5, 0.9],
)
Path 2: architecture_config dict
Use this dict for structural overrides that do not have a dedicated keyword argument, or to enforce settings regardless of the keyword values:
model = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
objective="hybrid", # step-2 keyword
architecture_config={
"encoder_type": "transformer", # step-3 override — wins
"decoder_attention_stack": ["cross", "memory"],
"feature_processing": "dense", # skip VSN
},
)
Common architecture_config keys:
Key |
Effect |
|---|---|
|
|
|
list from |
|
|
|
|
|
|
|
|
Path 3: BaseAttentiveSpec + Assembly (declarative)
The fully declarative v2 path. Build a frozen configuration object first, then assemble the model from it. No Keras imports are needed to create the spec — useful for config files, hyperparameter search, or experiment tracking systems.
from base_attentive.config import BaseAttentiveSpec, BaseAttentiveComponentSpec
from base_attentive.assembly import BaseAttentiveV2Assembly
spec = BaseAttentiveSpec(
# ── dimensions ──────────────────────────────────────────────
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
# ── capacity ────────────────────────────────────────────────
embed_dim=64,
hidden_units=128,
attention_heads=8,
dropout_rate=0.1,
activation="relu",
layer_norm_epsilon=1e-6,
# ── backend + head ──────────────────────────────────────────
backend_name="tensorflow", # "tensorflow" | "torch" | "jax"
head_type="quantile",
quantiles=(0.1, 0.5, 0.9),
# ── component overrides (all optional) ──────────────────────
components=BaseAttentiveComponentSpec(
sequence_pooling="pool.last",
temporal_encoder="encoder.temporal_self_attention",
),
)
model = BaseAttentiveV2Assembly().build(spec)
model.compile(optimizer="adam", loss="mse")
Serialising a spec to JSON and reloading it:
import json
from base_attentive.config import BaseAttentiveSpec
# Save
with open("my_spec.json", "w") as f:
json.dump(spec.__dict__, f, indent=2) # spec is a dataclass
# Reload
with open("my_spec.json") as f:
data = json.load(f)
reloaded_spec = BaseAttentiveSpec(**data)
model2 = BaseAttentiveV2Assembly().build(reloaded_spec)
BaseAttentiveComponentSpec field reference:
Field |
Default registry key |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Omit any field to use the registry default.
Architecture Parameters
Required parameters
Parameter |
Type |
Description |
|---|---|---|
|
int ≥ 0 |
Number of static (time-invariant) features |
|
int ≥ 1 |
Number of historical time-series features |
|
int ≥ 0 |
Number of known future covariate features |
|
int ≥ 1 |
Number of output variables (default: 1) |
|
int ≥ 1 |
Number of future steps to predict (default: 1) |
Capacity parameters
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
int |
32 |
Shared embedding dimension for all projections |
|
int |
64 |
Dense hidden layer width |
|
int |
64 |
LSTM hidden size (hybrid mode only) |
|
int |
32 |
Attention projection dimension |
|
int |
4 |
Multi-head attention head count |
|
int |
2 |
Stacked encoder layer count |
|
int |
100 |
Memory bank entries (memory-augmented attention) |
Encoder parameters
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
str |
|
Encoder type: |
|
list[int] or |
None |
Temporal scales for multi-scale LSTM, e.g. |
|
str |
|
How to merge multi-scale outputs: |
|
str |
|
Final sequence aggregation: |
|
int |
10 |
Maximum dynamic time window size |
Attention and decoder
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
str or None |
None |
Mode shortcut: |
|
str, list, int, None |
None (all three) |
Decoder attention stack: name, list, integer shortcut, or None |
|
dict or None |
None |
Structural overrides (highest precedence) |
Regularisation
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
float |
0.1 |
Dropout probability |
|
str |
|
|
|
bool |
False |
Apply batch normalisation |
|
bool |
True |
Residual connections throughout |
Feature processing
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
bool |
True |
Enable Variable Selection Network (VSN) |
|
int or None |
None |
Override VSN projection units (falls back to |
|
bool |
True |
Apply Dynamic Time Warping alignment |
Output
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
list[float] or None |
None |
Quantile levels, e.g. |
|
int |
0 |
Logging verbosity |
Mode Shortcuts
mode wires up encoder, attention stack, and decoder in one step so you
do not need to set objective, attention_levels, etc. individually:
# TFT-like: VSN + gated residuals + cross attention
model = BaseAttentive(..., mode="tft")
# PIHALNet-like: multi-scale LSTM + memory + hierarchical attention
model = BaseAttentive(..., mode="pihal")
You can still override individual parameters after setting mode:
model = BaseAttentive(
...,
mode="tft",
embed_dim=128, # larger capacity
dropout_rate=0.2, # more regularisation
quantiles=[0.05, 0.5, 0.95],
)
Attention Level Control
The decoder attention stack is set with attention_levels. All four
forms below are equivalent:
# All three levels (default)
model = BaseAttentive(..., attention_levels=None)
model = BaseAttentive(..., attention_levels=["cross", "hierarchical", "memory"])
# Single level
model = BaseAttentive(..., attention_levels="cross")
model = BaseAttentive(..., attention_levels=1) # integer: 1=cross
# Two levels
model = BaseAttentive(..., attention_levels=["cross", "memory"])
model = BaseAttentive(..., attention_levels=2) # 2=hierarchical
Integer shortcuts: 1 → cross, 2 → hierarchical,
3 → memory.
When to use each level:
Level |
When to use it |
|---|---|
|
Always include; it bridges encoded history and future context |
|
Add when the data has nested temporal patterns (daily + weekly + seasonal) |
|
Add for long-range dependencies or when repeated patterns need retrieval |
Multi-Scale Aggregation
scales activates the multi-scale LSTM encoder (objective="hybrid"
required). Each integer in the list defines a temporal stride:
model = BaseAttentive(
...,
objective="hybrid",
scales=[1, 2, 4], # three LSTMs at stride ×1, ×2, ×4
multi_scale_agg="average", # merge strategy
final_agg="last", # sequence → vector
)
multi_scale_agg choices:
Value |
Effect |
|---|---|
|
Keep the final hidden state of each scale; concatenate then project |
|
Average all hidden states across time; merge |
|
Flatten the full output sequence of each scale; project |
|
Sum hidden states element-wise across time |
|
Concatenate all time-step outputs end-to-end |
Use scales="auto" to let the model choose scales based on the
sequence length.
Training Patterns
Standard compile/fit
import numpy as np
from base_attentive import BaseAttentive
model = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
quantiles=[0.1, 0.5, 0.9],
embed_dim=64,
num_heads=8,
)
model.compile(
optimizer="adam",
loss="mse",
metrics=["mae"],
)
history = model.fit(
[x_static, x_dynamic, x_future],
y,
epochs=50,
batch_size=64,
validation_split=0.2,
callbacks=[
keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),
keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3),
],
)
Using CRPSLoss
For probabilistic training with the mixture mode:
from base_attentive.components import CRPSLoss
# Quantile CRPS (pinball approximation)
model.compile(
optimizer="adam",
loss=CRPSLoss(mode="quantile", quantiles=[0.1, 0.5, 0.9]),
)
# Gaussian closed-form CRPS
model.compile(
optimizer="adam",
loss=CRPSLoss(mode="gaussian"),
)
# Gaussian Mixture Monte-Carlo CRPS
model.compile(
optimizer="adam",
loss=CRPSLoss(mode="mixture", mc_samples=50),
)
Custom training loop
When you need per-step control (e.g. physics constraints, gradient clipping, or multi-task losses):
import keras
import numpy as np
from base_attentive import BaseAttentive
model = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
)
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.MeanSquaredError()
@keras.function
def train_step(x_batch, y_batch):
with keras.GradientTape() as tape:
preds = model(x_batch, training=True)
loss = loss_fn(y_batch, preds)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss
for epoch in range(50):
for x_batch, y_batch in dataset:
loss = train_step(x_batch, y_batch)
Selecting a Backend in V2
Set the backend before importing BaseAttentive or Keras:
import os
os.environ["KERAS_BACKEND"] = "torch" # or "tensorflow" / "jax"
from base_attentive import BaseAttentive
model = BaseAttentive(...)
Or, using the runtime API (also before any model creation):
from base_attentive import set_backend
set_backend("jax")
from base_attentive import BaseAttentive
model = BaseAttentive(...)
When building via BaseAttentiveSpec, declare the backend in the spec:
spec = BaseAttentiveSpec(
...,
backend_name="torch", # "tensorflow" | "torch" | "jax"
)
model = BaseAttentiveV2Assembly().build(spec)
See Backend Guide for installation details and device management.
Registering Custom Components
V2 makes every layer pluggable through the component registry. You can
replace any part of the model without subclassing BaseAttentive.
Step 1 — write the builder function
A builder receives a context (the BaseAttentiveSpec) plus any kwargs
the spec declares for that component:
from base_attentive.registry import DEFAULT_COMPONENT_REGISTRY
def my_dilated_encoder(*, context, units, hidden_units, **kw):
"""
WaveNet-style dilated causal encoder.
context: BaseAttentiveSpec gives dropout_rate, embed_dim, etc.
"""
import keras
inputs = keras.Input(shape=(None, units))
x = inputs
for rate in [1, 2, 4, 8]:
x = keras.layers.Conv1D(
filters=hidden_units,
kernel_size=2,
dilation_rate=rate,
padding="causal",
activation="relu",
)(x)
x = keras.layers.Dropout(context.dropout_rate)(x)
return keras.Model(inputs, x, name="dilated_encoder")
DEFAULT_COMPONENT_REGISTRY.register(
"encoder.dilated_causal",
my_dilated_encoder,
backend="generic",
description="WaveNet-style dilated causal encoder.",
)
Step 2 — reference the key in a spec
from base_attentive.config import BaseAttentiveSpec, BaseAttentiveComponentSpec
from base_attentive.assembly import BaseAttentiveV2Assembly
spec = BaseAttentiveSpec(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
embed_dim=64,
hidden_units=128,
components=BaseAttentiveComponentSpec(
temporal_encoder="encoder.dilated_causal",
),
)
model = BaseAttentiveV2Assembly().build(spec)
Step 3 — inspect the registry
from base_attentive.registry import DEFAULT_COMPONENT_REGISTRY
print(DEFAULT_COMPONENT_REGISTRY.list_keys())
# ['projection.static', 'projection.dynamic', ..., 'encoder.dilated_causal']
info = DEFAULT_COMPONENT_REGISTRY.get_info("encoder.dilated_causal")
print(info["description"]) # WaveNet-style dilated causal encoder.
Using BaseAttentive as a Keras Kernel
Wrapper pattern (recommended)
Wrap BaseAttentive when you want to add extra heads, safety logic, or
fusion with other models, while keeping the serialised API separate:
import keras
import numpy as np
from base_attentive import BaseAttentive
class ResidualForecastModel(keras.Model):
"""BaseAttentive kernel + learnable residual correction."""
def __init__(self, forecast_horizon=24, output_dim=1):
super().__init__()
self.kernel = BaseAttentive(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=output_dim,
forecast_horizon=forecast_horizon,
mode="tft",
)
self.context_pool = keras.layers.GlobalAveragePooling1D()
self.residual_head = keras.Sequential([
keras.layers.Dense(64, activation="relu"),
keras.layers.RepeatVector(forecast_horizon),
keras.layers.Dense(output_dim),
])
self.gate = keras.layers.Dense(output_dim, activation="sigmoid")
def call(self, inputs, training=False):
static_x, dynamic_x, future_x = inputs
base_forecast = self.kernel(inputs, training=training)
# (batch, H, output_dim)
context = self.context_pool(dynamic_x) # (batch, D)
residual = self.residual_head(context) # (batch, H, output_dim)
gate = keras.ops.expand_dims(
self.gate(context), axis=1
) # (batch, 1, output_dim)
return base_forecast + gate * residual
Direct inheritance (advanced)
Inherit when the new behaviour should be part of the same serialised model
(same get_config / from_config lifecycle):
from base_attentive import BaseAttentive
import keras
class PhysicsAwareModel(BaseAttentive):
def __init__(self, physics_weight=0.1, **kwargs):
super().__init__(**kwargs)
self.physics_weight = physics_weight
self.correction_head = keras.layers.Dense(self.output_dim)
def call(self, inputs, training=False):
base_forecast = super().call(inputs, training=training)
_, _, future_x = inputs
correction = self.correction_head(
keras.ops.mean(future_x, axis=1)
) # (batch, output_dim)
correction = keras.ops.expand_dims(correction, axis=1)
return base_forecast + self.physics_weight * correction
def get_config(self):
cfg = super().get_config()
cfg["physics_weight"] = self.physics_weight
return cfg
When to choose which path:
Wrapper — extra outputs, multi-task heads, downstream logic you do not want serialised together with the forecasting model.
Inheritance — the new behaviour must survive
from_configround-trips and is inseparable from the forecasting logic.
Serialisation and Reconfiguration
# Save / restore config
config = model.get_config()
cloned = BaseAttentive.from_config(config)
# Reconfigure without mutating the original
transformer_variant = model.reconfigure({"encoder_type": "transformer"})
# Full Keras save (weights + config)
model.save("my_model.keras")
loaded = keras.saving.load_model("my_model.keras")
Validation Helpers
from base_attentive.validation import (
validate_model_inputs,
maybe_reduce_quantiles_bh,
ensure_bh1,
)
# Check shapes and return normalised tensors
static, dynamic, future = validate_model_inputs(
[x_static, x_dynamic, x_future],
static_input_dim=4,
dynamic_input_dim=8,
verbose=1,
)
# Reduce quantile output to (batch, horizon, output_dim)
point_pred = maybe_reduce_quantiles_bh(predictions)
# Ensure (batch, horizon, 1) shape
reshaped = ensure_bh1(predictions)
Accelerated Inference (TensorFlow)
Wrap repeated inference with make_fast_predict_fn to compile it once
with tf.function:
from base_attentive import make_fast_predict_fn
import numpy as np
fast_predict = make_fast_predict_fn(
model,
warmup_inputs=[x_static, x_dynamic, x_future],
)
predictions = fast_predict([x_static, x_dynamic, x_future])
Keep input shapes stable across calls for best tracing performance. For
training, model.compile(..., jit_compile="auto") may also accelerate
TensorFlow graphs.
Runtime Backend Utilities
from base_attentive import (
get_backend,
get_available_backends,
detect_available_backends,
select_best_backend,
)
print(get_backend()) # "tensorflow"
print(get_available_backends()) # ["tensorflow", "torch"]
info = detect_available_backends()
for name, details in info.items():
print(f"{name}: {details.get('version')}")
best = select_best_backend(require_supported=True)
Complete Parameter Reference
For the full list of every accepted keyword and its type, range, and default value, see Configuration Guide.
See Also
Architecture Guide — How the v2 registry / assembly system works
Configuration Guide — Full parameter reference
Applications and Use Cases — Domain-specific usage patterns
Backend Guide — Backend selection and device management
API Reference — Complete API docs