API Reference

Public package entry points

from base_attentive import BaseAttentive
from base_attentive import make_fast_predict_fn
from base_attentive import get_backend, set_backend
from base_attentive import (
    get_available_backends,
    get_backend_capabilities,
    normalize_backend_name,
    detect_available_backends,
    select_best_backend,
    ensure_default_backend,
)
from base_attentive.validation import (
    validate_model_inputs,
    maybe_reduce_quantiles_bh,
    ensure_bh1,
)
from base_attentive.config import BaseAttentiveSpec, BaseAttentiveComponentSpec
from base_attentive.registry import (
    ComponentRegistry, ModelRegistry,
    DEFAULT_COMPONENT_REGISTRY, DEFAULT_MODEL_REGISTRY,
)

Top-level package

Public package surface for base_attentive.

class base_attentive.BaseAttentive(static_dim=None, dynamic_dim=None, future_dim=None, output_dim=1, forecast_horizon=1, mode=None, num_encoder_layers=2, quantiles=None, embed_dim=32, hidden_units=64, lstm_units=64, attention_units=32, num_heads=4, dropout_rate=0.1, lookback_window=10, memory_size=100, scales=None, multi_scale_agg='last', final_agg='last', activation='relu', use_residuals=True, use_vsn=True, vsn_units=None, use_batch_norm=False, apply_dtw=True, attention_stack=None, objective='hybrid', architecture_config=None, backend_name=None, component_overrides=None, verbose=0, output_mode=None, n_quantiles=None, name='BaseAttentive', *, static_input_dim=None, dynamic_input_dim=None, future_input_dim=None, max_window_size=None, attention_levels=None, **kwargs)[source]

Bases: BaseAttentiveV2, NNLearner

Compatibility wrapper over the resolver-driven V2 model.

The facade keeps the legacy constructor signature intact, converts the payload into BaseAttentiveSpec, and delegates the actual model assembly and execution to BaseAttentiveV2.

Parameters:
  • static_dim (int | None)

  • dynamic_dim (int | None)

  • future_dim (int | None)

  • output_dim (int)

  • forecast_horizon (int)

  • mode (str | None)

  • num_encoder_layers (int)

  • quantiles (list[float] | tuple[float, ...] | None)

  • embed_dim (int)

  • hidden_units (int)

  • lstm_units (int | tuple[int, ...])

  • attention_units (int)

  • num_heads (int)

  • dropout_rate (float)

  • lookback_window (int)

  • memory_size (int)

  • scales (list[int] | tuple[int, ...] | str | None)

  • multi_scale_agg (str)

  • final_agg (str)

  • activation (str)

  • use_residuals (bool)

  • use_vsn (bool)

  • vsn_units (int | None)

  • use_batch_norm (bool)

  • apply_dtw (bool)

  • attention_stack (str | list[str] | tuple[str, ...] | None)

  • objective (str)

  • architecture_config (dict[str, Any] | None)

  • backend_name (str | None)

  • component_overrides (dict[str, Any] | None)

  • verbose (int)

  • output_mode (str | None)

  • n_quantiles (int | None)

  • name (str)

  • static_input_dim (int | None)

  • dynamic_input_dim (int | None)

  • future_input_dim (int | None)

  • max_window_size (int | None)

  • attention_levels (str | list[str] | tuple[str, ...] | None)

get_config()[source]
Return type:

dict[str, Any]

classmethod from_config(config)[source]
Parameters:

config (dict[str, Any])

base_attentive.dependency_message(module_name)[source]

Return a dependency hint for missing runtime packages.

Parameters:

module_name (str)

Return type:

str

base_attentive.get_backend(name=None)[source]
Parameters:

name (str | None)

base_attentive.set_backend(name)[source]
Parameters:

name (str)

base_attentive.get_available_backends()[source]
base_attentive.get_backend_capabilities(name=None)[source]
Parameters:

name (str | None)

Return type:

dict[str, Any]

base_attentive.get_layer_class()[source]
base_attentive.get_model_class()[source]
base_attentive.register_keras_serializable(package='Custom', name=None)[source]
base_attentive.resolve_keras_dep(name, fallback=None)[source]
Parameters:
Return type:

Any

base_attentive.make_fast_predict_fn(model, *, jit_compile=True, reduce_retracing=True, warmup_inputs=None)[source]

Create a TensorFlow-traced prediction function for a Keras model.

The returned callable accepts the same input structure as model and always executes with training=False. This is useful when you want a reusable inference function with tf.function tracing and optional XLA compilation.

Parameters:
  • model (Any) – A Keras-compatible model or layer that can be called as model(inputs, training=False).

  • jit_compile (bool, default=True) – Whether to request XLA JIT compilation for the traced prediction function.

  • reduce_retracing (bool, default=True) – Whether TensorFlow should reduce retracing when input structures are reused.

  • warmup_inputs (Any, optional) – Example inputs used to trigger tracing before the callable is returned.

Returns:

A TensorFlow tf.function-wrapped prediction callable.

Return type:

callable

Raises:
  • RuntimeError – If the active package backend is not TensorFlow.

  • ImportError – If TensorFlow cannot be imported.

Core model

