Skip to content

Conversation

phoeenniixx
Copy link
Member

@phoeenniixx phoeenniixx commented Aug 28, 2025

Fixes #1956, Fixes #1844
This PR tries to make v2 compatible with Metrics. It also changes the contract of tensors to match v1: list for multi-target and a tensor for single-target

Stacks on #1965

@phoeenniixx
Copy link
Member Author

phoeenniixx commented Aug 28, 2025

The important question right now is should the TimeSeries return

  • the 2D tensors (timesteps, num_target) (for target), like it is now and we change the contract in data_module - TimeSeries follows the contract that D1 returns the raw tensors only.
  • Or a list of tensors should come from the start of pipeline - breaks this contract that D1 returns the raw tensors only and sends a list of tensors (for multi-target), but the contract for target remains consistent over the piepline.

@phoeenniixx
Copy link
Member Author

phoeenniixx commented Aug 31, 2025

Hi @fkiraly, @agobbifbk, @PranavBhatP
Some context for why do we need this change:
Right now, data_module returns 3D tensors for target (or ground truth) - (batch, timesteps, num_targets). To actually make Metrics compatible with the v2, we need to make sure we feed the tensors of correct shape to it. As we know Metrics can handle both 2D and 3D inputs for the predictions, so if we use MultiLoss we would need to use list of tensors - as we already agreed.

But the actual issue is not with the prediction (or output of forward) but with the target. If you see this loss function of MAPE:

def loss(self, y_pred, target):
loss = (self.to_prediction(y_pred) - target).abs() / (target.abs() + 1e-8)
return loss

The base Metric handles the 2D/3D y_pred, but it can only handle the 2D target, but right now data_module returns the 3D tensor. First doubt would be why it can only handle 2D target, because actually in loss we would always get a 2D y_pred. See here the to_prediction() - the return of which is subtracted from target in MAPE :

if y_pred.ndim == 3:
if self.quantiles is None:
assert (
y_pred.size(-1) == 1
), "Prediction should only have one extra dimension"
y_pred = y_pred[..., 0]
else:
y_pred = y_pred.mean(-1)
return y_pred

As you see here, even the 3D y_pred is changed to 2D (batch, timesteps) , so for all cases to_prediction() always returns a 2D tensor.
Now lets go back to MAPE:

 loss = (self.to_prediction(y_pred) - target).abs() / (target.abs() + 1e-8)

here, if to_prediction(y_pred) is 2D, target has to be 2D, otherwise there would be a shape mismatch. Also, we can't say that this happens only in point metrics, as the to_prediction that you see above is from Metric - the base class for all metrics.

So, we need to either change the Metrics to handle 3D target - which could be a harder issue as:

  • we would need to make changes to a lot of metrics and their loss logic
  • This could introduce bugs, and testing them would take a lot of effort and time

Or

We could change the data_module to return target as 2D tensors- This is possible in only one way: We use lists for multi-target.

This is what happens in v1 as well, you could see the docstring of to_dataloader() of TimeSeriesDataset from v1 here:

* target : float (batch_size x n_decoder_time_steps) or list thereof
if list, with each entry for a different target.
unscaled (continuous) or encoded (categories) targets,
list of tensors for multiple targets

(NOTE: Here data_module return means what data_module returns using dataloaders, so actually this is a return of dataloaders that reside inside the data_module.)

EDIT: I manually checked the shape of target that was passed to MultivariateDistributionLoss as well, it also uses 2D target

@agobbifbk
Copy link

Ok thx for the explanation. Do you confirm that in v1, in case of multi target, the batching process creates a list of 2d tensors? If is this the case, somewhere in the code relative to the loss, there is a place in which there is an iteration along this list and the loss functions for each target are summed up, isn't it? If yes, could we just change the iteration process based on the combination type/dimension of the target tensor and reuse all v1 metric module?

@phoeenniixx
Copy link
Member Author

If is this the case, somewhere in the code relative to the loss, there is a place in which there is an iteration along this list and the loss functions for each target are summed up,

Yes, that happens in MultiLoss, for multitaget, from what i am able to understand, we always have to se MutliLoss, if you want to use same loss for all the targets you do something like this:

multiloss = MultiLoss(
    metrics=[MASE(), MASE(), MASE()],  # One MASE per target
    weights=[1.0, 1.0, 1.0]           # Optional weights
)

See update, compute and forward of MultiLoss here

If yes, could we just change the iteration process based on the combination type/dimension of the target tensor and reuse all v1 metric module?

Yes, we could do this, we can iterate over the last dimension of the tensor (in case of 3D target). But then we would have to make some changes to the MultiLoss and step funcitons in v2 as well.

The changes to the v2 would be to lose the last dimension (for 3D target) and then pass the target to loss in case of single target.

@fkiraly
Copy link
Collaborator

fkiraly commented Sep 1, 2025

Thanks for the useful explanation, @phoeenniixx! Also @PranavBhatP - could you kindly point me to the best and most precise description of the current metric API? (or paste one here)

I understand the change would be necessary if we keep the current metrics API - so I would like to think a little bit about whether the metrics API could be improved instead.

@PranavBhatP
Copy link
Contributor

PranavBhatP commented Sep 1, 2025

could you kindly point me to the best and most precise description of the current metric API? (or paste one here)

key methods of Metric and child classes

update(y_pred, target, **more_args) and loss(y_pred, target, **more_args). Addtionally, rescale_parameters is also required for normalizing the params to the scale required for output, in distribution metrics.

  • target must be one of:

    • 2D (batch, timepoints) torch.Tensor for both methods
    • a tuple of tensors (tensor, weights), where tensor is any acceptable tensor (2d) representing the ground truth values and weights is a tensor of the shape (batch_size,1), (1,timesteps), (batch_size,) or same shape as tensor. Broadcasting of weights is possible in the first three cases.
    • rnn.PackedSequences - a 2D array-like (list of 1D).
  • y_pred must be one of:

    • 3D (batch, timepoints, loss_dim) torch.Tensor, with loss_dim an integer specific to the loss (loss_dim >= 2).
    • 2D (batch, timepoints) - only permitted for point prediction metrics. Interpreted the same as (batch, timepoints, 1) which is coerced/squeezed to obtain (batch, timepoints)
      All metrics used in code inherit from MultiHorizonMetric, the lowest level abstract interface class. For multiple targets, a list of these input tensors is used.

The following vignette lists all input conventions for the target dtypes and metric class methods.


# Sample data
# loss_dim is specific to the metric
y_pred = torch.randn(2, 5, loss_dim)  # Network predictions
y_actual = torch.randn(2, 5)  # Ground truth targets - 2d tensor.

# option 1 - Simple tensors
loss.update(y_pred=y_pred, target=y_actual)

# option 2 - Weighted targets 
weights = torch.ones(2, 5) * 0.8  # Give less weight to some samples
loss.update(y_pred=y_pred, target=(y_actual, weights))

# optione 3 - Variable-length sequences (PackedSequence)
from torch.nn.utils.rnn import pack_sequence

sequences = [torch.randn(3), torch.randn(5), torch.randn(2)]  # Different lengths
packed_targets = pack_sequence(sequences, enforce_sorted=False)
# predictions would need to be appropriately shaped for packed format
loss.update(predictions, packed_targets)

imo this is a good example for this context, which details the expected input formats for a metric in v1. Please let me know if this is still not sufficient.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@phoeenniixx
Copy link
Member Author

The current example notebook used nn.MSELoss, changed that to MAE from Metrics

@phoeenniixx
Copy link
Member Author

I understand the change would be necessary if we keep the current metrics API - so I would like to think a little bit about whether the metrics API could be improved instead.

The only change i could understand is the one @agobbifbk suggested - changing the looping logic, but still we'd need to change the BaseModel_v2, just the step functions (training-step etc) to lose the last dimension for single target.
Or we could make Metrics to handle 3D targets:

  • Looping over last dim for multi-target
  • ignoring the last dim if single -target

Copy link

codecov bot commented Sep 1, 2025

Codecov Report

❌ Patch coverage is 69.76744% with 13 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (main@e1cc1ce). Learn more about missing BASE report.

Files with missing lines Patch % Lines
pytorch_forecasting/data/_tslib_data_module.py 41.66% 7 Missing ⚠️
pytorch_forecasting/data/data_module.py 50.00% 6 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1960   +/-   ##
=======================================
  Coverage        ?   87.21%           
=======================================
  Files           ?      158           
  Lines           ?     9301           
  Branches        ?        0           
=======================================
  Hits            ?     8112           
  Misses          ?     1189           
  Partials        ?        0           
Flag Coverage Δ
cpu 87.21% <69.76%> (?)
pytest 87.21% <69.76%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@agobbifbk
Copy link

I understand the change would be necessary if we keep the current metrics API - so I would like to think a little bit about whether the metrics API could be improved instead.

The only change i could understand is the one @agobbifbk suggested - changing the looping logic, but still we'd need to change the BaseModel_v2, just the step functions (training-step etc) to lose the last dimension for single target. Or we could make Metrics to handle 3D targets:

  • Looping over last dim for multi-target
  • ignoring the last dim if single -target

What I understand:

  • In v1 the target is either a list of 2d tensor or a single 2d tensor
  • in v2 the target is always a 3d tensor
    Since there are no ambiguity here I think it is sufficient to change the looping strategy, or I'm missing something?

@phoeenniixx
Copy link
Member Author

phoeenniixx commented Sep 2, 2025

  • Since there are no ambiguity here I think it is sufficient to change the looping strategy

Yes, I agree, but there is one thing:
In V1:
From my understanding, this "looping" happens only when we use multi-target, when we do single-target, the target is passed directly as it is assumed to be a tensor and not a list with single element
So, for V2:
We not only have to update the looping logic - which is for multi-target - but also, how we pass the single-target. For single-target, we have to drop the last dimension (which always will be 1 for single-target).

We have to do this as metrics accept 2D target only...

@phoeenniixx
Copy link
Member Author

Hi @fkiraly, @agobbifbk, so right now we have 2 possibilities:

  1. Change the contract of data_module to return list of tensors for multi-target and single tensor for single-target - same as v1
  2. change the looping logic in loss for multi-target and step functions in BaseModel (like training_step etc) (see this comment) to pass 2D tensor for single-target and 3D for multi-target - For this we can move to 4D output formats then? as we will be changing the looping logic anyway...

@PranavBhatP
Copy link
Contributor

change the looping logic in loss for multi-target and step functions in BaseModel (like training_step etc) (see this #1960 (comment)) to pass 2D tensor for single-target and 3D for multi-target - For this we can move to 4D output formats then? as we will be changing the looping logic anyway...

I agree that these are the options, but I would be against a change in the metrics, since a lot of the existing metric implementations might break with a change in the API and the whole test framework would need major changes. I would go ahead with option 1 :)

@fkiraly
Copy link
Collaborator

fkiraly commented Sep 3, 2025

@phoeenniixx, I think the example you posted is not what we are talking about? Because there is talk about lists of tensors. None of the specs you posted involve lists explicitly.

Can I please repeat my query, @phoeenniixx, @PranavBhatP?
Could you kindly point me to the best and most precise description of the current metric API? (or paste one here)

please make sure to include the case we are actually discussing here, i.e., lists of something - and the alternative designs

@phoeenniixx
Copy link
Member Author

phoeenniixx commented Sep 4, 2025

What @PranavBhatP shared here: #1960 (comment), ig sums up the API (mainly for single target ig)? I will add some clarifications based on my understanding:

  • Single-target:
    Input y_pred:

    • shape can be 2D (batch, timesteps) or 3D (batch, timesteps, params) tensors. Here params depend on the loss being used. For point losses, it will be 1. For QuantileLoss, it is the number of quantiles and for DistributionLoss, it is the number of distribution_arguments.

    Input target:

    • tensors of shape 2D (batch, timesteps) .
    • a tuple of tensors (tensor, weights) , where tensor is any acceptable tensor (2d) representing the ground truth values and weights is a tensor of the shape (batch_size,1), (1,timesteps), (batch_size,) or same shape as tensor. Broadcasting of weights is possible in the first three cases.
    • rnn.PackedSequences - for metrics that support variable-length sequences
  • Multi -target:
    Input y_pred:

    • For multi-target, y_pred is a list of tensors, where the length of the list is the number of targets. Each element can have shape 2D (batch, timesteps) or 3D (batch, timesteps, params) tensors. Here params depend on the loss being used. For point losses, it will be 1. For QuantileLoss, it is the number of quantiles and for DistributionLoss, it is the number of distribution_arguments.

    Input target:

    • target is a list of any one of the three cases of single-target:
      • 2D tensor (batch, timesteps)
      • a tuple of tensors (tensor, weights)
      • rnn.PackedSequences

    For multi-target, MultiLoss handles lists of metrics and targets appropriately, mapping each metric to its corresponding prediction and target pair.
    if you want to use same loss for all the targets you do something like this:

    multiloss = MultiLoss(
      metrics=[MASE(), MASE(), MASE()],  # One MASE per target
      weights=[1.0, 1.0, 1.0]           # Optional weights
    )

    See update, compute and forward of MultiLoss here

@PranavBhatP, please correct me if there are some misconceptions from my side

Based on this info, i could come up with 2 appraoches mentioned here: #1960 (comment)

@fkiraly
Copy link
Collaborator

fkiraly commented Sep 4, 2025

@phoeenniixx, thanks, I was looking exactly for this, the multi-target specs!

Can you clarify how option 2 in #1960 (comment) would work if we have losses with different loss_dims (called params in #1960 (comment)), i.e., targets that cause 3D tensors of different size due to the 3rd dimension (integer index 2)?

Or would it not work at all?

I have a slight preference towards option 1 currently since the above may pose an issue - but it might be worth exploring if it can be prevented in option 2 (or if I misunderstand this, please help clarify)

@phoeenniixx
Copy link
Member Author

phoeenniixx commented Sep 4, 2025

Can you clarify how option 2 in #1960 (comment) would work if we have losses with different loss_dims (called params in #1960 (comment)), i.e., targets that cause 3D tensors of different size due to the 3rd dimension (integer index 2)?

According to my mental model, what happens right now is for multi-target, we have to use MultiLoss where each entry in the in MultiLoss is for only one target. This means if there are 3 targets, then there will be three loss enteries to the list which will then be passed to the MultiLoss. (see my previous comment for the example)
And what MultiLoss does is map these losses to the corresponding target, meaning, if there are 3 targets and list passed to MultiLoss is [MAE(), SMAPE(), MASE()], then the first target is paired with MAE, second wiith SMAPE and so on.
How it happens right now is along with list of losses, list of tensors are also passed to MultiLoss. so first element of both the lists are paired together.

If we were to pass a 3D tensor (in case of target) and 4D tensor (in case of y_pred) where the last dim will be the num_targets, then we will do something like this:
Imagine the target is a 3D tensor (batch, timesteps, 4) meaning there are 4 targets, so y_pred will be some thing like this - (batch, timesteps, loss_dim, 4).

  • We will iterate over the last dim, so our loop will run for 4 iterations, and for each iteration, we will take first 2 dims (in case of target) and first 3 dims (in case of y_pred). This will create 4 target- y_pred pairs.
  • For these 4 pairs, there will be 4 losses inputted as a list to MultiLoss, and we will pair them accordingly.
    # inside MultiLoss
    last_dim = target.shape[-1]
    for idx in range(last_dim):
       res = metric[idx](y_pred[... ,idx], target[... ,idx])
    # similar way for weights - still need to see into this

So, based on this, different params may pose a problem, as we would not get a rectangular tensor in that case and we would need to do padding? I think we need to think if this would even work... Padding can make things complex ig as we would need to remove this padding when we each tensor is passed to their corresponding losses..

@phoeenniixx
Copy link
Member Author

Hi @fkiraly, I have removed the usage of nn losses from the notebooks as based on our recent discussion, we need to first write some adapter classes to actually use them, so for now I think we should refrain from using them. Once, we have the adapter classes, then we'll add examples using them?

@phoeenniixx
Copy link
Member Author

Also, the notebooks are working locally, so I think the fix of #1965 should work...

Copy link
Collaborator

@fkiraly fkiraly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

The increased complexity in the case distinctions are making me feel a bit uneasy, but I am thinking it is ok, based on our discussion. The concrete classes also seem to become more streamlined - this is a good sign.

Some requests:

  • can you kindly write full docstrings for the returns of __getitem__ in the D2 classes changed? I am aware these were not there before, but it feels crucial now to ensure the contract is well documented.
  • the __init__ warning message in the D2 layer classes refers to the wrong class, this should be fixed (pre-existing but easy to fix)
  • docstring in BaseModel should be attached to the class, not the __init__ method
  • docstring in BaseModel should refer to Metric more clearly, e.g., descendant of what? And "metric in pytorch-forecasting API"?
  • in Tide, can we add a probabilistic loss to replace the Poisson loss?

@phoeenniixx
Copy link
Member Author

  • in Tide, can we add a probabilistic loss to replace the Poisson loss?

Model is only compatible with the point prediction losses right now, but I think we could make some improvements there in the future.

@agobbifbk
Copy link

Are you referring to V1 or V2 implementation? Maybe it worth to fix it since we are working on this, what do you think?

@phoeenniixx
Copy link
Member Author

phoeenniixx commented Sep 11, 2025

Are you referring to V1 or V2 implementation? Maybe it worth to fix it since we are working on this, what do you think?

Actually both :)
We should make some changes to both the implementations to handle probabilistic losses and also, we should think about what to do with the code... v1 and v2 implementations are very different and done by totally different people

@agobbifbk
Copy link

Well I think in the DSIP implementation it is sufficient to use the self.mul parameter while in the V1 modifiy the output channels according to the number of output you need, and finally reshape the final tensor in the correct form (list of 3d tensor right?). We can have a look together in discord!

@fkiraly fkiraly moved this from PR in progress to PR under review in May - Sep 2025 mentee projects Sep 15, 2025
@phoeenniixx phoeenniixx requested a review from fkiraly September 15, 2025 12:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: PR under review
Development

Successfully merging this pull request may close these issues.

[ENH] Add Metric support to ptf-v2 [ENH] change loss dtype from nn.Module to Metric.
4 participants