{ "cells": [ { "cell_type": "markdown", "id": "55c5d230", "metadata": {}, "source": [ "# PADR-Net Flood Forecasting\n", "\n", "> **Notebook goal**: build, train, and interpret the new `PADRNet`\n", "> flood-forecasting application model. We create regional synthetic\n", "> flood events for **WAF**, **EAF**, and **SAF**, train PADR-Net with a\n", "> physics-aware loss, and evaluate depth forecasts, threshold skill,\n", "> mass bias, and spatial-style interpretation maps.\n", "\n", "This notebook is intentionally self-contained. The synthetic data are\n", "not meant to replace hydrodynamic simulation or observations; they are a\n", "compact teaching dataset that exposes the same API shape used with real\n", "forcing and flood-depth arrays.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4fde9708", "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# PADR-Net currently provides TensorFlow and Torch backends. The main\n", "# tutorial uses TensorFlow because the training loop is compact and the\n", "# same public PADRNet factory is used for both backends.\n", "os.environ.setdefault(\"BASE_ATTENTIVE_BACKEND\", \"tensorflow\")\n", "os.environ.setdefault(\"KERAS_BACKEND\", \"tensorflow\")\n", "\n", "print(\"BASE_ATTENTIVE_BACKEND =\", os.environ[\"BASE_ATTENTIVE_BACKEND\"])\n", "print(\"KERAS_BACKEND =\", os.environ[\"KERAS_BACKEND\"])\n" ] }, { "cell_type": "markdown", "id": "0ba4426b", "metadata": {}, "source": [ "---\n", "\n", "## 1. Imports and Reproducibility\n", "\n", "The flood application exports the model factory, validated\n", "configuration, hydrological metrics, and backend-neutral physics\n", "helpers. The model output is a dictionary with three keys:\n", "\n", "- `depth`: multi-horizon flood depth forecast;\n", "- `exceedance_probability`: smooth threshold-exceedance probability;\n", "- `features`: latent event representation.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9fbc69a2", "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "import math\n", "from dataclasses import dataclass\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import tensorflow as tf\n", "\n", "from base_attentive import PADRNet, PADRNetConfig\n", "from base_attentive.applications.flood import (\n", " critical_success_index,\n", " delta_mass,\n", " exceedance_probability,\n", " linear_reservoir_response,\n", " mass_balance_residual,\n", " nash_sutcliffe_efficiency,\n", " true_skill_statistic,\n", ")\n", "\n", "np.random.seed(42)\n", "tf.random.set_seed(42)\n", "tf.get_logger().setLevel(\"ERROR\")\n", "\n", "plt.rcParams.update(\n", " {\n", " \"figure.dpi\": 120,\n", " \"axes.spines.top\": False,\n", " \"axes.spines.right\": False,\n", " \"axes.grid\": True,\n", " \"grid.alpha\": 0.25,\n", " }\n", ")\n" ] }, { "cell_type": "markdown", "id": "946318b6", "metadata": {}, "source": [ "---\n", "\n", "## 2. Synthetic Regional Flood Events\n", "\n", "We generate event hydrographs with a simple rainfall-storage response.\n", "Each region has a different response time, rainfall gain, basin area,\n", "slope, and imperviousness proxy. This gives PADR-Net a reason to use\n", "both dynamic sequences and static descriptors.\n", "\n", "The split mimics a paper workflow: early years are used for training,\n", "intermediate years for validation, and later years for testing.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4b391212", "metadata": {}, "outputs": [], "source": [ "@dataclass(frozen=True)\n", "class RegionSpec:\n", " code: str\n", " name: str\n", " color: str\n", " tau: float\n", " gain: float\n", " area: float\n", " slope: float\n", " impervious: float\n", " storm_scale: float\n", "\n", "\n", "REGIONS = [\n", " RegionSpec(\n", " \"WAF\", \"West Africa\", \"#1f66b1\", 28.0, 0.020,\n", " 0.72, 0.36, 0.22, 1.15,\n", " ),\n", " RegionSpec(\n", " \"EAF\", \"East Africa\", \"#f05a28\", 20.0, 0.018,\n", " 0.54, 0.52, 0.18, 1.00,\n", " ),\n", " RegionSpec(\n", " \"SAF\", \"South Africa\", \"#2f8f3b\", 34.0, 0.015,\n", " 0.61, 0.31, 0.15, 0.92,\n", " ),\n", "]\n", "\n", "LOOKBACK = 48\n", "HORIZON = 24\n", "TOTAL_STEPS = LOOKBACK + HORIZON\n", "INPUT_DIM = 8\n", "STATIC_DIM = 3\n", "FLOOD_THRESHOLD = 0.08\n", "\n", "\n", "def rolling_mean(values: np.ndarray, window: int) -> np.ndarray:\n", " kernel = np.ones(window, dtype=float) / window\n", " return np.convolve(values, kernel, mode=\"same\")\n", "\n", "\n", "def make_storm(rng: np.random.Generator, spec: RegionSpec) -> np.ndarray:\n", " t = np.arange(TOTAL_STEPS)\n", " rain = rng.gamma(1.2, 0.12, TOTAL_STEPS)\n", " n_pulses = rng.integers(1, 4)\n", " for _ in range(n_pulses):\n", " center = rng.uniform(10, LOOKBACK + 8)\n", " width = rng.uniform(2.5, 8.0)\n", " height = rng.uniform(1.0, 4.0) * spec.storm_scale\n", " rain += height * np.exp(-0.5 * ((t - center) / width) ** 2)\n", " return np.maximum(rain, 0.0)\n", "\n", "\n", "def make_event(\n", " rng: np.random.Generator,\n", " spec: RegionSpec,\n", " year: int,\n", ") -> dict[str, np.ndarray | str | int]:\n", " rain = make_storm(rng, spec)\n", " depth = linear_reservoir_response(\n", " rain,\n", " tau=spec.tau,\n", " gain=spec.gain,\n", " initial_depth=rng.uniform(0.0, 0.01),\n", " )\n", " depth *= 8.0\n", " depth += 0.015 * np.sin(np.linspace(0, 2 * np.pi, TOTAL_STEPS))\n", " depth += rng.normal(0.0, 0.002, TOTAL_STEPS)\n", " depth = np.maximum(depth, 0.0)\n", "\n", " antecedent = rolling_mean(rain, 6)\n", " wetness = rolling_mean(rain, 18)\n", " response_proxy = linear_reservoir_response(\n", " rain,\n", " tau=max(8.0, spec.tau * 0.65),\n", " gain=spec.gain * 0.8,\n", " )\n", " hour = np.arange(TOTAL_STEPS) / 24.0\n", " seasonal = (year - 2001) / 23.0\n", "\n", " dynamic = np.column_stack(\n", " [\n", " rain,\n", " antecedent,\n", " wetness,\n", " response_proxy,\n", " np.sin(2 * np.pi * hour),\n", " np.cos(2 * np.pi * hour),\n", " np.full(TOTAL_STEPS, seasonal),\n", " np.gradient(rain),\n", " ]\n", " )\n", " static = np.array([spec.area, spec.slope, spec.impervious])\n", " return {\n", " \"region\": spec.code,\n", " \"year\": year,\n", " \"x\": dynamic[:LOOKBACK].astype(\"float32\"),\n", " \"static\": static.astype(\"float32\"),\n", " \"y\": depth[LOOKBACK:].astype(\"float32\")[:, None],\n", " \"rain_future\": rain[LOOKBACK:].astype(\"float32\")[:, None],\n", " \"depth_full\": depth.astype(\"float32\"),\n", " \"rain_full\": rain.astype(\"float32\"),\n", " }\n", "\n", "\n", "def make_dataset(n_events: int = 720, seed: int = 13):\n", " rng = np.random.default_rng(seed)\n", " events = []\n", " years = rng.integers(2001, 2025, n_events)\n", " for year in years:\n", " spec = REGIONS[int(rng.integers(0, len(REGIONS)))]\n", " events.append(make_event(rng, spec, int(year)))\n", " return events\n", "\n", "\n", "events = make_dataset()\n", "print(f\"events: {len(events)}\")\n", "print(\"first event keys:\", sorted(events[0].keys()))\n" ] }, { "cell_type": "markdown", "id": "79ef693d", "metadata": {}, "source": [ "### Visual check: rainfall and flood response\n", "\n", "Before training a neural model, always inspect a few hydrographs. The\n", "vertical line marks the forecast origin. PADR-Net receives the left side\n", "and predicts the right side.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c488e33c", "metadata": {}, "outputs": [], "source": [ "def plot_event_examples(events, n_per_region=1):\n", " fig, axes = plt.subplots(3, n_per_region, figsize=(10, 6), sharex=True)\n", " axes = np.atleast_2d(axes)\n", " for row, spec in enumerate(REGIONS):\n", " region_events = [e for e in events if e[\"region\"] == spec.code]\n", " for col, event in enumerate(region_events[:n_per_region]):\n", " ax = axes[row, col]\n", " t = np.arange(TOTAL_STEPS)\n", " ax2 = ax.twinx()\n", " ax.bar(\n", " t,\n", " event[\"rain_full\"],\n", " color=\"#8fbce6\",\n", " alpha=0.45,\n", " width=1.0,\n", " label=\"rainfall\",\n", " )\n", " ax2.plot(\n", " t,\n", " event[\"depth_full\"],\n", " color=spec.color,\n", " lw=2.2,\n", " label=\"depth\",\n", " )\n", " ax2.axhline(\n", " FLOOD_THRESHOLD,\n", " color=\"#8b1e1e\",\n", " lw=1.2,\n", " ls=\"--\",\n", " label=\"threshold\",\n", " )\n", " ax.axvline(LOOKBACK, color=\"#222\", lw=1.0, ls=\":\")\n", " ax.set_title(f\"{spec.code} event, year {event['year']}\")\n", " ax.set_ylabel(\"rain\")\n", " ax2.set_ylabel(\"depth (m)\")\n", " ax.set_xlim(0, TOTAL_STEPS - 1)\n", " axes[-1, 0].set_xlabel(\"time step\")\n", " fig.tight_layout()\n", " return fig\n", "\n", "plot_event_examples(events, n_per_region=2)\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "bf4be5ef", "metadata": {}, "source": [ "---\n", "\n", "## 3. Build Arrays and Temporal Splits\n", "\n", "The arrays follow the PADR-Net input contract:\n", "\n", "- `X`: `(batch, lookback, input_dim)`;\n", "- `S`: `(batch, static_dim)`;\n", "- `Y`: `(batch, forecast_horizon, 1)`;\n", "- `P_future`: `(batch, forecast_horizon, 1)` for physics diagnostics.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "afa559dd", "metadata": {}, "outputs": [], "source": [ "X = np.stack([e[\"x\"] for e in events]).astype(\"float32\")\n", "S = np.stack([e[\"static\"] for e in events]).astype(\"float32\")\n", "Y = np.stack([e[\"y\"] for e in events]).astype(\"float32\")\n", "P_future = np.stack([e[\"rain_future\"] for e in events]).astype(\"float32\")\n", "years = np.array([e[\"year\"] for e in events])\n", "regions = np.array([e[\"region\"] for e in events])\n", "\n", "train_idx = np.where(years <= 2018)[0]\n", "val_idx = np.where((years >= 2019) & (years <= 2020))[0]\n", "test_idx = np.where(years >= 2021)[0]\n", "\n", "print(\"X:\", X.shape, \"S:\", S.shape, \"Y:\", Y.shape)\n", "print(\"train/val/test:\", len(train_idx), len(val_idx), len(test_idx))\n", "\n", "for spec in REGIONS:\n", " count = np.sum(regions == spec.code)\n", " print(f\"{spec.code}: {count} events\")\n" ] }, { "cell_type": "markdown", "id": "69136674", "metadata": {}, "source": [ "### Regional inventory plot\n", "\n", "This plot is not used by the model. It helps verify that all three\n", "regions are represented through time and that validation/test years are\n", "visible.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "86ef80f4", "metadata": {}, "outputs": [], "source": [ "def plot_event_inventory(years, regions):\n", " year_grid = np.arange(2001, 2025)\n", " counts = np.zeros((len(REGIONS), len(year_grid)), dtype=int)\n", " for i, spec in enumerate(REGIONS):\n", " for j, year in enumerate(year_grid):\n", " counts[i, j] = np.sum((regions == spec.code) & (years == year))\n", "\n", " fig, ax = plt.subplots(figsize=(11, 2.8))\n", " im = ax.imshow(counts, cmap=\"Blues\", aspect=\"auto\")\n", " ax.axvspan(2018.5 - 2001, 2020.5 - 2001, color=\"#f8c27c\", alpha=0.35)\n", " ax.axvspan(2020.5 - 2001, 2024.5 - 2001, color=\"#ef9cae\", alpha=0.28)\n", " ax.set_yticks(range(len(REGIONS)), [r.code for r in REGIONS])\n", " ax.set_xticks(range(0, len(year_grid), 3), year_grid[::3], rotation=45)\n", " ax.set_title(\"Synthetic flood event inventory by region and year\")\n", " ax.set_xlabel(\"year\")\n", " ax.set_ylabel(\"region\")\n", " cbar = fig.colorbar(im, ax=ax, pad=0.01)\n", " cbar.set_label(\"events\")\n", " fig.tight_layout()\n", " return fig\n", "\n", "plot_event_inventory(years, regions)\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "d6b6c675", "metadata": {}, "source": [ "---\n", "\n", "## 4. Configure PADR-Net\n", "\n", "The validated `PADRNetConfig` records the input dimensions, attention\n", "capacity, forecast horizon, threshold, and physics weights. The public\n", "`PADRNet` factory returns a backend-specific model while preserving the\n", "same API.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b02aced7", "metadata": {}, "outputs": [], "source": [ "config = PADRNetConfig(\n", " input_dim=INPUT_DIM,\n", " static_dim=STATIC_DIM,\n", " hidden_dim=48,\n", " num_heads=4,\n", " num_layers=2,\n", " forecast_horizon=HORIZON,\n", " dropout=0.05,\n", " lambda_physics=0.25,\n", " lambda_mass=0.05,\n", " lambda_smooth=0.01,\n", " flood_threshold=FLOOD_THRESHOLD,\n", " reservoir_tau=24.0,\n", ")\n", "\n", "model = PADRNet(config, backend=\"tensorflow\")\n", "outputs = model(tf.zeros((2, LOOKBACK, INPUT_DIM)), tf.zeros((2, STATIC_DIM)))\n", "\n", "print(type(model).__name__)\n", "for key, value in outputs.items():\n", " print(f\"{key:24s}\", tuple(value.shape))\n" ] }, { "cell_type": "markdown", "id": "e7d38ffe", "metadata": {}, "source": [ "---\n", "\n", "## 5. Physics-Aware Training Loop\n", "\n", "PADR-Net produces predictions; the training loop decides how to combine\n", "losses. Here we use:\n", "\n", "\\begin{align}\n", "\\mathcal{L} =\n", "\\mathcal{L}_{\\mathrm{mse}}\n", "+ \\lambda_{\\mathrm{phys}} \\lVert r_t \\rVert_2^2\n", "+ \\lambda_{\\mathrm{mass}} |\\Delta M|\n", "+ \\lambda_{\\mathrm{smooth}} \\lVert \\nabla_t \\hat{h} \\rVert_2^2.\n", "\\end{align}\n", "\n", "The residual is a simple rainfall-storage tendency. Real projects can\n", "replace it with a hydrodynamic operator, differentiable routing layer,\n", "or physics residual from a SWE/hydrological solver.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a31e3227", "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 48\n", "EPOCHS = 24\n", "\n", "\n", "def make_tf_dataset(indices, shuffle=False):\n", " ds = tf.data.Dataset.from_tensor_slices(\n", " (X[indices], S[indices], Y[indices], P_future[indices])\n", " )\n", " if shuffle:\n", " ds = ds.shuffle(len(indices), seed=42, reshuffle_each_iteration=True)\n", " return ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)\n", "\n", "\n", "train_ds = make_tf_dataset(train_idx, shuffle=True)\n", "val_ds = make_tf_dataset(val_idx, shuffle=False)\n", "test_ds = make_tf_dataset(test_idx, shuffle=False)\n", "\n", "try:\n", " optimizer = tf.keras.optimizers.Adam(learning_rate=3e-3)\n", "except (ImportError, ModuleNotFoundError):\n", " from tensorflow.python.keras.optimizer_v2.adam import Adam\n", "\n", " optimizer = Adam(learning_rate=3e-3)\n", "\n", "\n", "def physics_terms(y_true, y_pred, rain_future, cfg, gain_proxy):\n", " pred = tf.squeeze(y_pred, axis=-1)\n", " true = tf.squeeze(y_true, axis=-1)\n", " rain = tf.squeeze(rain_future, axis=-1)\n", "\n", " mse = tf.reduce_mean(tf.square(true - pred))\n", " dh = pred[:, 1:] - pred[:, :-1]\n", " storage_rhs = gain_proxy * rain[:, 1:] - pred[:, :-1] / cfg.reservoir_tau\n", " residual = dh - storage_rhs\n", " physics = tf.reduce_mean(tf.square(residual))\n", "\n", " mass = tf.reduce_mean(\n", " tf.abs(tf.reduce_sum(pred, axis=1) - tf.reduce_sum(true, axis=1))\n", " / (tf.reduce_sum(true, axis=1) + 1e-4)\n", " )\n", " smooth = tf.reduce_mean(tf.square(dh))\n", " total = (\n", " mse\n", " + cfg.lambda_physics * physics\n", " + cfg.lambda_mass * mass\n", " + cfg.lambda_smooth * smooth\n", " )\n", " return total, {\"mse\": mse, \"physics\": physics, \"mass\": mass, \"smooth\": smooth}\n", "\n", "\n", "# Notebook-local teaching gain for the residual.\n", "GAIN_PROXY = 0.12\n", "\n", "\n", "@tf.function\n", "def train_step(xb, sb, yb, pb):\n", " with tf.GradientTape() as tape:\n", " out = model(xb, sb, training=True)\n", " loss, terms = physics_terms(yb, out[\"depth\"], pb, config, GAIN_PROXY)\n", " grads = tape.gradient(loss, model.trainable_variables)\n", " optimizer.apply_gradients(zip(grads, model.trainable_variables))\n", " return loss, terms\n", "\n", "\n", "@tf.function\n", "def eval_step(xb, sb, yb, pb):\n", " out = model(xb, sb, training=False)\n", " return physics_terms(yb, out[\"depth\"], pb, config, GAIN_PROXY)\n", "\n", "\n", "def mean_epoch(ds, train=False):\n", " values = []\n", " parts = {\"mse\": [], \"physics\": [], \"mass\": [], \"smooth\": []}\n", " for xb, sb, yb, pb in ds:\n", " loss, terms = train_step(xb, sb, yb, pb) if train else eval_step(xb, sb, yb, pb)\n", " values.append(float(loss.numpy()))\n", " for key, val in terms.items():\n", " parts[key].append(float(val.numpy()))\n", " return float(np.mean(values)), {k: float(np.mean(v)) for k, v in parts.items()}\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7eadbe1f", "metadata": {}, "outputs": [], "source": [ "history = {\"train\": [], \"val\": [], \"mse\": [], \"physics\": [], \"mass\": []}\n", "\n", "for epoch in range(1, EPOCHS + 1):\n", " train_loss, train_terms = mean_epoch(train_ds, train=True)\n", " val_loss, val_terms = mean_epoch(val_ds, train=False)\n", " history[\"train\"].append(train_loss)\n", " history[\"val\"].append(val_loss)\n", " history[\"mse\"].append(val_terms[\"mse\"])\n", " history[\"physics\"].append(val_terms[\"physics\"])\n", " history[\"mass\"].append(val_terms[\"mass\"])\n", " if epoch == 1 or epoch % 4 == 0 or epoch == EPOCHS:\n", " print(\n", " f\"epoch {epoch:02d} | train={train_loss:.5f} \"\n", " f\"val={val_loss:.5f} mse={val_terms['mse']:.5f} \"\n", " f\"phys={val_terms['physics']:.5f}\"\n", " )\n" ] }, { "cell_type": "code", "execution_count": null, "id": "83051ec8", "metadata": {}, "outputs": [], "source": [ "def plot_training_history(history):\n", " epochs = np.arange(1, len(history[\"train\"]) + 1)\n", " fig, axes = plt.subplots(1, 2, figsize=(11, 3.5))\n", " axes[0].plot(epochs, history[\"train\"], label=\"train\", lw=2)\n", " axes[0].plot(epochs, history[\"val\"], label=\"validation\", lw=2)\n", " axes[0].set_title(\"PADR-Net training objective\")\n", " axes[0].set_xlabel(\"epoch\")\n", " axes[0].set_ylabel(\"loss\")\n", " axes[0].legend()\n", "\n", " axes[1].plot(epochs, history[\"mse\"], label=\"MSE\", lw=2)\n", " axes[1].plot(epochs, history[\"physics\"], label=\"physics\", lw=2)\n", " axes[1].plot(epochs, history[\"mass\"], label=\"mass\", lw=2)\n", " axes[1].set_title(\"Validation loss components\")\n", " axes[1].set_xlabel(\"epoch\")\n", " axes[1].legend()\n", " fig.tight_layout()\n", " return fig\n", "\n", "plot_training_history(history)\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "1dd56f1e", "metadata": {}, "source": [ "---\n", "\n", "## 6. Evaluation Metrics\n", "\n", "We evaluate continuous depth and event-threshold skill. `NSE` measures\n", "hydrograph agreement, `CSI` and `TSS` measure flooded-threshold\n", "classification, and `delta_mass` reports signed volume/depth bias.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "adf8d770", "metadata": {}, "outputs": [], "source": [ "def predict_arrays(indices):\n", " out = model(\n", " tf.convert_to_tensor(X[indices]),\n", " tf.convert_to_tensor(S[indices]),\n", " training=False,\n", " )\n", " pred = out[\"depth\"].numpy()\n", " prob = out[\"exceedance_probability\"].numpy()\n", " features = out[\"features\"].numpy()\n", " return pred, prob, features\n", "\n", "\n", "pred_test, prob_test, features_test = predict_arrays(test_idx)\n", "y_test = Y[test_idx]\n", "region_test = regions[test_idx]\n", "\n", "rows = []\n", "for spec in REGIONS:\n", " mask = region_test == spec.code\n", " yt = y_test[mask].reshape(-1)\n", " yp = pred_test[mask].reshape(-1)\n", " rows.append(\n", " {\n", " \"region\": spec.code,\n", " \"NSE\": nash_sutcliffe_efficiency(yt, yp),\n", " \"CSI\": critical_success_index(yt, yp, threshold=FLOOD_THRESHOLD),\n", " \"TSS\": true_skill_statistic(yt, yp, threshold=FLOOD_THRESHOLD),\n", " \"delta_mass_%\": delta_mass(yt, yp),\n", " }\n", " )\n", "\n", "for row in rows:\n", " print(\n", " f\"{row['region']}: NSE={row['NSE']:.3f} \"\n", " f\"CSI={row['CSI']:.3f} TSS={row['TSS']:.3f} \"\n", " f\"DeltaM={row['delta_mass_%']:.2f}%\"\n", " )\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2b93b435", "metadata": {}, "outputs": [], "source": [ "def plot_metric_bars(rows):\n", " labels = [r[\"region\"] for r in rows]\n", " metrics = [\"NSE\", \"CSI\", \"TSS\"]\n", " x = np.arange(len(labels))\n", " width = 0.24\n", " fig, axes = plt.subplots(1, 2, figsize=(11, 3.8))\n", " for i, metric in enumerate(metrics):\n", " axes[0].bar(\n", " x + (i - 1) * width,\n", " [r[metric] for r in rows],\n", " width,\n", " label=metric,\n", " )\n", " axes[0].set_xticks(x, labels)\n", " axes[0].set_ylim(-0.2, 1.05)\n", " axes[0].set_title(\"Regional forecast skill\")\n", " axes[0].legend()\n", "\n", " axes[1].bar(labels, [r[\"delta_mass_%\"] for r in rows], color=\"#7c6bb0\")\n", " axes[1].axhline(0, color=\"#222\", lw=1)\n", " axes[1].set_title(\"Mass bias\")\n", " axes[1].set_ylabel(\"DeltaM (%)\")\n", " fig.tight_layout()\n", " return fig\n", "\n", "plot_metric_bars(rows)\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "d69721c8", "metadata": {}, "source": [ "---\n", "\n", "## 7. Hydrograph Interpretation\n", "\n", "Each panel compares reference flood depth with PADR-Net depth at the\n", "same forecast horizon. The shaded region indicates where the reference\n", "is above the flood threshold.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "77dde9f8", "metadata": {}, "outputs": [], "source": [ "def plot_region_forecasts(indices, pred, prob):\n", " fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)\n", " test_events = [events[i] for i in indices]\n", " for ax, spec in zip(axes, REGIONS):\n", " candidates = [j for j, e in enumerate(test_events) if e[\"region\"] == spec.code]\n", " # Choose a visible event near the regional 80th percentile.\n", " peaks = [float(Y[indices[j]].max()) for j in candidates]\n", " chosen = candidates[int(np.argsort(peaks)[max(0, int(0.8 * len(peaks)) - 1)])]\n", " ref = Y[indices[chosen], :, 0]\n", " yh = pred[chosen, :, 0]\n", " pp = prob[chosen, :, 0]\n", " t = np.arange(1, HORIZON + 1)\n", " ax.plot(t, ref, color=\"#222\", lw=2.2, label=\"reference\")\n", " ax.plot(t, yh, color=spec.color, lw=2.2, label=\"PADR-Net\")\n", " ax.fill_between(\n", " t,\n", " 0,\n", " ref,\n", " where=ref >= FLOOD_THRESHOLD,\n", " color=spec.color,\n", " alpha=0.16,\n", " label=\"reference flooded\",\n", " )\n", " ax2 = ax.twinx()\n", " ax2.plot(t, pp, color=\"#9c6b21\", lw=1.6, ls=\"--\", label=\"exceedance\")\n", " ax2.set_ylim(0, 1.05)\n", " ax2.set_ylabel(\"prob.\")\n", " ax.axhline(FLOOD_THRESHOLD, color=\"#8b1e1e\", lw=1.0, ls=\":\")\n", " ax.set_title(f\"{spec.name} ({spec.code})\")\n", " ax.set_ylabel(\"depth (m)\")\n", " ax.legend(loc=\"upper left\")\n", " axes[-1].set_xlabel(\"forecast lead time\")\n", " fig.tight_layout()\n", " return fig\n", "\n", "plot_region_forecasts(test_idx, pred_test, prob_test)\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "7921595b", "metadata": {}, "source": [ "---\n", "\n", "## 8. Spatial-Style PADR-Net Diagnostic Map\n", "\n", "For papers and reports, hydrographs are not always enough. The helper\n", "below turns a regional peak depth into a synthetic flood-depth field so\n", "we can show the same idea as a 2x3 map:\n", "\n", "- first row: reference hydrodynamic response;\n", "- second row: PADR-Net forecast at the same peak time;\n", "- shared colorbar: water depth in metres.\n", "\n", "With real data, replace `make_spatial_field` with your gridded reference\n", "and PADR-Net gridded forecast arrays.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6f1810c1", "metadata": {}, "outputs": [], "source": [ "def make_spatial_field(amplitude, region_index, n=80):\n", " yy, xx = np.mgrid[-1:1:complex(n), -1:1:complex(n)]\n", " angle = [0.4, -0.2, 0.8][region_index]\n", " xr = xx * np.cos(angle) - yy * np.sin(angle)\n", " yr = xx * np.sin(angle) + yy * np.cos(angle)\n", " basin = np.exp(-((xr / 0.55) ** 2 + (yr / 0.75) ** 2))\n", " channel = np.exp(-(yr / 0.11) ** 2) * np.exp(-(xr / 0.9) ** 4)\n", " tributary = 0.55 * np.exp(-((yr - 0.35 * xr) / 0.09) ** 2)\n", " field = amplitude * (0.55 * basin + 0.35 * channel + 0.10 * tributary)\n", " field[field < 0.002] = np.nan\n", " return field\n", "\n", "\n", "def plot_spatial_diagnostic(indices, pred):\n", " fig, axes = plt.subplots(2, 3, figsize=(12, 6.4), constrained_layout=True)\n", " vmax = 0.0\n", " fields = []\n", " for i, spec in enumerate(REGIONS):\n", " region_pos = np.where(regions[indices] == spec.code)[0]\n", " peak_order = np.argsort(Y[indices[region_pos]].max(axis=(1, 2)))\n", " chosen = region_pos[peak_order[int(0.75 * len(peak_order))]]\n", " ref_peak = float(Y[indices[chosen]].max())\n", " pred_peak = float(pred[chosen].max())\n", " ref_field = make_spatial_field(ref_peak, i)\n", " pred_field = make_spatial_field(pred_peak, i)\n", " fields.append((ref_field, pred_field, spec))\n", " vmax = max(vmax, np.nanmax(ref_field), np.nanmax(pred_field))\n", "\n", " for col, (ref_field, pred_field, spec) in enumerate(fields):\n", " for row, field in enumerate([ref_field, pred_field]):\n", " ax = axes[row, col]\n", " im = ax.imshow(field, cmap=\"Blues\", vmin=0, vmax=vmax)\n", " ax.contour(\n", " np.nan_to_num(field, nan=0.0),\n", " levels=[FLOOD_THRESHOLD],\n", " colors=[spec.color],\n", " linewidths=1.6,\n", " )\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " if row == 0:\n", " ax.set_title(f\"{spec.code}\")\n", " if col == 0:\n", " ax.set_ylabel([\"Reference\", \"PADR-Net\"][row])\n", " cbar = fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.86, pad=0.02)\n", " cbar.set_label(\"water depth (m)\")\n", " fig.suptitle(\"Regional peak-depth map diagnostic\", y=1.03, fontsize=14)\n", " return fig\n", "\n", "plot_spatial_diagnostic(test_idx, pred_test)\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "385d6a2d", "metadata": {}, "source": [ "---\n", "\n", "## 9. Latent Features and Regional Transfer\n", "\n", "The `features` output can be used for diagnostics. Here we reduce the\n", "latent vectors to two principal directions and color the points by\n", "region. A strong model should not merely memorize region identity; it\n", "should organize events by hydrological response and severity.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "b56ebec8", "metadata": {}, "outputs": [], "source": [ "def pca2(values):\n", " centered = values - values.mean(axis=0, keepdims=True)\n", " _, _, vt = np.linalg.svd(centered, full_matrices=False)\n", " return centered @ vt[:2].T\n", "\n", "coords = pca2(features_test)\n", "peak_depth = y_test.max(axis=(1, 2))\n", "\n", "fig, ax = plt.subplots(figsize=(7.2, 5.2))\n", "for spec in REGIONS:\n", " mask = region_test == spec.code\n", " sc = ax.scatter(\n", " coords[mask, 0],\n", " coords[mask, 1],\n", " c=peak_depth[mask],\n", " cmap=\"viridis\",\n", " s=40,\n", " alpha=0.85,\n", " label=spec.code,\n", " edgecolor=\"white\",\n", " linewidth=0.4,\n", " )\n", "ax.set_title(\"PADR-Net latent event space\")\n", "ax.set_xlabel(\"PC1\")\n", "ax.set_ylabel(\"PC2\")\n", "ax.legend(title=\"region\")\n", "cbar = fig.colorbar(sc, ax=ax)\n", "cbar.set_label(\"reference peak depth (m)\")\n", "fig.tight_layout()\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "b0269117", "metadata": {}, "source": [ "---\n", "\n", "## 10. Optional: PyTorch Backend Smoke Test\n", "\n", "The public factory is backend-neutral. The PyTorch implementation uses\n", "the same `PADRNetConfig` and returns the same dictionary keys. This cell\n", "is a smoke test; the training loop above remains TensorFlow-specific.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0692158f", "metadata": {}, "outputs": [], "source": [ "try:\n", " import torch\n", "\n", " torch_model = PADRNet(config, backend=\"torch\")\n", " torch_out = torch_model(\n", " torch.zeros(2, LOOKBACK, INPUT_DIM),\n", " torch.zeros(2, STATIC_DIM),\n", " )\n", " print(type(torch_model).__name__)\n", " for key, value in torch_out.items():\n", " print(f\"{key:24s}\", tuple(value.shape))\n", "except Exception as exc:\n", " print(\"PyTorch smoke test skipped:\", type(exc).__name__, exc)\n" ] }, { "cell_type": "markdown", "id": "a96cf9e2", "metadata": {}, "source": [ "---\n", "\n", "## 11. Exercises\n", "\n", "### Exercise 1: Flood threshold sensitivity\n", "\n", "Change `FLOOD_THRESHOLD` from `0.05` to `0.03` or `0.08`, rerun the\n", "notebook, and compare `CSI`, `TSS`, and the contour line in the spatial\n", "map. Which threshold makes false alarms more likely?\n", "\n", "### Exercise 2: Physics weight ablation\n", "\n", "Set `lambda_physics=0.0`, retrain, and compare the validation loss\n", "components and mass bias. Does the model become more accurate in MSE but\n", "less physically plausible?\n", "\n", "### Exercise 3: Leave-one-region-out transfer\n", "\n", "Train only on two regions and test on the held-out region. For example,\n", "train on WAF+SAF and test on EAF. Which static descriptors help the most\n", "when transferring to a new region?\n", "\n", "### Exercise 4: Replace the synthetic forcing\n", "\n", "Replace `make_dataset()` with real arrays from ERA5/GloFAS/SWE outputs.\n", "Keep the same shapes:\n", "\n", "- `X`: `(batch, lookback, input_dim)`;\n", "- `S`: `(batch, static_dim)`;\n", "- `Y`: `(batch, forecast_horizon, 1)`.\n", "\n", "The PADR-Net API does not change when the data source changes.\n" ] }, { "cell_type": "markdown", "id": "0a5f3c19", "metadata": {}, "source": [ "---\n", "\n", "## 12. Summary\n", "\n", "This notebook introduced the complete PADR-Net workflow:\n", "\n", "1. create regional flood-event tensors;\n", "2. configure `PADRNetConfig` with validated parameters;\n", "3. instantiate `PADRNet` with the TensorFlow backend;\n", "4. train with prediction and physics-aware loss terms;\n", "5. evaluate NSE, CSI, TSS, and mass bias;\n", "6. interpret forecasts with hydrographs, spatial maps, and latent\n", " event-space diagnostics.\n", "\n", "The same structure can now be reused with real hydrodynamic reference\n", "outputs or observed flood-depth products.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }