Skip to content

Commit 1f3dc2e

Browse files
authored
Merge pull request #663 from alan-turing-institute/467-update-docs
Add docs lint config and update docstrings (#467, #555)
2 parents 5075324 + 5e2b8dc commit 1f3dc2e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1164
-401
lines changed

autoemulate/experimental/calibration/bayes.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
class BayesianCalibration(TorchDeviceMixin):
1515
"""
16+
Bayesian calibration using Markov Chain Monte Carlo (MCMC).
17+
1618
Bayesian calibration estimates the probability distribution over input parameters
1719
given observed data, providing uncertainty estimates.
1820
"""
@@ -120,9 +122,7 @@ def __init__( # noqa: PLR0913
120122
raise ValueError(msg)
121123

122124
def _get_kernel(self, sampler: str, **sampler_kwargs):
123-
"""
124-
Get the appropriate MCMC kernel based on sampler choice.
125-
"""
125+
"""Get the appropriate MCMC kernel based on sampler choice."""
126126
sampler = sampler.lower()
127127

128128
if sampler == "nuts":
@@ -158,7 +158,6 @@ def model(self, predict: bool = False):
158158
Whether to run the model with existing samples to generate posterior
159159
predictive distribution. Used with `pyro.infer.Predictive`.
160160
"""
161-
162161
# Pre-allocate tensor for all input parameters, shape [1, n_inputs]
163162
param_list = []
164163
# Each param is either sampled (if calibrated) or set to a constant value
@@ -234,7 +233,6 @@ def run_mcmc(
234233
MCMC
235234
The Pyro MCMC object. Methods include `summary()` and `get_samples()`.
236235
"""
237-
238236
# Check initial param values match number of chains
239237

240238
if initial_params is not None:

autoemulate/experimental/calibration/history_matching.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414

1515
class HistoryMatching(TorchDeviceMixin):
16-
"""
16+
r"""
17+
History Matching class for model calibration.
18+
1719
History matching is a model calibration method, which uses observed data to
1820
rule out ``implausible`` parameter values. The implausibility metric is:
1921
@@ -130,8 +132,9 @@ def get_nroy(
130132
self, implausibility: TensorLike, x: TensorLike | None = None
131133
) -> TensorLike:
132134
"""
133-
Get indices of NROY points from implausibility scores. If `x`
134-
is provided, returns parameter values at NROY indices.
135+
Get indices of NROY points from implausibility scores.
136+
137+
If `x` is provided, returns parameter values at NROY indices.
135138
136139
Parameters
137140
----------
@@ -155,8 +158,9 @@ def get_ro(
155158
self, implausibility: TensorLike, x: TensorLike | None = None
156159
) -> TensorLike:
157160
"""
158-
Get indices of RO points from implausibility scores. If `x`
159-
is provided, returns parameter values at RO indices.
161+
Get indices of RO points from implausibility scores.
162+
163+
If `x` is provided, returns parameter values at RO indices.
160164
161165
Parameters
162166
----------
@@ -255,6 +259,8 @@ def generate_param_bounds(
255259

256260
class HistoryMatchingWorkflow(HistoryMatching):
257261
"""
262+
History Matching Workflow class.
263+
258264
Run history matching workflow:
259265
- sample parameter values to test from the current NROY parameter space
260266
- use emulator to rule out implausible parameters and update NROY space
@@ -321,6 +327,8 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
321327

322328
def generate_samples(self, n: int) -> tuple[TensorLike, TensorLike]:
323329
"""
330+
Generate parameter samples and evaluate implausibility.
331+
324332
Draw `n` samples from the simulator min/max parameter bounds and
325333
evaluate implausability given emulator predictions.
326334

autoemulate/experimental/calibration/history_matching_dashboard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
class HistoryMatchingDashboard:
1111
"""
12+
History Matching Dashboard.
13+
1214
Interactive dashboard for exploring history matching with UI controls that adapt
1315
based on selected plot type.
1416
"""
@@ -787,7 +789,6 @@ def _plot_implausibility_radar(self, df: pd.DataFrame, impl_scores: NumpyLike):
787789

788790
def display(self):
789791
"""Display the dashboard."""
790-
791792
heading = widgets.HTML(value="<h2>History Matching Dashboard</h2>")
792793

793794
# Display the heading and instructions first

autoemulate/experimental/callbacks/early_stopping.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55

66

77
class EarlyStoppingException(Exception):
8-
"""
9-
Custom exception to signal early stopping during training.
10-
"""
8+
"""Custom exception to signal early stopping during training."""
119

1210
def __init__(
1311
self, message: str = "Training stopped early due to early stopping criteria."
@@ -17,6 +15,8 @@ def __init__(
1715

1816
class EarlyStopping:
1917
"""
18+
Early stopping callback for PyTorch models.
19+
2020
Stop training early if the training loss did not improve in `patience` number of
2121
epochs by at least `threshold` value. Can be used inside the training loop of any
2222
PyTorch model.
@@ -65,12 +65,14 @@ def __init__(
6565
self.load_best = load_best
6666

6767
def __getstate__(self):
68+
"""Return state without pickling the best model weights."""
6869
# Avoids having to save the module_ weights twice when pickling the model
6970
state = self.__dict__.copy()
7071
state["best_model_weights_"] = None
7172
return state
7273

7374
def on_train_begin(self):
75+
"""Initialize early stopping parameters at the start of training."""
7476
if self.threshold_mode not in ["rel", "abs"]:
7577
raise ValueError(f"Invalid threshold mode: '{self.threshold_mode}'")
7678
self.misses_ = 0
@@ -91,7 +93,6 @@ def on_epoch_end(self, model: nn.Module, curr_epoch: int, curr_score: float):
9193
curr_score: float
9294
The current training loss.
9395
"""
94-
9596
if not self._is_score_improved(curr_score):
9697
self.misses_ += 1
9798
else:

autoemulate/experimental/compare.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@
3131

3232

3333
class AutoEmulate(ConversionMixin, TorchDeviceMixin, Results):
34+
"""
35+
Automated emulator fitting.
36+
37+
The AutoEmulate class is the main class of the AutoEmulate package.
38+
It is used to set up and compare different emulator models on a given dataset.
39+
It can also be used to summarise and visualise results, and to save and load models.
40+
41+
"""
42+
3443
def __init__( # noqa: PLR0913
3544
self,
3645
x: InputLike,
@@ -48,10 +57,7 @@ def __init__( # noqa: PLR0913
4857
log_level: str = "progress_bar",
4958
):
5059
"""
51-
The AutoEmulate class is the main class of the AutoEmulate package.
52-
It is used to set up and compare different emulator models on a given dataset.
53-
It can also be used to summarise and visualise results,
54-
and to save and load models.
60+
Initialize the AutoEmulate class.
5561
5662
Parameters
5763
----------
@@ -143,13 +149,12 @@ def __init__( # noqa: PLR0913
143149

144150
@staticmethod
145151
def all_emulators() -> list[type[Emulator]]:
152+
"""Return a list of all available emulators."""
146153
return ALL_EMULATORS
147154

148155
@staticmethod
149156
def list_emulators() -> pd.DataFrame:
150-
"""
151-
Return a dataframe with the model_name and short_name
152-
of all available emulators.
157+
"""Return a dataframe with model names of all available emulators.
153158
154159
Returns
155160
-------
@@ -167,6 +172,7 @@ def list_emulators() -> pd.DataFrame:
167172
def get_models(
168173
self, models: list[type[Emulator] | str] | None = None
169174
) -> list[type[Emulator]]:
175+
"""Return a list of the model classes for comparisons."""
170176
if models is None:
171177
return self.all_emulators()
172178

@@ -186,6 +192,7 @@ def get_models(
186192
def get_transforms(
187193
self, transforms: list[AutoEmulateTransform | dict[str, object]]
188194
) -> list[AutoEmulateTransform]:
195+
"""Process and return a list of transforms."""
189196
processed_transforms = []
190197
for transform in transforms:
191198
if isinstance(transform, dict):
@@ -201,6 +208,7 @@ def get_transforms(
201208
def filter_models_if_multioutput(
202209
self, models: list[type[Emulator]], warn: bool
203210
) -> list[type[Emulator]]:
211+
"""Filter models to only include those that support multi-output data."""
204212
updated_models = []
205213
for model in models:
206214
if not model.is_multioutput():
@@ -223,6 +231,7 @@ def log_compare( # noqa: PLR0913
223231
r2_score,
224232
rmse_score,
225233
):
234+
"""Log the comparison results."""
226235
msg = (
227236
"Comparison results:\n"
228237
f"Best Model: {best_model_name}, "
@@ -236,8 +245,16 @@ def log_compare( # noqa: PLR0913
236245

237246
def compare(self):
238247
"""
239-
Tune hyperparameters of all emulators using the train/validation data
240-
and evaluate performance of all tuned emulators on the test data.
248+
Compare different models on the provided dataset.
249+
250+
The method will:
251+
- Loop over all combinations of x and y transforms and models.
252+
- Set up the tuner with the training/validation data.
253+
- Tune hyperparameters for each model.
254+
- Fit the best model with the tuned hyperparameters.
255+
- Evaluate the performance of the best model on the test data.
256+
- Log the results.
257+
- Save the best model and its configuration.
241258
"""
242259
tuner = Tuner(self.train_val, y=None, n_iter=self.n_iter, device=self.device)
243260
self.logger.info(
@@ -508,7 +525,7 @@ def save(
508525
path: str | Path | None = None,
509526
use_timestamp: bool = True,
510527
) -> Path:
511-
"""Saves model to disk.
528+
"""Save model to disk.
512529
513530
Parameters
514531
----------
@@ -549,7 +566,7 @@ def save(
549566
return self.model_serialiser._save_model(model, filename, path)
550567

551568
def load(self, path: str | Path) -> Emulator | Result:
552-
"""Loads a stored model or result from disk.
569+
"""Load a stored model or result from disk.
553570
554571
Parameters
555572
----------

autoemulate/experimental/data/utils.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,15 @@
1313

1414

1515
class ConversionMixin:
16-
"""
17-
Mixin class to convert input data to pytorch Datasets and DataLoaders.
18-
"""
16+
"""Mixin class to convert input data to PyTorch Datasets and DataLoaders."""
1917

2018
@classmethod
2119
def _convert_to_dataset(
2220
cls,
2321
x: InputLike,
2422
y: InputLike | None = None,
2523
) -> Dataset:
26-
"""
27-
Convert input data to pytorch Dataset.
28-
"""
24+
"""Convert input data to PyTorch Dataset."""
2925
# Convert input to Dataset if not already
3026
if isinstance(x, np.ndarray):
3127
x = torch.tensor(x, dtype=torch.float32)
@@ -58,9 +54,7 @@ def _convert_to_dataloader(
5854
batch_size: int = 16,
5955
shuffle: bool = True,
6056
) -> DataLoader:
61-
"""
62-
Convert input data to pytorch DataLoaders.
63-
"""
57+
"""Convert input data to PyTorch DataLoaders."""
6458
if isinstance(x, DataLoader) and y is None:
6559
dataloader = x
6660
elif isinstance(x, DataLoader) and y is not None:
@@ -79,9 +73,7 @@ def _convert_to_tensors(
7973
y: InputLike | None = None,
8074
dtype: torch.dtype = torch.float32,
8175
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
82-
"""
83-
Convert InputLike x, y to Tensor or tuple of Tensors.
84-
"""
76+
"""Convert InputLike x, y to Tensor or tuple of Tensors."""
8577
dataset = cls._convert_to_dataset(x, y)
8678

8779
# Handle Subset of TensorDataset
@@ -133,9 +125,7 @@ def _convert_to_numpy(
133125
x: InputLike,
134126
y: InputLike | None = None,
135127
) -> tuple[np.ndarray, np.ndarray | None]:
136-
"""
137-
Convert InputLike x, y to tuple of numpy arrays.
138-
"""
128+
"""Convert InputLike x, y to tuple of numpy arrays."""
139129
if isinstance(x, np.ndarray) and (y is None or isinstance(y, np.ndarray)):
140130
return x, y
141131

@@ -227,6 +217,7 @@ def set_random_seed(seed: int = 42, deterministic: bool = True):
227217
class ValidationMixin:
228218
"""
229219
Mixin class for validation methods.
220+
230221
This class provides static methods for checking the types and shapes of
231222
input and output data, as well as validating specific tensor shapes.
232223
"""
@@ -235,9 +226,9 @@ class ValidationMixin:
235226
def _check(x: TensorLike, y: TensorLike | None):
236227
"""
237228
Check the types and shape are correct for the input data.
229+
238230
Checks are equivalent to sklearn's check_array.
239231
"""
240-
241232
if not isinstance(x, TensorLike):
242233
raise ValueError(f"Expected x to be TensorLike, got {type(x)}")
243234

@@ -271,10 +262,7 @@ def _check(x: TensorLike, y: TensorLike | None):
271262

272263
@staticmethod
273264
def _check_output(output: OutputLike):
274-
"""
275-
Check the types and shape are correct
276-
for the output data.
277-
"""
265+
"""Check the types and shape are correct for the output data."""
278266
if not isinstance(output, OutputLike):
279267
raise ValueError(f"Expected OutputLike, got {type(output)}")
280268

@@ -424,6 +412,8 @@ def trace(Sigma: TensorLike, d: int) -> TensorLike:
424412
@staticmethod
425413
def logdet(Sigma: TensorLike, dim: int) -> TensorLike:
426414
"""
415+
Return the log-determinant of the covariance matrix.
416+
427417
Compute the log-determinant of the covariance matrix (D-optimal design
428418
criterion).
429419
@@ -455,6 +445,8 @@ def logdet(Sigma: TensorLike, dim: int) -> TensorLike:
455445
@staticmethod
456446
def max_eigval(Sigma: TensorLike) -> TensorLike:
457447
"""
448+
Return the maximum eigenvalue of the covariance matrix.
449+
458450
Compute the maximum eigenvalue of the covariance matrix (E-optimal design
459451
criterion).
460452

0 commit comments

Comments
 (0)