Skip to content

Commit 22a9781

Browse files
authored
[ENH] Standardize output format for tslib v2 models (#1965)
Fixes #1964 This PR standardizes the output format of `forward` of `tslib` models to: - [x] 3D tensors for single-target for `DLinear` - [x] 3D tensors for single-target for `TimeXer`
1 parent e1cc1ce commit 22a9781

File tree

6 files changed

+22
-21
lines changed

6 files changed

+22
-21
lines changed

pytorch_forecasting/layers/_output/_flatten_head.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ def forward(self, x):
3838
x = self.flatten(x)
3939
x = self.linear(x)
4040
x = self.dropout(x)
41+
x = x.permute(0, 2, 1)
4142

4243
if self.n_quantiles is not None:
43-
batch_size, n_vars = x.shape[0], x.shape[1]
44-
x = x.reshape(batch_size, n_vars, -1, self.n_quantiles)
44+
batch_size = x.shape[0]
45+
x = x.reshape(batch_size, -1, self.n_quantiles)
4546
return x

pytorch_forecasting/models/dlinear/_dlinear_v2.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,22 +239,17 @@ def _reshape_output(self, output: torch.Tensor) -> torch.Tensor:
239239
Returns
240240
-------
241241
output: torch.Tensor
242-
Reshaped tensor (batch_size, prediction_length, n_features, n_quantiles)
242+
Reshaped tensor (batch_size, prediction_length, n_quantiles)
243243
or (batch_size, prediction_length, n_features) if n_quantiles is None.
244244
"""
245245
if self.n_quantiles is not None:
246-
batch_size, n_features = output.shape[0], output.shape[1]
246+
batch_size = output.shape[0]
247247
output = output.reshape(
248-
batch_size, n_features, self.prediction_length, self.n_quantiles
248+
batch_size, self.prediction_length, self.n_quantiles
249249
)
250-
output = output.permute(0, 2, 1, 3) # (batch, time, features, quantiles)
251250
else:
252251
output = output.permute(0, 2, 1) # (batch, time, features)
253252

254-
# univariate forecasting
255-
if self.target_dim == 1 and output.shape[-1] == 1:
256-
output = output.squeeze(-1)
257-
258253
return output
259254

260255
def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:

pytorch_forecasting/models/timexer/_timexer_v2.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -311,11 +311,6 @@ def _forecast(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
311311

312312
dec_out = self.head(enc_out)
313313

314-
if self.n_quantiles is not None:
315-
dec_out = dec_out.permute(0, 2, 1, 3)
316-
else:
317-
dec_out = dec_out.permute(0, 2, 1)
318-
319314
return dec_out
320315

321316
def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
@@ -330,10 +325,6 @@ def forward(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
330325
out = self._forecast(x)
331326
prediction = out[:, : self.prediction_length, :]
332327

333-
# check to see if the output shape is equal to number of targets
334-
if prediction.size(2) != self.target_dim:
335-
prediction = prediction[:, :, : self.target_dim]
336-
337328
if "target_scale" in x:
338329
prediction = self.transform_output(prediction, x["target_scale"])
339330

pytorch_forecasting/tests/test_all_estimators_v2.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import lightning.pytorch as pl
77
from lightning.pytorch.callbacks import EarlyStopping
88
from lightning.pytorch.loggers import TensorBoardLogger
9+
import torch
910
import torch.nn as nn
1011

1112
from pytorch_forecasting.tests.test_all_estimators import (
@@ -77,6 +78,19 @@ def _integration(
7778
test_outputs = trainer.test(net, dataloaders=test_dataloader)
7879
assert len(test_outputs) > 0
7980

81+
# todo: add the predict pipeline and make this test cleaner
82+
x, y = next(iter(test_dataloader))
83+
net.eval()
84+
with torch.no_grad():
85+
output = net(x)
86+
net.train()
87+
prediction = output["prediction"]
88+
n_dims = prediction.ndim
89+
assert n_dims == 3, (
90+
f"Prediction output must be 3D, but got {n_dims}D tensor "
91+
f"with shape {output.shape}"
92+
)
93+
8094
shutil.rmtree(tmp_path, ignore_errors=True)
8195

8296

tests/test_models/test_dlinear_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_quantile_loss_output(sample_dataset):
124124

125125
assert "prediction" in output
126126
pred = output["prediction"]
127-
assert pred.ndim == 4
127+
assert pred.ndim == 3
128128
assert pred.shape[-1] == len(quantiles)
129129
assert pred.shape[1] == metadata["prediction_length"]
130130

tests/test_models/test_timexer_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def test_quantile_predictions(basic_metadata):
297297
output = model(sample_input_data)
298298

299299
predictions = output["prediction"]
300-
assert predictions.shape == (batch_size, 8, 1, 3)
300+
assert predictions.shape == (batch_size, 8, 3)
301301

302302

303303
def test_missing_history_target_handling(basic_metadata):

0 commit comments

Comments
 (0)