Source code for base_attentive.components.attention

# SPDX-License-Identifier: Apache-2.0
# Author: LKouadio <etanoyau@gmail.com>
# Adapted from: earthai-tech/fusionlab-learn — https://github.com/earthai-tech/gofast
# Modified for GeoPrior-v3 API.

"""
Attention-centric layers for FusionLab.
"""

from __future__ import annotations

from numbers import Integral, Real

from ..api.property import NNLearner
from ..compat.sklearn import Interval, validate_params
from ..compat.types import TensorLike
from ..utils.deps_utils import ensure_pkg
from ._config import (
    DEP_MSG,
    KERAS_BACKEND,
    Dense,
    Dropout,
    Layer,
    LayerNormalization,
    MultiHeadAttention,
    _logger,
    add,
    bool_dtype,
    cast,
    do_not_convert,
    expand_dims,
    logical_and,
    ones,
    ones_like,
    register_keras_serializable,
    shape,
    tile,
)
from .gating_norm import GatedResidualNetwork
from .misc import Activation

__all__ = [
    "TemporalAttentionLayer",
    "CrossAttention",
    "MemoryAugmentedAttention",
    "HierarchicalAttention",
    "ExplainableAttention",
    "MultiResolutionAttentionFusion",
]
SERIALIZATION_PACKAGE = __name__


@register_keras_serializable(
    SERIALIZATION_PACKAGE, name="TemporalAttentionLayer"
)
class TemporalAttentionLayer(Layer):
    """Temporal Attention Layer conditioning query with context."""

    @validate_params(
        {
            "units": [
                Interval(Integral, 0, None, closed="left")
            ],
            "num_heads": [
                Interval(Integral, 0, None, closed="left")
            ],
            "dropout_rate": [
                Interval(Real, 0, 1, closed="both")
            ],
            "use_batch_norm": [bool],
        }
    )
    @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
    def __init__(
        self,
        units: int,
        num_heads: int,
        dropout_rate: float = 0.0,
        activation: str = "elu",
        use_batch_norm: bool = False,
        **kwargs,
    ):
        """Initializes the TemporalAttentionLayer."""
        super().__init__(**kwargs)
        self.units = units
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.use_batch_norm = use_batch_norm
        self.activation_str = Activation(
            activation
        ).activation_str

        # --- Define Internal Layers ---
        self.multi_head_attention = MultiHeadAttention(
            num_heads=num_heads,
            key_dim=units,
            dropout=dropout_rate,
            name="mha",
        )
        self.dropout = Dropout(
            dropout_rate, name="attn_dropout"
        )
        self.layer_norm1 = LayerNormalization(
            name="layer_norm_1"
        )

        # GRN to process the input context_vector
        # Ensure this is a single instance, passing the activation string
        self.context_grn = GatedResidualNetwork(
            units=units,  # Output matches main path 'units'
            dropout_rate=dropout_rate,
            activation=self.activation_str,
            use_batch_norm=self.use_batch_norm,
            name="context_grn",
            # Note: GRN's internal activation handling should be fixed
        )

        # Final GRN (position-wise feedforward)
        # Ensure this is also a single instance
        self.output_grn = GatedResidualNetwork(
            units=units,
            dropout_rate=dropout_rate,
            activation=self.activation_str,
            use_batch_norm=self.use_batch_norm,
            name="output_grn",
        )

    def build(self, input_shape):
        """Builds internal layers, especially GRNs."""
        # input_shape corresponds to the main 'inputs' tensor (B, T, U)
        if not isinstance(input_shape, (list, tuple)):
            # If only main input shape is passed (common)
            main_input_shape = tuple(input_shape)
        elif len(input_shape) == 2:
            #  [inputs_shape, context_shape] rarelly happended
            main_input_shape = tuple(input_shape[0])
            # Optionally build context_grn if context_shape is known
            context_shape = tuple(input_shape[1])
            if not self.context_grn.built:
                self.context_grn.build(context_shape)
        else:
            raise ValueError(
                "Unexpected input_shape format for build."
            )

        if len(main_input_shape) < 3:
            raise ValueError(
                "TemporalAttentionLayer expects input rank >= 3"
            )

        # Define expected input shape for output_grn
        # It receives output from layer_norm1, which has same shape as input
        output_grn_input_shape = main_input_shape

        # Explicitly build the output GRN if not already built
        if not self.output_grn.built:
            self.output_grn.build(output_grn_input_shape)
            # Developer comment: Explicitly built output_grn.

        # Build context_grn lazily during call or here
        # Call the parent build method AFTER building sub-layers
        super().build(input_shape)
        # Developer comment: Layer built status should now be True.

    def call(
        self, inputs, context_vector=None, training=False
    ):
        """Forward pass of the temporal attention layer."""
        # Input shapes: inputs=(B, T, U), context_vector=(B, U_ctx)
        attention_source = inputs
        if (
            isinstance(inputs, (list, tuple))
            and len(inputs) == 2
        ):
            inputs, secondary = inputs
            if (
                context_vector is None
                and hasattr(secondary, "shape")
                and len(secondary.shape) == 2
            ):
                context_vector = secondary
            else:
                attention_source = secondary

        query = inputs  # Default query
        processed_context = None

        # --- Process Context Vector (if provided) ---
        if context_vector is not None:
            # Pass context_vector as the main input 'x' to context_grn
            processed_context = self.context_grn(
                x=context_vector,
                context=None,  # No nested context for the context_grn itself
                training=training,
            )
            # Output shape: (B, units)

            # Expand context across time: (B, units) -> (B, 1, units)
            context_expanded = expand_dims(
                processed_context, axis=1
            )
            # Add to inputs (broadcasting handles time dimension)
            query = add(inputs, context_expanded)
            # Comment: Query now incorporates static context.

        # --- Multi-Head Self-Attention ---
        attn_output = self.multi_head_attention(
            query=query,
            value=attention_source,
            key=attention_source,
            training=training,
        )  # Shape: (B, T, units)

        # --- Add & Norm (First Residual Connection) ---
        attn_output_dropout = self.dropout(
            attn_output, training=training
        )
        # Residual connection uses original 'inputs'
        x_attn = self.layer_norm1(
            add(inputs, attn_output_dropout)
        )
        # Shape: (B, T, units)

        # --- Position-wise Feedforward (Final GRN) ---
        # This GRN takes the output of the attention block as input 'x'
        # It does not receive the external 'context_vector' here.
        # --- DEBUG lines ---
        _logger.debug(
            "\nDEBUG>> About to call self.output_grn"
        )
        _logger.debug(
            f"DEBUG>> Type of self.output_grn: {type(self.output_grn)}"
        )
        _logger.debug(
            f"DEBUG>> Is self.output_grn callable: {callable(self.output_grn)}"
        )
        try:
            # Try accessing an attribute expected on a Keras layer
            _logger.debug(
                f"DEBUG>> self.output_grn name: {self.output_grn.name}"
            )
            _logger.debug(
                f"DEBUG>> self.output_grn built status: {self.output_grn.built}"
            )
        except AttributeError as ae:
            _logger.debug(
                f"DEBUG>> Failed to access attributes of self.output_grn: {ae}"
            )
        _logger.debug(
            f"DEBUG>> Input x_attn shape: {shape(x_attn)}\n"
        )

        # --- End DEBUG lines ---
        output = self.output_grn(
            x=x_attn,
            context=None,  # No external context for the final GRN
            training=training,
        )
        # Shape: (B, T, units)
        return output

    def get_config(self):
        """Returns the layer configuration."""
        config = super().get_config()
        config.update(
            {
                "units": self.units,
                "num_heads": self.num_heads,
                "dropout_rate": self.dropout_rate,
                "activation": self.activation_str,
                "use_batch_norm": self.use_batch_norm,
            }
        )
        return config

    @classmethod
    def from_config(cls, config):
        """Creates layer from its config."""
        return cls(**config)


@register_keras_serializable(
    SERIALIZATION_PACKAGE, name="CrossAttention_"
)
class CrossAttention_(Layer, NNLearner):
    r"""
    CrossAttention layer that attends one source
    sequence to another [1]_.

    This layer transforms two input sources,
    ``source1`` and ``source2``, into a shared
    dimensionality via separate dense layers,
    then applies multi-head attention using
    ``source1`` as the query and ``source2`` as
    both key and value. The output shape depends
    on the specified ``units``.

    .. math::
        \mathbf{H}_{\text{out}} = \text{MHA}(
            \mathbf{W}_{1}\,\mathbf{S}_1,\,
            \mathbf{W}_{2}\,\mathbf{S}_2,\,
            \mathbf{W}_{2}\,\mathbf{S}_2
        )

    where :math:`\mathbf{S}_1` and :math:`\mathbf{S}_2`
    are the two source sequences.

    Parameters
    ----------
    units : int
        Dimensionality for the internal projections
        of the query/key/value in multi-head attention.
    num_heads : int
        Number of attention heads.

    Notes
    -----
    Cross attention is particularly useful when
    focusing on how one sequence (the query) relates
    to another (the key/value). For example, in
    multi-modal time series settings, one might
    attend dynamic covariates to static ones or
    vice versa.

    Methods
    -------
    call(`inputs`, training=False)
        Forward pass of the cross-attention layer.
    get_config()
        Returns the configuration dictionary for
        serialization.
    from_config(`config`)
        Creates a new layer from the given config.

    Examples
    --------
    >>> from geoprior.nn.components import CrossAttention
    >>> import tensorflow as tf
    >>> # Two sequences of shape (batch_size, time_steps, features)
    >>> source1 = tf.random.normal((32, 10, 64))
    >>> source2 = tf.random.normal((32, 10, 64))
    >>> # Instantiate the CrossAttention layer
    >>> cross_attn = CrossAttention(units=64, num_heads=4)
    >>> # Forward pass
    >>> outputs = cross_attn([source1, source2])

    See Also
    --------
    HierarchicalAttention
        Another attention-based layer focusing on
        short/long-term sequences.
    MemoryAugmentedAttention
        Uses a learned memory matrix to enhance
        representations.

    References
    ----------
    .. [1] Vaswani, A., Shazeer, N., Parmar, N.,
           Uszkoreit, J., Jones, L., Gomez, A. N.,
           Kaiser, L., & Polosukhin, I. (2017).
           "Attention is all you need." In
           *Advances in Neural Information
           Processing Systems* (pp. 5998-6008).
    """

    @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
    def __init__(self, units: int, num_heads: int):
        r"""
        Initialize the CrossAttention layer.

        Parameters
        ----------
        units : int
            Number of output units for the
            internal Dense projections and
            multi-head attention dimension.
        num_heads : int
            Number of attention heads to use
            in the multi-head attention module.
        """
        super().__init__()
        self.units = units
        # Dense layers to project each source
        self.source1_dense = Dense(units)
        self.source2_dense = Dense(units)
        # Multi-head attention
        self.cross_attention = MultiHeadAttention(
            num_heads=num_heads, key_dim=units
        )

    @do_not_convert
    def call(self, inputs, training=False):
        r"""
        Forward pass of CrossAttention.

        Parameters
        ----------
        ``inputs`` : list of tf.Tensor
            A list [source1, source2], each of shape
            (batch_size, time_steps, features).
        training : bool, optional
            Indicates if the layer is in training
            mode (for dropout, if any).
            Defaults to ``False``.

        Returns
        -------
        tf.Tensor
            A tensor of shape (batch_size, time_steps,
            units) representing cross-attended features.
        """
        source1, source2 = inputs
        # Project each source
        source1 = self.source1_dense(source1)
        source2 = self.source2_dense(source2)
        # Apply cross attention
        return self.cross_attention(
            query=source1, value=source2, key=source2
        )

    def get_config(self):
        r"""
        Returns configuration dictionary for this
        layer.

        Returns
        -------
        dict
            Configuration dictionary, including
            'units'.
        """
        config = super().get_config().copy()
        config.update({"units": self.units})
        return config

    @classmethod
    def from_config(cls, config):
        r"""
        Create a new CrossAttention layer from
        the given config dictionary.

        Parameters
        ----------
        ``config`` : dict
            Configuration as returned by
            ``get_config``.

        Returns
        -------
        CrossAttention
            A new instance of CrossAttention.
        """
        return cls(**config)


[docs] @register_keras_serializable( SERIALIZATION_PACKAGE, name="CrossAttention" ) class CrossAttention(Layer, NNLearner): r""" CrossAttention that attends ``source1`` (query) to ``source2`` (key/value) with optional masks. attention_mask : Tensor, optional Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed directly to Keras ``MultiHeadAttention``. query_mask, value_mask : Tensor, optional 1D/2D masks (B, Tq) or (B, Tv). If provided and ``attention_mask`` is None, they are combined to form (B, Tq, Tv). use_causal_mask : bool Forwarded to MHA. Default False. """ @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): super().__init__() self.units = units self.source1_dense = Dense(units) self.source2_dense = Dense(units) self.cross_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units )
[docs] @do_not_convert def call( self, inputs, training: bool = False, *, attention_mask: TensorLike | None = None, query_mask: TensorLike | None = None, value_mask: TensorLike | None = None, use_causal_mask: bool = False, **kwargs, ): r""" Forward pass of CrossAttention. Parameters ---------- ``inputs`` : list of tf.Tensor A list [source1, source2], each of shape (batch_size, time_steps, features). training : bool, optional Indicates if the layer is in training mode (for dropout, if any). Defaults to ``False``. attention_mask : Tensor, optional Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed directly to Keras ``MultiHeadAttention``. query_mask, value_mask : Tensor, optional 1D/2D masks (B, Tq) or (B, Tv). If provided and ``attention_mask`` is None, they are combined to form (B, Tq, Tv). use_causal_mask : bool Forwarded to MHA. Default False. Returns ------- tf.Tensor A tensor of shape (batch_size, time_steps, units) representing cross-attended features. """ source1, source2 = ( inputs # shapes: (B, Tq, Fq), (B, Tv, Fv) ) # Project to common dim q = self.source1_dense(source1) kv = self.source2_dense(source2) # Build attention_mask if needed if attention_mask is None and ( query_mask is not None or value_mask is not None ): # default to all True if one side is None if query_mask is None: query_mask = ones_like( source1[..., 0], dtype=bool_dtype ) if value_mask is None: value_mask = ones_like( source2[..., 0], dtype=bool_dtype ) qm = expand_dims( cast(query_mask, bool_dtype), axis=-1 ) vm = expand_dims( cast(value_mask, bool_dtype), axis=1 ) # (B, Tq, 1) & (B, 1, Tv) -> (B, Tq, Tv) attention_mask = logical_and(qm, vm) try: return self.cross_attention( query=q, key=kv, value=kv, attention_mask=attention_mask, use_causal_mask=use_causal_mask, training=training, ) except TypeError: return self.cross_attention( q, kv, training=training )
[docs] def get_config(self): cfg = super().get_config().copy() cfg.update({"units": self.units}) return cfg
[docs] @classmethod def from_config(cls, config): return cls(**config)
[docs] @register_keras_serializable( SERIALIZATION_PACKAGE, name="MemoryAugmentedAttention" ) class MemoryAugmentedAttention(Layer, NNLearner): r"""Memory-augmented attention with optional masking.""" @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__( self, units: int, memory_size: int = 1, num_heads: int = 1, ): super().__init__() self.units = units self.memory_size = memory_size self.attention = MultiHeadAttention( num_heads=num_heads, key_dim=units )
[docs] def build(self, input_shape): self.memory = self.add_weight( name="memory", shape=(self.memory_size, self.units), initializer="zeros", trainable=True, ) super().build(input_shape)
[docs] @do_not_convert def call( self, inputs, training: bool = False, *, attention_mask: TensorLike | None = None, query_mask: TensorLike | None = None, value_mask: TensorLike | None = None, use_causal_mask: bool = False, **kwargs, ): # inputs: (B, T, U) batch_size = shape(inputs)[0] mem = expand_dims(self.memory, 0) # (1, M, U) mem = tile(mem, [batch_size, 1, 1]) # (B, M, U) # Build attention_mask if only per-sequence masks given if attention_mask is None and ( query_mask is not None or value_mask is not None ): if query_mask is None: query_mask = ones_like( inputs[..., 0], dtype=bool_dtype ) if value_mask is None: value_mask = ones( (batch_size, self.memory_size), dtype=bool_dtype, ) qm = expand_dims( cast(query_mask, bool_dtype), -1 ) # (B,T,1) vm = expand_dims( cast(value_mask, bool_dtype), 1 ) # (B,1,M) attention_mask = logical_and(qm, vm) # (B,T,M) try: mem_att = self.attention( query=inputs, key=mem, value=mem, attention_mask=attention_mask, use_causal_mask=use_causal_mask, training=training, ) except TypeError: mem_att = self.attention( inputs, mem, training=training ) return mem_att + inputs
[docs] def get_config(self): cfg = super().get_config().copy() cfg.update( { "units": self.units, "memory_size": self.memory_size, } ) return cfg
[docs] @classmethod def from_config(cls, config): return cls(**config)
@register_keras_serializable( SERIALIZATION_PACKAGE, name="HierarchicalAttention_" ) class HierarchicalAttention_(Layer, NNLearner): r""" Hierarchical Attention layer that processes short-term and long-term sequences separately using multi-head attention, then combines their outputs [1]_. This allows the model to focus on different aspects of the data in short-term and long-term contexts and aggregate the attention outputs for a more comprehensive representation. .. math:: \mathbf{Z} = \text{MHA}(\mathbf{X}_{s}) + \text{MHA}(\mathbf{X}_{l}) where :math:`\mathbf{X}_{s}` and :math:`\mathbf{X}_{l}` are the short- and long-term sequences, respectively. Parameters ---------- units : int Dimensionality of the projection for the attention keys, queries, and values. num_heads : int Number of attention heads to use in each multi-head attention sub-layer. Notes ----- The output shape depends on the last dimension in the short and long sequences, projected to `units`. The final output is the sum of the short-term attention output and the long-term attention output. Methods ------- call(`inputs`, training=False) Forward pass. Expects a list `[short_term, long_term]` with shapes (B, T, D_s) and (B, T, D_l). get_config() Returns configuration dictionary for serialization. from_config(`config`) Recreates the layer from a config dict. Examples -------- >>> from geoprior.nn.components import ( ... HierarchicalAttention, ... ) >>> import tensorflow as tf >>> # Suppose short_term and long_term have ... # shape (batch_size, time_steps, features). >>> short_term = tf.random.normal((32, 10, 64)) >>> long_term = tf.random.normal((32, 10, 64)) >>> # Instantiate hierarchical attention >>> ha = HierarchicalAttention(units=64, num_heads=4) >>> # Forward pass >>> outputs = ha([short_term, long_term]) See Also -------- MultiModalEmbedding Can precede attention by embedding multiple sources of input. LearnedNormalization Can be applied to short_term and long_term sequences prior to attention. References ---------- .. [1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). "Attention is all you need." In *Advances in Neural Information Processing Systems* (pp. 5998-6008). """ @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): super().__init__() self.units = units # Dense layers for short/long sequences self.short_term_dense = Dense(units) self.long_term_dense = Dense(units) # Multi-head attention for short/long self.short_term_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units ) self.long_term_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units ) @do_not_convert def call(self, inputs, training=False): r""" Forward pass of the HierarchicalAttention. Parameters ---------- ``inputs`` : list of tf.Tensor A list `[short_term, long_term]`. Each tensor should have shape :math:`(B, T, D)`. training : bool, optional Indicates whether the layer is in training mode. Defaults to ``False``. Returns ------- tf.Tensor A tensor of shape :math:`(B, T, U)`, where `U = units`, representing the combined attention outputs. """ short_term, long_term = inputs # Linear projections to unify # dimensionality short_term = self.short_term_dense(short_term) long_term = self.long_term_dense(long_term) # Multi-head attention on short_term short_term_attention = self.short_term_attention( short_term, short_term ) # Multi-head attention on long_term long_term_attention = self.long_term_attention( long_term, long_term ) # Combine return short_term_attention + long_term_attention def get_config(self): r""" Returns a dictionary of config parameters for serialization. Returns ------- dict Dictionary with 'units', 'short_term_dense' config, and 'long_term_dense' config. """ config = super().get_config().copy() config.update( { "units": self.units, "short_term_dense": self.short_term_dense.get_config(), "long_term_dense": self.long_term_dense.get_config(), } ) return config @classmethod def from_config(cls, config): r""" Recreates the HierarchicalAttention layer from a config dictionary. Parameters ---------- ``config`` : dict Configuration dictionary. Returns ------- HierarchicalAttention A new instance with the specified configuration. """ return cls(**config)
[docs] @register_keras_serializable( SERIALIZATION_PACKAGE, name="HierarchicalAttention" ) class HierarchicalAttention(Layer, NNLearner): r"""Short/long-term MHA with optional masks.""" @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): super().__init__() self.units = units self.short_term_dense = Dense(units) self.long_term_dense = Dense(units) self.short_term_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units ) self.long_term_attention = MultiHeadAttention( num_heads=num_heads, key_dim=units )
[docs] @do_not_convert def call( self, inputs, training: bool = False, *, short_mask: TensorLike | None = None, long_mask: TensorLike | None = None, use_causal_mask: bool = False, **kwargs, ): # inputs: [short_term, long_term] if ( isinstance(inputs, (list, tuple)) and len(inputs) == 2 ): short_term, long_term = inputs else: short_term = long_term = inputs s = self.short_term_dense(short_term) long_proj = self.long_term_dense(long_term) # Build masks to (B, T, T) if provided as (B,T) def _expand_mask(m): if m is None: return None m = cast(m, bool_dtype) qm = expand_dims(m, 1) # (B,1,T) vm = expand_dims(m, 2) # (B,T,1) return logical_and(vm, qm) # (B,T,T) s_mask = _expand_mask(short_mask) l_mask = _expand_mask(long_mask) s_att = self.short_term_attention( query=s, key=s, value=s, attention_mask=s_mask, use_causal_mask=use_causal_mask, training=training, ) l_att = self.long_term_attention( query=long_proj, key=long_proj, value=long_proj, attention_mask=l_mask, use_causal_mask=use_causal_mask, training=training, ) return s_att + l_att
[docs] def get_config(self): cfg = super().get_config().copy() cfg.update( { "units": self.units, "short_term_dense": self.short_term_dense.get_config(), "long_term_dense": self.long_term_dense.get_config(), } ) return cfg
[docs] @classmethod def from_config(cls, config): return cls(**config)
@register_keras_serializable( SERIALIZATION_PACKAGE, name="ExplainableAttention" ) class ExplainableAttention(Layer, NNLearner): r""" ExplainableAttention layer that returns attention scores from multi-head attention [1]_. This layer is useful for interpretability, providing insight into how the attention mechanism focuses on different time steps. .. math:: \mathbf{A} = \text{MHA}(\mathbf{X},\,\mathbf{X}) \rightarrow \text{attention\_scores} Here, :math:`\mathbf{X}` is an input tensor, and ``attention_scores`` is the matrix capturing attention weights. Parameters ---------- num_heads : int Number of heads for multi-head attention. key_dim : int Dimensionality of the query/key projections. Notes ----- Unlike standard layers that return the transformation output, this layer specifically returns the attention score matrix for interpretability. Methods ------- call(`inputs`, training=False) Forward pass that outputs only the attention scores. get_config() Returns the configuration for serialization. from_config(`config`) Creates a new instance from the given config. Examples -------- >>> from geoprior.nn.components import ( ... ExplainableAttention, ... ) >>> import tensorflow as tf >>> # Suppose we have input of shape (batch_size, time_steps, features) >>> x = tf.random.normal((32, 10, 64)) >>> # Instantiate explainable attention >>> ea = ExplainableAttention(num_heads=4, key_dim=64) >>> # Forward pass returns attention scores: (B, num_heads, T, T) >>> scores = ea(x) See Also -------- CrossAttention Another attention variant for cross-sequence contexts. MultiResolutionAttentionFusion For fusing features via multi-head attention. References ---------- .. [1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). "Attention is all you need." In *Advances in Neural Information Processing Systems* (pp. 5998-6008). """ @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__( self, num_heads: int, key_dim: int | None = None, *, units: int | None = None, ): r""" Initialize the ExplainableAttention layer. Parameters ---------- num_heads : int Number of attention heads. key_dim : int Dimensionality of query/key projections in multi-head attention. """ super().__init__() self.num_heads = num_heads self.key_dim = ( key_dim if key_dim is not None else units ) if self.key_dim is None: raise ValueError( "Provide `key_dim` or `units` for ExplainableAttention." ) # MultiHeadAttention, focusing on returning # the attention scores self.attention = MultiHeadAttention( num_heads=num_heads, key_dim=self.key_dim ) @do_not_convert def call(self, inputs, training=False): r""" Forward pass that returns only the attention scores. Parameters ---------- ``inputs`` : tf.Tensor Tensor of shape (B, T, D). training : bool, optional Indicates training mode; not used in this layer. Defaults to ``False``. Returns ------- tf.Tensor Attention scores of shape (B, num_heads, T, T). """ if ( isinstance(inputs, (list, tuple)) and len(inputs) == 2 ): query, value = inputs else: query = value = inputs _, attention_scores = self.attention( query, value, return_attention_scores=True ) return attention_scores def get_config(self): r""" Returns the layer configuration. Returns ------- dict Dictionary containing 'num_heads' and 'key_dim'. """ config = super().get_config().copy() config.update( { "num_heads": self.num_heads, "key_dim": self.key_dim, } ) return config @classmethod def from_config(cls, config): r""" Creates a new instance from the config dictionary. Parameters ---------- ``config`` : dict Configuration dictionary. Returns ------- ExplainableAttention A new instance of this layer. """ return cls(**config) @register_keras_serializable( SERIALIZATION_PACKAGE, name="MultiResolutionAttentionFusion", ) class MultiResolutionAttentionFusion(Layer, NNLearner): r""" MultiResolutionAttentionFusion layer applying multi-head attention fusion over features [1]_. This layer merges or fuses features at different resolutions or sources via multi-head attention. The input is projected to shape `(B, T, D)`, and the output shares the same shape. .. math:: \mathbf{Z} = \text{MHA}(\mathbf{X}, \mathbf{X}) Parameters ---------- units : int Dimension of the key, query, and value projections. num_heads : int Number of attention heads. Notes ----- Typically used in multi-resolution contexts where time steps or multiple feature sets are merged. Methods ------- call(`inputs`, training=False) Forward pass of the multi-head attention layer. get_config() Returns config for serialization. from_config(`config`) Reconstructs the layer from a config. Examples -------- >>> from geoprior.nn.components import ( ... MultiResolutionAttentionFusion, ... ) >>> import tensorflow as tf >>> x = tf.random.normal((32, 10, 64)) >>> # Instantiate multi-resolution attention >>> mraf = MultiResolutionAttentionFusion( ... units=64, num_heads=4 ... ) >>> # Forward pass => (32, 10, 64) >>> y = mraf(x) See Also -------- HierarchicalAttention Combines short and long-term sequences with attention. ExplainableAttention Another attention layer returning attention scores. References ---------- .. [1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). "Attention is all you need." In *Advances in Neural Information Processing Systems* (pp. 5998-6008). """ @ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG) def __init__(self, units: int, num_heads: int): r""" Initialize the MultiResolutionAttentionFusion layer. Parameters ---------- units : int Dimensionality for the attention projections. num_heads : int Number of heads for multi-head attention. """ super().__init__() self.units = units self.num_heads = num_heads # MultiHeadAttention instance self.attention = MultiHeadAttention( num_heads=num_heads, key_dim=units ) @do_not_convert def call(self, inputs, training=False): r""" Forward pass applying multi-head attention to fuse features. Parameters ---------- ``inputs`` : tf.Tensor Tensor of shape (B, T, D). training : bool, optional Indicates training mode. Defaults to ``False``. Returns ------- tf.Tensor Tensor of shape (B, T, D), representing fused features. """ if ( isinstance(inputs, (list, tuple)) and len(inputs) == 2 ): query, value = inputs return self.attention(query, value) return self.attention(inputs, inputs) def get_config(self): r""" Returns configuration dictionary with 'units' and 'num_heads'. Returns ------- dict Configuration for serialization. """ config = super().get_config().copy() config.update( {"units": self.units, "num_heads": self.num_heads} ) return config @classmethod def from_config(cls, config): r""" Instantiate a new MultiResolutionAttentionFusion layer from config. Parameters ---------- ``config`` : dict Configuration dictionary. Returns ------- MultiResolutionAttentionFusion A new instance of this layer. """ return cls(**config)