8
8
from gpytorch .kernels import MultitaskKernel , ScaleKernel
9
9
from gpytorch .likelihoods import MultitaskGaussianLikelihood
10
10
from gpytorch .means import MultitaskMean
11
- from torch import nn , optim
11
+ from torch import optim
12
12
from torch .optim .lr_scheduler import LRScheduler
13
13
14
14
from autoemulate .experimental .callbacks .early_stopping import (
@@ -63,7 +63,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
63
63
mean_module_fn : MeanModuleFn = constant_mean ,
64
64
covar_module_fn : CovarModuleFn = rbf ,
65
65
epochs : int = 50 ,
66
- activation : type [nn .Module ] = nn .ReLU ,
67
66
lr : float = 1e-1 ,
68
67
early_stopping : EarlyStopping | None = None ,
69
68
device : DeviceLike | None = None ,
@@ -87,8 +86,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
87
86
Function to create the covariance module.
88
87
epochs : int, default=50
89
88
Number of training epochs.
90
- activation : type[nn.Module], default=nn.ReLU
91
- Activation function to use in the model.
92
89
lr : float, default=2e-1
93
90
Learning rate for the optimizer.
94
91
device : DeviceLike | None, default=None
@@ -130,7 +127,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
130
127
self .covar_module = covar_module
131
128
self .epochs = epochs
132
129
self .lr = lr
133
- self .activation = activation
134
130
self .optimizer = self .optimizer_cls (self .parameters (), lr = self .lr ) # type: ignore[call-arg] since all optimizers include lr
135
131
# Extract scheduler-specific kwargs if present
136
132
scheduler_kwargs = kwargs .pop ("scheduler_kwargs" , {})
@@ -206,8 +202,27 @@ def _predict(self, x: TensorLike, with_grad: bool) -> GaussianProcessLike:
206
202
x = x .to (self .device )
207
203
return self (x )
208
204
205
+ @classmethod
206
+ def scheduler_config (cls ) -> dict :
207
+ """
208
+ Returns a random configuration for the learning rate scheduler.
209
+ This should be added to the `get_tune_config()` method of subclasses
210
+ to allow tuning of the scheduler parameters.
211
+ """
212
+ all_params = [
213
+ {"scheduler_cls" : None , "scheduler_kwargs" : None },
214
+ {
215
+ "scheduler_cls" : [LRScheduler ],
216
+ "scheduler_kwargs" : [
217
+ {"policy" : "ReduceLROnPlateau" , "patience" : 5 , "factor" : 0.5 }
218
+ ],
219
+ },
220
+ ]
221
+ return np .random .choice (all_params )
222
+
209
223
@staticmethod
210
224
def get_tune_config ():
225
+ scheduler_params = GaussianProcessExact .scheduler_config ()
211
226
return {
212
227
"mean_module_fn" : [
213
228
constant_mean ,
@@ -226,12 +241,10 @@ def get_tune_config():
226
241
rbf_times_linear ,
227
242
],
228
243
"epochs" : [50 , 100 , 200 ],
229
- "activation" : [
230
- nn .ReLU ,
231
- nn .GELU ,
232
- ],
233
- "lr" : list (np .logspace (- 3 , - 1 )),
244
+ "lr" : list (np .logspace (- 3 , 0 , 100 )),
234
245
"likelihood_cls" : [MultitaskGaussianLikelihood ],
246
+ "scheduler_cls" : scheduler_params ["scheduler_cls" ],
247
+ "scheduler_kwargs" : scheduler_params ["scheduler_kwargs" ],
235
248
}
236
249
237
250
@@ -255,7 +268,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
255
268
mean_module_fn : MeanModuleFn = constant_mean ,
256
269
covar_module_fn : CovarModuleFn = rbf ,
257
270
epochs : int = 50 ,
258
- activation : type [nn .Module ] = nn .ReLU ,
259
271
lr : float = 2e-1 ,
260
272
early_stopping : EarlyStopping | None = None ,
261
273
seed : int | None = None ,
@@ -332,7 +344,6 @@ def __init__( # noqa: PLR0913 allow too many arguments since all currently requ
332
344
self .covar_module = covar_module
333
345
self .epochs = epochs
334
346
self .lr = lr
335
- self .activation = activation
336
347
self .optimizer = self .optimizer_cls (self .parameters (), lr = self .lr ) # type: ignore[call-arg] since all optimizers include lr
337
348
# Extract scheduler-specific kwargs if present
338
349
scheduler_kwargs = kwargs .pop ("scheduler_kwargs" , {})
0 commit comments