Source code for base_attentive.components.temporal

# SPDX-License-Identifier: Apache-2.0
# Author: LKouadio <etanoyau@gmail.com>
# Adapted from: earthai-tech/fusionlab-learn https://github.com/earthai-tech/gofast
# Modified for GeoPrior-v3 API

"""
Temporal modules:
    - MultiScaleLSTM
    - aggregate_multiscale
    - aggregate_multiscale_on_3d
    - aggregate_time_window_output
"""

from __future__ import annotations

from ..api.property import NNLearner
from ..core.checks import validate_nested_param
from ..utils.deps_utils import ensure_pkg
from ._config import (
    DEP_MSG,
    KERAS_BACKEND,
    LSTM,
    Layer,
    concat,
    do_not_convert,
    register_keras_serializable,
)

__all__ = [
    "MultiScaleLSTM",
    "DynamicTimeWindow",
]
SERIALIZATION_PACKAGE = __name__


[docs] @register_keras_serializable( SERIALIZATION_PACKAGE, name="MultiScaleLSTM" ) class MultiScaleLSTM(Layer, NNLearner): r""" MultiScaleLSTM layer applying multiple LSTMs at different sampling scales and concatenating their outputs [1]_. Each LSTM can either return the full sequence or only the last hidden state, controlled by `return_sequences`. The user specifies `scales` to sub-sample the time dimension. For example, a scale of 2 processes every 2nd time step. Parameters ---------- lstm_units : int Number of units in each LSTM. scales : list of int or str or None, optional List of scale factors. If `'auto'` or None, defaults to `[1]` (no sub-sampling). return_sequences : bool, optional If True, each LSTM returns the entire sequence. Otherwise, it returns only the last hidden state. Defaults to False. **kwargs Additional arguments passed to the parent Keras `Layer`. Notes ----- - If `return_sequences=False`, the output is concatenated along features: :math:`(B, \text{units} \times \text{num\_scales})`. - If `return_sequences=True`, a list of sequence outputs is returned. Each may have a different time dimension if scales differ. Methods ------- call(`inputs`, training=False) Forward pass, applying each LSTM at the specified scale. get_config() Returns the layer's configuration dict. from_config(`config`) Builds the layer from the config dict. Examples -------- >>> from geoprior.nn.components import MultiScaleLSTM >>> import tensorflow as tf >>> x = tf.random.normal((32, 20, 16)) # (B, T, D) >>> # Instantiating a multi-scale LSTM >>> mslstm = MultiScaleLSTM( ... lstm_units=32, ... scales=[1, 2], ... return_sequences=False, ... ) >>> y = mslstm(x) # shape => (32, 64) >>> # because scale=1 and scale=2 each produce 32 units, ... # which are concatenated => 64 See Also -------- DynamicTimeWindow For slicing sequences before applying multi-scale LSTMs. TemporalFusionTransformer A complex model that can incorporate multi-scale modules. References ---------- .. [1] Lim, B., & Zohren, S. (2021). "Time-series forecasting with deep learning: a survey." *Philosophical Transactions of the Royal Society A*, 379(2194), 20200209. """ @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__( self, lstm_units: int | None = None, scales: str | list[int] | None = None, return_sequences: bool = False, *, units: int | None = None, **kwargs, ): super().__init__(**kwargs) if lstm_units is None: lstm_units = units if lstm_units is None: raise ValueError( "Provide `lstm_units` or `units`." ) if scales is None or scales == "auto": scales = [1] # Validate that scales is a list of int scales = validate_nested_param( scales, list[int], "scales" ) self.lstm_units = lstm_units self.scales = scales self.return_sequences = return_sequences # Create an LSTM for each scale self.lstm_layers = [ LSTM( lstm_units, return_sequences=return_sequences ) for _ in scales ]
[docs] @do_not_convert def call(self, inputs, training=False): r""" Forward pass that processes the input at multiple scales. Parameters ---------- ``inputs`` : tf.Tensor Shape (B, T, D). training : bool, optional Training mode. Defaults to ``False``. Returns ------- tf.Tensor or list of tf.Tensor - If `return_sequences=False`, returns a single 2D tensor of shape (B, lstm_units * len(scales)). - If `return_sequences=True`, returns a list of 3D tensors, each with shape (B, T', lstm_units), where T' depends on the scale sub-sampling. """ # On macOS MPS, aten::linalg_qr is not implemented; enabling the # PyTorch CPU fallback flag allows those ops to run on CPU transparently. _dev = getattr(inputs, "device", None) if _dev is not None and str(_dev).startswith("mps"): import os os.environ.setdefault( "PYTORCH_ENABLE_MPS_FALLBACK", "1" ) outputs = [] for scale, lstm in zip(self.scales, self.lstm_layers): scaled_input = inputs[:, ::scale, :] lstm_output = lstm( scaled_input, training=training ) outputs.append(lstm_output) # If return_sequences=False: # => (B, units) from each sub-lstm # -> concat => (B, units*len(scales)) if not self.return_sequences: return concat(outputs, axis=-1) else: # return a list of sequences return outputs
[docs] def get_config(self): r""" Returns a config dictionary containing 'lstm_units', 'scales', and 'return_sequences'. Returns ------- dict Configuration dictionary. """ config = super().get_config().copy() config.update( { "lstm_units": self.lstm_units, "scales": self.scales, "return_sequences": self.return_sequences, } ) return config
[docs] @classmethod def from_config(cls, config): r""" Builds MultiScaleLSTM from the given config dictionary. Parameters ---------- ``config`` : dict Must include 'lstm_units', 'scales', 'return_sequences'. Returns ------- MultiScaleLSTM A new instance of this layer. """ return cls(**config)
@register_keras_serializable( SERIALIZATION_PACKAGE, name="DynamicTimeWindow" ) class DynamicTimeWindow(Layer, NNLearner): r""" DynamicTimeWindow layer that slices the last `max_window_size` steps from the input sequence. This helps in focusing on the most recent time steps if the sequence is longer than `max_window_size`. .. math:: \mathbf{Z} = \mathbf{X}[:, -W:, :] where `W` = `max_window_size`. Parameters ---------- max_window_size : int Number of time steps to keep from the end of the sequence. Notes ----- This can be used for models that only need the last few time steps instead of the entire sequence. Methods ------- call(`inputs`, training=False) Slice the last `max_window_size` steps. get_config() Returns configuration dictionary. from_config(`config`) Recreates the layer from config. Examples -------- >>> from geoprior.nn.components import DynamicTimeWindow >>> import tensorflow as tf >>> x = tf.random.normal((32, 50, 64)) >>> # Keep last 10 time steps >>> dtw = DynamicTimeWindow(max_window_size=10) >>> y = dtw(x) >>> y.shape TensorShape([32, 10, 64]) See Also -------- MultiResolutionAttentionFusion Another layer that can be used after slicing to fuse temporal features. References ---------- .. [1] Lim, B., & Zohren, S. (2021). "Time-series forecasting with deep learning: a survey." *Philosophical Transactions of the Royal Society A*, 379(2194), 20200209. """ @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__( self, max_window_size: int | None = None, *, units: int | None = None, ): r""" Initialize the DynamicTimeWindow layer. Parameters ---------- max_window_size : int Number of steps to slice from the end of the sequence. """ super().__init__() if max_window_size is None: max_window_size = ( units if units is not None else 1 ) self.max_window_size = max_window_size def call(self, inputs, training=False): r""" Forward pass that slices the last `max_window_size` steps. Parameters ---------- ``inputs`` : tf.Tensor Tensor of shape :math:`(B, T, D)`. training : bool, optional Unused. Defaults to ``False``. Returns ------- tf.Tensor A sliced tensor of shape :math:`(B, W, D)` where W = `max_window_size`. """ return inputs[:, -self.max_window_size :, :] def get_config(self): r""" Returns configuration dictionary. Returns ------- dict Contains 'max_window_size'. """ config = super().get_config().copy() config.update( {"max_window_size": self.max_window_size} ) return config @classmethod def from_config(cls, config): r""" Creates a new DynamicTimeWindow layer from config. Parameters ---------- ``config`` : dict Must include 'max_window_size'. Returns ------- DynamicTimeWindow A new instance of this layer. """ return cls(**config)