Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from pytorch_forecasting.data import TorchNormalizer
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -16,6 +17,16 @@ class BetaDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "beta",
"info:metric_name": "BetaDistributionLoss",
"requires:data_type": "beta_distribution_forecast",
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"loss_ndim": 2,
}

clip_target = True
data_loader_kwargs = {
"target_normalizer": GroupNormalizer(
groups=["agency", "sku"], transformation="logit"
)
}

@classmethod
Expand All @@ -30,3 +41,14 @@ def get_encoder(cls):
Returns a TorchNormalizer instance for rescaling parameters.
"""
return TorchNormalizer(transformation="logit")

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for BetaDistributionLoss.
"""
kwargs = dict(target="agency")
kwargs.update(cls.data_loader_kwargs)
return super()._get_test_dataloaders_from(
params, clip_target=cls.clip_target, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class ImplicitQuantileNetworkDistributionLoss_pkg(_BasePtMetric):
"requires:data_type": "implicit_quantile_network_distribution_forecast",
"capability:quantile_generation": True,
"shape:adds_quantile_dimension": True,
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
}

@classmethod
Expand All @@ -44,3 +46,10 @@ def get_metric_test_params(cls):
fixture for testing the ImplicitQuantileNetworkDistributionLoss metric.
"""
return [{"input_size": 5}]

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for ImplicitQuantileNetworkDistributionLoss.
"""
return super()._get_test_dataloaders_from(params)
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from pytorch_forecasting.data import TorchNormalizer
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -18,6 +19,16 @@ class LogNormalDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "log_normal",
"info:metric_name": "LogNormalDistributionLoss",
"requires:data_type": "log_normal_distribution_forecast",
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"loss_ndim": 2,
}

clip_target = True
data_loader_kwargs = {
"target_normalizer": GroupNormalizer(
groups=["agency", "sku"], transformation="log1p"
)
}

@classmethod
Expand Down Expand Up @@ -48,3 +59,14 @@ def prepare_test_inputs(cls, test_case):
)

return y_pred, y

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for LogNormalDistributionLoss.
"""
kwargs = dict(target="agency")
kwargs.update(cls.data_loader_kwargs)
return super()._get_test_dataloaders_from(
params, clip_target=cls.clip_target, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from pytorch_forecasting.data import TorchNormalizer
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -18,6 +19,13 @@ class MQF2DistributionLoss_pkg(_BasePtMetric):
"python_dependencies": ["cpflows"],
"capability:quantile_generation": True,
"requires:data_type": "mqf2_distribution_forecast",
"clip_target": True,
"data_loader_kwargs": {
"target_normalizer": GroupNormalizer(
groups=["agency", "sku"], center=False, transformation="log1p"
)
},
"trainer_kwargs": dict(accelerator="cpu"),
}

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Package container for multivariate normal distribution loss metric.
"""

from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -17,6 +18,15 @@ class MultivariateNormalDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "multivariate_normal",
"info:metric_name": "MultivariateNormalDistributionLoss",
"requires:data_type": "multivariate_normal_distribution_forecast",
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"loss_ndim": 2,
}

data_loader_kwargs = {
"target_normalizer": GroupNormalizer(
groups=["agency", "sku"], transformation="log1p"
)
}

@classmethod
Expand All @@ -26,3 +36,12 @@ def get_cls(cls):
)

return MultivariateNormalDistributionLoss

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for MultivariateNormalDistributionLoss.
"""
kwargs = dict(target="agency")
kwargs.update(cls.data_loader_kwargs)
return super()._get_test_dataloaders_from(params, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from pytorch_forecasting.data import TorchNormalizer
from pytorch_forecasting.data.encoders import GroupNormalizer
from pytorch_forecasting.metrics.base_metrics._base_object import _BasePtMetric


Expand All @@ -16,6 +17,14 @@ class NegativeBinomialDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "negative_binomial",
"info:metric_name": "NegativeBinomialDistributionLoss",
"requires:data_type": "negative_binomial_distribution_forecast",
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"loss_ndim": 2,
}

clip_target = False
data_loader_kwargs = {
"target_normalizer": GroupNormalizer(groups=["agency", "sku"], center=False)
}

@classmethod
Expand All @@ -32,3 +41,14 @@ def get_encoder(cls):
Returns a TorchNormalizer instance for rescaling parameters.
"""
return TorchNormalizer(center=False)

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for NegativeBinomialDistributionLoss.
"""
kwargs = dict(target="agency")
kwargs.update(cls.data_loader_kwargs)
return super()._get_test_dataloaders_from(
params, clip_target=cls.clip_target, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,20 @@ class NormalDistributionLoss_pkg(_BasePtMetric):
"distribution_type": "normal",
"info:metric_name": "NormalDistributionLoss",
"requires:data_type": "normal_distribution_forecast",
"info:pred_type": ["distr"],
"info:y_type": ["numeric"],
"loss_ndim": 2,
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.distributions import NormalDistributionLoss

return NormalDistributionLoss

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for NormalDistributionLoss.
"""
super()._get_test_dataloaders_from(params=params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ class CrossEntropy_pkg(_BasePtMetric):
"requires:data_type": "classification_forecast",
"info:metric_name": "CrossEntropy",
"no_rescaling": True,
"info:pred_type": ["point"],
"info:y_type": ["category"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics import CrossEntropy

return CrossEntropy

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for CrossEntropy.
"""
super()._get_test_dataloaders_from(params=params, target="category")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_mae/_mae_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ class MAE_pkg(_BasePtMetric):
"metric_type": "point",
"requires:data_type": "point_forecast",
"info:metric_name": "MAE",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics import MAE

return MAE

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for MAE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_mape/_mape_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ class MAPE_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "MAPE",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import MAPE

return MAPE

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for MAPE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_mase/_mase_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,19 @@ class MASE_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "MASE",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics import MASE

return MASE

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for MASE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ class PoissonLoss_pkg(_BasePtMetric):
"requires:data_type": "point_forecast",
"capability:quantile_generation": True,
"shape:adds_quantile_dimension": True,
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
}

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import PoissonLoss

return PoissonLoss

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for PoissonLoss.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_rmse/_rmse_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ class RMSE_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "RMSE",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
} # noqa: E501

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import RMSE

return RMSE

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for RMSE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
9 changes: 9 additions & 0 deletions pytorch_forecasting/metrics/_point_pkg/_smape/_smape_pkg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@ class SMAPE_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "SMAPE",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_type": ["numeric"],
} # noqa: E501

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import SMAPE

return SMAPE

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for SMAPE.
"""
return super()._get_test_dataloaders_from(params=params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@ class TweedieLoss_pkg(_BasePtMetric):
"metric_type": "point",
"info:metric_name": "TweedieLoss",
"requires:data_type": "point_forecast",
"info:pred_type": ["point"],
"info:y_types": ["numeric"],
} # noqa: E501

@classmethod
def get_cls(cls):
from pytorch_forecasting.metrics.point import TweedieLoss

return TweedieLoss

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for TweedieLoss.
"""
return super()._get_test_dataloaders_from(params, target="agency")
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class QuantileLoss_pkg(_BasePtMetric):
"metric_type": "quantile",
"info:metric_name": "QuantileLoss",
"requires:data_type": "quantile_forecast",
"info:pred_type": ["quantile"],
"info:y_type": ["numeric"],
} # noqa: E501

@classmethod
Expand All @@ -34,3 +36,10 @@ def get_metric_test_params(cls):
"quantiles": [0.2, 0.5],
},
]

@classmethod
def _get_test_dataloaders(cls, params=None):
"""
Returns test dataloaders configured for QuantileLoss.
"""
return super()._get_test_dataloaders_from(params, target="agency")
Loading
Loading