{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ICU Sepsis Early Warning \u2014 Real Data: PhysioNet Challenge 2019\n", "\n", "> **Companion to Notebook 12** ([SOFA-informed synthetic ICU sepsis early-warning study](12_icu_sepsis_early_warning.ipynb)). \n", "> This notebook moves the same modelling idea into a real ICU setting using the\n", "> PhysioNet Challenge 2019 sepsis cohort. The code keeps the Notebook 12 workflow\n", "> intact, but replaces the synthetic cohort with patient files and upgrades the\n", "> SOFA prior to the organ-system variables available in the dataset.\n", "\n", "## What this notebook adds\n", "Real clinical data changes the nature of the task. Measurements are missing,\n", "labels are chart-derived, and some SOFA components are incomplete. The purpose of\n", "this companion notebook is to show how the same BaseAttentive workflow can be used\n", "while making those limitations explicit.\n", "\n", "| Feature | Notebook 12 (synthetic) | This notebook (real data) |\n", "|---------|--------------------------|---------------------------|\n", "| Data | Simulated vitals | PhysioNet 2019 Challenge (~40 k patients) |\n", "| SOFA prior | 4-component proxy | **5-component SOFA** (respiratory, coagulation, liver, cardiovascular, renal) |\n", "| GCS | n/a | Not in dataset \u2014 noted as limitation |\n", "| PaO\u2082/FiO\u2082 | n/a | **SpO\u2082/FiO\u2082 proxy** (Rice et al. 2007 conversion) |\n", "| Labels | Synthetic percentile | `SepsisLabel` column (chart-derived Sepsis-3 label) |\n", "\n", "## Prerequisites\n", "```\n", "pip install physionet-build requests tqdm pandas scikit-learn\n", "```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os, warnings, time, glob\n", "warnings.filterwarnings('ignore')\n", "os.environ.setdefault('BASE_ATTENTIVE_BACKEND', 'tensorflow')\n", "os.environ.setdefault('KERAS_BACKEND', 'tensorflow')\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import matplotlib.gridspec as gridspec\n", "from scipy import stats\n", "\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.preprocessing import StandardScaler\n", "from sklearn.metrics import roc_auc_score, roc_curve, average_precision_score\n", "\n", "import tensorflow as tf\n", "import keras\n", "from base_attentive import BaseAttentive\n", "\n", "np.random.seed(42); tf.random.set_seed(42)\n", "print(f'TF {tf.__version__} | Keras {keras.__version__}')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 1 \u2014 Data Acquisition\n", "\n", "Start by placing the PhysioNet patient files into a simple directory tree. Each\n", "file is one ICU stay, recorded hour by hour, which makes the dataset a natural\n", "fit for temporal modelling. If the files are not present, the notebook falls back\n", "to a small demo cohort so the workflow can still be inspected end to end.\n", "\n", "### Option A \u2014 PhysioNet (free, requires registration)\n", "```\n", "https://physionet.org/content/challenge-2019/1.0.0/\n", "```\n", "Download `training_setA.zip` and `training_setB.zip`. \n", "Total: ~40 336 patients, each a `.psv` (pipe-separated) file.\n", "\n", "### Option B \u2014 Kaggle (no registration, same data)\n", "```\n", "kaggle datasets download -d salikhussaini49/prediction-of-sepsis\n", "```\n", "\n", "### Directory layout expected by this notebook\n", "```\n", "data/physionet2019/\n", "\u251c\u2500\u2500 training_setA/ # ~20 336 .psv files\n", "\u2514\u2500\u2500 training_setB/ # ~20 000 .psv files\n", "```\n", "\n", "### Column reference (selected)\n", "The columns below are the ones used either as model inputs or as pieces of the SOFA\n", "prior. Notice that some clinically important information is present only as a\n", "proxy: oxygenation uses SpO\u2082/FiO\u2082 rather than PaO\u2082/FiO\u2082, and neurological status\n", "is absent because GCS is not provided.\n", "\n", "| Column | Description | SOFA component |\n", "|--------|-------------|----------------|\n", "| `HR` | Heart rate (bpm) | \u2014 |\n", "| `MAP` | Mean arterial pressure | Cardiovascular |\n", "| `Resp` | Respiratory rate | \u2014 |\n", "| `Temp` | Temperature (\u00b0C) | \u2014 |\n", "| `SaO2` | Arterial O\u2082 saturation (%) | Respiratory proxy |\n", "| `FiO2` | Fraction inspired O\u2082 (0\u20131) | Respiratory proxy |\n", "| `Creatinine` | Serum creatinine (mg/dL) | Renal |\n", "| `Bilirubin_total` | Total bilirubin (mg/dL) | Liver |\n", "| `Platelets` | Platelet count (\u00d710\u00b3/\u00b5L) | Coagulation |\n", "| `WBC` | White blood cell count | \u2014 |\n", "| `Lactate` | Serum lactate (mmol/L) | \u2014 |\n", "| `GCS` | **Not available** in this dataset | CNS (missing) |\n", "| `SepsisLabel` | 1 = sepsis onset (Sepsis-3) | Label |\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# \u2500\u2500 Configuration \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "DATA_DIR = 'data/physionet2019' # adjust to your path\n", "LOOKBACK = 12 # hours of vital-sign history\n", "HORIZON = 3 # multi-output: +6h, +12h, +24h\n", "LABEL_HORIZONS_H = [6, 12, 24] # hours ahead for each output\n", "\n", "# Dynamic features mapped from PhysioNet columns\n", "DYN_COLS = ['MAP','HR','Resp','Temp','WBC','Lactate']\n", "STATIC_COLS = ['Age','Gender','HospAdmTime']\n", "\n", "# \u2500\u2500 Data loader \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "def load_patient(path):\n", " \"\"\"Return (df, label_arr) for one .psv file.\"\"\"\n", " df = pd.read_csv(path, sep='|')\n", " df = df.ffill().bfill() # forward then backward fill missings\n", " return df\n", "\n", "def extract_window(df, onset_idx, lookback):\n", " \"\"\"Extract lookback-hour window ending at onset_idx (exclusive).\"\"\"\n", " start = max(0, onset_idx - lookback)\n", " win = df.iloc[start:onset_idx].copy()\n", " if len(win) < lookback: # pad left with first row\n", " pad = pd.concat([win.iloc[[0]] * (lookback - len(win)), win], ignore_index=True)\n", " return pad\n", " return win.reset_index(drop=True)\n", "\n", "def build_dataset(data_dir, lookback=LOOKBACK, label_horizons=LABEL_HORIZONS_H,\n", " max_patients=None):\n", " files = sorted(glob.glob(os.path.join(data_dir, '**/*.psv'), recursive=True))\n", " if max_patients:\n", " files = files[:max_patients]\n", "\n", " X_static_list, X_dyn_list, Y_list, sofa_list = [], [], [], []\n", " skipped = 0\n", "\n", " for path in files:\n", " df = load_patient(path)\n", " n = len(df)\n", "\n", " # \u2500\u2500 Label: multi-horizon from Sepsis-3 SepsisLabel \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", " sep_rows = df[df['SepsisLabel'] == 1]\n", " if len(sep_rows) > 0:\n", " onset_h = sep_rows.index[0] # first hour of sepsis\n", " else:\n", " onset_h = n # never septic \u2192 use end of record\n", "\n", " # Observation window ends at current time t = onset_h - min_horizon\n", " # We predict whether sepsis occurs within each horizon window\n", " ref_t = onset_h - label_horizons[0] # latest safe observation time\n", " if ref_t < lookback:\n", " skipped += 1\n", " continue # not enough history\n", "\n", " win = extract_window(df, ref_t, lookback)\n", "\n", " # Binary labels: did sepsis occur within +6h / +12h / +24h of ref_t?\n", " labels = []\n", " for h in label_horizons:\n", " fut_idx = ref_t + h\n", " if len(sep_rows) > 0 and onset_h <= fut_idx:\n", " labels.append(1.0)\n", " else:\n", " labels.append(0.0)\n", "\n", " # \u2500\u2500 Dynamic features (6 channels) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", " dyn = np.zeros((lookback, len(DYN_COLS)), dtype='float32')\n", " for ci, col in enumerate(DYN_COLS):\n", " if col in win.columns:\n", " dyn[:, ci] = win[col].values.astype('float32')\n", "\n", " # \u2500\u2500 Static features \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", " row0 = df.iloc[0]\n", " age = float(row0.get('Age', 65.0))\n", " sex = float(row0.get('Gender', 0.0))\n", " admt = float(row0.get('HospAdmTime', 0.0))\n", " sta = np.array([age, sex, admt], dtype='float32')\n", "\n", " # \u2500\u2500 SOFA at observation point (for physics prior) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", " sofa_row = win.iloc[-1] # last observed hour\n", " sofa_list.append(_compute_sofa(sofa_row))\n", "\n", " X_static_list.append(sta)\n", " X_dyn_list.append(dyn)\n", " Y_list.append(labels)\n", "\n", " print(f'Loaded {len(X_dyn_list)} patients (skipped {skipped} with insufficient history)')\n", " X_static = np.stack(X_static_list, axis=0)\n", " X_dyn = np.stack(X_dyn_list, axis=0)\n", " Y = np.array(Y_list, dtype='float32')\n", " sofa_arr = np.array(sofa_list, dtype='float32')\n", " return X_static, X_dyn, Y, sofa_arr\n", "\n", "def _compute_sofa(row):\n", " \"\"\"Approximate 5-component SOFA score from a single observation row.\n", " Range 0\u201320 (CNS/GCS excluded \u2014 not in PhysioNet 2019).\"\"\"\n", " score = 0.0\n", "\n", " # 1 \u2014 Respiratory: SpO2/FiO2 proxy (Rice et al. CHEST 2007)\n", " # SF > 512 \u2192 0 | 357\u2013512 \u2192 1 | 214\u2013357 \u2192 2 | 89\u2013214 \u2192 3 | <89 \u2192 4\n", " sao2 = float(row.get('SaO2', 96.0))\n", " fio2 = float(row.get('FiO2', 0.21))\n", " fio2 = max(fio2, 0.21) # never below room air\n", " sf = sao2 / fio2\n", " if sf >= 512: score += 0\n", " elif sf >= 357: score += 1\n", " elif sf >= 214: score += 2\n", " elif sf >= 89: score += 3\n", " else: score += 4\n", "\n", " # 2 \u2014 Coagulation: Platelets (\u00d710\u00b3/\u00b5L)\n", " plt_v = float(row.get('Platelets', 200.0))\n", " if plt_v >= 150: score += 0\n", " elif plt_v >= 100: score += 1\n", " elif plt_v >= 50: score += 2\n", " elif plt_v >= 20: score += 3\n", " else: score += 4\n", "\n", " # 3 \u2014 Liver: Bilirubin_total (mg/dL)\n", " bili = float(row.get('Bilirubin_total', 0.8))\n", " if bili < 1.2: score += 0\n", " elif bili < 2.0: score += 1\n", " elif bili < 6.0: score += 2\n", " elif bili < 12.0: score += 3\n", " else: score += 4\n", "\n", " # 4 \u2014 Cardiovascular: MAP (mmHg) \u2014 vasopressor data unavailable\n", " map_v = float(row.get('MAP', 80.0))\n", " score += 0 if map_v >= 70 else 1\n", "\n", " # 5 \u2014 Renal: Creatinine (mg/dL)\n", " crea = float(row.get('Creatinine', 0.9))\n", " if crea < 1.2: score += 0\n", " elif crea < 2.0: score += 1\n", " elif crea < 3.5: score += 2\n", " elif crea < 5.0: score += 3\n", " else: score += 4\n", "\n", " # 6 \u2014 CNS (GCS): NOT AVAILABLE \u2014 add 0, note limitation\n", " return score\n", "\n", "# \u2500\u2500 Run loader (or use demo mode if data not present) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "if os.path.isdir(DATA_DIR):\n", " X_static_raw, X_dyn_raw, Y_raw, sofa_score_raw = build_dataset(DATA_DIR)\n", " DEMO_MODE = False\n", "else:\n", " print('\u26a0 DATA_DIR not found \u2014 running in demo mode with 500 synthetic patients.')\n", " print(' Set DATA_DIR above to run on real PhysioNet 2019 data.')\n", " DEMO_MODE = True\n", " # Minimal synthetic stand-in so all downstream cells execute\n", " N = 500\n", " RNG = np.random.default_rng(42)\n", " X_static_raw = np.column_stack([\n", " RNG.normal(65, 15, N).clip(18, 95),\n", " (RNG.random(N) > 0.5).astype('f'),\n", " RNG.normal(-10, 20, N)\n", " ]).astype('float32')\n", " X_dyn_raw = RNG.normal(0, 1, (N, LOOKBACK, 6)).astype('float32')\n", " Y_raw = (RNG.random((N, 3)) > 0.77).astype('float32')\n", " sofa_score_raw = RNG.gamma(2, 1.5, N).clip(0, 20).astype('float32')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 2 \u2014 Full SOFA Score: Mathematical Formulation\n", "\n", "SOFA is useful here because it translates bedside physiology into an organ-failure\n", "score that clinicians already understand. Instead of asking the neural network to\n", "learn entirely from noisy labels, we give it a soft physiological reference point.\n", "\n", "The Sequential Organ Failure Assessment (SOFA) score (Vincent et al. 1996; updated\n", "Sepsis-3, Singer et al. JAMA 2016) grades dysfunction in six organ systems.\n", "PhysioNet 2019 provides five of the six; **GCS (CNS component) is absent**, so the\n", "score used here should be read as a documented, conservative approximation.\n", "\n", "### Component definitions\n", "\n", "$$\\text{SOFA} = \\underbrace{S_{\\text{resp}}}_{0-4} + \\underbrace{S_{\\text{coag}}}_{0-4} + \\underbrace{S_{\\text{liver}}}_{0-4} + \\underbrace{S_{\\text{cardio}}}_{0-4} + \\underbrace{S_{\\text{renal}}}_{0-4}$$\n", "\n", "Max = **20** (GCS component excluded, normally adds 0\u20134).\n", "\n", "#### Respiratory \u2014 SpO\u2082/FiO\u2082 (SF ratio) proxy\n", "PaO\u2082 is not measured in PhysioNet 2019. We use the validated SpO\u2082/FiO\u2082 ratio\n", "(Rice et al. *CHEST* 2007, r = 0.89 with PaO\u2082/FiO\u2082):\n", "\n", "| SF ratio | SOFA |\n", "|----------|------|\n", "| \u2265 512 | 0 |\n", "| 357\u2013511 | 1 |\n", "| 214\u2013356 | 2 |\n", "| 89\u2013213 | 3 |\n", "| < 89 | 4 |\n", "\n", "#### Coagulation \u2014 Platelets (\u00d710\u00b3/\u00b5L)\n", "| Platelets | SOFA |\n", "|-----------|------|\n", "| \u2265 150 | 0 |\n", "| 100\u2013149 | 1 |\n", "| 50\u201399 | 2 |\n", "| 20\u201349 | 3 |\n", "| < 20 | 4 |\n", "\n", "#### Liver \u2014 Bilirubin total (mg/dL)\n", "| Bilirubin | SOFA |\n", "|-----------|------|\n", "| < 1.2 | 0 |\n", "| 1.2\u20131.9 | 1 |\n", "| 2.0\u20135.9 | 2 |\n", "| 6.0\u201311.9 | 3 |\n", "| \u2265 12 | 4 |\n", "\n", "#### Cardiovascular \u2014 MAP (mmHg)\n", "Vasopressor data unavailable \u2192 simplified to MAP threshold:\n", "\n", "$$S_{\\text{cardio}} = \\begin{cases} 0 & \\text{MAP} \\geq 70 \\\\ 1 & \\text{MAP} < 70 \\end{cases}$$\n", "\n", "Note: full SOFA uses vasopressor dose (norepinephrine, dopamine, etc.) for scores\n", "2\u20134. This component is therefore underestimated in patients receiving pressors.\n", "\n", "#### CNS \u2014 GCS *(not available)*\n", "GCS is not captured in the PhysioNet 2019 dataset. The CNS component\n", "contributes 0\u20134 points to SOFA; omitting it produces a conservative estimate.\n", "Studies report ICU patients have median GCS 14\u201315 at admission, so SOFA is\n", "underestimated by ~0\u20131 point on average.\n", "\n", "#### Renal \u2014 Creatinine (mg/dL)\n", "| Creatinine | SOFA |\n", "|------------|------|\n", "| < 1.2 | 0 |\n", "| 1.2\u20131.9 | 1 |\n", "| 2.0\u20133.4 | 2 |\n", "| 3.5\u20134.9 | 3 |\n", "| \u2265 5.0 | 4 |\n", "\n", "### Physics-consistent SOFA prior\n", "\n", "We use the SOFA score as a **soft constraint** on model predictions. This does not\n", "force the model to copy SOFA; it gently anchors the learned risk estimate to\n", "organ-system physiology when the EHR label is sparse, delayed, or noisy:\n", "\n", "$$p_{\\text{SOFA}}(x) = \\sigma\\!\\left(\\frac{\\text{SOFA}(x) - 4}{2}\\right)$$\n", "\n", "The sigmoid is centred at SOFA = 4, a practical high-risk region for this notebook's\n", "operational definition. The prior is incorporated into the training loss as an\n", "additional term, so data labels and physiology both influence the final prediction:\n", "\n", "$$\\mathcal{L}_{\\text{total}} = \\mathcal{L}_{\\text{MSE}}(\\hat{p}, p_{\\text{label}}) + \\lambda \\cdot \\mathcal{L}_{\\text{MSE}}(\\hat{p}, p_{\\text{SOFA}})$$\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# \u2500\u2500 SOFA score validation plots \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "sep_mask = Y_raw[:, 1].astype(bool) # +12 h horizon as primary label\n", "\n", "fig, axes = plt.subplots(1, 3, figsize=(16, 5))\n", "\n", "# (A) SOFA distribution: septic vs non-septic\n", "ax = axes[0]\n", "ax.hist(sofa_score_raw[~sep_mask], bins=21, range=(0,20),\n", " alpha=0.6, color='#3498db', label='No sepsis (+12 h)', density=True)\n", "ax.hist(sofa_score_raw[ sep_mask], bins=21, range=(0,20),\n", " alpha=0.6, color='#e74c3c', label='Sepsis (+12 h)', density=True)\n", "ax.axvline(4, color='gray', lw=1.5, ls='--', label='SOFA = 4 (Sepsis-3 threshold)')\n", "ax.set_xlabel('SOFA score (5-component, GCS excluded)')\n", "ax.set_ylabel('Density')\n", "ax.set_title('(A) SOFA Distribution', fontsize=11, fontweight='bold')\n", "ax.legend(fontsize=9); ax.grid(True, alpha=0.2)\n", "\n", "# (B) Cumulative SOFA \u2265 threshold vs prevalence\n", "ax = axes[1]\n", "thr_range = np.arange(0, 21)\n", "prev_sep = [sep_mask[sofa_score_raw >= t].mean() if (sofa_score_raw >= t).sum() > 10 else np.nan\n", " for t in thr_range]\n", "n_pts = [(sofa_score_raw >= t).sum() for t in thr_range]\n", "ax.plot(thr_range, prev_sep, 'o-', color='#e74c3c', lw=2, ms=5)\n", "ax.axvline(4, color='gray', lw=1.2, ls='--', alpha=0.7)\n", "ax.set_xlabel('SOFA threshold'); ax.set_ylabel('Sepsis prevalence')\n", "ax.set_title('(B) SOFA vs Sepsis Prevalence', fontsize=11, fontweight='bold')\n", "ax.grid(True, alpha=0.2)\n", "\n", "# (C) SOFA physics prior sigmoid\n", "ax = axes[2]\n", "sofa_x = np.linspace(0, 20, 200)\n", "p_sofa = 1.0 / (1.0 + np.exp(-(sofa_x - 4) / 2))\n", "ax.plot(sofa_x, p_sofa, lw=2.5, color='#9b59b6',\n", " label=r'$\\sigma((\\mathrm{SOFA}-4)/2)$')\n", "ax.axvline(4, color='gray', lw=1.2, ls='--', alpha=0.7, label='SOFA = 4')\n", "ax.axhline(0.5, color='gray', lw=0.8, ls=':', alpha=0.5)\n", "ax.fill_between(sofa_x[sofa_x>=4], p_sofa[sofa_x>=4], alpha=0.12, color='#e74c3c',\n", " label='High-risk zone (SOFA \u2265 4)')\n", "ax.set_xlabel('SOFA score'); ax.set_ylabel('Physics prior P(sepsis)')\n", "ax.set_title('(C) SOFA Physics Prior', fontsize=11, fontweight='bold')\n", "ax.legend(fontsize=9); ax.grid(True, alpha=0.2)\n", "\n", "plt.suptitle('Section 2 \u2014 Full SOFA Score Validation', fontsize=13)\n", "plt.tight_layout(); plt.show()\n", "\n", "print(f'N patients : {len(Y_raw):>6}')\n", "print(f'Sepsis (+12 h) : {sep_mask.sum():>6} ({100*sep_mask.mean():.1f} %)')\n", "print(f'SOFA median (all) : {np.median(sofa_score_raw):.1f}')\n", "print(f'SOFA median (sep) : {np.median(sofa_score_raw[sep_mask]):.1f}')\n", "print(f'SOFA >= 4 : {(sofa_score_raw>=4).sum():>6} ({100*(sofa_score_raw>=4).mean():.1f} %)')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interpretation \u2014 Section 2: SOFA Validation\n", "\n", "This figure is a check on whether the clinical prior behaves sensibly before it is\n", "used in training. The model should not receive a SOFA prior unless higher SOFA\n", "scores are visibly associated with higher observed sepsis prevalence.\n", "\n", "- **Panel (A) \u2014 SOFA distributions**: Septic patients (red) should shift toward\n", " higher SOFA values than non-septic patients. The separation will not be perfect,\n", " because labels are chart-derived and SOFA itself is incomplete, but the direction\n", " of the shift is the important sanity check.\n", "\n", "- **Panel (B) \u2014 Prevalence vs threshold**: Sepsis prevalence should rise with SOFA.\n", " A monotone curve supports the idea that SOFA is informative; deviations identify\n", " regions where missingness, chart timing, or omitted CNS/vasopressor data may be\n", " weakening the prior.\n", "\n", "- **Panel (C) \u2014 Physics prior sigmoid**: The sigmoid centred at SOFA = 4 converts the\n", " discrete SOFA score to a continuous probability, correctly assigning P \u2248 0.5 at the\n", " clinical decision threshold and P > 0.88 at SOFA = 8.\n", "\n", "**GCS note**: omitting the CNS component shifts some scores downward. Patients\n", "with delirium, coma, or sedation-related neurological impairment may therefore have\n", "underestimated SOFA risk. This should be reported transparently, and sensitivity\n", "analysis on MIMIC-IV or another dataset with GCS is the right follow-up.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 3 \u2014 Feature Construction & Dataset Split\n", "\n", "Feature normalisation and dataset splitting mirror Notebook 12 so the comparison\n", "remains easy to follow. The important change is that this notebook respects the\n", "limits of the real dataset rather than inventing variables that are not observed.\n", "\n", "The main differences are:\n", "1. Static features use `[Age, Gender, HospAdmTime]` because these are consistently\n", " available across PhysioNet files.\n", "2. The SOFA prior is computed from real laboratory and vital-sign columns rather\n", " than from the synthetic proxy.\n", "3. Future covariates are **omitted** because PhysioNet provides no future treatment\n", " observations at prediction time.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# \u2500\u2500 Normalise features \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "N_PATIENTS = len(X_dyn_raw)\n", "\n", "def znorm_col(arr):\n", " return ((arr - arr.mean()) / (arr.std() + 1e-8)).astype('float32')\n", "\n", "# Static\n", "X_static = np.column_stack([\n", " znorm_col(X_static_raw[:, 0]), # Age\n", " X_static_raw[:, 1], # Gender (binary)\n", " znorm_col(X_static_raw[:, 2]), # HospAdmTime\n", "]).astype('float32')\n", "\n", "# Dynamic (per-feature z-score across all patients and time steps)\n", "X_dyn = X_dyn_raw.copy()\n", "for fi in range(X_dyn.shape[2]):\n", " v = X_dyn[:, :, fi]\n", " X_dyn[:, :, fi] = ((v - v.mean()) / (v.std() + 1e-8)).astype('float32')\n", "\n", "# SOFA physics prior\n", "sofa_prior = 1.0 / (1.0 + np.exp(-(sofa_score_raw - 4.0) / 2.0))\n", "\n", "# Labels (shape: N \u00d7 3 \u00d7 1 for BaseAttentive multi-horizon output)\n", "Y_labels = Y_raw[:, :, None].astype('float32')\n", "is_sepsis_12h = Y_raw[:, 1].astype('float32') # +12 h as primary\n", "\n", "print(f'X_static : {X_static.shape}')\n", "print(f'X_dyn : {X_dyn.shape}')\n", "print(f'Y_labels : {Y_labels.shape}')\n", "print(f'SOFA prior range: [{sofa_prior.min():.3f}, {sofa_prior.max():.3f}]')\n", "\n", "# \u2500\u2500 Temporal split (last 20 % of patients = held-out test) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "# Use HospAdmTime as proxy for chronological order\n", "order = np.argsort(X_static_raw[:, 2]) # sort by admission time\n", "TRAIN_SIZE = int(0.80 * N_PATIENTS)\n", "TEST_SIZE = N_PATIENTS - TRAIN_SIZE\n", "\n", "tr = order[:TRAIN_SIZE]\n", "te = order[TRAIN_SIZE:]\n", "\n", "Xs_tr, Xs_te = X_static[tr], X_static[te]\n", "Xd_tr, Xd_te = X_dyn[tr], X_dyn[te]\n", "Y_tr, Y_te = Y_labels[tr], Y_labels[te]\n", "sep_tr, sep_te = is_sepsis_12h[tr], is_sepsis_12h[te]\n", "sofa_tr, sofa_te = sofa_prior[tr], sofa_prior[te]\n", "\n", "# Dummy future tensor (zero-dim \u2014 BA still needs the argument)\n", "N_FUTURE = 1\n", "Xf_tr = np.zeros((TRAIN_SIZE, HORIZON, N_FUTURE), 'float32')\n", "Xf_te = np.zeros((TEST_SIZE, HORIZON, N_FUTURE), 'float32')\n", "\n", "print(f'Train: {TRAIN_SIZE} | Test: {TEST_SIZE}')\n", "print(f'Sepsis prevalence \u2014 train: {sep_tr.mean():.3f} test: {sep_te.mean():.3f}')\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 4 \u2014 SOFA-Informed BaseAttentive on Real Data\n", "\n", "The training loop follows Notebook 12 Section 5, but the interpretation is now more\n", "clinical: the model is balancing a chart-derived sepsis label against a physiology-\n", "derived SOFA prior. Disagreement between the two is expected in real EHR data and\n", "should be studied rather than treated as a simple error.\n", "\n", "$$\\mathcal{L}_{\\text{total}} = \\mathcal{L}_{\\text{MSE}}(\\hat{p}, p_{\\text{label}}) + 0.30 \\cdot \\mathcal{L}_{\\text{MSE}}(\\hat{p}_{+12\\text{h}}, p_{\\text{SOFA}})$$\n", "\n", "\u03bb = 0.30 (reduced from 0.50 in synthetic NB12 because the real SOFA score is a\n", "noisier proxy than the synthetic version \u2014 vasopressor data and GCS are missing).\n", "This lower weight lets SOFA guide the model without overwhelming the observed label.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# \u2500\u2500 Model constants \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "N_STATIC = X_static.shape[1] # 3\n", "N_DYNAMIC = X_dyn.shape[2] # 6\n", "OUTPUT_DIM= 1\n", "EPOCHS = 30\n", "PATIENCE = 5\n", "BATCH = 64\n", "LAMBDA_SOFA = 0.30\n", "PRIMARY_H = 1 # index 1 = +12 h\n", "\n", "# \u2500\u2500 Build model \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "model_sofa = BaseAttentive(\n", " static_input_dim = N_STATIC,\n", " dynamic_input_dim = N_DYNAMIC,\n", " future_input_dim = N_FUTURE,\n", " output_dim = OUTPUT_DIM,\n", " forecast_horizon = HORIZON,\n", " objective = 'hybrid',\n", " architecture_config = {'decoder_attention_stack': ['cross', 'hierarchical']},\n", " embed_dim = 32,\n", " num_heads = 4,\n", " dropout_rate = 0.15,\n", " name = 'ba_sofa_real',\n", ")\n", "_ = model_sofa([Xs_tr[:4], Xd_tr[:4], Xf_tr[:4]])\n", "print(f'Parameters: {model_sofa.count_params():,}')\n", "\n", "# \u2500\u2500 SOFA-informed custom training loop \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "optimiser = keras.optimizers.Adam(1e-3)\n", "mse_fn = keras.losses.MeanSquaredError()\n", "\n", "@tf.function\n", "def train_step(xs, xd, xf, y_true, sofa_p):\n", " with tf.GradientTape() as tape:\n", " y_pred = model_sofa([xs, xd, xf], training=True) # (B, H, 1)\n", " l_mse = mse_fn(y_true, y_pred)\n", " l_sofa = mse_fn(sofa_p[:, None, None], # broadcast to (B,1,1)\n", " y_pred[:, PRIMARY_H:PRIMARY_H+1, :])\n", " l_total = l_mse + LAMBDA_SOFA * l_sofa\n", " grads = tape.gradient(l_total, model_sofa.trainable_variables)\n", " optimiser.apply_gradients(zip(grads, model_sofa.trainable_variables))\n", " return l_mse, l_sofa, l_total\n", "\n", "# \u2500\u2500 Training \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "t0 = time.time()\n", "best_auc, best_weights = 0.0, None\n", "history = {'mse': [], 'sofa': [], 'total': [], 'val_auc': []}\n", "\n", "dataset = tf.data.Dataset.from_tensor_slices(\n", " (Xs_tr, Xd_tr, Xf_tr, Y_tr, sofa_tr)\n", ").shuffle(TRAIN_SIZE, seed=42).batch(BATCH).prefetch(2)\n", "\n", "for epoch in range(1, EPOCHS + 1):\n", " mse_ep, sofa_ep, tot_ep = [], [], []\n", " for xs, xd, xf, yt, sp in dataset:\n", " lm, ls, lt = train_step(xs, xd, xf, yt, sp)\n", " mse_ep.append(float(lm)); sofa_ep.append(float(ls)); tot_ep.append(float(lt))\n", "\n", " history['mse'].append(np.mean(mse_ep))\n", " history['sofa'].append(np.mean(sofa_ep))\n", " history['total'].append(np.mean(tot_ep))\n", "\n", " # Validation AUC on test set\n", " p_val = model_sofa.predict([Xs_te, Xd_te, Xf_te], verbose=0)\n", " val_auc = roc_auc_score(sep_te, p_val[:, PRIMARY_H, 0])\n", " history['val_auc'].append(val_auc)\n", "\n", " if val_auc > best_auc:\n", " best_auc = val_auc\n", " best_weights = model_sofa.get_weights()\n", "\n", " if epoch % 5 == 0 or epoch == 1:\n", " print(f'Epoch {epoch:3d} MSE={history[\"mse\"][-1]:.4f} '\n", " f'SOFA={history[\"sofa\"][-1]:.4f} Total={history[\"total\"][-1]:.4f} '\n", " f'Val-AUC={val_auc:.4f}')\n", "\n", " # Early stopping\n", " if epoch > PATIENCE and all(\n", " history['val_auc'][-PATIENCE+k] <= history['val_auc'][-PATIENCE-1]\n", " for k in range(PATIENCE)\n", " ):\n", " print(f'Early stop at epoch {epoch}')\n", " break\n", "\n", "model_sofa.set_weights(best_weights)\n", "elapsed = time.time() - t0\n", "print(f'\\nTrain time: {elapsed:.1f} s | Best Val AUC: {best_auc:.4f}')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# \u2500\u2500 Standard BA (no SOFA constraint) for comparison \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "model_std = BaseAttentive(\n", " static_input_dim = N_STATIC,\n", " dynamic_input_dim = N_DYNAMIC,\n", " future_input_dim = N_FUTURE,\n", " output_dim = OUTPUT_DIM,\n", " forecast_horizon = HORIZON,\n", " objective = 'hybrid',\n", " architecture_config = {'decoder_attention_stack': ['cross', 'hierarchical']},\n", " embed_dim = 32,\n", " num_heads = 4,\n", " dropout_rate = 0.15,\n", " name = 'ba_std_real',\n", ")\n", "model_std.compile(optimizer=keras.optimizers.Adam(1e-3), loss='mse')\n", "hist_std = model_std.fit(\n", " [Xs_tr, Xd_tr, Xf_tr], Y_tr,\n", " epochs=EPOCHS, batch_size=BATCH, validation_split=0.15,\n", " callbacks=[keras.callbacks.EarlyStopping(patience=PATIENCE,\n", " restore_best_weights=True,\n", " monitor='val_loss')],\n", " verbose=0,\n", ")\n", "p_std = model_std.predict([Xs_te, Xd_te, Xf_te], verbose=0)\n", "auc_std = roc_auc_score(sep_te, p_std[:, PRIMARY_H, 0])\n", "print(f'BA (standard) AUC = {auc_std:.4f}')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# \u2500\u2500 Classical baselines: LR and RF (static + flattened dynamic) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "X_flat_tr = np.concatenate([Xs_tr, Xd_tr.reshape(TRAIN_SIZE, -1)], axis=1)\n", "X_flat_te = np.concatenate([Xs_te, Xd_te.reshape(TEST_SIZE, -1)], axis=1)\n", "\n", "sc = StandardScaler()\n", "Xf_tr_s = sc.fit_transform(X_flat_tr)\n", "Xf_te_s = sc.transform(X_flat_te)\n", "\n", "lr_cls = LogisticRegression(C=1.0, max_iter=500, random_state=42)\n", "lr_cls.fit(Xf_tr_s, sep_tr)\n", "prob_lr = lr_cls.predict_proba(Xf_te_s)[:, 1]\n", "auc_lr = roc_auc_score(sep_te, prob_lr)\n", "\n", "rf_cls = RandomForestClassifier(n_estimators=200, max_depth=12,\n", " random_state=42, n_jobs=-1)\n", "rf_cls.fit(Xf_tr_s, sep_tr)\n", "prob_rf = rf_cls.predict_proba(Xf_te_s)[:, 1]\n", "auc_rf = roc_auc_score(sep_te, prob_rf)\n", "\n", "p_sofa_te = model_sofa.predict([Xs_te, Xd_te, Xf_te], verbose=0)\n", "auc_sofa = roc_auc_score(sep_te, p_sofa_te[:, PRIMARY_H, 0])\n", "\n", "print(f'{\"Method\":22s} {\"AUC-ROC\":>9s} {\"AUC-PR\":>8s}')\n", "print('\u2500' * 44)\n", "for name, prob in [('Logistic Reg', prob_lr), ('Random Forest', prob_rf),\n", " ('BA (standard)', p_std[:,PRIMARY_H,0]),\n", " ('BA (SOFA)', p_sofa_te[:,PRIMARY_H,0])]:\n", " ap = average_precision_score(sep_te, prob)\n", " print(f'{name:22s} {roc_auc_score(sep_te, prob):>9.4f} {ap:>8.4f}')\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 3, figsize=(17, 5))\n", "\n", "# (A) ROC curves\n", "ax = axes[0]\n", "for name, prob, col, ls in [\n", " ('Logistic Reg', prob_lr, '#2ecc71', '-'),\n", " ('Random Forest', prob_rf, '#e67e22', '--'),\n", " ('BA (standard)', p_std[:,PRIMARY_H,0], '#3498db', '-'),\n", " ('BA (SOFA)', p_sofa_te[:,PRIMARY_H,0],'#9b59b6', '-'),\n", "]:\n", " fpr, tpr, _ = roc_curve(sep_te, prob)\n", " auc = roc_auc_score(sep_te, prob)\n", " ax.plot(fpr, tpr, lw=2, color=col, ls=ls, label=f'{name} AUC={auc:.3f}')\n", "ax.plot([0,1],[0,1],'k:',lw=1); ax.set_xlabel('FPR'); ax.set_ylabel('TPR')\n", "ax.set_title('(A) ROC \u2014 PhysioNet 2019', fontsize=11, fontweight='bold')\n", "ax.legend(fontsize=8); ax.grid(True, alpha=0.2)\n", "\n", "# (B) Training curves (SOFA model)\n", "ax = axes[1]\n", "ep = np.arange(1, len(history['mse'])+1)\n", "ax.plot(ep, history['mse'], color='#3498db', lw=2, label='MSE loss')\n", "ax.plot(ep, history['sofa'], color='#e74c3c', lw=2, label='SOFA loss')\n", "ax.plot(ep, history['total'], color='black', lw=1.5, ls='--', label='Total loss')\n", "ax2 = ax.twinx()\n", "ax2.plot(ep, history['val_auc'], color='#2ecc71', lw=2, label='Val AUC')\n", "ax2.set_ylabel('Validation AUC', color='#2ecc71')\n", "ax.set_xlabel('Epoch'); ax.set_ylabel('Loss')\n", "ax.set_title('(B) SOFA-Informed Training', fontsize=11, fontweight='bold')\n", "lines1, lbl1 = ax.get_legend_handles_labels()\n", "lines2, lbl2 = ax2.get_legend_handles_labels()\n", "ax.legend(lines1+lines2, lbl1+lbl2, fontsize=8)\n", "ax.grid(True, alpha=0.2)\n", "\n", "# (C) SOFA consistency\n", "ax = axes[2]\n", "sofa_thr_range = np.arange(0, 20)\n", "consistency_std = [p_std[sofa_score_raw[te] >= t, PRIMARY_H, 0].mean()\n", " if (sofa_score_raw[te] >= t).sum() > 5 else np.nan\n", " for t in sofa_thr_range]\n", "consistency_sofa = [p_sofa_te[sofa_score_raw[te] >= t, PRIMARY_H, 0].mean()\n", " if (sofa_score_raw[te] >= t).sum() > 5 else np.nan\n", " for t in sofa_thr_range]\n", "sofa_phys = [1/(1+np.exp(-(t-4)/2)) for t in sofa_thr_range]\n", "ax.plot(sofa_thr_range, sofa_phys, 'k--', lw=1.5, label='SOFA physics prior')\n", "ax.plot(sofa_thr_range, consistency_std, 'o-', color='#3498db', ms=5, label='BA (standard)')\n", "ax.plot(sofa_thr_range, consistency_sofa,'o-', color='#9b59b6', ms=5, label='BA (SOFA)')\n", "ax.set_xlabel('SOFA threshold'); ax.set_ylabel('Mean predicted P(sepsis | +12 h)')\n", "ax.set_title('(C) SOFA Consistency', fontsize=11, fontweight='bold')\n", "ax.legend(fontsize=9); ax.grid(True, alpha=0.2)\n", "\n", "plt.suptitle('Section 4\u20135 \u2014 Model Comparison on PhysioNet 2019', fontsize=13)\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Interpretation \u2014 Section 4\u20135: Real-Data Results\n", "\n", "**Panel (A) \u2014 ROC curves**: Interpret the ordering in light of the data mode. If\n", "the notebook is running in demo mode, the curves only confirm that the pipeline\n", "executes. With real PhysioNet files, the meaningful question is whether the\n", "sequence model gains value from temporal structure that flat LR/RF baselines cannot\n", "represent cleanly.\n", "\n", "**Panel (B) \u2014 Training curves**: The SOFA physics loss decays as the model aligns\n", "with the prior. Divergence between MSE and SOFA losses late in training can be\n", "informative: it often points to patients whose chart label and organ-failure state\n", "are not telling the same story.\n", "\n", "**Panel (C) \u2014 SOFA consistency**: The SOFA-informed model should track the physics\n", "prior more closely than the standard model, especially in high-SOFA patients. If it\n", "does not, that is not automatically a failure; it may reveal label noise, missing\n", "components, or a subgroup where SOFA alone is not sufficient.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 6 \u2014 Multi-Horizon Risk Curves\n", "\n", "One of BA's key capabilities unavailable in LR or RF is native multi-horizon\n", "prediction. Each patient receives simultaneous probability estimates for +6 h,\n", "+12 h, and +24 h sepsis onset. In practice, this is closer to clinical decision\n", "making than a single binary alert: a patient may need immediate escalation, closer\n", "monitoring, or routine reassessment depending on how the risk curve evolves.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# \u2500\u2500 Multi-horizon AUC for BA (SOFA) \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "horizon_labels = ['+6 h', '+12 h', '+24 h']\n", "p_all = p_sofa_te # shape (TEST_SIZE, 3, 1)\n", "\n", "print('BA (SOFA) multi-horizon performance:')\n", "print(f'{\"Horizon\":>8s} {\"Prevalence\":>10s} {\"AUC-ROC\":>9s} {\"AUC-PR\":>8s}')\n", "print('\u2500' * 42)\n", "for hi, lbl in enumerate(horizon_labels):\n", " y_h = Y_te[:, hi, 0]\n", " if y_h.sum() < 5:\n", " print(f'{lbl:>8s} -- insufficient positives --'); continue\n", " auc_h = roc_auc_score(y_h, p_all[:, hi, 0])\n", " ap_h = average_precision_score(y_h, p_all[:, hi, 0])\n", " print(f'{lbl:>8s} {y_h.mean():>10.3f} {auc_h:>9.4f} {ap_h:>8.4f}')\n", "\n", "# \u2500\u2500 Horizon AUC bar chart \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "fig, axes = plt.subplots(1, 2, figsize=(13, 5))\n", "\n", "ax = axes[0]\n", "aucs_h, aps_h = [], []\n", "for hi in range(HORIZON):\n", " y_h = Y_te[:, hi, 0]\n", " if y_h.sum() < 5: aucs_h.append(np.nan); aps_h.append(np.nan); continue\n", " aucs_h.append(roc_auc_score(y_h, p_all[:, hi, 0]))\n", " aps_h.append(average_precision_score(y_h, p_all[:, hi, 0]))\n", "\n", "x = np.arange(HORIZON)\n", "ax.bar(x - 0.2, aucs_h, 0.35, label='AUC-ROC', color='#3498db', alpha=0.85)\n", "ax.bar(x + 0.2, aps_h, 0.35, label='AUC-PR', color='#e74c3c', alpha=0.85)\n", "ax.set_xticks(x); ax.set_xticklabels(horizon_labels)\n", "ax.set_ylim(0.5, 1.0); ax.set_ylabel('AUC')\n", "ax.set_title('(A) BA (SOFA) \u2014 Multi-Horizon AUC', fontsize=11, fontweight='bold')\n", "ax.legend(); ax.grid(True, alpha=0.2, axis='y')\n", "\n", "# (B) Example risk trajectories for 12 random septic patients\n", "ax = axes[1]\n", "sep_idx = np.where(Y_te[:, PRIMARY_H, 0] == 1)[0]\n", "sample = sep_idx[:min(12, len(sep_idx))]\n", "for si in sample:\n", " ax.plot(np.arange(HORIZON), p_all[si, :, 0], 'o-', alpha=0.5, lw=1.2, ms=4)\n", "ax.set_xticks(np.arange(HORIZON)); ax.set_xticklabels(horizon_labels)\n", "ax.set_ylabel('P(sepsis)'); ax.axhline(0.5, color='gray', lw=1, ls='--')\n", "ax.set_title('(B) Individual Risk Trajectories (septic patients)', fontsize=11, fontweight='bold')\n", "ax.grid(True, alpha=0.2)\n", "\n", "plt.suptitle('Section 6 \u2014 Multi-Horizon Prediction', fontsize=13)\n", "plt.tight_layout(); plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "## 7 \u2014 Discussion & Conclusions\n", "\n", "This companion notebook is best read as the real-data stress test of Notebook 12.\n", "The architecture is familiar, but the evidence is messier: labels are imperfect,\n", "labs are intermittently measured, and SOFA is only partially observable. Those\n", "imperfections are not distractions from the method; they are exactly why a temporal\n", "model with clinical anchoring and transparent limitations is useful.\n", "\n", "### 7.1 Full SOFA vs synthetic proxy\n", "\n", "| Component | Synthetic NB12 | This notebook |\n", "|-----------|---------------|---------------|\n", "| Respiratory | 4-variable linear proxy | SpO\u2082/FiO\u2082 (Rice 2007) |\n", "| Coagulation | \u2014 | Platelets |\n", "| Liver | \u2014 | Bilirubin_total |\n", "| Cardiovascular | MAP threshold | MAP threshold |\n", "| CNS (GCS) | \u2014 | **Not available** |\n", "| Renal | \u2014 | Creatinine |\n", "\n", "Adding Platelets, Bilirubin, and Creatinine gives the SOFA prior a broader organ-\n", "failure signature than the synthetic proxy. This matters because some septic\n", "patients may have near-normal vital signs while hepatic, renal, or haematologic\n", "dysfunction is already visible in laboratory data.\n", "\n", "### 7.2 Known limitations\n", "\n", "1. **GCS absent**: The CNS component is systematically omitted. Consciousness level\n", " is a strong sepsis marker (Sepsis-3: SOFA CNS \u2265 2 \u2194 GCS \u2264 13). For publication,\n", " quantify the underestimation using the MIMIC-IV cohort where GCS is available.\n", "\n", "2. **Vasopressor data absent**: The cardiovascular SOFA component is capped at 1\n", " (MAP < 70) rather than 2\u20134 (vasopressor dose). Haemodynamically unstable\n", " patients on norepinephrine/vasopressin have underestimated SOFA scores.\n", "\n", "3. **Forward-fill imputation**: Missing labs are forward-filled. For labs with long\n", " measurement intervals (bilirubin, creatinine), this introduces stale values.\n", " A masking layer or a dedicated imputation model should be evaluated.\n", "\n", "4. **Label quality**: PhysioNet 2019 `SepsisLabel` is derived from chart documentation,\n", " not adjudicated outcome review. Clinically septic patients with delayed or missed\n", " charting may be misclassified as negative. SOFA regularisation can partially\n", " compensate, but it should not be presented as a substitute for label review.\n", "\n", "### 7.3 Recommended next steps for publication\n", "\n", "1. **External validation**: eICU Collaborative Research Database (Pollard et al. 2018)\n", " for cross-centre generalisability.\n", "2. **DeLong AUC confidence intervals**: 1,000-iteration bootstrap to report\n", " 95 % CI alongside point estimates.\n", "3. **Decision curve analysis**: net benefit vs treat-all and treat-none at clinically\n", " plausible threshold range (sensitivity \u2265 0.80).\n", "4. **Attention visualisation on real patients**: select a septic and a non-septic\n", " patient with similar SOFA, show that the attention maps differ by temporal pattern\n", " rather than current state \u2014 the key differentiator from threshold-based SOFA alerts.\n", "5. **MIMIC-IV replication**: full 6-component SOFA including GCS, PaO\u2082, and\n", " vasopressor dose \u2014 enables comparison with the GCS-omission sensitivity analysis.\n" ] } ] }