class base_attentive.core.base_attentive.BaseAttentive(static_dim=None, dynamic_dim=None, future_dim=None, output_dim=1, forecast_horizon=1, mode=None, num_encoder_layers=2, quantiles=None, embed_dim=32, hidden_units=64, lstm_units=64, attention_units=32, num_heads=4, dropout_rate=0.1, lookback_window=10, memory_size=100, scales=None, multi_scale_agg='last', final_agg='last', activation='relu', use_residuals=True, use_vsn=True, vsn_units=None, use_batch_norm=False, apply_dtw=True, attention_stack=None, objective='hybrid', architecture_config=None, backend_name=None, component_overrides=None, verbose=0, output_mode=None, n_quantiles=None, name='BaseAttentive', *, static_input_dim=None, dynamic_input_dim=None, future_input_dim=None, max_window_size=None, attention_levels=None, **kwargs)[source]

Bases: BaseAttentiveV2, NNLearner

Compatibility wrapper over the resolver-driven V2 model.

The facade keeps the legacy constructor signature intact, converts the payload into BaseAttentiveSpec, and delegates the actual model assembly and execution to BaseAttentiveV2.

Parameters:
  • static_dim (int | None)

  • dynamic_dim (int | None)

  • future_dim (int | None)

  • output_dim (int)

  • forecast_horizon (int)

  • mode (str | None)

  • num_encoder_layers (int)

  • quantiles (list[float] | tuple[float, ...] | None)

  • embed_dim (int)

  • hidden_units (int)

  • lstm_units (int | tuple[int, ...])

  • attention_units (int)

  • num_heads (int)

  • dropout_rate (float)

  • lookback_window (int)

  • memory_size (int)

  • scales (list[int] | tuple[int, ...] | str | None)

  • multi_scale_agg (str)

  • final_agg (str)

  • activation (str)

  • use_residuals (bool)

  • use_vsn (bool)

  • vsn_units (int | None)

  • use_batch_norm (bool)

  • apply_dtw (bool)

  • attention_stack (str | list[str] | tuple[str, ...] | None)

  • objective (str)

  • architecture_config (dict[str, Any] | None)

  • backend_name (str | None)

  • component_overrides (dict[str, Any] | None)

  • verbose (int)

  • output_mode (str | None)

  • n_quantiles (int | None)

  • name (str)

  • static_input_dim (int | None)

  • dynamic_input_dim (int | None)

  • future_input_dim (int | None)

  • max_window_size (int | None)

  • attention_levels (str | list[str] | tuple[str, ...] | None)

get_config()[source]
Return type:

dict[str, Any]

classmethod from_config(config)[source]
Parameters:

config (dict[str, Any])

V2 Configuration Schema

BaseAttentiveSpec is a frozen dataclass that fully describes a V2 model without referencing any backend objects.

class base_attentive.config.schema.BaseAttentiveSpec(static_input_dim, dynamic_input_dim, future_input_dim, output_dim=1, forecast_horizon=1, embed_dim=32, hidden_units=64, attention_heads=4, layer_norm_epsilon=1e-06, dropout_rate=0.0, activation='relu', backend_name='tensorflow', head_type='point', quantiles=(), lstm_units=64, attention_units=32, vsn_units=None, architecture=<factory>, runtime=<factory>, components=<factory>, extras=<factory>)[source]

Bases: object

Backend-neutral configuration for BaseAttentive models.

Parameters:
static_input_dim: int
dynamic_input_dim: int
future_input_dim: int
output_dim: int = 1
forecast_horizon: int = 1
embed_dim: int = 32
hidden_units: int = 64
attention_heads: int = 4
layer_norm_epsilon: float = 1e-06
dropout_rate: float = 0.0
activation: str = 'relu'
backend_name: str = 'tensorflow'
head_type: str = 'point'
quantiles: tuple[float, ...] = ()
lstm_units: int | tuple[int, ...] = 64
attention_units: int = 32
vsn_units: int | None = None
architecture: BaseAttentiveArchitectureSpec
runtime: BaseAttentiveRuntimeSpec
components: BaseAttentiveComponentSpec
extras: dict[str, Any]
property num_heads: int

Legacy alias for attention_heads.

property num_encoder_layers: int
property mode: str | None
property max_window_size: int
property memory_size: int
property scales: tuple[int, ...] | str
property multi_scale_agg: str
property final_agg: str
property use_residuals: bool
property use_batch_norm: bool
property apply_dtw: bool
property verbose: int
property objective: str
property use_vsn: bool
property attention_levels: tuple[str, ...]
class base_attentive.config.schema.BaseAttentiveComponentSpec(static_projection='projection.static', dynamic_projection='projection.dynamic', future_projection='projection.future', dynamic_encoder='encoder.temporal_self_attention', future_encoder='encoder.temporal_self_attention', sequence_pooling='pool.mean', fusion='fusion.concat', hidden_projection='projection.hidden', point_head='head.point_forecast', quantile_head='head.quantile_forecast', static_processor='feature.static_processor', dynamic_processor='feature.dynamic_processor', future_processor='feature.future_processor', positional_encoder='embedding.positional', hybrid_encoder='encoder.hybrid_multiscale', dynamic_window='encoder.dynamic_window', decoder_cross_attention='decoder.cross_attention', decoder_hierarchical_attention='decoder.hierarchical_attention', decoder_memory_attention='decoder.memory_attention', decoder_fusion='fusion.multi_resolution_attention', multi_horizon_head='head.multi_horizon', quantile_distribution_head='head.quantile_distribution', final_pool_last='pool.final_last', final_pool_mean='pool.final_mean', final_pool_flatten='pool.final_flatten')[source]

Bases: object

Logical component selections for resolver-driven models.

Parameters:
  • static_projection (str)

  • dynamic_projection (str)

  • future_projection (str)

  • dynamic_encoder (str)

  • future_encoder (str)

  • sequence_pooling (str)

  • fusion (str)

  • hidden_projection (str)

  • point_head (str)

  • quantile_head (str)

  • static_processor (str)

  • dynamic_processor (str)

  • future_processor (str)

  • positional_encoder (str)

  • hybrid_encoder (str)

  • dynamic_window (str)

  • decoder_cross_attention (str)

  • decoder_hierarchical_attention (str)

  • decoder_memory_attention (str)

  • decoder_fusion (str)

  • multi_horizon_head (str)

  • quantile_distribution_head (str)

  • final_pool_last (str)

  • final_pool_mean (str)

  • final_pool_flatten (str)

static_projection: str = 'projection.static'
dynamic_projection: str = 'projection.dynamic'
future_projection: str = 'projection.future'
dynamic_encoder: str = 'encoder.temporal_self_attention'
future_encoder: str = 'encoder.temporal_self_attention'
sequence_pooling: str = 'pool.mean'
fusion: str = 'fusion.concat'
hidden_projection: str = 'projection.hidden'
point_head: str = 'head.point_forecast'
quantile_head: str = 'head.quantile_forecast'
static_processor: str = 'feature.static_processor'
dynamic_processor: str = 'feature.dynamic_processor'
future_processor: str = 'feature.future_processor'
positional_encoder: str = 'embedding.positional'
hybrid_encoder: str = 'encoder.hybrid_multiscale'
dynamic_window: str = 'encoder.dynamic_window'
decoder_cross_attention: str = 'decoder.cross_attention'
decoder_hierarchical_attention: str = 'decoder.hierarchical_attention'
decoder_memory_attention: str = 'decoder.memory_attention'
decoder_fusion: str = 'fusion.multi_resolution_attention'
multi_horizon_head: str = 'head.multi_horizon'
quantile_distribution_head: str = 'head.quantile_distribution'
final_pool_last: str = 'pool.final_last'
final_pool_mean: str = 'pool.final_mean'
final_pool_flatten: str = 'pool.final_flatten'

Example:

from base_attentive.config import BaseAttentiveSpec, BaseAttentiveComponentSpec

spec = BaseAttentiveSpec(
    static_input_dim=4,
    dynamic_input_dim=8,
    future_input_dim=6,
    output_dim=1,
    forecast_horizon=24,
    embed_dim=32,
    hidden_units=64,
    attention_heads=4,
    dropout_rate=0.1,
    head_type="point",
    backend_name="tensorflow",
    components=BaseAttentiveComponentSpec(
        sequence_pooling="pool.last",
    ),
)

Registry

The registry stores named builder functions (for components) and assembler functions (for complete models), keyed by (name, backend).

class base_attentive.registry.ComponentRegistry[source]

Bases: object

Registry of backend-specific component builders.

register(key, builder, *, backend='generic', description='', experimental=False, replace=False)[source]
Parameters:
Return type:

ComponentRegistration

has(key, *, backend=None)[source]
Parameters:
  • key (str)

  • backend (str | None)

Return type:

bool

resolve(key, *, backend, allow_generic=True)[source]
Parameters:
Return type:

ComponentRegistration

list_keys()[source]
Return type:

list[str]

clone()[source]
Return type:

ComponentRegistry

class base_attentive.registry.ModelRegistry[source]

Bases: object

Registry of backend-specific model assemblers.

register(key, builder, *, backend='generic', description='', experimental=False, replace=False)[source]
Parameters:
Return type:

ModelRegistration

has(key, *, backend=None)[source]
Parameters:
  • key (str)

  • backend (str | None)

Return type:

bool

resolve(key, *, backend, allow_generic=True)[source]
Parameters:
Return type:

ModelRegistration

Pre-populated default registries:

from base_attentive.registry import (
    DEFAULT_COMPONENT_REGISTRY,
    DEFAULT_MODEL_REGISTRY,
)

DEFAULT_COMPONENT_REGISTRY.list_registered()
DEFAULT_COMPONENT_REGISTRY.has("encoder.temporal_self_attention", backend="generic")
builder = DEFAULT_COMPONENT_REGISTRY.resolve(
    "encoder.temporal_self_attention", backend="generic",
)

Resolver / Assembly

base_attentive.resolver.component_resolver.build_component(key, *, backend_context, registry=None, model_registry=None, allow_generic=True, spec=None, **kwargs)[source]

Resolve and build a component for the requested backend.

Parameters:
Return type:

Any

class base_attentive.resolver.assembly.BaseAttentiveV2Assembly(backend_context, static_projection, dynamic_projection, future_projection, dynamic_encoder, future_encoder, sequence_pool, fusion, hidden_projection, output_head, dropout=None, static_processor=None, dynamic_processor=None, future_processor=None, encoder_positional_encoding=None, future_positional_encoding=None, dynamic_window=None, decoder_input_projection=None, decoder_cross_attention=None, decoder_cross_postprocess=None, decoder_hierarchical_attention=None, decoder_memory_attention=None, decoder_fusion=None, residual_projection=None, decoder_residual_add=None, decoder_residual_norm=None, final_residual_add=None, final_residual_norm=None, final_pool=None, multi_horizon_head=None, quantile_distribution_head=None)[source]

Bases: object

Resolved V2 model components.

The assembly keeps the original V2 field names for compatibility, while also exposing migrated component names used by the legacy-to-resolver rewrite.

Parameters:
  • backend_context (BackendContext)

  • static_projection (Any | None)

  • dynamic_projection (Any)

  • future_projection (Any | None)

  • dynamic_encoder (Any | None)

  • future_encoder (Any | None)

  • sequence_pool (Any)

  • fusion (Any)

  • hidden_projection (Any)

  • output_head (Any)

  • dropout (Any | None)

  • static_processor (Any | None)

  • dynamic_processor (Any | None)

  • future_processor (Any | None)

  • encoder_positional_encoding (Any | None)

  • future_positional_encoding (Any | None)

  • dynamic_window (Any | None)

  • decoder_input_projection (Any | None)

  • decoder_cross_attention (Any | None)

  • decoder_cross_postprocess (Any | None)

  • decoder_hierarchical_attention (Any | None)

  • decoder_memory_attention (Any | None)

  • decoder_fusion (Any | None)

  • residual_projection (Any | None)

  • decoder_residual_add (Any | None)

  • decoder_residual_norm (Any | None)

  • final_residual_add (Any | None)

  • final_residual_norm (Any | None)

  • final_pool (Any | None)

  • multi_horizon_head (Any | None)

  • quantile_distribution_head (Any | None)

backend_context: BackendContext
static_projection: Any | None
dynamic_projection: Any
future_projection: Any | None
dynamic_encoder: Any | None
future_encoder: Any | None
sequence_pool: Any
fusion: Any
hidden_projection: Any
output_head: Any
dropout: Any | None = None
static_processor: Any | None = None
dynamic_processor: Any | None = None
future_processor: Any | None = None
encoder_positional_encoding: Any | None = None
future_positional_encoding: Any | None = None
dynamic_window: Any | None = None
decoder_input_projection: Any | None = None
decoder_cross_attention: Any | None = None
decoder_cross_postprocess: Any | None = None
decoder_hierarchical_attention: Any | None = None
decoder_memory_attention: Any | None = None
decoder_fusion: Any | None = None
residual_projection: Any | None = None
decoder_residual_add: Any | None = None
decoder_residual_norm: Any | None = None
final_residual_add: Any | None = None
final_residual_norm: Any | None = None
final_pool: Any | None = None
multi_horizon_head: Any | None = None
quantile_distribution_head: Any | None = None

Backend utilities

Lazy backend runtime abstraction for Base-Attentive.

This package exposes backend selection, capability inspection, and helper utilities without importing all backend implementations eagerly.

class base_attentive.backend.Backend(load_runtime=True)[source]

Bases: object

Base class for runtime backend descriptors.

Parameters:

load_runtime (bool)

name: str = 'base'
framework: str = 'unknown'
required_modules: tuple[str, ...] = ()
uses_keras_runtime: bool = False
experimental: bool = False
supports_base_attentive: bool = False
supports_base_attentive_v2: bool = False
blockers: tuple[str, ...] = ()
v2_blockers: tuple[str, ...] = ()
Tensor: Any = None
Layer: Any = None
Model: Any = None
Sequential: Any = None
Dense: Any = None
LSTM: Any = None
MultiHeadAttention: Any = None
LayerNormalization: Any = None
Dropout: Any = None
BatchNormalization: Any = None
is_available()[source]

Check whether the backend can be imported.

Return type:

bool

get_capabilities()[source]

Return a capability summary for the backend.

Return type:

dict[str, Any]

class base_attentive.backend.TensorFlowBackend(load_runtime=True)[source]

Bases: Backend

TensorFlow-backed runtime.

Parameters:

load_runtime (bool)

name: str = 'tensorflow'
framework: str = 'tensorflow'
required_modules: tuple[str, ...] = ('tensorflow',)
uses_keras_runtime: bool = True
supports_base_attentive: bool = True
supports_base_attentive_v2: bool = True
class base_attentive.backend.JaxBackend(load_runtime=True)[source]

Bases: Backend

Keras-on-JAX runtime descriptor.

Parameters:

load_runtime (bool)

name: str = 'jax'
framework: str = 'jax'
required_modules: tuple[str, ...] = ('keras', 'jax')
uses_keras_runtime: bool = True
experimental: bool = True
supports_base_attentive: bool = False
supports_base_attentive_v2: bool = True
blockers: tuple[str, ...] = ('BaseAttentive still contains TensorFlow-oriented compatibility paths.', 'The compat.tf helpers are still TensorFlow-specific.', 'Some runtime shape/assert checks still assume TensorFlow graph semantics.')
v2_blockers: tuple[str, ...] = ('Advanced encoder-decoder blocks are still being ported through the V2 registry path.', 'Cross-backend serialization parity for the full V2 model is still under validation.')
class base_attentive.backend.TorchBackend(load_runtime=True)[source]

Bases: Backend

Keras-on-Torch runtime descriptor.

Parameters:

load_runtime (bool)

name: str = 'torch'
framework: str = 'torch'
required_modules: tuple[str, ...] = ('keras', 'torch')
uses_keras_runtime: bool = True
experimental: bool = True
supports_base_attentive: bool = False
supports_base_attentive_v2: bool = True
blockers: tuple[str, ...] = ('BaseAttentive still contains TensorFlow-oriented compatibility paths.', 'The compat.tf helpers are still TensorFlow-specific.', 'Some runtime shape/assert checks still assume TensorFlow graph semantics.')
v2_blockers: tuple[str, ...] = ('Advanced encoder-decoder blocks are still being ported through the V2 registry path.', 'Cross-backend serialization parity for the full V2 model is still under validation.')
class base_attentive.backend.PyTorchBackend(load_runtime=True)[source]

Bases: TorchBackend

Backward-compatible alias for the Torch runtime.

Parameters:

load_runtime (bool)

name: str = 'pytorch'
base_attentive.backend.get_backend(name=None)[source]
Parameters:

name (str | None)

base_attentive.backend.set_backend(name)[source]
Parameters:

name (str)

base_attentive.backend.get_available_backends()[source]
base_attentive.backend.get_backend_capabilities(name=None)[source]
Parameters:

name (str | None)

Return type:

dict[str, Any]

base_attentive.backend.normalize_backend_name(name)[source]
Parameters:

name (str | None)

Return type:

str

base_attentive.backend.detect_available_backends()[source]
base_attentive.backend.select_best_backend(prefer=None, require_supported=True)[source]
Parameters:
  • prefer (str | None)

  • require_supported (bool)

base_attentive.backend.ensure_default_backend(auto_install=False, install_tensorflow=True)[source]
Parameters:
  • auto_install (bool)

  • install_tensorflow (bool)

Return type:

str

base_attentive.backend.get_backend_version(name)[source]
Parameters:

name (str)

base_attentive.backend.check_tensorflow_compatibility()[source]
base_attentive.backend.check_torch_compatibility()[source]
base_attentive.backend.parse_version(version)[source]
Parameters:

version (str)

base_attentive.backend.version_at_least(version, minimum)[source]
Parameters:
base_attentive.backend.get_torch_device(prefer='cuda', verbose=True)[source]

Get the best available device for PyTorch computations.

Parameters:
  • prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type. - ‘cuda’: NVIDIA GPU (with CUDA support) - ‘cpu’: CPU - ‘mps’: Apple Metal Performance Shaders (macOS)

  • verbose (bool, default=True) – Whether to log device selection info.

Returns:

Device string for use with PyTorch (e.g., ‘cuda:0’, ‘cpu’).

Return type:

str

Examples

>>> device = get_torch_device()
>>> # 'cuda:0' if available, else 'cpu'
>>> device = get_torch_device(prefer="cpu")
>>> # 'cpu' always
base_attentive.backend.get_torch_version()[source]

Get installed PyTorch version.

Returns:

Version string (e.g., “2.0.1”) or None if not installed.

Return type:

str or None

base_attentive.backend.torch_is_available()[source]

Check if PyTorch is installed and importable.

Returns:

True if PyTorch is available.

Return type:

bool

class base_attentive.backend.TorchDeviceManager(prefer='cuda')[source]

Bases: object

Utility class for managing PyTorch device selection and configuration.

Initialize device manager.

Parameters:

prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type.

__init__(prefer='cuda')[source]

Initialize device manager.

Parameters:

prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type.

property device: str

Get the selected device.

set_device(device)[source]

Set the device explicitly.

Parameters:

device (str) – Device string or name.

Returns:

The set device string.

Return type:

str

get_available_devices()[source]

Get availability of different device types.

Returns:

Mapping of device types to availability.

Return type:

dict

get_device_info()[source]

Get detailed information about available devices.

Returns:

Device information including GPU count, names, memory, etc.

Return type:

dict

reset_cache()[source]

Clear PyTorch cache to free memory.

Return type:

None

Core backend functions:

from base_attentive import get_backend, set_backend
from base_attentive import get_available_backends, get_backend_capabilities
from base_attentive import normalize_backend_name

b = get_backend()
print(b.name)                        # e.g. 'tensorflow'
set_backend("tensorflow")
get_available_backends()             # ['tensorflow', ...]
get_backend_capabilities()           # {'name': ..., 'version': ..., ...}
normalize_backend_name("tf")         # -> "tensorflow"

Detection and selection:

from base_attentive import detect_available_backends, select_best_backend
from base_attentive import ensure_default_backend

info = detect_available_backends()
# {'tensorflow': {'available': True, 'version': '2.x'}, ...}

best = select_best_backend()
name = ensure_default_backend()

Version compatibility checking:

from base_attentive.backend import (
    check_tensorflow_compatibility,
    check_torch_compatibility,
    get_backend_version,
    version_at_least,
)

ok, msg = check_tensorflow_compatibility()
ok, msg = check_torch_compatibility()
ver     = get_backend_version("tensorflow")
ok      = version_at_least("2.13.0", "2.12.0")

PyTorch device utilities

base_attentive.backend.torch_is_available()[source]

Check if PyTorch is installed and importable.

Returns:

True if PyTorch is available.

Return type:

bool

base_attentive.backend.get_torch_version()[source]

Get installed PyTorch version.

Returns:

Version string (e.g., “2.0.1”) or None if not installed.

Return type:

str or None

base_attentive.backend.get_torch_device(prefer='cuda', verbose=True)[source]

Get the best available device for PyTorch computations.

Parameters:
  • prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type. - ‘cuda’: NVIDIA GPU (with CUDA support) - ‘cpu’: CPU - ‘mps’: Apple Metal Performance Shaders (macOS)

  • verbose (bool, default=True) – Whether to log device selection info.

Returns:

Device string for use with PyTorch (e.g., ‘cuda:0’, ‘cpu’).

Return type:

str

Examples

>>> device = get_torch_device()
>>> # 'cuda:0' if available, else 'cpu'
>>> device = get_torch_device(prefer="cpu")
>>> # 'cpu' always
class base_attentive.backend.TorchDeviceManager(prefer='cuda')[source]

Bases: object

Utility class for managing PyTorch device selection and configuration.

Initialize device manager.

Parameters:

prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type.

__init__(prefer='cuda')[source]

Initialize device manager.

Parameters:

prefer ({'cuda', 'cpu', 'mps'}, default='cuda') – Preferred device type.

property device: str

Get the selected device.

set_device(device)[source]

Set the device explicitly.

Parameters:

device (str) – Device string or name.

Returns:

The set device string.

Return type:

str

get_available_devices()[source]

Get availability of different device types.

Returns:

Mapping of device types to availability.

Return type:

dict

get_device_info()[source]

Get detailed information about available devices.

Returns:

Device information including GPU count, names, memory, etc.

Return type:

dict

reset_cache()[source]

Clear PyTorch cache to free memory.

Return type:

None

Example:

from base_attentive.backend import TorchDeviceManager, get_torch_device

device  = get_torch_device(prefer="cuda", verbose=True)

manager = TorchDeviceManager(prefer="cuda")
print(manager.device)
print(manager.get_available_devices())
info = manager.get_device_info()
manager.clear_gpu_cache()

Learner mixin

Property and base class definitions for NN learners.

class base_attentive.api.property.NNLearner[source]

Bases: object

Base class for neural network learners.

Provides parameter management, introspection, and a compact pretty-printer for NN components.

get_params(deep=True)[source]

Get the parameters for this learner.

Parameters:

deep (bool)

Return type:

dict[str, Any]

set_params(**params)[source]

Set the parameters of this learner.

Parameters:

params (Any)

Return type:

NNLearner

Validation helpers

Utilities for backend-agnostic tensor validation.

base_attentive.validation.validate_model_inputs(inputs, static_input_dim=None, dynamic_input_dim=None, future_covariate_dim=None, forecast_horizon=None, error='raise', mode='strict', deep_check=None, model_name=None, verbose=0, **kwargs)[source]

Validate and homogenize input tensors for model workflows.

This entrypoint intentionally stays lightweight: it normalizes the input container shape and converts values into the active Keras runtime tensor type when a runtime is available. When no Keras runtime is configured, the raw values are returned unchanged. None inputs are normalized to (None, None, None).

Parameters:
  • inputs (Any | ndarray | list)

  • static_input_dim (int | None)

  • dynamic_input_dim (int | None)

  • future_covariate_dim (int | None)

  • forecast_horizon (int | None)

  • error (str)

  • mode (str)

  • deep_check (bool | None)

  • model_name (str | None)

  • verbose (int)

Return type:

Tuple[Any | None, Any | None, Any | None]

base_attentive.validation.maybe_reduce_quantiles_bh(x, *, name='tensor', axis=2, reduction='mean')[source]

Reduce a quantile axis when a backend tensor carries one.

Parameters:
Return type:

Any

base_attentive.validation.ensure_bh1(x, *, name='tensor', dtype=None, reduce_axis=None, reduction='mean', allow_rank1=False)[source]

Ensure a tensor-like value has shape (B, H, 1).

Parameters:
  • x (Any)

  • name (str)

  • dtype (Any | None)

  • reduce_axis (int | None)

  • reduction (str | callable)

  • allow_rank1 (bool)

Return type:

Any

from base_attentive.validation import (
    validate_model_inputs,
    maybe_reduce_quantiles_bh,
    ensure_bh1,
)

static, dynamic, future = validate_model_inputs(
    [x_static, x_dynamic, x_future],
    static_input_dim=4,
    dynamic_input_dim=8,
)

reduced  = maybe_reduce_quantiles_bh(predictions)
reshaped = ensure_bh1(output)

Runtime helpers

Runtime helpers for accelerated inference.

base_attentive.runtime.make_fast_predict_fn(model, *, jit_compile=True, reduce_retracing=True, warmup_inputs=None)[source]

Create a TensorFlow-traced prediction function for a Keras model.

The returned callable accepts the same input structure as model and always executes with training=False. This is useful when you want a reusable inference function with tf.function tracing and optional XLA compilation.

Parameters:
  • model (Any) – A Keras-compatible model or layer that can be called as model(inputs, training=False).

  • jit_compile (bool, default=True) – Whether to request XLA JIT compilation for the traced prediction function.

  • reduce_retracing (bool, default=True) – Whether TensorFlow should reduce retracing when input structures are reused.

  • warmup_inputs (Any, optional) – Example inputs used to trigger tracing before the callable is returned.

Returns:

A TensorFlow tf.function-wrapped prediction callable.

Return type:

callable

Raises:
  • RuntimeError – If the active package backend is not TensorFlow.

  • ImportError – If TensorFlow cannot be imported.

from base_attentive import make_fast_predict_fn

fast_predict = make_fast_predict_fn(
    model,
    jit_compile=True,
    reduce_retracing=True,
    warmup_inputs=[x_static, x_dynamic, x_future],
)
predictions = fast_predict([x_static, x_dynamic, x_future])

Component utilities

from base_attentive.components.utils import (
    resolve_attn_levels,
    configure_architecture,
    resolve_fusion_mode,
)

resolve_attn_levels(None)                # ['cross', 'hierarchical', 'memory']
resolve_attn_levels("cross")             # ['cross']
resolve_attn_levels(1)                   # ['cross']
resolve_attn_levels(["cross", "memory"]) # ['cross', 'memory']

arch = configure_architecture(
    objective="hybrid",
    use_vsn=True,
    attention_levels=["cross", "hierarchical"],
)

resolve_fusion_mode(None)      # 'integrated'
resolve_fusion_mode("disjoint") # 'disjoint'

Key Components

Variable Selection Network

