Source code for base_attentive.validation
"""Utilities for backend-agnostic tensor validation."""
from __future__ import annotations
from typing import Any, Optional, Tuple, Union
import numpy as np
from ..logging import get_logger
try:
from .._bootstrap import KERAS_BACKEND, KERAS_DEPS
except Exception: # pragma: no cover
KERAS_BACKEND = ""
KERAS_DEPS = None
__all__ = [
"validate_model_inputs",
"maybe_reduce_quantiles_bh",
"ensure_bh1",
]
_logger = get_logger(__name__)
def _has_runtime() -> bool:
return bool(KERAS_BACKEND and KERAS_DEPS is not None)
def _normalize_inputs(
inputs: Union[Any, np.ndarray, list, tuple, None],
) -> list[Optional[Any]]:
"""Normalize caller inputs into a fixed ``[static, dynamic, future]`` list."""
if inputs is None:
return [None, None, None]
if isinstance(inputs, (list, tuple)):
normalized = list(inputs[:3])
else:
normalized = [inputs]
while len(normalized) < 3:
normalized.append(None)
return normalized[:3]
[docs]
def validate_model_inputs(
inputs: Union[Any, np.ndarray, list],
static_input_dim: Optional[int] = None,
dynamic_input_dim: Optional[int] = None,
future_covariate_dim: Optional[int] = None,
forecast_horizon: Optional[int] = None,
error: str = "raise",
mode: str = "strict",
deep_check: Optional[bool] = None,
model_name: Optional[str] = None,
verbose: int = 0,
**kwargs,
) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
"""
Validate and homogenize input tensors for model workflows.
This entrypoint intentionally stays lightweight: it normalizes the input
container shape and converts values into the active Keras runtime tensor
type when a runtime is available. When no Keras runtime is configured, the
raw values are returned unchanged. ``None`` inputs are normalized to
``(None, None, None)``.
"""
normalized_inputs = _normalize_inputs(inputs)
if not _has_runtime():
return tuple(normalized_inputs)
tensors = []
for inp in normalized_inputs:
if inp is None:
tensors.append(None)
continue
try:
tensors.append(KERAS_DEPS.convert_to_tensor(inp))
except Exception as exc:
if error == "raise":
raise ValueError(
f"Could not convert input to tensor: {exc}"
) from exc
tensors.append(inp)
static, dynamic, future = tensors
if verbose > 0:
_logger.info("Validating input tensors...")
if static is not None:
_logger.info(
" Static shape: %s",
getattr(static, "shape", None),
)
if dynamic is not None:
_logger.info(
" Dynamic shape: %s",
getattr(dynamic, "shape", None),
)
if future is not None:
_logger.info(
" Future shape: %s",
getattr(future, "shape", None),
)
return static, dynamic, future
[docs]
def maybe_reduce_quantiles_bh(
x: Any,
*,
name: str = "tensor",
axis: int = 2,
reduction: Union[str, callable] = "mean",
) -> Any:
"""Reduce a quantile axis when a backend tensor carries one."""
if not _has_runtime():
return x
x = KERAS_DEPS.convert_to_tensor(x)
rank = len(getattr(x, "shape", ()))
if rank >= 4:
if callable(reduction):
return reduction(x, axis=axis)
if reduction == "mean":
return KERAS_DEPS.reduce_mean(x, axis=axis)
if reduction == "sum":
return KERAS_DEPS.reduce_sum(x, axis=axis)
elif rank == 3:
last_dim = x.shape[-1]
if last_dim is not None and last_dim > 1:
if callable(reduction):
return reduction(x, axis=axis)
if reduction == "mean":
return KERAS_DEPS.reduce_mean(x, axis=axis)
if reduction == "sum":
return KERAS_DEPS.reduce_sum(x, axis=axis)
return x
[docs]
def ensure_bh1(
x: Any,
*,
name: str = "tensor",
dtype: Optional[Any] = None,
reduce_axis: Optional[int] = None,
reduction: Union[str, callable] = "mean",
allow_rank1: bool = False,
) -> Any:
"""Ensure a tensor-like value has shape ``(B, H, 1)``."""
if not _has_runtime():
return x
x = KERAS_DEPS.convert_to_tensor(x)
while len(getattr(x, "shape", ())) < 3:
x = KERAS_DEPS.expand_dims(x, axis=-1)
if reduce_axis is not None:
if callable(reduction):
x = reduction(x, axis=reduce_axis)
elif reduction == "mean":
x = KERAS_DEPS.reduce_mean(x, axis=reduce_axis)
elif reduction == "sum":
x = KERAS_DEPS.reduce_sum(x, axis=reduce_axis)
if dtype is not None:
x = KERAS_DEPS.cast(x, dtype)
return x