Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from pytorch_forecasting.data import TorchNormalizer
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -16,6 +17,15 @@ class BetaDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "beta",
"info:metric_name": "BetaDistributionLoss",
"requires:data_type": "beta_distribution_forecast",
"clip_target": True,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we need clip_target and data_loader_kwargs in tags as they donot convey any useful info, rather they are the params to make the dataloaders more comaptible with each metric, maybe we should move it somewhere else in the class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. You're right. Where do you I should be moving it ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe they can be one of the properties? or args in __init__?
what do you think is a better idea? a property or an arg?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fkiraly what is your opinion about this?

Copy link
Author

@ParamThakkar123 ParamThakkar123 Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phoeenniixx I think we can keep it as an arg?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is it being passed?

"data_loader_kwargs": {
"target_normalizer": GroupNormalizer(
groups=["agency", "sku"], transformation="logit"
)
},
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"expected_loss_ndim": 2,
}

@classmethod
Expand All @@ -30,3 +40,10 @@ def get_encoder(cls):
Returns a TorchNormalizer instance for rescaling parameters.
"""
return TorchNormalizer(transformation="logit")

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for BetaDistributionLoss.
"""
return super()._get_test_dataloaders_from(params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class ImplicitQuantileNetworkDistributionLoss_pkg(_BasePtMetric):
"requires:data_type": "implicit_quantile_network_distribution_forecast",
"capability:quantile_generation": True,
"shape:adds_quantile_dimension": True,
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
}

@classmethod
Expand All @@ -44,3 +46,10 @@ def get_metric_test_params(cls):
fixture for testing the ImplicitQuantileNetworkDistributionLoss metric.
"""
return [{"input_size": 5}]

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for ImplicitQuantileNetworkDistributionLoss.
"""
return super()._get_test_dataloaders_from(params)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from pytorch_forecasting.data import TorchNormalizer
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -18,6 +19,15 @@ class LogNormalDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "log_normal",
"info:metric_name": "LogNormalDistributionLoss",
"requires:data_type": "log_normal_distribution_forecast",
"clip_target": True,
"data_loader_kwargs": {
"target_normalizer": GroupNormalizer(
groups=["agency", "sku"], transformation="log1p"
)
},
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"expected_loss_ndim": 2,
}

@classmethod
Expand Down Expand Up @@ -48,3 +58,10 @@ def prepare_test_inputs(cls, test_case):
)

return y_pred, y

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for LogNormalDistributionLoss.
"""
super()._get_test_dataloaders_from(params=params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from pytorch_forecasting.data import TorchNormalizer
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -18,6 +19,13 @@ class MQF2DistributionLoss_pkg(_BasePtMetric):
"python_dependencies": ["cpflows"],
"capability:quantile_generation": True,
"requires:data_type": "mqf2_distribution_forecast",
"clip_target": True,
"data_loader_kwargs": {
"target_normalizer": GroupNormalizer(
groups=["agency", "sku"], center=False, transformation="log1p"
)
},
"trainer_kwargs": dict(accelerator="cpu"),
}

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Package container for multivariate normal distribution loss metric.
"""

from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -17,6 +18,14 @@ class MultivariateNormalDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "multivariate_normal",
"info:metric_name": "MultivariateNormalDistributionLoss",
"requires:data_type": "multivariate_normal_distribution_forecast",
"data_loader_kwargs": {
"target_normalizer": GroupNormalizer(
groups=["agency", "sku"], transformation="log1p"
)
},
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"expected_loss_ndim": 2,
}

@classmethod
Expand All @@ -26,3 +35,10 @@ def get_cls(cls):
)

return MultivariateNormalDistributionLoss

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for MultivariateNormalDistributionLoss.
"""
super()._get_test_dataloaders_from(params=params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from pytorch_forecasting.data import TorchNormalizer
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -16,6 +17,13 @@ class NegativeBinomialDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "negative_binomial",
"info:metric_name": "NegativeBinomialDistributionLoss",
"requires:data_type": "negative_binomial_distribution_forecast",
"clip_target": False,
"data_loader_kwargs": {
"target_normalizer": GroupNormalizer(groups=["agency", "sku"], center=False)
},
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"expected_loss_ndim": 2,
}

@classmethod
Expand All @@ -32,3 +40,10 @@ def get_encoder(cls):
Returns a TorchNormalizer instance for rescaling parameters.
"""
return TorchNormalizer(center=False)

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for NegativeBinomialDistributionLoss.
"""
super()._get_test_dataloaders_from(params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,20 @@ class NormalDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "normal",
"info:metric_name": "NormalDistributionLoss",
"requires:data_type": "normal_distribution_forecast",
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"expected_loss_ndim": 2,
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.distributions import NormalDistributionLoss

return NormalDistributionLoss

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for NormalDistributionLoss.
"""
super()._get_test_dataloaders_from(params=params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ class CrossEntropy_pkg(_BasePtMetric):
"requires:data_type": "classification_forecast",
"info:metric_name": "CrossEntropy",
"no_rescaling": True,
"info:pred_type": ["point"],
"info:y_type": ["category"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics import CrossEntropy

return CrossEntropy

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for CrossEntropy.
"""
super()._get_test_dataloaders_from(params=params, target="category")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_mae/_mae_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ class MAE_pkg(_BasePtMetric):
"metric_type": "point",
"requires:data_type": "point_forecast",
"info:metric_name": "MAE",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics import MAE

return MAE

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for MAE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_mape/_mape_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ class MAPE_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "MAPE",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import MAPE

return MAPE

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for MAPE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_mase/_mase_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,19 @@ class MASE_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "MASE",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics import MASE

return MASE

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for MASE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ class PoissonLoss_pkg(_BasePtMetric):
"requires:data_type": "point_forecast",
"capability:quantile_generation": True,
"shape:adds_quantile_dimension": True,
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import PoissonLoss

return PoissonLoss

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for PoissonLoss.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_rmse/_rmse_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ class RMSE_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "RMSE",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
} # noqa: E501

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import RMSE

return RMSE

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for RMSE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_smape/_smape_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ class SMAPE_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "SMAPE",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
} # noqa: E501

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import SMAPE

return SMAPE

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for SMAPE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ class TweedieLoss_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "TweedieLoss",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_types": ["numeric"],
} # noqa: E501

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import TweedieLoss

return TweedieLoss

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for TweedieLoss.
"""
return super()._get_test_dataloaders_from(params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class QuantileLoss_pkg(_BasePtMetric):
"metric_type": "quantile",
"info:metric_name": "QuantileLoss",
"requires:data_type": "quantile_forecast",
"info:pred_type": ["quantile"],
"info:y_type": ["numeric"],
} # noqa: E501

@classmethod
Expand All @@ -34,3 +36,10 @@ def get_metric_test_params(cls):
"quantiles": [0.2, 0.5],
},
]

@classmethod
def _get_test_dataloaders_from(cls, params=None):
"""
Returns test dataloaders configured for QuantileLoss.
"""
return super()._get_test_dataloaders_from(params, target="agency")
Loading
Loading