Skip to content

Commit 6af60fc

Browse files
committed
Update GP config, remove assignment
1 parent 28506b0 commit 6af60fc

File tree

1 file changed

+23
-12
lines changed
  • autoemulate/experimental/emulators/gaussian_process

1 file changed

+23
-12
lines changed

autoemulate/experimental/emulators/gaussian_process/exact.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from gpytorch.kernels import MultitaskKernel, ScaleKernel
99
from gpytorch.likelihoods import MultitaskGaussianLikelihood
1010
from gpytorch.means import MultitaskMean
11-
from torch import nn, optim
11+
from torch import optim
1212
from torch.optim.lr_scheduler import LRScheduler
1313

1414
from autoemulate.experimental.callbacks.early_stopping import (
@@ -63,7 +63,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
6363
mean_module_fn: MeanModuleFn = constant_mean,
6464
covar_module_fn: CovarModuleFn = rbf,
6565
epochs: int = 50,
66-
activation: type[nn.Module] = nn.ReLU,
6766
lr: float = 1e-1,
6867
early_stopping: EarlyStopping | None = None,
6968
device: DeviceLike | None = None,
@@ -87,8 +86,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
8786
Function to create the covariance module.
8887
epochs : int, default=50
8988
Number of training epochs.
90-
activation : type[nn.Module], default=nn.ReLU
91-
Activation function to use in the model.
9289
lr : float, default=2e-1
9390
Learning rate for the optimizer.
9491
device : DeviceLike | None, default=None
@@ -130,7 +127,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
130127
self.covar_module = covar_module
131128
self.epochs = epochs
132129
self.lr = lr
133-
self.activation = activation
134130
self.optimizer = self.optimizer_cls(self.parameters(), lr=self.lr) # type: ignore[call-arg] since all optimizers include lr
135131
# Extract scheduler-specific kwargs if present
136132
scheduler_kwargs = kwargs.pop("scheduler_kwargs", {})
@@ -206,8 +202,27 @@ def _predict(self, x: TensorLike, with_grad: bool) -> GaussianProcessLike:
206202
x = x.to(self.device)
207203
return self(x)
208204

205+
@classmethod
206+
def scheduler_config(cls) -> dict:
207+
"""
208+
Returns a random configuration for the learning rate scheduler.
209+
This should be added to the `get_tune_config()` method of subclasses
210+
to allow tuning of the scheduler parameters.
211+
"""
212+
all_params = [
213+
{"scheduler_cls": None, "scheduler_kwargs": None},
214+
{
215+
"scheduler_cls": [LRScheduler],
216+
"scheduler_kwargs": [
217+
{"policy": "ReduceLROnPlateau", "patience": 5, "factor": 0.5}
218+
],
219+
},
220+
]
221+
return np.random.choice(all_params)
222+
209223
@staticmethod
210224
def get_tune_config():
225+
scheduler_params = GaussianProcessExact.scheduler_config()
211226
return {
212227
"mean_module_fn": [
213228
constant_mean,
@@ -226,12 +241,10 @@ def get_tune_config():
226241
rbf_times_linear,
227242
],
228243
"epochs": [50, 100, 200],
229-
"activation": [
230-
nn.ReLU,
231-
nn.GELU,
232-
],
233-
"lr": list(np.logspace(-3, -1)),
244+
"lr": list(np.logspace(-3, 0, 100)),
234245
"likelihood_cls": [MultitaskGaussianLikelihood],
246+
"scheduler_cls": scheduler_params["scheduler_cls"],
247+
"scheduler_kwargs": scheduler_params["scheduler_kwargs"],
235248
}
236249

237250

@@ -255,7 +268,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
255268
mean_module_fn: MeanModuleFn = constant_mean,
256269
covar_module_fn: CovarModuleFn = rbf,
257270
epochs: int = 50,
258-
activation: type[nn.Module] = nn.ReLU,
259271
lr: float = 2e-1,
260272
early_stopping: EarlyStopping | None = None,
261273
seed: int | None = None,
@@ -332,7 +344,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
332344
self.covar_module = covar_module
333345
self.epochs = epochs
334346
self.lr = lr
335-
self.activation = activation
336347
self.optimizer = self.optimizer_cls(self.parameters(), lr=self.lr) # type: ignore[call-arg] since all optimizers include lr
337348
# Extract scheduler-specific kwargs if present
338349
scheduler_kwargs = kwargs.pop("scheduler_kwargs", {})

0 commit comments

Comments
 (0)