Skip to content

Commit b1840bc

Browse files
authored
Add script for benchmarking simulators with different parameters (#621)
Add scripts and plotting notebook for benchmarking (#454)
1 parent 1f3dc2e commit b1840bc

File tree

8 files changed

+316
-11
lines changed

8 files changed

+316
-11
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ repos:
44
hooks:
55
- id: black
66
language_version: python3
7-
exclude: "^autoemulate/experimental/|^tests/experimental/"
7+
exclude: "^autoemulate/experimental/|^tests/experimental/|^benchmarks/"
88
- repo: https://github.com/asottile/reorder-python-imports
99
rev: v3.12.0
1010
hooks:
1111
- id: reorder-python-imports
12-
exclude: "^autoemulate/experimental/|^tests/experimental/"
12+
exclude: "^autoemulate/experimental/|^tests/experimental/|^benchmarks/"
1313
- repo: https://github.com/astral-sh/ruff-pre-commit
1414
# Ruff version.
1515
rev: v0.11.4
@@ -18,13 +18,13 @@ repos:
1818
- id: ruff
1919
types_or: [ python, pyi ]
2020
args: [ --fix ]
21-
files: ^autoemulate/experimental/|^tests/experimental/
21+
files: ^autoemulate/experimental/|^tests/experimental/|^benchmarks/
2222
# Run the formatter.
2323
- id: ruff-format
2424
types_or: [ python, pyi ]
25-
files: ^autoemulate/experimental/|^tests/experimental/
25+
files: ^autoemulate/experimental/|^tests/experimental/|^benchmarks/
2626
- repo: https://github.com/RobertCraigie/pyright-python
2727
rev: v1.1.398
2828
hooks:
2929
- id: pyright
30-
files: ^autoemulate/experimental/|^tests/experimental/
30+
files: ^autoemulate/experimental/|^tests/experimental/|^benchmarks/
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .epidemic import Epidemic
2+
from .flow_problem import FlowProblem
23
from .projectile import Projectile, ProjectileMultioutput
34

4-
ALL_SIMULATORS = [Epidemic, Projectile, ProjectileMultioutput]
5+
ALL_SIMULATORS = [Epidemic, FlowProblem, Projectile, ProjectileMultioutput]
56

6-
__all__ = ["Epidemic", "Projectile", "ProjectileMultioutput"]
7+
__all__ = ["Epidemic", "FlowProblem", "Projectile", "ProjectileMultioutput"]
78

89
SIMULATOR_REGISTRY = dict(zip(__all__, ALL_SIMULATORS, strict=False))

autoemulate/experimental/simulations/flow_problem.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ class FlowProblem(Simulator):
1919

2020
def __init__(
2121
self,
22-
parameters_range: dict[str, tuple[float, float]],
23-
output_names: list[str],
22+
parameters_range: dict[str, tuple[float, float]] | None = None,
23+
output_names: list[str] | None = None,
2424
log_level: str = "progress_bar",
2525
ncycles: int = 10,
2626
ncomp: int = 10,
@@ -47,6 +47,29 @@ def __init__(
4747
ncomp: int
4848
Number of compartments in the tube.
4949
"""
50+
if parameters_range is None:
51+
parameters_range = {
52+
# Cardiac cycle period (s)
53+
"T": (0.5, 2.0),
54+
# Pulse duration (s)
55+
"td": (0.1, 0.5),
56+
# Amplitude (e.g., pressure or flow rate)
57+
"amp": (100.0, 1000.0),
58+
# Time step (s)
59+
"dt": (0.0001, 0.01),
60+
# Compliance (unit varies based on context)
61+
"C": (20.0, 60.0),
62+
# Resistance (unit varies based on context)
63+
"R": (0.01, 0.1),
64+
# Inductance (unit varies based on context)
65+
"L": (0.001, 0.005),
66+
# Outflow resistance (unit varies based on context)
67+
"R_o": (0.01, 0.05),
68+
# Initial pressure (unit varies based on context)
69+
"p_o": (5.0, 15.0),
70+
}
71+
if output_names is None:
72+
output_names = ["pressure"]
5073
super().__init__(parameters_range, output_names, log_level)
5174
self.ncycles = ncycles
5275
self.ncomp = ncomp

benchmarks/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Bechmarks
2+
3+
- [benchmark.py](./benchmark.py): a script with CLI for running batches of simulations with AutoEmulate for different numbers of tuningiterations
4+
- [run_benchmark.sh](./run_benchmark.sh): runs batches of simulations enabling some parallelisation
5+
- [plot_benchmark.ipynb](./plot_benchmark.ipynb): notebook for plotting results
6+
7+
## Quickstart
8+
- Install [pueue](https://github.com/Nukesor/pueue): is included in [run_benchmark.sh](./run_benchmark.sh) and simplifies running multiple python scripts
9+
- Run:
10+
```bash
11+
./run_benchmark.sh
12+
```
13+
14+

benchmarks/benchmark.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import itertools
2+
from typing import cast
3+
4+
import click
5+
import numpy as np
6+
import pandas as pd
7+
import torch
8+
from autoemulate.experimental.compare import AutoEmulate
9+
from autoemulate.experimental.emulators import ALL_EMULATORS
10+
from autoemulate.experimental.emulators.base import Emulator
11+
from autoemulate.experimental.simulations import SIMULATOR_REGISTRY
12+
from autoemulate.experimental.simulations.base import Simulator
13+
from tqdm import tqdm
14+
15+
16+
def run_benchmark(
17+
x: torch.Tensor, y: torch.Tensor, n_iter: int, n_splits: int, log_level: str
18+
) -> pd.DataFrame:
19+
ae = AutoEmulate(
20+
x,
21+
y,
22+
models=cast(list[type[Emulator] | str], ALL_EMULATORS),
23+
n_iter=n_iter,
24+
n_splits=n_splits,
25+
log_level=log_level,
26+
)
27+
return ae.summarise()
28+
29+
30+
@click.command()
31+
@click.option(
32+
"--simulators",
33+
type=str,
34+
multiple=True,
35+
default=["ProjectileMultioutput"],
36+
help="Number of samples to generate",
37+
)
38+
@click.option(
39+
"--n_samples_list",
40+
type=int,
41+
multiple=True,
42+
default=[20, 50, 100, 200, 500],
43+
help="Number of samples to generate",
44+
)
45+
@click.option(
46+
"--n_iter_list",
47+
type=int,
48+
multiple=True,
49+
default=[10, 50, 100],
50+
help="Number of iterations to run",
51+
)
52+
@click.option(
53+
"--n_splits_list",
54+
type=int,
55+
multiple=True,
56+
default=[2, 5],
57+
help="Number of splits for cross-validation",
58+
)
59+
@click.option(
60+
"--seed",
61+
type=int,
62+
default=42,
63+
help="Seed for the permutations over params",
64+
)
65+
@click.option(
66+
"--output_file",
67+
type=str,
68+
default="benchmark_results.csv",
69+
help="File name for output",
70+
)
71+
@click.option("--log_level", default="progress_bar", help="Logging level")
72+
def main( # noqa: PLR0913
73+
simulators, n_samples_list, n_iter_list, n_splits_list, seed, output_file, log_level
74+
):
75+
print(f"Running benchmark with simulators: {simulators}")
76+
print(f"Number of samples: {n_samples_list}")
77+
print(f"Number of iterations: {n_iter_list}")
78+
print(f"Number of splits: {n_splits_list}")
79+
print(f"Seed: {seed}")
80+
print(f"Output file: {output_file}")
81+
print(f"Log level: {log_level}")
82+
print("-" * 50)
83+
84+
dfs = []
85+
for simulator_str in simulators:
86+
# Generate samples
87+
simulator: Simulator = SIMULATOR_REGISTRY[simulator_str]()
88+
max_samples = max(n_samples_list)
89+
x_all = simulator.sample_inputs(max_samples, random_seed=seed).to(torch.float32)
90+
y_all = simulator.forward_batch(x_all).to(torch.float32)
91+
92+
params = list(itertools.product(n_samples_list, n_iter_list, n_splits_list))
93+
np.random.seed(seed)
94+
params = np.random.permutation(params)
95+
for n_samples, n_iter, n_splits in tqdm(params):
96+
print(
97+
f"Running benchmark for {simulator_str} with {n_samples} samples, "
98+
f"{n_iter} iterations, and {n_splits} splits"
99+
)
100+
try:
101+
x = x_all[:n_samples]
102+
y = y_all[:n_samples]
103+
df = run_benchmark(x, y, n_iter, n_splits, log_level)
104+
df["simulator"] = simulator_str
105+
df["n_samples"] = n_samples
106+
df["n_iter"] = n_iter
107+
df["n_splits"] = n_splits
108+
dfs.append(df)
109+
final_df = pd.concat(dfs, ignore_index=True)
110+
final_df.sort_values("r2_test", ascending=False).to_csv(
111+
output_file, index=False
112+
)
113+
except Exception as e:
114+
print(f"Error raised while testing :\n{e}")
115+
116+
117+
if __name__ == "__main__":
118+
main()

benchmarks/plot_benchmark.ipynb

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "702bb87d",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import matplotlib.pyplot as plt\n",
11+
"import numpy as np\n",
12+
"import pandas as pd\n",
13+
"\n",
14+
"df = pd.read_csv(\"https://github.com/user-attachments/files/21469860/benchmark_results.csv\")"
15+
]
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": null,
20+
"id": "c0dd4297",
21+
"metadata": {},
22+
"outputs": [],
23+
"source": [
24+
"N_BOOTSTRAPS=100\n",
25+
"\n",
26+
"def generate_plots(df, metric=\"r2_train\", exclude = [\"SupportVectorMachine\", \"LightGBM\"], fontsize=\"small\"):\n",
27+
" simulator_list = sorted(df[\"simulator\"].unique().tolist())\n",
28+
" n_iter_list = sorted(df[\"n_iter\"].unique().tolist())\n",
29+
" n_splits_list = sorted(df[\"n_splits\"].unique().tolist())\n",
30+
" color = {name:f\"C{idx}\" for idx, name in enumerate(sorted(df[\"model_name\"].unique().tolist()))}\n",
31+
" for plot_idx, simulator in enumerate(simulator_list):\n",
32+
" fig, axs = plt.subplots(len(n_splits_list), len(n_iter_list), figsize=(12, 6), squeeze=False)\n",
33+
" handles = []\n",
34+
" labels = []\n",
35+
" for row_idx, n_splits in enumerate(n_splits_list):\n",
36+
" for col_idx, n_iter in enumerate(n_iter_list):\n",
37+
" subset = df[df[\"simulator\"].eq(simulator) & df[\"n_splits\"].eq(n_splits) & df[\"n_iter\"].eq(n_iter)]\n",
38+
" ax = axs[row_idx][col_idx]\n",
39+
" for idx, ((name,), group) in enumerate(subset.groupby([\"model_name\"], sort=True)): \n",
40+
" if name in exclude:\n",
41+
" continue\n",
42+
" group_sorted = group.sort_values(\"n_samples\")\n",
43+
" line = ax.plot(group_sorted[\"n_samples\"], group_sorted[metric], label=name, c=color[name])\n",
44+
"\n",
45+
" if row_idx == 0 and col_idx == 0:\n",
46+
" handles.append(line[0])\n",
47+
" labels.append(name)\n",
48+
" \n",
49+
" mean = group_sorted[metric]\n",
50+
" ste = group_sorted[f\"{metric}_std\"] / np.sqrt(N_BOOTSTRAPS)\n",
51+
" ax.fill_between(group_sorted[\"n_samples\"], mean - ste, mean + ste, alpha=0.2, lw=0, color=color[name])\n",
52+
" ax.set_ylim(-0.1, 1.05)\n",
53+
" # ax.set_xlim(df[\"n_samples\"].min(), df[\"n_samples\"].max())\n",
54+
" ax.set_xlim(10, df[\"n_samples\"].max())\n",
55+
" ax.axhline(0., lw=0.5, ls=\"--\", c=\"grey\", alpha=0.5, zorder=-1)\n",
56+
" \n",
57+
" ax.set_xscale(\"log\")\n",
58+
" # ax.set_yscale(\"log\")\n",
59+
" if col_idx == 0:\n",
60+
" ax.set_ylabel(metric, size=fontsize)\n",
61+
" if row_idx == len(n_splits_list)-1:\n",
62+
" ax.set_xlabel(\"n_samples\", size=fontsize)\n",
63+
" ax.tick_params(labelsize=fontsize)\n",
64+
" ax.set_title(f\"{simulator} (n_iter={n_iter}, n_splits={n_splits})\", size=fontsize)\n",
65+
" ax.grid(True, which='both', linestyle=':', linewidth=0.5, alpha=0.7)\n",
66+
" fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.98), ncol=df[\"model_name\"].nunique()-len(exclude), fontsize=fontsize)\n",
67+
" \n",
68+
" # Adjust layout to make room for legend\n",
69+
" plt.tight_layout()\n",
70+
" plt.subplots_adjust(top=0.88)\n",
71+
" \n",
72+
" plt.show()\n"
73+
]
74+
},
75+
{
76+
"cell_type": "code",
77+
"execution_count": null,
78+
"id": "be99a004",
79+
"metadata": {},
80+
"outputs": [],
81+
"source": [
82+
"# All models\n",
83+
"generate_plots(df, metric=\"r2_test\", exclude=[])\n"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": null,
89+
"id": "ffa939d7",
90+
"metadata": {},
91+
"outputs": [],
92+
"source": [
93+
"# GPs, ensembles and MLPs only\n",
94+
"generate_plots(df, metric=\"r2_test\", exclude=[\"RandomForest\", \"LightGBM\", \"SupportVectorMachine\", \"RadialBasisFunctions\"])"
95+
]
96+
},
97+
{
98+
"cell_type": "code",
99+
"execution_count": null,
100+
"id": "a313ab5c",
101+
"metadata": {},
102+
"outputs": [],
103+
"source": []
104+
}
105+
],
106+
"metadata": {
107+
"kernelspec": {
108+
"display_name": ".venv",
109+
"language": "python",
110+
"name": "python3"
111+
},
112+
"language_info": {
113+
"codemirror_mode": {
114+
"name": "ipython",
115+
"version": 3
116+
},
117+
"file_extension": ".py",
118+
"mimetype": "text/x-python",
119+
"name": "python",
120+
"nbconvert_exporter": "python",
121+
"pygments_lexer": "ipython3",
122+
"version": "3.12.11"
123+
}
124+
},
125+
"nbformat": 4,
126+
"nbformat_minor": 5
127+
}

benchmarks/run_benchmark.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/bin/bash
2+
set -e
3+
source .venv/bin/activate
4+
5+
# Run the benchmark script with the specified parameters
6+
date_time=$(date +"%Y-%m-%d_%H%M%S")
7+
outpath="./benchmarks/data/${date_time}/"
8+
mkdir -p "$outpath"
9+
for simulator in Epidemic FlowProblem Projectile ProjectileMultioutput; do
10+
for n_iter_pair in "10 100" "150 50" "200 20"; do
11+
for n_splits in 5 2; do
12+
n_iter_array=($n_iter_pair)
13+
n_iter1=${n_iter_array[0]}
14+
n_iter2=${n_iter_array[1]}
15+
echo "Running benchmark for simulator: $simulator, n_splits: $n_splits, n_iter: $n_iter1 $n_iter2"
16+
pueue add "python benchmarks/benchmark.py --simulators \"$simulator\" --n_splits_list \"$n_splits\" --n_iter_list \"$n_iter1\" --n_iter_list \"$n_iter2\" --log_level info --output_file \"${outpath}/benchmark_results_${simulator}_n_splits_${n_splits}_n_iter_${n_iter1}_${n_iter2}.csv\""
17+
done
18+
done
19+
done
20+
21+
# Combine outputs with:
22+
# xsv cat rows benchmarks/data/${date_time}/benchmark_*.csv > benchmark_results.csv

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,12 +64,12 @@ source = [".", "/tmp"]
6464
[tool.pyright]
6565
venvPath = "."
6666
venv = ".venv"
67-
include = ["autoemulate/experimental/*", "tests/experimental/*"]
67+
include = ["autoemulate/experimental/*", "tests/experimental/*", "benchmarks/*"]
6868

6969
[tool.ruff]
7070
src = ["autoemulate/"]
7171
line-length = 88
72-
include = ["autoemulate/experimental/**/*.py", "tests/experimental/**/*.py"]
72+
include = ["autoemulate/experimental/**/*.py", "tests/experimental/**/*.py", "benchmarks/**/*.py"]
7373
target-version = "py310"
7474

7575
[tool.ruff.format]

0 commit comments

Comments
 (0)