class base_attentive.components.VariableSelectionNetwork(num_inputs, units, dropout_rate=0.0, use_time_distributed=False, activation='elu', use_batch_norm=False, **kwargs)[source]

Bases: Layer, NNLearner

Applies GRN to each variable and learns importance weights.

Parameters:
  • num_inputs (int)

  • units (int)

  • dropout_rate (float)

  • use_time_distributed (bool)

  • activation (str)

  • use_batch_norm (bool)

build(input_shape)[source]

Builds internal GRNs and projection layers with explicit shapes.

call(inputs, context=None, training=False)[source]

Execute the forward pass with optional context.

get_config()[source]

Returns the layer configuration.

classmethod from_config(config)[source]

Creates layer from its config.

Multi-Scale LSTM

class base_attentive.components.MultiScaleLSTM(lstm_units=None, scales=None, return_sequences=False, *, units=None, **kwargs)[source]

Bases: Layer, NNLearner

MultiScaleLSTM layer applying multiple LSTMs at different sampling scales and concatenating their outputs [1]_.

Each LSTM can either return the full sequence or only the last hidden state, controlled by return_sequences. The user specifies scales to sub-sample the time dimension. For example, a scale of 2 processes every 2nd time step.

Parameters:
  • lstm_units (int) – Number of units in each LSTM.

  • scales (list of int or str or None, optional) – List of scale factors. If ‘auto’ or None, defaults to [1] (no sub-sampling).

  • return_sequences (bool, optional) – If True, each LSTM returns the entire sequence. Otherwise, it returns only the last hidden state. Defaults to False.

  • **kwargs – Additional arguments passed to the parent Keras Layer.

  • units (int | None)

Notes

  • If return_sequences=False, the output is concatenated along features: \((B, \text{units} \times \text{num\_scales})\).

  • If return_sequences=True, a list of sequence outputs is returned. Each may have a different time dimension if scales differ.

call(`inputs`, training=False)[source]

Forward pass, applying each LSTM at the specified scale.

get_config()[source]

Returns the layer’s configuration dict.

from_config(`config`)[source]

Builds the layer from the config dict.

Examples

>>> from geoprior.nn.components import MultiScaleLSTM
>>> import tensorflow as tf
>>> x = tf.random.normal((32, 20, 16))  # (B, T, D)
>>> # Instantiating a multi-scale LSTM
>>> mslstm = MultiScaleLSTM(
...     lstm_units=32,
...     scales=[1, 2],
...     return_sequences=False,
... )
>>> y = mslstm(x)  # shape => (32, 64)
>>> # because scale=1 and scale=2 each produce 32 units,
... # which are concatenated => 64

See also

DynamicTimeWindow

For slicing sequences before applying multi-scale LSTMs.

TemporalFusionTransformer

A complex model that can incorporate multi-scale modules.

References

call(inputs, training=False)[source]

Forward pass that processes the input at multiple scales.

Parameters:
  • inputs (tf.Tensor) – Shape (B, T, D).

  • training (bool, optional) – Training mode. Defaults to False.

Returns:

  • If return_sequences=False, returns a single 2D tensor of shape (B, lstm_units * len(scales)).

  • If return_sequences=True, returns a list of 3D tensors, each with shape (B, T’, lstm_units), where T’ depends on the scale sub-sampling.

Return type:

tf.Tensor or list of tf.Tensor

get_config()[source]

Returns a config dictionary containing ‘lstm_units’, ‘scales’, and ‘return_sequences’.

Returns:

Configuration dictionary.

Return type:

dict

classmethod from_config(config)[source]

Builds MultiScaleLSTM from the given config dictionary.

Parameters:

config (dict) – Must include ‘lstm_units’, ‘scales’, ‘return_sequences’.

Returns:

A new instance of this layer.

Return type:

MultiScaleLSTM

Cross-Attention

class base_attentive.components.CrossAttention(units, num_heads)[source]

Bases: Layer, NNLearner

CrossAttention that attends source1 (query) to source2 (key/value) with optional masks.

attention_maskTensor, optional

Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed directly to Keras MultiHeadAttention.

query_mask, value_maskTensor, optional

1D/2D masks (B, Tq) or (B, Tv). If provided and attention_mask is None, they are combined to form (B, Tq, Tv).

use_causal_maskbool

Forwarded to MHA. Default False.

Parameters:
call(inputs, training=False, *, attention_mask=None, query_mask=None, value_mask=None, use_causal_mask=False, **kwargs)[source]

Forward pass of CrossAttention.

Parameters:
  • inputs (list of tf.Tensor) – A list [source1, source2], each of shape (batch_size, time_steps, features).

  • training (bool, optional) – Indicates if the layer is in training mode (for dropout, if any). Defaults to False.

  • attention_mask (Tensor, optional) – Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed directly to Keras MultiHeadAttention.

  • query_mask (Tensor, optional) – 1D/2D masks (B, Tq) or (B, Tv). If provided and attention_mask is None, they are combined to form (B, Tq, Tv).

  • value_mask (Tensor, optional) – 1D/2D masks (B, Tq) or (B, Tv). If provided and attention_mask is None, they are combined to form (B, Tq, Tv).

  • use_causal_mask (bool) – Forwarded to MHA. Default False.

Returns:

A tensor of shape (batch_size, time_steps, units) representing cross-attended features.

Return type:

tf.Tensor

get_config()[source]
classmethod from_config(config)[source]

Hierarchical Attention

class base_attentive.components.HierarchicalAttention(units, num_heads)[source]

Bases: Layer, NNLearner

Short/long-term MHA with optional masks.

