Source code for base_attentive.runtime

"""Runtime helpers for accelerated inference."""

from __future__ import annotations

import importlib
from typing import Any

from . import KERAS_BACKEND


def _load_tensorflow():
    try:
        return importlib.import_module("tensorflow")
    except ImportError as exc:
        raise ImportError(
            "TensorFlow is required to build a fast prediction function. "
            "Install TensorFlow and use the TensorFlow backend."
        ) from exc


[docs] def make_fast_predict_fn( model: Any, *, jit_compile: bool = True, reduce_retracing: bool = True, warmup_inputs: Any | None = None, ): """ Create a TensorFlow-traced prediction function for a Keras model. The returned callable accepts the same input structure as ``model`` and always executes with ``training=False``. This is useful when you want a reusable inference function with ``tf.function`` tracing and optional XLA compilation. Parameters ---------- model : Any A Keras-compatible model or layer that can be called as ``model(inputs, training=False)``. jit_compile : bool, default=True Whether to request XLA JIT compilation for the traced prediction function. reduce_retracing : bool, default=True Whether TensorFlow should reduce retracing when input structures are reused. warmup_inputs : Any, optional Example inputs used to trigger tracing before the callable is returned. Returns ------- callable A TensorFlow ``tf.function``-wrapped prediction callable. Raises ------ RuntimeError If the active package backend is not TensorFlow. ImportError If TensorFlow cannot be imported. """ if KERAS_BACKEND != "tensorflow": raise RuntimeError( "make_fast_predict_fn requires the TensorFlow backend. " f"Current backend is {KERAS_BACKEND!r}. Set " "`KERAS_BACKEND=tensorflow` before importing base_attentive." ) tf = _load_tensorflow() @tf.function( jit_compile=jit_compile, reduce_retracing=reduce_retracing, ) def predict_fn(inputs): return model(inputs, training=False) if warmup_inputs is not None: predict_fn(warmup_inputs) return predict_fn
__all__ = ["make_fast_predict_fn"]