API Reference
Public package entry points
from base_attentive import BaseAttentive
from base_attentive import make_fast_predict_fn
from base_attentive import get_backend, set_backend
from base_attentive import (
get_available_backends,
get_backend_capabilities,
normalize_backend_name,
detect_available_backends,
select_best_backend,
ensure_default_backend,
)
from base_attentive.validation import (
validate_model_inputs,
maybe_reduce_quantiles_bh,
ensure_bh1,
)
from base_attentive.config import BaseAttentiveSpec, BaseAttentiveComponentSpec
from base_attentive.registry import (
ComponentRegistry, ModelRegistry,
DEFAULT_COMPONENT_REGISTRY, DEFAULT_MODEL_REGISTRY,
)
Top-level package
Public package surface for base_attentive.
- class base_attentive.BaseAttentive(static_dim=None, dynamic_dim=None, future_dim=None, output_dim=1, forecast_horizon=1, mode=None, num_encoder_layers=2, quantiles=None, embed_dim=32, hidden_units=64, lstm_units=64, attention_units=32, num_heads=4, dropout_rate=0.1, lookback_window=10, memory_size=100, scales=None, multi_scale_agg='last', final_agg='last', activation='relu', use_residuals=True, use_vsn=True, vsn_units=None, use_batch_norm=False, apply_dtw=True, attention_stack=None, objective='hybrid', architecture_config=None, backend_name=None, component_overrides=None, verbose=0, output_mode=None, n_quantiles=None, name='BaseAttentive', *, static_input_dim=None, dynamic_input_dim=None, future_input_dim=None, max_window_size=None, attention_levels=None, **kwargs)[source]
Bases:
BaseAttentiveV2,NNLearnerCompatibility wrapper over the resolver-driven V2 model.
The facade keeps the legacy constructor signature intact, converts the payload into
BaseAttentiveSpec, and delegates the actual model assembly and execution toBaseAttentiveV2.- Parameters:
static_dim (int | None)
dynamic_dim (int | None)
future_dim (int | None)
output_dim (int)
forecast_horizon (int)
mode (str | None)
num_encoder_layers (int)
embed_dim (int)
hidden_units (int)
attention_units (int)
num_heads (int)
dropout_rate (float)
lookback_window (int)
memory_size (int)
multi_scale_agg (str)
final_agg (str)
activation (str)
use_residuals (bool)
use_vsn (bool)
vsn_units (int | None)
use_batch_norm (bool)
apply_dtw (bool)
objective (str)
backend_name (str | None)
verbose (int)
output_mode (str | None)
n_quantiles (int | None)
name (str)
static_input_dim (int | None)
dynamic_input_dim (int | None)
future_input_dim (int | None)
max_window_size (int | None)
- base_attentive.dependency_message(module_name)[source]
Return a dependency hint for missing runtime packages.
- base_attentive.make_fast_predict_fn(model, *, jit_compile=True, reduce_retracing=True, warmup_inputs=None)[source]
Create a TensorFlow-traced prediction function for a Keras model.
The returned callable accepts the same input structure as
modeland always executes withtraining=False. This is useful when you want a reusable inference function withtf.functiontracing and optional XLA compilation.- Parameters:
model (Any) – A Keras-compatible model or layer that can be called as
model(inputs, training=False).jit_compile (bool, default=True) – Whether to request XLA JIT compilation for the traced prediction function.
reduce_retracing (bool, default=True) – Whether TensorFlow should reduce retracing when input structures are reused.
warmup_inputs (Any, optional) – Example inputs used to trigger tracing before the callable is returned.
- Returns:
A TensorFlow
tf.function-wrapped prediction callable.- Return type:
callable
- Raises:
RuntimeError – If the active package backend is not TensorFlow.
ImportError – If TensorFlow cannot be imported.
Core model
- class base_attentive.core.base_attentive.BaseAttentive(static_dim=None, dynamic_dim=None, future_dim=None, output_dim=1, forecast_horizon=1, mode=None, num_encoder_layers=2, quantiles=None, embed_dim=32, hidden_units=64, lstm_units=64, attention_units=32, num_heads=4, dropout_rate=0.1, lookback_window=10, memory_size=100, scales=None, multi_scale_agg='last', final_agg='last', activation='relu', use_residuals=True, use_vsn=True, vsn_units=None, use_batch_norm=False, apply_dtw=True, attention_stack=None, objective='hybrid', architecture_config=None, backend_name=None, component_overrides=None, verbose=0, output_mode=None, n_quantiles=None, name='BaseAttentive', *, static_input_dim=None, dynamic_input_dim=None, future_input_dim=None, max_window_size=None, attention_levels=None, **kwargs)[source]
Bases:
BaseAttentiveV2,NNLearnerCompatibility wrapper over the resolver-driven V2 model.
The facade keeps the legacy constructor signature intact, converts the payload into
BaseAttentiveSpec, and delegates the actual model assembly and execution toBaseAttentiveV2.- Parameters:
static_dim (int | None)
dynamic_dim (int | None)
future_dim (int | None)
output_dim (int)
forecast_horizon (int)
mode (str | None)
num_encoder_layers (int)
embed_dim (int)
hidden_units (int)
attention_units (int)
num_heads (int)
dropout_rate (float)
lookback_window (int)
memory_size (int)
multi_scale_agg (str)
final_agg (str)
activation (str)
use_residuals (bool)
use_vsn (bool)
vsn_units (int | None)
use_batch_norm (bool)
apply_dtw (bool)
objective (str)
backend_name (str | None)
verbose (int)
output_mode (str | None)
n_quantiles (int | None)
name (str)
static_input_dim (int | None)
dynamic_input_dim (int | None)
future_input_dim (int | None)
max_window_size (int | None)
V2 Configuration Schema
BaseAttentiveSpec is a frozen dataclass that fully describes a V2 model
without referencing any backend objects.
- class base_attentive.config.schema.BaseAttentiveSpec(static_input_dim, dynamic_input_dim, future_input_dim, output_dim=1, forecast_horizon=1, embed_dim=32, hidden_units=64, attention_heads=4, layer_norm_epsilon=1e-06, dropout_rate=0.0, activation='relu', backend_name='tensorflow', head_type='point', quantiles=(), lstm_units=64, attention_units=32, vsn_units=None, architecture=<factory>, runtime=<factory>, components=<factory>, extras=<factory>)[source]
Bases:
objectBackend-neutral configuration for BaseAttentive models.
- Parameters:
static_input_dim (int)
dynamic_input_dim (int)
future_input_dim (int)
output_dim (int)
forecast_horizon (int)
embed_dim (int)
hidden_units (int)
attention_heads (int)
layer_norm_epsilon (float)
dropout_rate (float)
activation (str)
backend_name (str)
head_type (str)
attention_units (int)
vsn_units (int | None)
architecture (BaseAttentiveArchitectureSpec)
runtime (BaseAttentiveRuntimeSpec)
components (BaseAttentiveComponentSpec)
- architecture: BaseAttentiveArchitectureSpec
- runtime: BaseAttentiveRuntimeSpec
- components: BaseAttentiveComponentSpec
- class base_attentive.config.schema.BaseAttentiveComponentSpec(static_projection='projection.static', dynamic_projection='projection.dynamic', future_projection='projection.future', dynamic_encoder='encoder.temporal_self_attention', future_encoder='encoder.temporal_self_attention', sequence_pooling='pool.mean', fusion='fusion.concat', hidden_projection='projection.hidden', point_head='head.point_forecast', quantile_head='head.quantile_forecast', static_processor='feature.static_processor', dynamic_processor='feature.dynamic_processor', future_processor='feature.future_processor', positional_encoder='embedding.positional', hybrid_encoder='encoder.hybrid_multiscale', dynamic_window='encoder.dynamic_window', decoder_cross_attention='decoder.cross_attention', decoder_hierarchical_attention='decoder.hierarchical_attention', decoder_memory_attention='decoder.memory_attention', decoder_fusion='fusion.multi_resolution_attention', multi_horizon_head='head.multi_horizon', quantile_distribution_head='head.quantile_distribution', final_pool_last='pool.final_last', final_pool_mean='pool.final_mean', final_pool_flatten='pool.final_flatten')[source]
Bases:
objectLogical component selections for resolver-driven models.
- Parameters:
static_projection (str)
dynamic_projection (str)
future_projection (str)
dynamic_encoder (str)
future_encoder (str)
sequence_pooling (str)
fusion (str)
hidden_projection (str)
point_head (str)
quantile_head (str)
static_processor (str)
dynamic_processor (str)
future_processor (str)
positional_encoder (str)
hybrid_encoder (str)
dynamic_window (str)
decoder_cross_attention (str)
decoder_hierarchical_attention (str)
decoder_memory_attention (str)
decoder_fusion (str)
multi_horizon_head (str)
quantile_distribution_head (str)
final_pool_last (str)
final_pool_mean (str)
final_pool_flatten (str)
Example:
from base_attentive.config import BaseAttentiveSpec, BaseAttentiveComponentSpec
spec = BaseAttentiveSpec(
static_input_dim=4,
dynamic_input_dim=8,
future_input_dim=6,
output_dim=1,
forecast_horizon=24,
embed_dim=32,
hidden_units=64,
attention_heads=4,
dropout_rate=0.1,
head_type="point",
backend_name="tensorflow",
components=BaseAttentiveComponentSpec(
sequence_pooling="pool.last",
),
)
Registry
The registry stores named builder functions (for components) and assembler
functions (for complete models), keyed by (name, backend).
- class base_attentive.registry.ComponentRegistry[source]
Bases:
objectRegistry of backend-specific component builders.
- register(key, builder, *, backend='generic', description='', experimental=False, replace=False)[source]
- class base_attentive.registry.ModelRegistry[source]
Bases:
objectRegistry of backend-specific model assemblers.
Pre-populated default registries:
from base_attentive.registry import (
DEFAULT_COMPONENT_REGISTRY,
DEFAULT_MODEL_REGISTRY,
)
DEFAULT_COMPONENT_REGISTRY.list_registered()
DEFAULT_COMPONENT_REGISTRY.has("encoder.temporal_self_attention", backend="generic")
builder = DEFAULT_COMPONENT_REGISTRY.resolve(
"encoder.temporal_self_attention", backend="generic",
)
Resolver / Assembly
- base_attentive.resolver.component_resolver.build_component(key, *, backend_context, registry=None, model_registry=None, allow_generic=True, spec=None, **kwargs)[source]
Resolve and build a component for the requested backend.
- Parameters:
key (str)
backend_context (BackendContext)
registry (ComponentRegistry | None)
model_registry (ModelRegistry | None)
allow_generic (bool)
spec (Any | None)
kwargs (Any)
- Return type:
- class base_attentive.resolver.assembly.BaseAttentiveV2Assembly(backend_context, static_projection, dynamic_projection, future_projection, dynamic_encoder, future_encoder, sequence_pool, fusion, hidden_projection, output_head, dropout=None, static_processor=None, dynamic_processor=None, future_processor=None, encoder_positional_encoding=None, future_positional_encoding=None, dynamic_window=None, decoder_input_projection=None, decoder_cross_attention=None, decoder_cross_postprocess=None, decoder_hierarchical_attention=None, decoder_memory_attention=None, decoder_fusion=None, residual_projection=None, decoder_residual_add=None, decoder_residual_norm=None, final_residual_add=None, final_residual_norm=None, final_pool=None, multi_horizon_head=None, quantile_distribution_head=None)[source]
Bases:
objectResolved V2 model components.
The assembly keeps the original V2 field names for compatibility, while also exposing migrated component names used by the legacy-to-resolver rewrite.
- Parameters:
backend_context (BackendContext)
static_projection (Any | None)
dynamic_projection (Any)
future_projection (Any | None)
dynamic_encoder (Any | None)
future_encoder (Any | None)
sequence_pool (Any)
fusion (Any)
hidden_projection (Any)
output_head (Any)
dropout (Any | None)
static_processor (Any | None)
dynamic_processor (Any | None)
future_processor (Any | None)
encoder_positional_encoding (Any | None)
future_positional_encoding (Any | None)
dynamic_window (Any | None)
decoder_input_projection (Any | None)
decoder_cross_attention (Any | None)
decoder_cross_postprocess (Any | None)
decoder_hierarchical_attention (Any | None)
decoder_memory_attention (Any | None)
decoder_fusion (Any | None)
residual_projection (Any | None)
decoder_residual_add (Any | None)
decoder_residual_norm (Any | None)
final_residual_add (Any | None)
final_residual_norm (Any | None)
final_pool (Any | None)
multi_horizon_head (Any | None)
quantile_distribution_head (Any | None)
- backend_context: BackendContext
Backend utilities
Lazy backend runtime abstraction for Base-Attentive.
This package exposes backend selection, capability inspection, and helper utilities without importing all backend implementations eagerly.
- class base_attentive.backend.Backend(load_runtime=True)[source]
Bases:
objectBase class for runtime backend descriptors.
- Parameters:
load_runtime (bool)
- class base_attentive.backend.TensorFlowBackend(load_runtime=True)[source]
Bases:
BackendTensorFlow-backed runtime.
- Parameters:
load_runtime (bool)
- class base_attentive.backend.JaxBackend(load_runtime=True)[source]
Bases:
BackendKeras-on-JAX runtime descriptor.
- Parameters:
load_runtime (bool)
- class base_attentive.backend.TorchBackend(load_runtime=True)[source]
Bases:
BackendKeras-on-Torch runtime descriptor.
- Parameters:
load_runtime (bool)
- class base_attentive.backend.PyTorchBackend(load_runtime=True)[source]
Bases:
TorchBackendBackward-compatible alias for the Torch runtime.
- Parameters:
load_runtime (bool)
- base_attentive.backend.get_torch_device(prefer='cuda', verbose=True)[source]
Get the best available device for PyTorch computations.
- Parameters:
prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type. - ‘cuda’: NVIDIA GPU (with CUDA support) - ‘cpu’: CPU - ‘mps’: Apple Metal Performance Shaders (macOS)
verbose (bool, default=True) – Whether to log device selection info.
- Returns:
Device string for use with PyTorch (e.g., ‘cuda:0’, ‘cpu’).
- Return type:
Examples
>>> device = get_torch_device() >>> # 'cuda:0' if available, else 'cpu' >>> device = get_torch_device(prefer="cpu") >>> # 'cpu' always
- base_attentive.backend.get_torch_version()[source]
Get installed PyTorch version.
- Returns:
Version string (e.g., “2.0.1”) or None if not installed.
- Return type:
str or None
- base_attentive.backend.torch_is_available()[source]
Check if PyTorch is installed and importable.
- Returns:
True if PyTorch is available.
- Return type:
- class base_attentive.backend.TorchDeviceManager(prefer='cuda')[source]
Bases:
objectUtility class for managing PyTorch device selection and configuration.
Initialize device manager.
- Parameters:
prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type.
- __init__(prefer='cuda')[source]
Initialize device manager.
- Parameters:
prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type.
- get_available_devices()[source]
Get availability of different device types.
- Returns:
Mapping of device types to availability.
- Return type:
Core backend functions:
from base_attentive import get_backend, set_backend
from base_attentive import get_available_backends, get_backend_capabilities
from base_attentive import normalize_backend_name
b = get_backend()
print(b.name) # e.g. 'tensorflow'
set_backend("tensorflow")
get_available_backends() # ['tensorflow', ...]
get_backend_capabilities() # {'name': ..., 'version': ..., ...}
normalize_backend_name("tf") # -> "tensorflow"
Detection and selection:
from base_attentive import detect_available_backends, select_best_backend
from base_attentive import ensure_default_backend
info = detect_available_backends()
# {'tensorflow': {'available': True, 'version': '2.x'}, ...}
best = select_best_backend()
name = ensure_default_backend()
Version compatibility checking:
from base_attentive.backend import (
check_tensorflow_compatibility,
check_torch_compatibility,
get_backend_version,
version_at_least,
)
ok, msg = check_tensorflow_compatibility()
ok, msg = check_torch_compatibility()
ver = get_backend_version("tensorflow")
ok = version_at_least("2.13.0", "2.12.0")
PyTorch device utilities
- base_attentive.backend.torch_is_available()[source]
Check if PyTorch is installed and importable.
- Returns:
True if PyTorch is available.
- Return type:
- base_attentive.backend.get_torch_version()[source]
Get installed PyTorch version.
- Returns:
Version string (e.g., “2.0.1”) or None if not installed.
- Return type:
str or None
- base_attentive.backend.get_torch_device(prefer='cuda', verbose=True)[source]
Get the best available device for PyTorch computations.
- Parameters:
prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type. - ‘cuda’: NVIDIA GPU (with CUDA support) - ‘cpu’: CPU - ‘mps’: Apple Metal Performance Shaders (macOS)
verbose (bool, default=True) – Whether to log device selection info.
- Returns:
Device string for use with PyTorch (e.g., ‘cuda:0’, ‘cpu’).
- Return type:
Examples
>>> device = get_torch_device() >>> # 'cuda:0' if available, else 'cpu' >>> device = get_torch_device(prefer="cpu") >>> # 'cpu' always
- class base_attentive.backend.TorchDeviceManager(prefer='cuda')[source]
Bases:
objectUtility class for managing PyTorch device selection and configuration.
Initialize device manager.
- Parameters:
prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type.
- __init__(prefer='cuda')[source]
Initialize device manager.
- Parameters:
prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type.
- get_available_devices()[source]
Get availability of different device types.
- Returns:
Mapping of device types to availability.
- Return type:
Example:
from base_attentive.backend import TorchDeviceManager, get_torch_device
device = get_torch_device(prefer="cuda", verbose=True)
manager = TorchDeviceManager(prefer="cuda")
print(manager.device)
print(manager.get_available_devices())
info = manager.get_device_info()
manager.clear_gpu_cache()
Learner mixin
Property and base class definitions for NN learners.
Validation helpers
Utilities for backend-agnostic tensor validation.
- base_attentive.validation.validate_model_inputs(inputs, static_input_dim=None, dynamic_input_dim=None, future_covariate_dim=None, forecast_horizon=None, error='raise', mode='strict', deep_check=None, model_name=None, verbose=0, **kwargs)[source]
Validate and homogenize input tensors for model workflows.
This entrypoint intentionally stays lightweight: it normalizes the input container shape and converts values into the active Keras runtime tensor type when a runtime is available. When no Keras runtime is configured, the raw values are returned unchanged.
Noneinputs are normalized to(None, None, None).- Parameters:
- Return type:
- base_attentive.validation.maybe_reduce_quantiles_bh(x, *, name='tensor', axis=2, reduction='mean')[source]
Reduce a quantile axis when a backend tensor carries one.
- base_attentive.validation.ensure_bh1(x, *, name='tensor', dtype=None, reduce_axis=None, reduction='mean', allow_rank1=False)[source]
Ensure a tensor-like value has shape
(B, H, 1).
from base_attentive.validation import (
validate_model_inputs,
maybe_reduce_quantiles_bh,
ensure_bh1,
)
static, dynamic, future = validate_model_inputs(
[x_static, x_dynamic, x_future],
static_input_dim=4,
dynamic_input_dim=8,
)
reduced = maybe_reduce_quantiles_bh(predictions)
reshaped = ensure_bh1(output)
Runtime helpers
Runtime helpers for accelerated inference.
- base_attentive.runtime.make_fast_predict_fn(model, *, jit_compile=True, reduce_retracing=True, warmup_inputs=None)[source]
Create a TensorFlow-traced prediction function for a Keras model.
The returned callable accepts the same input structure as
modeland always executes withtraining=False. This is useful when you want a reusable inference function withtf.functiontracing and optional XLA compilation.- Parameters:
model (Any) – A Keras-compatible model or layer that can be called as
model(inputs, training=False).jit_compile (bool, default=True) – Whether to request XLA JIT compilation for the traced prediction function.
reduce_retracing (bool, default=True) – Whether TensorFlow should reduce retracing when input structures are reused.
warmup_inputs (Any, optional) – Example inputs used to trigger tracing before the callable is returned.
- Returns:
A TensorFlow
tf.function-wrapped prediction callable.- Return type:
callable
- Raises:
RuntimeError – If the active package backend is not TensorFlow.
ImportError – If TensorFlow cannot be imported.
from base_attentive import make_fast_predict_fn
fast_predict = make_fast_predict_fn(
model,
jit_compile=True,
reduce_retracing=True,
warmup_inputs=[x_static, x_dynamic, x_future],
)
predictions = fast_predict([x_static, x_dynamic, x_future])
Component utilities
from base_attentive.components.utils import (
resolve_attn_levels,
configure_architecture,
resolve_fusion_mode,
)
resolve_attn_levels(None) # ['cross', 'hierarchical', 'memory']
resolve_attn_levels("cross") # ['cross']
resolve_attn_levels(1) # ['cross']
resolve_attn_levels(["cross", "memory"]) # ['cross', 'memory']
arch = configure_architecture(
objective="hybrid",
use_vsn=True,
attention_levels=["cross", "hierarchical"],
)
resolve_fusion_mode(None) # 'integrated'
resolve_fusion_mode("disjoint") # 'disjoint'
Key Components
Variable Selection Network
Multi-Scale LSTM
- class base_attentive.components.MultiScaleLSTM(lstm_units=None, scales=None, return_sequences=False, *, units=None, **kwargs)[source]
Bases:
Layer,NNLearnerMultiScaleLSTM layer applying multiple LSTMs at different sampling scales and concatenating their outputs [1]_.
Each LSTM can either return the full sequence or only the last hidden state, controlled by return_sequences. The user specifies scales to sub-sample the time dimension. For example, a scale of 2 processes every 2nd time step.
- Parameters:
lstm_units (int) – Number of units in each LSTM.
scales (list of int or str or None, optional) – List of scale factors. If ‘auto’ or None, defaults to [1] (no sub-sampling).
return_sequences (bool, optional) – If True, each LSTM returns the entire sequence. Otherwise, it returns only the last hidden state. Defaults to False.
**kwargs – Additional arguments passed to the parent Keras Layer.
units (int | None)
Notes
If return_sequences=False, the output is concatenated along features: \((B, \text{units} \times \text{num\_scales})\).
If return_sequences=True, a list of sequence outputs is returned. Each may have a different time dimension if scales differ.
Examples
>>> from geoprior.nn.components import MultiScaleLSTM >>> import tensorflow as tf >>> x = tf.random.normal((32, 20, 16)) # (B, T, D) >>> # Instantiating a multi-scale LSTM >>> mslstm = MultiScaleLSTM( ... lstm_units=32, ... scales=[1, 2], ... return_sequences=False, ... ) >>> y = mslstm(x) # shape => (32, 64) >>> # because scale=1 and scale=2 each produce 32 units, ... # which are concatenated => 64
See also
DynamicTimeWindowFor slicing sequences before applying multi-scale LSTMs.
TemporalFusionTransformerA complex model that can incorporate multi-scale modules.
References
- call(inputs, training=False)[source]
Forward pass that processes the input at multiple scales.
- Parameters:
inputs (tf.Tensor) – Shape (B, T, D).
training (bool, optional) – Training mode. Defaults to
False.
- Returns:
If return_sequences=False, returns a single 2D tensor of shape (B, lstm_units * len(scales)).
If return_sequences=True, returns a list of 3D tensors, each with shape (B, T’, lstm_units), where T’ depends on the scale sub-sampling.
- Return type:
tf.Tensor or list of tf.Tensor
- get_config()[source]
Returns a config dictionary containing ‘lstm_units’, ‘scales’, and ‘return_sequences’.
- Returns:
Configuration dictionary.
- Return type:
Cross-Attention
- class base_attentive.components.CrossAttention(units, num_heads)[source]
Bases:
Layer,NNLearnerCrossAttention that attends
source1(query) tosource2(key/value) with optional masks.- attention_maskTensor, optional
Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed directly to Keras
MultiHeadAttention.- query_mask, value_maskTensor, optional
1D/2D masks (B, Tq) or (B, Tv). If provided and
attention_maskis None, they are combined to form (B, Tq, Tv).- use_causal_maskbool
Forwarded to MHA. Default False.
- call(inputs, training=False, *, attention_mask=None, query_mask=None, value_mask=None, use_causal_mask=False, **kwargs)[source]
Forward pass of CrossAttention.
- Parameters:
inputs (list of tf.Tensor) – A list [source1, source2], each of shape (batch_size, time_steps, features).
training (bool, optional) – Indicates if the layer is in training mode (for dropout, if any). Defaults to
False.attention_mask (Tensor, optional) – Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed directly to Keras
MultiHeadAttention.query_mask (Tensor, optional) – 1D/2D masks (B, Tq) or (B, Tv). If provided and
attention_maskis None, they are combined to form (B, Tq, Tv).value_mask (Tensor, optional) – 1D/2D masks (B, Tq) or (B, Tv). If provided and
attention_maskis None, they are combined to form (B, Tq, Tv).use_causal_mask (bool) – Forwarded to MHA. Default False.
- Returns:
A tensor of shape (batch_size, time_steps, units) representing cross-attended features.
- Return type:
tf.Tensor
Hierarchical Attention
Memory-Augmented Attention
- class base_attentive.components.MemoryAugmentedAttention(units, memory_size=1, num_heads=1)[source]
Bases:
Layer,NNLearnerMemory-augmented attention with optional masking.
Transformer Encoder Layer
- class base_attentive.components.TransformerEncoderLayer(embed_dim=None, num_heads=1, ffn_dim=None, dropout_rate=0.1, ffn_activation='relu', layer_norm_epsilon=1e-06, *, units=None, **kwargs)[source]
Bases:
Layer,NNLearnerA single layer of the Transformer Encoder.
- Parameters:
Transformer Decoder Layer
- class base_attentive.components.TransformerDecoderLayer(embed_dim=None, num_heads=1, ffn_dim=None, dropout_rate=0.1, ffn_activation='relu', layer_norm_epsilon=1e-06, *, units=None, **kwargs)[source]
Bases:
Layer,NNLearnerA single layer of the Transformer Decoder. (Arguments similar to TransformerEncoderLayer)
- Parameters:
Multi-Decoder
- class base_attentive.components.MultiDecoder(output_dim=None, num_horizons=None, *, units=None, num_heads=None)[source]
Bases:
Layer,NNLearnerMultiDecoder for multi-horizon forecasting [1]_.
This layer takes a single feature vector per example of shape \((B, F)\) and produces a separate output for each horizon step, resulting in \((B, H, O)\).
\[\mathbf{Y}_h = \text{Dense}_h(\mathbf{x}),\, h \in [1..H]\]Each horizon has its own decoder layer.
- Parameters:
Notes
This layer is particularly useful when you want separate parameters for each horizon, instead of a single shared head.
Examples
>>> from geoprior.nn.components import MultiDecoder >>> import tensorflow as tf >>> # Input of shape (batch_size, feature_dim) >>> x = tf.random.normal((32, 128)) >>> # Instantiate multi-horizon decoder >>> decoder = MultiDecoder(output_dim=1, num_horizons=3) >>> # Output shape => (32, 3, 1) >>> y = decoder(x)
See also
MultiModalEmbeddingProvides feature embeddings that can be fed into MultiDecoder.
QuantileDistributionModelingProjects deterministic outputs into multiple quantiles per horizon.
References
[1] Lim, B., & Zohren, S. (2021). “Time-series forecasting with deep learning: a survey.” Philosophical Transactions of the Royal Society A, 379(2194), 20200209.
Initialize the MultiDecoder.
- Parameters:
- __init__(output_dim=None, num_horizons=None, *, units=None, num_heads=None)[source]
Initialize the MultiDecoder.
- call(x, training=False)[source]
Forward pass: each horizon has a separate Dense layer.
- Parameters:
x (tf.Tensor) – A 2D tensor (B, F).
training (bool, optional) – Unused in this layer. Defaults to
False.
- Returns:
A 3D tensor of shape (B, H, O).
- Return type:
tf.Tensor
- get_config()[source]
Returns layer configuration for serialization.
- Returns:
Dictionary containing ‘output_dim’ and ‘num_horizons’.
- Return type: