# SPDX-License-Identifier: Apache-2.0
# Author: LKouadio <etanoyau@gmail.com>
# Adapted from: earthai-tech/fusionlab-learn — https://github.com/earthai-tech/gofast
# Modified for GeoPrior-v3 API.
"""
Attention-centric layers for FusionLab.
"""
from __future__ import annotations
from numbers import Integral, Real
from ..api.property import NNLearner
from ..compat.sklearn import Interval, validate_params
from ..compat.types import TensorLike
from ..utils.deps_utils import ensure_pkg
from ._config import (
DEP_MSG,
KERAS_BACKEND,
Dense,
Dropout,
Layer,
LayerNormalization,
MultiHeadAttention,
_logger,
add,
bool_dtype,
cast,
do_not_convert,
expand_dims,
logical_and,
ones,
ones_like,
register_keras_serializable,
shape,
tile,
)
from .gating_norm import GatedResidualNetwork
from .misc import Activation
__all__ = [
"TemporalAttentionLayer",
"CrossAttention",
"MemoryAugmentedAttention",
"HierarchicalAttention",
"ExplainableAttention",
"MultiResolutionAttentionFusion",
]
SERIALIZATION_PACKAGE = __name__
@register_keras_serializable(
SERIALIZATION_PACKAGE, name="TemporalAttentionLayer"
)
class TemporalAttentionLayer(Layer):
"""Temporal Attention Layer conditioning query with context."""
@validate_params(
{
"units": [
Interval(Integral, 0, None, closed="left")
],
"num_heads": [
Interval(Integral, 0, None, closed="left")
],
"dropout_rate": [
Interval(Real, 0, 1, closed="both")
],
"use_batch_norm": [bool],
}
)
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(
self,
units: int,
num_heads: int,
dropout_rate: float = 0.0,
activation: str = "elu",
use_batch_norm: bool = False,
**kwargs,
):
"""Initializes the TemporalAttentionLayer."""
super().__init__(**kwargs)
self.units = units
self.num_heads = num_heads
self.dropout_rate = dropout_rate
self.use_batch_norm = use_batch_norm
self.activation_str = Activation(
activation
).activation_str
# --- Define Internal Layers ---
self.multi_head_attention = MultiHeadAttention(
num_heads=num_heads,
key_dim=units,
dropout=dropout_rate,
name="mha",
)
self.dropout = Dropout(
dropout_rate, name="attn_dropout"
)
self.layer_norm1 = LayerNormalization(
name="layer_norm_1"
)
# GRN to process the input context_vector
# Ensure this is a single instance, passing the activation string
self.context_grn = GatedResidualNetwork(
units=units, # Output matches main path 'units'
dropout_rate=dropout_rate,
activation=self.activation_str,
use_batch_norm=self.use_batch_norm,
name="context_grn",
# Note: GRN's internal activation handling should be fixed
)
# Final GRN (position-wise feedforward)
# Ensure this is also a single instance
self.output_grn = GatedResidualNetwork(
units=units,
dropout_rate=dropout_rate,
activation=self.activation_str,
use_batch_norm=self.use_batch_norm,
name="output_grn",
)
def build(self, input_shape):
"""Builds internal layers, especially GRNs."""
# input_shape corresponds to the main 'inputs' tensor (B, T, U)
if not isinstance(input_shape, (list, tuple)):
# If only main input shape is passed (common)
main_input_shape = tuple(input_shape)
elif len(input_shape) == 2:
# [inputs_shape, context_shape] rarelly happended
main_input_shape = tuple(input_shape[0])
# Optionally build context_grn if context_shape is known
context_shape = tuple(input_shape[1])
if not self.context_grn.built:
self.context_grn.build(context_shape)
else:
raise ValueError(
"Unexpected input_shape format for build."
)
if len(main_input_shape) < 3:
raise ValueError(
"TemporalAttentionLayer expects input rank >= 3"
)
# Define expected input shape for output_grn
# It receives output from layer_norm1, which has same shape as input
output_grn_input_shape = main_input_shape
# Explicitly build the output GRN if not already built
if not self.output_grn.built:
self.output_grn.build(output_grn_input_shape)
# Developer comment: Explicitly built output_grn.
# Build context_grn lazily during call or here
# Call the parent build method AFTER building sub-layers
super().build(input_shape)
# Developer comment: Layer built status should now be True.
def call(
self, inputs, context_vector=None, training=False
):
"""Forward pass of the temporal attention layer."""
# Input shapes: inputs=(B, T, U), context_vector=(B, U_ctx)
attention_source = inputs
if (
isinstance(inputs, (list, tuple))
and len(inputs) == 2
):
inputs, secondary = inputs
if (
context_vector is None
and hasattr(secondary, "shape")
and len(secondary.shape) == 2
):
context_vector = secondary
else:
attention_source = secondary
query = inputs # Default query
processed_context = None
# --- Process Context Vector (if provided) ---
if context_vector is not None:
# Pass context_vector as the main input 'x' to context_grn
processed_context = self.context_grn(
x=context_vector,
context=None, # No nested context for the context_grn itself
training=training,
)
# Output shape: (B, units)
# Expand context across time: (B, units) -> (B, 1, units)
context_expanded = expand_dims(
processed_context, axis=1
)
# Add to inputs (broadcasting handles time dimension)
query = add(inputs, context_expanded)
# Comment: Query now incorporates static context.
# --- Multi-Head Self-Attention ---
attn_output = self.multi_head_attention(
query=query,
value=attention_source,
key=attention_source,
training=training,
) # Shape: (B, T, units)
# --- Add & Norm (First Residual Connection) ---
attn_output_dropout = self.dropout(
attn_output, training=training
)
# Residual connection uses original 'inputs'
x_attn = self.layer_norm1(
add(inputs, attn_output_dropout)
)
# Shape: (B, T, units)
# --- Position-wise Feedforward (Final GRN) ---
# This GRN takes the output of the attention block as input 'x'
# It does not receive the external 'context_vector' here.
# --- DEBUG lines ---
_logger.debug(
"\nDEBUG>> About to call self.output_grn"
)
_logger.debug(
f"DEBUG>> Type of self.output_grn: {type(self.output_grn)}"
)
_logger.debug(
f"DEBUG>> Is self.output_grn callable: {callable(self.output_grn)}"
)
try:
# Try accessing an attribute expected on a Keras layer
_logger.debug(
f"DEBUG>> self.output_grn name: {self.output_grn.name}"
)
_logger.debug(
f"DEBUG>> self.output_grn built status: {self.output_grn.built}"
)
except AttributeError as ae:
_logger.debug(
f"DEBUG>> Failed to access attributes of self.output_grn: {ae}"
)
_logger.debug(
f"DEBUG>> Input x_attn shape: {shape(x_attn)}\n"
)
# --- End DEBUG lines ---
output = self.output_grn(
x=x_attn,
context=None, # No external context for the final GRN
training=training,
)
# Shape: (B, T, units)
return output
def get_config(self):
"""Returns the layer configuration."""
config = super().get_config()
config.update(
{
"units": self.units,
"num_heads": self.num_heads,
"dropout_rate": self.dropout_rate,
"activation": self.activation_str,
"use_batch_norm": self.use_batch_norm,
}
)
return config
@classmethod
def from_config(cls, config):
"""Creates layer from its config."""
return cls(**config)
@register_keras_serializable(
SERIALIZATION_PACKAGE, name="CrossAttention_"
)
class CrossAttention_(Layer, NNLearner):
r"""
CrossAttention layer that attends one source
sequence to another [1]_.
This layer transforms two input sources,
``source1`` and ``source2``, into a shared
dimensionality via separate dense layers,
then applies multi-head attention using
``source1`` as the query and ``source2`` as
both key and value. The output shape depends
on the specified ``units``.
.. math::
\mathbf{H}_{\text{out}} = \text{MHA}(
\mathbf{W}_{1}\,\mathbf{S}_1,\,
\mathbf{W}_{2}\,\mathbf{S}_2,\,
\mathbf{W}_{2}\,\mathbf{S}_2
)
where :math:`\mathbf{S}_1` and :math:`\mathbf{S}_2`
are the two source sequences.
Parameters
----------
units : int
Dimensionality for the internal projections
of the query/key/value in multi-head attention.
num_heads : int
Number of attention heads.
Notes
-----
Cross attention is particularly useful when
focusing on how one sequence (the query) relates
to another (the key/value). For example, in
multi-modal time series settings, one might
attend dynamic covariates to static ones or
vice versa.
Methods
-------
call(`inputs`, training=False)
Forward pass of the cross-attention layer.
get_config()
Returns the configuration dictionary for
serialization.
from_config(`config`)
Creates a new layer from the given config.
Examples
--------
>>> from geoprior.nn.components import CrossAttention
>>> import tensorflow as tf
>>> # Two sequences of shape (batch_size, time_steps, features)
>>> source1 = tf.random.normal((32, 10, 64))
>>> source2 = tf.random.normal((32, 10, 64))
>>> # Instantiate the CrossAttention layer
>>> cross_attn = CrossAttention(units=64, num_heads=4)
>>> # Forward pass
>>> outputs = cross_attn([source1, source2])
See Also
--------
HierarchicalAttention
Another attention-based layer focusing on
short/long-term sequences.
MemoryAugmentedAttention
Uses a learned memory matrix to enhance
representations.
References
----------
.. [1] Vaswani, A., Shazeer, N., Parmar, N.,
Uszkoreit, J., Jones, L., Gomez, A. N.,
Kaiser, L., & Polosukhin, I. (2017).
"Attention is all you need." In
*Advances in Neural Information
Processing Systems* (pp. 5998-6008).
"""
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(self, units: int, num_heads: int):
r"""
Initialize the CrossAttention layer.
Parameters
----------
units : int
Number of output units for the
internal Dense projections and
multi-head attention dimension.
num_heads : int
Number of attention heads to use
in the multi-head attention module.
"""
super().__init__()
self.units = units
# Dense layers to project each source
self.source1_dense = Dense(units)
self.source2_dense = Dense(units)
# Multi-head attention
self.cross_attention = MultiHeadAttention(
num_heads=num_heads, key_dim=units
)
@do_not_convert
def call(self, inputs, training=False):
r"""
Forward pass of CrossAttention.
Parameters
----------
``inputs`` : list of tf.Tensor
A list [source1, source2], each of shape
(batch_size, time_steps, features).
training : bool, optional
Indicates if the layer is in training
mode (for dropout, if any).
Defaults to ``False``.
Returns
-------
tf.Tensor
A tensor of shape (batch_size, time_steps,
units) representing cross-attended features.
"""
source1, source2 = inputs
# Project each source
source1 = self.source1_dense(source1)
source2 = self.source2_dense(source2)
# Apply cross attention
return self.cross_attention(
query=source1, value=source2, key=source2
)
def get_config(self):
r"""
Returns configuration dictionary for this
layer.
Returns
-------
dict
Configuration dictionary, including
'units'.
"""
config = super().get_config().copy()
config.update({"units": self.units})
return config
@classmethod
def from_config(cls, config):
r"""
Create a new CrossAttention layer from
the given config dictionary.
Parameters
----------
``config`` : dict
Configuration as returned by
``get_config``.
Returns
-------
CrossAttention
A new instance of CrossAttention.
"""
return cls(**config)
[docs]
@register_keras_serializable(
SERIALIZATION_PACKAGE, name="CrossAttention"
)
class CrossAttention(Layer, NNLearner):
r"""
CrossAttention that attends ``source1`` (query) to ``source2``
(key/value) with optional masks.
attention_mask : Tensor, optional
Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed
directly to Keras ``MultiHeadAttention``.
query_mask, value_mask : Tensor, optional
1D/2D masks (B, Tq) or (B, Tv). If provided and
``attention_mask`` is None, they are combined to form
(B, Tq, Tv).
use_causal_mask : bool
Forwarded to MHA. Default False.
"""
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(self, units: int, num_heads: int):
super().__init__()
self.units = units
self.source1_dense = Dense(units)
self.source2_dense = Dense(units)
self.cross_attention = MultiHeadAttention(
num_heads=num_heads, key_dim=units
)
[docs]
@do_not_convert
def call(
self,
inputs,
training: bool = False,
*,
attention_mask: TensorLike | None = None,
query_mask: TensorLike | None = None,
value_mask: TensorLike | None = None,
use_causal_mask: bool = False,
**kwargs,
):
r"""
Forward pass of CrossAttention.
Parameters
----------
``inputs`` : list of tf.Tensor
A list [source1, source2], each of shape
(batch_size, time_steps, features).
training : bool, optional
Indicates if the layer is in training
mode (for dropout, if any).
Defaults to ``False``.
attention_mask : Tensor, optional
Bool / 0‑1 mask broadcastable to (B, Tq, Tv). Passed
directly to Keras ``MultiHeadAttention``.
query_mask, value_mask : Tensor, optional
1D/2D masks (B, Tq) or (B, Tv). If provided and
``attention_mask`` is None, they are combined to form
(B, Tq, Tv).
use_causal_mask : bool
Forwarded to MHA. Default False.
Returns
-------
tf.Tensor
A tensor of shape (batch_size, time_steps,
units) representing cross-attended features.
"""
source1, source2 = (
inputs # shapes: (B, Tq, Fq), (B, Tv, Fv)
)
# Project to common dim
q = self.source1_dense(source1)
kv = self.source2_dense(source2)
# Build attention_mask if needed
if attention_mask is None and (
query_mask is not None or value_mask is not None
):
# default to all True if one side is None
if query_mask is None:
query_mask = ones_like(
source1[..., 0], dtype=bool_dtype
)
if value_mask is None:
value_mask = ones_like(
source2[..., 0], dtype=bool_dtype
)
qm = expand_dims(
cast(query_mask, bool_dtype), axis=-1
)
vm = expand_dims(
cast(value_mask, bool_dtype), axis=1
)
# (B, Tq, 1) & (B, 1, Tv) -> (B, Tq, Tv)
attention_mask = logical_and(qm, vm)
try:
return self.cross_attention(
query=q,
key=kv,
value=kv,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
training=training,
)
except TypeError:
return self.cross_attention(
q, kv, training=training
)
[docs]
def get_config(self):
cfg = super().get_config().copy()
cfg.update({"units": self.units})
return cfg
[docs]
@classmethod
def from_config(cls, config):
return cls(**config)
[docs]
@register_keras_serializable(
SERIALIZATION_PACKAGE, name="MemoryAugmentedAttention"
)
class MemoryAugmentedAttention(Layer, NNLearner):
r"""Memory-augmented attention with optional masking."""
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(
self,
units: int,
memory_size: int = 1,
num_heads: int = 1,
):
super().__init__()
self.units = units
self.memory_size = memory_size
self.attention = MultiHeadAttention(
num_heads=num_heads, key_dim=units
)
[docs]
def build(self, input_shape):
self.memory = self.add_weight(
name="memory",
shape=(self.memory_size, self.units),
initializer="zeros",
trainable=True,
)
super().build(input_shape)
[docs]
@do_not_convert
def call(
self,
inputs,
training: bool = False,
*,
attention_mask: TensorLike | None = None,
query_mask: TensorLike | None = None,
value_mask: TensorLike | None = None,
use_causal_mask: bool = False,
**kwargs,
):
# inputs: (B, T, U)
batch_size = shape(inputs)[0]
mem = expand_dims(self.memory, 0) # (1, M, U)
mem = tile(mem, [batch_size, 1, 1]) # (B, M, U)
# Build attention_mask if only per-sequence masks given
if attention_mask is None and (
query_mask is not None or value_mask is not None
):
if query_mask is None:
query_mask = ones_like(
inputs[..., 0], dtype=bool_dtype
)
if value_mask is None:
value_mask = ones(
(batch_size, self.memory_size),
dtype=bool_dtype,
)
qm = expand_dims(
cast(query_mask, bool_dtype), -1
) # (B,T,1)
vm = expand_dims(
cast(value_mask, bool_dtype), 1
) # (B,1,M)
attention_mask = logical_and(qm, vm) # (B,T,M)
try:
mem_att = self.attention(
query=inputs,
key=mem,
value=mem,
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
training=training,
)
except TypeError:
mem_att = self.attention(
inputs, mem, training=training
)
return mem_att + inputs
[docs]
def get_config(self):
cfg = super().get_config().copy()
cfg.update(
{
"units": self.units,
"memory_size": self.memory_size,
}
)
return cfg
[docs]
@classmethod
def from_config(cls, config):
return cls(**config)
@register_keras_serializable(
SERIALIZATION_PACKAGE, name="HierarchicalAttention_"
)
class HierarchicalAttention_(Layer, NNLearner):
r"""
Hierarchical Attention layer that processes
short-term and long-term sequences separately
using multi-head attention, then combines
their outputs [1]_.
This allows the model to focus on different
aspects of the data in short-term and long-term
contexts and aggregate the attention outputs
for a more comprehensive representation.
.. math::
\mathbf{Z} = \text{MHA}(\mathbf{X}_{s})
+ \text{MHA}(\mathbf{X}_{l})
where :math:`\mathbf{X}_{s}` and
:math:`\mathbf{X}_{l}` are the short- and
long-term sequences, respectively.
Parameters
----------
units : int
Dimensionality of the projection for the
attention keys, queries, and values.
num_heads : int
Number of attention heads to use in each
multi-head attention sub-layer.
Notes
-----
The output shape depends on the last
dimension in the short and long sequences,
projected to `units`. The final output is
the sum of the short-term attention output
and the long-term attention output.
Methods
-------
call(`inputs`, training=False)
Forward pass. Expects a list `[short_term,
long_term]` with shapes
(B, T, D_s) and (B, T, D_l).
get_config()
Returns configuration dictionary for
serialization.
from_config(`config`)
Recreates the layer from a config dict.
Examples
--------
>>> from geoprior.nn.components import (
... HierarchicalAttention,
... )
>>> import tensorflow as tf
>>> # Suppose short_term and long_term have
... # shape (batch_size, time_steps, features).
>>> short_term = tf.random.normal((32, 10, 64))
>>> long_term = tf.random.normal((32, 10, 64))
>>> # Instantiate hierarchical attention
>>> ha = HierarchicalAttention(units=64, num_heads=4)
>>> # Forward pass
>>> outputs = ha([short_term, long_term])
See Also
--------
MultiModalEmbedding
Can precede attention by embedding
multiple sources of input.
LearnedNormalization
Can be applied to short_term and
long_term sequences prior to attention.
References
----------
.. [1] Vaswani, A., Shazeer, N., Parmar, N.,
Uszkoreit, J., Jones, L., Gomez, A. N.,
Kaiser, L., & Polosukhin, I. (2017).
"Attention is all you need."
In *Advances in Neural Information
Processing Systems* (pp. 5998-6008).
"""
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(self, units: int, num_heads: int):
super().__init__()
self.units = units
# Dense layers for short/long sequences
self.short_term_dense = Dense(units)
self.long_term_dense = Dense(units)
# Multi-head attention for short/long
self.short_term_attention = MultiHeadAttention(
num_heads=num_heads, key_dim=units
)
self.long_term_attention = MultiHeadAttention(
num_heads=num_heads, key_dim=units
)
@do_not_convert
def call(self, inputs, training=False):
r"""
Forward pass of the HierarchicalAttention.
Parameters
----------
``inputs`` : list of tf.Tensor
A list `[short_term, long_term]`.
Each tensor should have shape
:math:`(B, T, D)`.
training : bool, optional
Indicates whether the layer is
in training mode. Defaults to
``False``.
Returns
-------
tf.Tensor
A tensor of shape :math:`(B, T, U)`,
where `U = units`, representing the
combined attention outputs.
"""
short_term, long_term = inputs
# Linear projections to unify
# dimensionality
short_term = self.short_term_dense(short_term)
long_term = self.long_term_dense(long_term)
# Multi-head attention on short_term
short_term_attention = self.short_term_attention(
short_term, short_term
)
# Multi-head attention on long_term
long_term_attention = self.long_term_attention(
long_term, long_term
)
# Combine
return short_term_attention + long_term_attention
def get_config(self):
r"""
Returns a dictionary of config
parameters for serialization.
Returns
-------
dict
Dictionary with 'units',
'short_term_dense' config,
and 'long_term_dense' config.
"""
config = super().get_config().copy()
config.update(
{
"units": self.units,
"short_term_dense": self.short_term_dense.get_config(),
"long_term_dense": self.long_term_dense.get_config(),
}
)
return config
@classmethod
def from_config(cls, config):
r"""
Recreates the HierarchicalAttention
layer from a config dictionary.
Parameters
----------
``config`` : dict
Configuration dictionary.
Returns
-------
HierarchicalAttention
A new instance with the
specified configuration.
"""
return cls(**config)
[docs]
@register_keras_serializable(
SERIALIZATION_PACKAGE, name="HierarchicalAttention"
)
class HierarchicalAttention(Layer, NNLearner):
r"""Short/long-term MHA with optional masks."""
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(self, units: int, num_heads: int):
super().__init__()
self.units = units
self.short_term_dense = Dense(units)
self.long_term_dense = Dense(units)
self.short_term_attention = MultiHeadAttention(
num_heads=num_heads, key_dim=units
)
self.long_term_attention = MultiHeadAttention(
num_heads=num_heads, key_dim=units
)
[docs]
@do_not_convert
def call(
self,
inputs,
training: bool = False,
*,
short_mask: TensorLike | None = None,
long_mask: TensorLike | None = None,
use_causal_mask: bool = False,
**kwargs,
):
# inputs: [short_term, long_term]
if (
isinstance(inputs, (list, tuple))
and len(inputs) == 2
):
short_term, long_term = inputs
else:
short_term = long_term = inputs
s = self.short_term_dense(short_term)
long_proj = self.long_term_dense(long_term)
# Build masks to (B, T, T) if provided as (B,T)
def _expand_mask(m):
if m is None:
return None
m = cast(m, bool_dtype)
qm = expand_dims(m, 1) # (B,1,T)
vm = expand_dims(m, 2) # (B,T,1)
return logical_and(vm, qm) # (B,T,T)
s_mask = _expand_mask(short_mask)
l_mask = _expand_mask(long_mask)
s_att = self.short_term_attention(
query=s,
key=s,
value=s,
attention_mask=s_mask,
use_causal_mask=use_causal_mask,
training=training,
)
l_att = self.long_term_attention(
query=long_proj,
key=long_proj,
value=long_proj,
attention_mask=l_mask,
use_causal_mask=use_causal_mask,
training=training,
)
return s_att + l_att
[docs]
def get_config(self):
cfg = super().get_config().copy()
cfg.update(
{
"units": self.units,
"short_term_dense": self.short_term_dense.get_config(),
"long_term_dense": self.long_term_dense.get_config(),
}
)
return cfg
[docs]
@classmethod
def from_config(cls, config):
return cls(**config)
@register_keras_serializable(
SERIALIZATION_PACKAGE, name="ExplainableAttention"
)
class ExplainableAttention(Layer, NNLearner):
r"""
ExplainableAttention layer that returns attention
scores from multi-head attention [1]_.
This layer is useful for interpretability,
providing insight into how the attention
mechanism focuses on different time steps.
.. math::
\mathbf{A} = \text{MHA}(\mathbf{X},\,\mathbf{X})
\rightarrow \text{attention\_scores}
Here, :math:`\mathbf{X}` is an input tensor,
and ``attention_scores`` is the matrix
capturing attention weights.
Parameters
----------
num_heads : int
Number of heads for multi-head attention.
key_dim : int
Dimensionality of the query/key projections.
Notes
-----
Unlike standard layers that return the
transformation output, this layer specifically
returns the attention score matrix for
interpretability.
Methods
-------
call(`inputs`, training=False)
Forward pass that outputs only the
attention scores.
get_config()
Returns the configuration for serialization.
from_config(`config`)
Creates a new instance from the given config.
Examples
--------
>>> from geoprior.nn.components import (
... ExplainableAttention,
... )
>>> import tensorflow as tf
>>> # Suppose we have input of shape (batch_size, time_steps, features)
>>> x = tf.random.normal((32, 10, 64))
>>> # Instantiate explainable attention
>>> ea = ExplainableAttention(num_heads=4, key_dim=64)
>>> # Forward pass returns attention scores: (B, num_heads, T, T)
>>> scores = ea(x)
See Also
--------
CrossAttention
Another attention variant for cross-sequence
contexts.
MultiResolutionAttentionFusion
For fusing features via multi-head attention.
References
----------
.. [1] Vaswani, A., Shazeer, N., Parmar, N.,
Uszkoreit, J., Jones, L., Gomez, A. N.,
Kaiser, L., & Polosukhin, I. (2017).
"Attention is all you need." In
*Advances in Neural Information
Processing Systems* (pp. 5998-6008).
"""
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(
self,
num_heads: int,
key_dim: int | None = None,
*,
units: int | None = None,
):
r"""
Initialize the ExplainableAttention layer.
Parameters
----------
num_heads : int
Number of attention heads.
key_dim : int
Dimensionality of query/key projections
in multi-head attention.
"""
super().__init__()
self.num_heads = num_heads
self.key_dim = (
key_dim if key_dim is not None else units
)
if self.key_dim is None:
raise ValueError(
"Provide `key_dim` or `units` for ExplainableAttention."
)
# MultiHeadAttention, focusing on returning
# the attention scores
self.attention = MultiHeadAttention(
num_heads=num_heads, key_dim=self.key_dim
)
@do_not_convert
def call(self, inputs, training=False):
r"""
Forward pass that returns only the
attention scores.
Parameters
----------
``inputs`` : tf.Tensor
Tensor of shape (B, T, D).
training : bool, optional
Indicates training mode; not used in
this layer. Defaults to ``False``.
Returns
-------
tf.Tensor
Attention scores of shape
(B, num_heads, T, T).
"""
if (
isinstance(inputs, (list, tuple))
and len(inputs) == 2
):
query, value = inputs
else:
query = value = inputs
_, attention_scores = self.attention(
query, value, return_attention_scores=True
)
return attention_scores
def get_config(self):
r"""
Returns the layer configuration.
Returns
-------
dict
Dictionary containing 'num_heads'
and 'key_dim'.
"""
config = super().get_config().copy()
config.update(
{
"num_heads": self.num_heads,
"key_dim": self.key_dim,
}
)
return config
@classmethod
def from_config(cls, config):
r"""
Creates a new instance from the config
dictionary.
Parameters
----------
``config`` : dict
Configuration dictionary.
Returns
-------
ExplainableAttention
A new instance of this layer.
"""
return cls(**config)
@register_keras_serializable(
SERIALIZATION_PACKAGE,
name="MultiResolutionAttentionFusion",
)
class MultiResolutionAttentionFusion(Layer, NNLearner):
r"""
MultiResolutionAttentionFusion layer applying
multi-head attention fusion over features [1]_.
This layer merges or fuses features at different
resolutions or sources via multi-head attention.
The input is projected to shape `(B, T, D)`,
and the output shares the same shape.
.. math::
\mathbf{Z} = \text{MHA}(\mathbf{X}, \mathbf{X})
Parameters
----------
units : int
Dimension of the key, query, and value
projections.
num_heads : int
Number of attention heads.
Notes
-----
Typically used in multi-resolution contexts
where time steps or multiple feature sets
are merged.
Methods
-------
call(`inputs`, training=False)
Forward pass of the multi-head attention
layer.
get_config()
Returns config for serialization.
from_config(`config`)
Reconstructs the layer from a config.
Examples
--------
>>> from geoprior.nn.components import (
... MultiResolutionAttentionFusion,
... )
>>> import tensorflow as tf
>>> x = tf.random.normal((32, 10, 64))
>>> # Instantiate multi-resolution attention
>>> mraf = MultiResolutionAttentionFusion(
... units=64, num_heads=4
... )
>>> # Forward pass => (32, 10, 64)
>>> y = mraf(x)
See Also
--------
HierarchicalAttention
Combines short and long-term sequences
with attention.
ExplainableAttention
Another attention layer returning
attention scores.
References
----------
.. [1] Vaswani, A., Shazeer, N., Parmar, N.,
Uszkoreit, J., Jones, L., Gomez, A. N.,
Kaiser, L., & Polosukhin, I. (2017).
"Attention is all you need." In
*Advances in Neural Information
Processing Systems* (pp. 5998-6008).
"""
@ensure_pkg(KERAS_BACKEND or "keras", extra=DEP_MSG)
def __init__(self, units: int, num_heads: int):
r"""
Initialize the MultiResolutionAttentionFusion
layer.
Parameters
----------
units : int
Dimensionality for the attention
projections.
num_heads : int
Number of heads for multi-head
attention.
"""
super().__init__()
self.units = units
self.num_heads = num_heads
# MultiHeadAttention instance
self.attention = MultiHeadAttention(
num_heads=num_heads, key_dim=units
)
@do_not_convert
def call(self, inputs, training=False):
r"""
Forward pass applying multi-head attention
to fuse features.
Parameters
----------
``inputs`` : tf.Tensor
Tensor of shape (B, T, D).
training : bool, optional
Indicates training mode. Defaults to
``False``.
Returns
-------
tf.Tensor
Tensor of shape (B, T, D),
representing fused features.
"""
if (
isinstance(inputs, (list, tuple))
and len(inputs) == 2
):
query, value = inputs
return self.attention(query, value)
return self.attention(inputs, inputs)
def get_config(self):
r"""
Returns configuration dictionary with
'units' and 'num_heads'.
Returns
-------
dict
Configuration for serialization.
"""
config = super().get_config().copy()
config.update(
{"units": self.units, "num_heads": self.num_heads}
)
return config
@classmethod
def from_config(cls, config):
r"""
Instantiate a new
MultiResolutionAttentionFusion layer from
config.
Parameters
----------
``config`` : dict
Configuration dictionary.
Returns
-------
MultiResolutionAttentionFusion
A new instance of this layer.
"""
return cls(**config)