Parameters:
call(inputs, training=False, *, short_mask=None, long_mask=None, use_causal_mask=False, **kwargs)[source]
Parameters:
  • training (bool)

  • short_mask (Any | None)

  • long_mask (Any | None)

  • use_causal_mask (bool)

get_config()[source]
classmethod from_config(config)[source]

Memory-Augmented Attention

class base_attentive.components.MemoryAugmentedAttention(units, memory_size=1, num_heads=1)[source]

Bases: Layer, NNLearner

Memory-augmented attention with optional masking.

Parameters:
  • units (int)

  • memory_size (int)

  • num_heads (int)

build(input_shape)[source]
call(inputs, training=False, *, attention_mask=None, query_mask=None, value_mask=None, use_causal_mask=False, **kwargs)[source]
Parameters:
  • training (bool)

  • attention_mask (Any | None)

  • query_mask (Any | None)

  • value_mask (Any | None)

  • use_causal_mask (bool)

get_config()[source]
classmethod from_config(config)[source]

Transformer Encoder Layer

class base_attentive.components.TransformerEncoderLayer(embed_dim=None, num_heads=1, ffn_dim=None, dropout_rate=0.1, ffn_activation='relu', layer_norm_epsilon=1e-06, *, units=None, **kwargs)[source]

Bases: Layer, NNLearner

A single layer of the Transformer Encoder.

Parameters:
  • (int) (ffn_dim)

  • (int)

  • (int)

  • (float) (layer_norm_epsilon)

  • (str) (ffn_activation)

  • (float)

  • embed_dim (int | None)

  • num_heads (int)

  • ffn_dim (int | None)

  • dropout_rate (float)

  • ffn_activation (str)

  • layer_norm_epsilon (float)

  • units (int | None)

call(x, training=False, attention_mask=None)[source]
Parameters:
Return type:

ndarray

get_config()[source]

Transformer Decoder Layer

class base_attentive.components.TransformerDecoderLayer(embed_dim=None, num_heads=1, ffn_dim=None, dropout_rate=0.1, ffn_activation='relu', layer_norm_epsilon=1e-06, *, units=None, **kwargs)[source]

Bases: Layer, NNLearner

A single layer of the Transformer Decoder. (Arguments similar to TransformerEncoderLayer)

Parameters:
  • embed_dim (int | None)

  • num_heads (int)

  • ffn_dim (int | None)

  • dropout_rate (float)

  • ffn_activation (str)

  • layer_norm_epsilon (float)

  • units (int | None)

call(x, enc_output=None, training=False, look_ahead_mask=None, padding_mask=None)[source]
Parameters:
Return type:

ndarray

get_config()[source]

Multi-Decoder

class base_attentive.components.MultiDecoder(output_dim=None, num_horizons=None, *, units=None, num_heads=None)[source]

Bases: Layer, NNLearner

MultiDecoder for multi-horizon forecasting [1]_.

This layer takes a single feature vector per example of shape \((B, F)\) and produces a separate output for each horizon step, resulting in \((B, H, O)\).

\[\mathbf{Y}_h = \text{Dense}_h(\mathbf{x}),\, h \in [1..H]\]

Each horizon has its own decoder layer.

Parameters:
  • output_dim (int) – Number of output features for each horizon.

  • num_horizons (int) – Number of forecast horizons.

  • units (int | None)

  • num_heads (int | None)

Notes

This layer is particularly useful when you want separate parameters for each horizon, instead of a single shared head.

call(`x`, training=False)[source]

Forward pass that produces horizon-specific outputs.

get_config()[source]

Returns configuration for serialization.

from_config(`config`)[source]

Builds a new instance from config.

Examples

>>> from geoprior.nn.components import MultiDecoder
>>> import tensorflow as tf
>>> # Input of shape (batch_size, feature_dim)
>>> x = tf.random.normal((32, 128))
>>> # Instantiate multi-horizon decoder
>>> decoder = MultiDecoder(output_dim=1, num_horizons=3)
>>> # Output shape => (32, 3, 1)
>>> y = decoder(x)

See also

MultiModalEmbedding

Provides feature embeddings that can be fed into MultiDecoder.

QuantileDistributionModeling

Projects deterministic outputs into multiple quantiles per horizon.

References

Initialize the MultiDecoder.

Parameters:
  • output_dim (int) – Number of features each horizon decoder should output.

  • num_horizons (int) – Number of horizons to predict, each with its own Dense layer.

  • units (int | None)

  • num_heads (int | None)

__init__(output_dim=None, num_horizons=None, *, units=None, num_heads=None)[source]

Initialize the MultiDecoder.

Parameters:
  • output_dim (int) – Number of features each horizon decoder should output.

  • num_horizons (int) – Number of horizons to predict, each with its own Dense layer.

  • units (int | None)

  • num_heads (int | None)

call(x, training=False)[source]

Forward pass: each horizon has a separate Dense layer.

Parameters:
  • x (tf.Tensor) – A 2D tensor (B, F).

  • training (bool, optional) – Unused in this layer. Defaults to False.

Returns:

A 3D tensor of shape (B, H, O).

Return type:

tf.Tensor

get_config()[source]

Returns layer configuration for serialization.

Returns:

Dictionary containing ‘output_dim’ and ‘num_horizons’.

Return type:

dict

classmethod from_config(config)[source]

Create a new MultiDecoder from the config.

Parameters:

config (dict) – Contains ‘output_dim’, ‘num_horizons’.

Returns:

A new instance.

Return type:

MultiDecoder