Source code for base_attentive.config.schema

"""Schema objects for BaseAttentive configuration."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any


@dataclass(frozen=True)
class BaseAttentiveArchitectureSpec:
    """Logical architecture choices for BaseAttentive models."""

    encoder_type: str = "transformer"
    decoder_attention_stack: tuple[str, ...] = ("cross",)
    feature_processing: str = "dense"


@dataclass(frozen=True)
class BaseAttentiveRuntimeSpec:
    """Runtime behavior shared across legacy and V2 paths."""

    mode: str | None = None
    num_encoder_layers: int = 2
    max_window_size: int = 10
    memory_size: int = 100
    scales: tuple[int, ...] | str = ()
    multi_scale_agg: str = "last"
    final_agg: str = "last"
    use_residuals: bool = True
    use_batch_norm: bool = False
    apply_dtw: bool = True
    verbose: int = 0


[docs] @dataclass(frozen=True) class BaseAttentiveComponentSpec: """Logical component selections for resolver-driven models.""" static_projection: str = "projection.static" dynamic_projection: str = "projection.dynamic" future_projection: str = "projection.future" dynamic_encoder: str = "encoder.temporal_self_attention" future_encoder: str = "encoder.temporal_self_attention" sequence_pooling: str = "pool.mean" fusion: str = "fusion.concat" hidden_projection: str = "projection.hidden" point_head: str = "head.point_forecast" quantile_head: str = "head.quantile_forecast" # Planned legacy-to-resolver keys. static_processor: str = "feature.static_processor" dynamic_processor: str = "feature.dynamic_processor" future_processor: str = "feature.future_processor" positional_encoder: str = "embedding.positional" hybrid_encoder: str = "encoder.hybrid_multiscale" dynamic_window: str = "encoder.dynamic_window" decoder_cross_attention: str = "decoder.cross_attention" decoder_hierarchical_attention: str = ( "decoder.hierarchical_attention" ) decoder_memory_attention: str = "decoder.memory_attention" decoder_fusion: str = "fusion.multi_resolution_attention" multi_horizon_head: str = "head.multi_horizon" quantile_distribution_head: str = ( "head.quantile_distribution" ) final_pool_last: str = "pool.final_last" final_pool_mean: str = "pool.final_mean" final_pool_flatten: str = "pool.final_flatten"
[docs] @dataclass(frozen=True) class BaseAttentiveSpec: """Backend-neutral configuration for BaseAttentive models.""" static_input_dim: int dynamic_input_dim: int future_input_dim: int output_dim: int = 1 forecast_horizon: int = 1 embed_dim: int = 32 hidden_units: int = 64 attention_heads: int = 4 layer_norm_epsilon: float = 1e-6 dropout_rate: float = 0.0 activation: str = "relu" backend_name: str = "tensorflow" head_type: str = "point" quantiles: tuple[float, ...] = () # Legacy-compatible scalar fields. lstm_units: int | tuple[int, ...] = 64 attention_units: int = 32 vsn_units: int | None = None architecture: BaseAttentiveArchitectureSpec = field( default_factory=BaseAttentiveArchitectureSpec ) runtime: BaseAttentiveRuntimeSpec = field( default_factory=BaseAttentiveRuntimeSpec ) components: BaseAttentiveComponentSpec = field( default_factory=BaseAttentiveComponentSpec ) extras: dict[str, Any] = field(default_factory=dict) @property def num_heads(self) -> int: """Legacy alias for ``attention_heads``.""" return self.attention_heads @property def num_encoder_layers(self) -> int: return self.runtime.num_encoder_layers @property def mode(self) -> str | None: return self.runtime.mode @property def max_window_size(self) -> int: return self.runtime.max_window_size @property def memory_size(self) -> int: return self.runtime.memory_size @property def scales(self) -> tuple[int, ...] | str: return self.runtime.scales @property def multi_scale_agg(self) -> str: return self.runtime.multi_scale_agg @property def final_agg(self) -> str: return self.runtime.final_agg @property def use_residuals(self) -> bool: return self.runtime.use_residuals @property def use_batch_norm(self) -> bool: return self.runtime.use_batch_norm @property def apply_dtw(self) -> bool: return self.runtime.apply_dtw @property def verbose(self) -> int: return self.runtime.verbose @property def objective(self) -> str: return self.architecture.encoder_type @property def use_vsn(self) -> bool: return self.architecture.feature_processing == "vsn" @property def attention_levels(self) -> tuple[str, ...]: return self.architecture.decoder_attention_stack
__all__ = [ "BaseAttentiveArchitectureSpec", "BaseAttentiveComponentSpec", "BaseAttentiveRuntimeSpec", "BaseAttentiveSpec", ]