{ "cells": [ { "cell_type": "markdown", "id": "8dac1885", "metadata": {}, "source": [ "# Advanced PADR-Net Workflow: Transfer, Uncertainty, and Stress Testing\n", "\n", "> **Goal**: use `PADRNet` as an experiment framework, not only as a\n", "> single forecaster. We run leave-one-region-out transfer, physics-loss\n", "> ablation, threshold calibration, Monte-Carlo dropout uncertainty, and\n", "> rainfall-intensification stress tests.\n", "\n", "This notebook is the advanced companion to\n", "`14_padrnet_flood_forecasting.ipynb`.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9f007401", "metadata": {}, "outputs": [], "source": [ "import os\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "os.environ.setdefault(\"BASE_ATTENTIVE_BACKEND\", \"tensorflow\")\n", "os.environ.setdefault(\"KERAS_BACKEND\", \"tensorflow\")\n", "print(\"BASE_ATTENTIVE_BACKEND =\", os.environ[\"BASE_ATTENTIVE_BACKEND\"])\n", "print(\"KERAS_BACKEND =\", os.environ[\"KERAS_BACKEND\"])\n" ] }, { "cell_type": "markdown", "id": "dc1e4a07", "metadata": {}, "source": [ "---\n", "\n", "## 1. Imports\n", "\n", "The public API remains small: `PADRNetConfig` validates the model\n", "configuration and `PADRNet` creates the backend-specific model. The\n", "advanced workflow is built around experimental splits and diagnostics.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f8a84b24", "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import tensorflow as tf\n", "\n", "from base_attentive import PADRNet, PADRNetConfig\n", "from base_attentive.applications.flood import (\n", " critical_success_index,\n", " delta_mass,\n", " linear_reservoir_response,\n", " mass_balance_residual,\n", " nash_sutcliffe_efficiency,\n", " true_skill_statistic,\n", ")\n", "\n", "np.random.seed(123)\n", "tf.random.set_seed(123)\n", "tf.get_logger().setLevel(\"ERROR\")\n", "plt.rcParams.update({\"figure.dpi\": 120, \"axes.grid\": True, \"grid.alpha\": 0.25})\n" ] }, { "cell_type": "markdown", "id": "92ad5a1a", "metadata": {}, "source": [ "---\n", "\n", "## 2. Regional Synthetic Data with Transfer Shift\n", "\n", "The generator makes WAF, EAF, and SAF hydrologically different. That\n", "regional shift is what makes transfer testing meaningful.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fe618ec2", "metadata": {}, "outputs": [], "source": [ "@dataclass(frozen=True)\n", "class RegionSpec:\n", " code: str\n", " name: str\n", " color: str\n", " tau: float\n", " gain: float\n", " area: float\n", " slope: float\n", " impervious: float\n", " storm_scale: float\n", " convective: float\n", "\n", "REGIONS = [\n", " RegionSpec(\"WAF\", \"West Africa\", \"#1f66b1\", 30.0, 0.020, 0.74, 0.34, 0.23, 1.18, 0.70),\n", " RegionSpec(\"EAF\", \"East Africa\", \"#f05a28\", 19.0, 0.018, 0.55, 0.55, 0.18, 1.02, 1.10),\n", " RegionSpec(\"SAF\", \"South Africa\", \"#2f8f3b\", 36.0, 0.015, 0.63, 0.30, 0.15, 0.92, 0.55),\n", "]\n", "LOOKBACK, HORIZON = 48, 24\n", "TOTAL_STEPS = LOOKBACK + HORIZON\n", "INPUT_DIM, STATIC_DIM = 8, 3\n", "FLOOD_THRESHOLD = 0.08\n", "GAIN_PROXY = 0.12\n", "\n", "\n", "def rolling_mean(v, w):\n", " return np.convolve(v, np.ones(w) / w, mode=\"same\")\n", "\n", "\n", "def make_storm(rng, spec, intensify=1.0):\n", " t = np.arange(TOTAL_STEPS)\n", " rain = rng.gamma(1.15, 0.10, TOTAL_STEPS)\n", " for _ in range(rng.integers(1, 4)):\n", " c = rng.uniform(10, LOOKBACK + 10)\n", " width = rng.uniform(2.2, 7.5)\n", " height = rng.uniform(1.1, 4.4) * spec.storm_scale * intensify\n", " rain += height * np.exp(-0.5 * ((t - c) / width) ** 2)\n", " if rng.random() < spec.convective:\n", " c = rng.uniform(LOOKBACK - 5, LOOKBACK + 8)\n", " rain += 1.5 * intensify * np.exp(-0.5 * ((t - c) / 1.8) ** 2)\n", " return np.maximum(rain, 0.0)\n", "\n", "\n", "def make_event(rng, spec, year, intensify=1.0):\n", " rain = make_storm(rng, spec, intensify)\n", " depth = linear_reservoir_response(rain, tau=spec.tau, gain=spec.gain)\n", " depth = 8.0 * depth + 0.014 * np.sin(np.linspace(0, 2 * np.pi, TOTAL_STEPS))\n", " depth += rng.normal(0.0, 0.004, TOTAL_STEPS)\n", " depth = np.maximum(depth, 0.0)\n", " fast = linear_reservoir_response(rain, tau=max(8.0, spec.tau * 0.6), gain=spec.gain * 0.8)\n", " hour = np.arange(TOTAL_STEPS) / 24.0\n", " dyn = np.column_stack([\n", " rain, rolling_mean(rain, 6), rolling_mean(rain, 18), fast,\n", " np.sin(2 * np.pi * hour), np.cos(2 * np.pi * hour),\n", " np.full(TOTAL_STEPS, (year - 2001) / 24.0), np.gradient(rain),\n", " ])\n", " return {\n", " \"region\": spec.code,\n", " \"year\": year,\n", " \"x\": dyn[:LOOKBACK].astype(\"float32\"),\n", " \"static\": np.array([spec.area, spec.slope, spec.impervious], dtype=\"float32\"),\n", " \"y\": depth[LOOKBACK:].astype(\"float32\")[:, None],\n", " \"p\": rain[LOOKBACK:].astype(\"float32\")[:, None],\n", " }\n", "\n", "\n", "def make_dataset(n=540, seed=33, intensify=1.0):\n", " rng = np.random.default_rng(seed)\n", " out = []\n", " for _ in range(n):\n", " spec = REGIONS[int(rng.integers(0, len(REGIONS)))]\n", " out.append(make_event(rng, spec, int(rng.integers(2001, 2025)), intensify))\n", " return out\n", "\n", "\n", "def arrays(events):\n", " return {\n", " \"X\": np.stack([e[\"x\"] for e in events]).astype(\"float32\"),\n", " \"S\": np.stack([e[\"static\"] for e in events]).astype(\"float32\"),\n", " \"Y\": np.stack([e[\"y\"] for e in events]).astype(\"float32\"),\n", " \"P\": np.stack([e[\"p\"] for e in events]).astype(\"float32\"),\n", " \"region\": np.array([e[\"region\"] for e in events]),\n", " \"year\": np.array([e[\"year\"] for e in events]),\n", " }\n", "\n", "data = arrays(make_dataset())\n", "print(data[\"X\"].shape, data[\"S\"].shape, data[\"Y\"].shape)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "34219a4b", "metadata": {}, "outputs": [], "source": [ "peaks = data[\"Y\"].max(axis=(1, 2))\n", "fig, ax = plt.subplots(figsize=(8.2, 3.5))\n", "for i, spec in enumerate(REGIONS):\n", " vals = peaks[data[\"region\"] == spec.code]\n", " parts = ax.violinplot(vals, positions=[i], widths=0.65, showmeans=True)\n", " for body in parts[\"bodies\"]:\n", " body.set_facecolor(spec.color); body.set_alpha(0.35)\n", " ax.scatter(np.full(vals.size, i), vals, s=8, color=spec.color, alpha=0.28)\n", "ax.axhline(FLOOD_THRESHOLD, color=\"#8b1e1e\", ls=\"--\", lw=1.2)\n", "ax.set_xticks(range(3), [r.code for r in REGIONS])\n", "ax.set_ylabel(\"peak depth (m)\")\n", "ax.set_title(\"Regional peak-depth distribution\")\n", "fig.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "769c53e1", "metadata": {}, "source": [ "---\n", "\n", "## 3. Reusable Experiment Helpers\n", "\n", "These helpers keep the workflow compact. The model returns a dictionary,\n", "so the training loop explicitly uses the `depth` output and combines MSE\n", "with hydrological residual, mass-bias, and smoothness penalties.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "bef93739", "metadata": {}, "outputs": [], "source": [ "def make_config(lambda_physics=0.25, lambda_mass=0.05, dropout=0.08):\n", " return PADRNetConfig(\n", " input_dim=INPUT_DIM, static_dim=STATIC_DIM, hidden_dim=48,\n", " num_heads=4, num_layers=2, forecast_horizon=HORIZON,\n", " dropout=dropout, lambda_physics=lambda_physics,\n", " lambda_mass=lambda_mass, lambda_smooth=0.01,\n", " flood_threshold=FLOOD_THRESHOLD, reservoir_tau=24.0,\n", " )\n", "\n", "\n", "def ds(a, idx, batch=48, shuffle=False):\n", " out = tf.data.Dataset.from_tensor_slices((a[\"X\"][idx], a[\"S\"][idx], a[\"Y\"][idx], a[\"P\"][idx]))\n", " if shuffle:\n", " out = out.shuffle(len(idx), seed=123, reshuffle_each_iteration=True)\n", " return out.batch(batch).prefetch(tf.data.AUTOTUNE)\n", "\n", "\n", "def optimizer():\n", " try:\n", " return tf.keras.optimizers.Adam(3e-3)\n", " except (ImportError, ModuleNotFoundError):\n", " from tensorflow.python.keras.optimizer_v2.adam import Adam\n", " return Adam(3e-3)\n", "\n", "\n", "def loss_terms(y, pred, rain, cfg):\n", " h = tf.squeeze(pred, -1); y = tf.squeeze(y, -1); p = tf.squeeze(rain, -1)\n", " mse = tf.reduce_mean(tf.square(y - h))\n", " dh = h[:, 1:] - h[:, :-1]\n", " residual = dh - (GAIN_PROXY * p[:, 1:] - h[:, :-1] / cfg.reservoir_tau)\n", " physics = tf.reduce_mean(tf.square(residual))\n", " mass = tf.reduce_mean(tf.abs(tf.reduce_sum(h, 1) - tf.reduce_sum(y, 1)) / (tf.reduce_sum(y, 1) + 1e-4))\n", " smooth = tf.reduce_mean(tf.square(dh))\n", " total = mse + cfg.lambda_physics * physics + cfg.lambda_mass * mass + cfg.lambda_smooth * smooth\n", " return total, {\"mse\": mse, \"physics\": physics, \"mass\": mass, \"smooth\": smooth}\n", "\n", "\n", "def train_model(a, train_idx, val_idx, cfg, epochs=12, verbose=False):\n", " model = PADRNet(cfg, backend=\"tensorflow\")\n", " _ = model(tf.zeros((1, LOOKBACK, INPUT_DIM)), tf.zeros((1, STATIC_DIM)))\n", " opt = optimizer()\n", " train_ds, val_ds = ds(a, train_idx, shuffle=True), ds(a, val_idx)\n", " hist = {\"train\": [], \"val\": [], \"mse\": [], \"physics\": [], \"mass\": []}\n", "\n", " @tf.function\n", " def train_step(xb, sb, yb, pb):\n", " with tf.GradientTape() as tape:\n", " out = model(xb, sb, training=True)\n", " loss, parts = loss_terms(yb, out[\"depth\"], pb, cfg)\n", " opt.apply_gradients(zip(tape.gradient(loss, model.trainable_variables), model.trainable_variables))\n", " return loss, parts\n", "\n", " @tf.function\n", " def eval_step(xb, sb, yb, pb):\n", " out = model(xb, sb, training=False)\n", " return loss_terms(yb, out[\"depth\"], pb, cfg)\n", "\n", " def mean_epoch(data_ds, train=False):\n", " losses, parts = [], {\"mse\": [], \"physics\": [], \"mass\": []}\n", " for xb, sb, yb, pb in data_ds:\n", " loss, detail = train_step(xb, sb, yb, pb) if train else eval_step(xb, sb, yb, pb)\n", " losses.append(float(loss.numpy()))\n", " for k in parts: parts[k].append(float(detail[k].numpy()))\n", " return float(np.mean(losses)), {k: float(np.mean(v)) for k, v in parts.items()}\n", "\n", " for epoch in range(1, epochs + 1):\n", " tr, _ = mean_epoch(train_ds, train=True)\n", " va, detail = mean_epoch(val_ds)\n", " hist[\"train\"].append(tr); hist[\"val\"].append(va)\n", " for k in [\"mse\", \"physics\", \"mass\"]: hist[k].append(detail[k])\n", " if verbose and (epoch == 1 or epoch == epochs or epoch % 5 == 0):\n", " print(f\"epoch {epoch:02d} train={tr:.4f} val={va:.4f}\")\n", " return model, hist\n", "\n", "\n", "def predict(model, a, idx, training=False):\n", " out = model(tf.convert_to_tensor(a[\"X\"][idx]), tf.convert_to_tensor(a[\"S\"][idx]), training=training)\n", " return out[\"depth\"].numpy(), out[\"exceedance_probability\"].numpy(), out[\"features\"].numpy()\n", "\n", "\n", "def score(y, pred, threshold=FLOOD_THRESHOLD):\n", " yt, yp = y.reshape(-1), pred.reshape(-1)\n", " return {\n", " \"NSE\": nash_sutcliffe_efficiency(yt, yp),\n", " \"CSI\": critical_success_index(yt, yp, threshold=threshold),\n", " \"TSS\": true_skill_statistic(yt, yp, threshold=threshold),\n", " \"DeltaM\": delta_mass(yt, yp),\n", " \"RMSE\": float(np.sqrt(np.mean((yt - yp) ** 2))),\n", " }\n" ] }, { "cell_type": "markdown", "id": "fd495887", "metadata": {}, "source": [ "---\n", "\n", "## 4. Baseline Temporal Split\n", "\n", "This baseline is the reference for the harder transfer and ablation\n", "experiments.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ff54d76d", "metadata": {}, "outputs": [], "source": [ "years = data[\"year\"]\n", "train_idx = np.where(years <= 2018)[0]\n", "val_idx = np.where((years >= 2019) & (years <= 2020))[0]\n", "test_idx = np.where(years >= 2021)[0]\n", "\n", "base_cfg = make_config()\n", "base_model, base_hist = train_model(data, train_idx, val_idx, base_cfg, epochs=16, verbose=True)\n", "base_pred, base_prob, base_feat = predict(base_model, data, test_idx)\n", "base_score = score(data[\"Y\"][test_idx], base_pred)\n", "base_score\n" ] }, { "cell_type": "code", "execution_count": null, "id": "82e99e62", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 2, figsize=(10.5, 3.5))\n", "e = np.arange(1, len(base_hist[\"train\"]) + 1)\n", "axes[0].plot(e, base_hist[\"train\"], label=\"train\", lw=2)\n", "axes[0].plot(e, base_hist[\"val\"], label=\"validation\", lw=2)\n", "axes[0].set_title(\"Baseline training\")\n", "axes[0].set_xlabel(\"epoch\"); axes[0].set_ylabel(\"loss\"); axes[0].legend()\n", "axes[1].plot(e, base_hist[\"mse\"], label=\"MSE\", lw=2)\n", "axes[1].plot(e, base_hist[\"physics\"], label=\"physics\", lw=2)\n", "axes[1].plot(e, base_hist[\"mass\"], label=\"mass\", lw=2)\n", "axes[1].set_title(\"Validation components\"); axes[1].legend()\n", "fig.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "62acc8de", "metadata": {}, "source": [ "---\n", "\n", "## 5. Leave-One-Region-Out Transfer\n", "\n", "Train on two regions and test on the held-out target region. This is a\n", "stronger diagnostic than a random split because it tests regional\n", "transfer.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "30aaa72c", "metadata": {}, "outputs": [], "source": [ "transfer = []\n", "for target in [r.code for r in REGIONS]:\n", " source = data[\"region\"] != target\n", " target_mask = data[\"region\"] == target\n", " tr = np.where(source & (years <= 2019))[0]\n", " va = np.where(source & (years == 2020))[0]\n", " te = np.where(target_mask & (years >= 2021))[0]\n", " m, _ = train_model(data, tr, va, make_config(), epochs=10)\n", " p, _, _ = predict(m, data, te)\n", " row = score(data[\"Y\"][te], p); row.update({\"target\": target, \"n\": len(te)})\n", " transfer.append(row)\n", "\n", "for r in transfer:\n", " print(f\"hold out {r['target']}: NSE={r['NSE']:.3f} CSI={r['CSI']:.3f} TSS={r['TSS']:.3f} DeltaM={r['DeltaM']:.1f}%\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "00680fb5", "metadata": {}, "outputs": [], "source": [ "labels = [r[\"target\"] for r in transfer]\n", "x = np.arange(len(labels))\n", "fig, axes = plt.subplots(1, 2, figsize=(10.5, 3.5))\n", "axes[0].bar(x - 0.18, [r[\"NSE\"] for r in transfer], 0.36, label=\"NSE\")\n", "axes[0].bar(x + 0.18, [r[\"CSI\"] for r in transfer], 0.36, label=\"CSI\")\n", "axes[0].set_xticks(x, labels); axes[0].set_ylim(-0.2, 1.05)\n", "axes[0].set_title(\"Leave-one-region-out skill\"); axes[0].legend()\n", "axes[1].bar(labels, [r[\"DeltaM\"] for r in transfer], color=\"#7c6bb0\")\n", "axes[1].axhline(0, color=\"#222\", lw=1); axes[1].set_ylabel(\"DeltaM (%)\")\n", "axes[1].set_title(\"Transfer mass bias\")\n", "fig.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "82abe0be", "metadata": {}, "source": [ "---\n", "\n", "## 6. Physics-Loss Ablation\n", "\n", "A reviewer may ask whether the physics term helps. We compare the\n", "physics-aware model against the same architecture with physics and mass\n", "terms set to zero.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "542d0f82", "metadata": {}, "outputs": [], "source": [ "ablation = []\n", "for name, cfg in [(\"PADR-Net\", make_config(0.25, 0.05)), (\"No physics\", make_config(0.0, 0.0))]:\n", " m, _ = train_model(data, train_idx, val_idx, cfg, epochs=12)\n", " p, _, _ = predict(m, data, test_idx)\n", " row = score(data[\"Y\"][test_idx], p)\n", " resid = mass_balance_residual(data[\"P\"][test_idx, :, 0], p[:, :, 0], tau=cfg.reservoir_tau, gain=GAIN_PROXY)\n", " row.update({\"model\": name, \"ResidualRMSE\": float(np.sqrt(np.mean(resid**2)))})\n", " ablation.append(row)\n", "\n", "for r in ablation:\n", " print(f\"{r['model']:10s} NSE={r['NSE']:.3f} CSI={r['CSI']:.3f} DeltaM={r['DeltaM']:.1f}% residual={r['ResidualRMSE']:.4f}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f36432e7", "metadata": {}, "outputs": [], "source": [ "labels = [r[\"model\"] for r in ablation]\n", "fig, axes = plt.subplots(1, 3, figsize=(12, 3.4))\n", "for ax, key, title in zip(axes, [\"NSE\", \"DeltaM\", \"ResidualRMSE\"], [\"Depth skill\", \"Mass bias\", \"Physics residual\"]):\n", " ax.bar(labels, [r[key] for r in ablation], color=[\"#2c7fb8\", \"#999999\"])\n", " ax.set_title(title); ax.set_ylabel(key)\n", " if key == \"DeltaM\": ax.axhline(0, color=\"#222\", lw=1)\n", "fig.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "13e32a14", "metadata": {}, "source": [ "---\n", "\n", "## 7. Threshold Calibration\n", "\n", "Warning performance depends on the flood threshold. We scan thresholds\n", "and compare CSI/TSS.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "5d6dd7ba", "metadata": {}, "outputs": [], "source": [ "thresholds = np.linspace(0.03, 0.18, 32)\n", "csi = [critical_success_index(data[\"Y\"][test_idx], base_pred, threshold=t) for t in thresholds]\n", "tss = [true_skill_statistic(data[\"Y\"][test_idx], base_pred, threshold=t) for t in thresholds]\n", "print(\"best CSI threshold:\", round(float(thresholds[int(np.argmax(csi))]), 3))\n", "print(\"best TSS threshold:\", round(float(thresholds[int(np.argmax(tss))]), 3))\n", "fig, ax = plt.subplots(figsize=(7.5, 3.7))\n", "ax.plot(thresholds, csi, lw=2, label=\"CSI\")\n", "ax.plot(thresholds, tss, lw=2, label=\"TSS\")\n", "ax.axvline(FLOOD_THRESHOLD, color=\"#8b1e1e\", ls=\"--\", lw=1.2, label=\"configured\")\n", "ax.set_xlabel(\"flood threshold (m)\"); ax.set_ylabel(\"skill\")\n", "ax.set_title(\"Threshold calibration curve\"); ax.legend()\n", "fig.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "47346c34", "metadata": {}, "source": [ "---\n", "\n", "## 8. Monte-Carlo Dropout Uncertainty\n", "\n", "Calling PADR-Net with `training=True` at inference activates dropout and\n", "creates a lightweight epistemic uncertainty diagnostic.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "9386f938", "metadata": {}, "outputs": [], "source": [ "def mc_predict(model, a, idx, n=30):\n", " samples = [predict(model, a, idx, training=True)[0] for _ in range(n)]\n", " samples = np.stack(samples)\n", " return samples.mean(0), np.quantile(samples, [0.1, 0.9], axis=0)\n", "\n", "sample_idx = test_idx[:40]\n", "mc_mean, mc_q = mc_predict(base_model, data, sample_idx, n=30)\n", "local = int(np.argmax(data[\"Y\"][sample_idx].max(axis=(1, 2))))\n", "t = np.arange(1, HORIZON + 1)\n", "ref = data[\"Y\"][sample_idx[local], :, 0]\n", "fig, ax = plt.subplots(figsize=(8.5, 3.8))\n", "ax.fill_between(t, mc_q[0, local, :, 0], mc_q[1, local, :, 0], color=\"#6baed6\", alpha=0.28, label=\"80% MC interval\")\n", "ax.plot(t, mc_mean[local, :, 0], color=\"#1f66b1\", lw=2.2, label=\"PADR-Net mean\")\n", "ax.plot(t, ref, color=\"#222\", lw=2.0, label=\"reference\")\n", "ax.axhline(FLOOD_THRESHOLD, color=\"#8b1e1e\", ls=\"--\", lw=1.2)\n", "ax.set_xlabel(\"lead time\"); ax.set_ylabel(\"depth (m)\")\n", "ax.set_title(\"MC-dropout uncertainty\"); ax.legend()\n", "fig.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "6b113059", "metadata": {}, "source": [ "---\n", "\n", "## 9. Rainfall-Intensification Stress Test\n", "\n", "We perturb rainfall intensity and check whether predicted peak depth and\n", "exceedance probability respond monotonically.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "e274aff9", "metadata": {}, "outputs": [], "source": [ "scenario_events, factors = [], []\n", "for factor in [1.0, 1.25, 1.50]:\n", " ev = make_dataset(n=100, seed=int(100 * factor), intensify=factor)\n", " scenario_events.extend(ev); factors.extend([factor] * len(ev))\n", "scenario = arrays(scenario_events)\n", "factors = np.array(factors)\n", "spred, sprob, _ = predict(base_model, scenario, np.arange(len(factors)))\n", "rows = []\n", "for factor in [1.0, 1.25, 1.50]:\n", " mask = factors == factor\n", " rows.append({\n", " \"factor\": factor,\n", " \"reference\": float(scenario[\"Y\"][mask].max(axis=(1, 2)).mean()),\n", " \"pred\": float(spred[mask].max(axis=(1, 2)).mean()),\n", " \"exceed\": float((sprob[mask] > 0.5).mean()),\n", " })\n", "for r in rows:\n", " print(f\"x{r['factor']:.2f}: ref={r['reference']:.3f} pred={r['pred']:.3f} P(exceed)={r['exceed']:.3f}\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7ca416ac", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 2, figsize=(10.5, 3.6))\n", "f = [r[\"factor\"] for r in rows]\n", "axes[0].plot(f, [r[\"reference\"] for r in rows], \"o-\", lw=2, label=\"reference\")\n", "axes[0].plot(f, [r[\"pred\"] for r in rows], \"o-\", lw=2, label=\"PADR-Net\")\n", "axes[0].set_xlabel(\"rainfall multiplier\"); axes[0].set_ylabel(\"mean peak depth (m)\")\n", "axes[0].set_title(\"Stress-test peak response\"); axes[0].legend()\n", "axes[1].plot(f, [r[\"exceed\"] for r in rows], \"o-\", lw=2, color=\"#8b1e1e\")\n", "axes[1].set_xlabel(\"rainfall multiplier\"); axes[1].set_ylabel(\"fraction prob. > 0.5\")\n", "axes[1].set_title(\"Exceedance response\")\n", "fig.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "id": "8c1c69a4", "metadata": {}, "source": [ "---\n", "\n", "## Exercises\n", "\n", "### Exercise 1: Harder transfer\n", "\n", "Hold out one region and one later period simultaneously. Which target\n", "region has the largest performance drop relative to the baseline?\n", "\n", "### Exercise 2: Warning threshold operations\n", "\n", "Use the threshold curve to choose a threshold maximizing `TSS`, then\n", "recompute CSI and mass bias at that threshold.\n", "\n", "### Exercise 3: Uncertainty screening\n", "\n", "Flag events whose MC-dropout interval width is above the 90th\n", "percentile. Are those events concentrated in one region?\n", "\n", "### Exercise 4: Real scenario forcing\n", "\n", "Replace the synthetic storm multipliers with real rainfall ensemble or\n", "climate perturbation members. The stress-test code remains unchanged if\n", "`X`, `S`, and `Y` keep the PADR-Net shapes.\n" ] }, { "cell_type": "markdown", "id": "5ba3ff31", "metadata": {}, "source": [ "---\n", "\n", "## Summary\n", "\n", "This advanced workflow turns PADR-Net into a compact experiment suite:\n", "\n", "1. baseline temporal evaluation;\n", "2. leave-one-region-out transfer;\n", "3. physics-loss ablation;\n", "4. threshold calibration;\n", "5. MC-dropout uncertainty;\n", "6. rainfall-intensification stress testing.\n", "\n", "These diagnostics make PADR-Net results easier to defend in papers and\n", "operational reports.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }