-
Notifications
You must be signed in to change notification settings - Fork 705
[ENH] Duet Implementation #1968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Shvrth
wants to merge
5
commits into
sktime:main
Choose a base branch
from
Shvrth:duet_implementation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
385,705 changes: 385,705 additions & 0 deletions
385,705
pytorch_forecasting/models/duet/dataset/AQWan.csv
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,310 @@ | ||
from copy import copy | ||
from logging import config | ||
from typing import Callable, Optional, Union | ||
|
||
from einops import rearrange | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from torchmetrics import Metric | ||
|
||
from pytorch_forecasting.data import TimeSeriesDataSet | ||
from pytorch_forecasting.metrics import ( | ||
MAE, | ||
MAPE, | ||
RMSE, | ||
SMAPE, | ||
MultiHorizonMetric, | ||
QuantileLoss, | ||
) | ||
from pytorch_forecasting.models.base import BaseModelWithCovariates | ||
from pytorch_forecasting.models.duet.sub_modules.layers.linear_extractor_cluster import ( # noqa: E501 | ||
Linear_extractor_cluster, | ||
) | ||
from pytorch_forecasting.models.duet.sub_modules.utils.masked_attention import ( | ||
AttentionLayer, | ||
Encoder, | ||
EncoderLayer, | ||
FullAttention, | ||
Mahalanobis_mask, | ||
) | ||
from pytorch_forecasting.models.nn.embeddings import MultiEmbedding | ||
|
||
# Default hyperparameters as specified in the official implementation. | ||
DEFAULT_DUET_HYPER_PARAMS = { | ||
"enc_in": 1, | ||
"dec_in": 1, | ||
"c_out": 1, | ||
"e_layers": 2, | ||
"d_layers": 1, | ||
"d_model": 512, | ||
"d_ff": 2048, | ||
"hidden_size": 256, | ||
"freq": "h", | ||
"factor": 1, | ||
"n_heads": 8, | ||
"activation": "gelu", | ||
"output_attention": 0, | ||
"patch_len": 16, | ||
"stride": 8, | ||
"dropout": 0.2, | ||
"fc_dropout": 0.2, | ||
"moving_avg": 25, | ||
"batch_size": 256, | ||
"lradj": "type3", | ||
"num_workers": 0, | ||
"patience": 10, | ||
"num_experts": 4, | ||
"noisy_gating": True, | ||
"k": 1, | ||
"CI": False, | ||
} | ||
|
||
|
||
class DUETModel(BaseModelWithCovariates): | ||
""" | ||
Initial implementation of DUET: Dual Clustering Enhanced Multivariate Time | ||
Series Forecasting | ||
|
||
Original paper: https://arxiv.org/pdf/2412.10859 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
# Parameters for BaseModelWithCovariates | ||
static_categoricals: list[str], | ||
static_reals: list[str], | ||
time_varying_categoricals_encoder: list[str], | ||
time_varying_categoricals_decoder: list[str], | ||
time_varying_reals_encoder: list[str], | ||
time_varying_reals_decoder: list[str], | ||
x_reals: list[str], | ||
x_categoricals: list[str], | ||
embedding_sizes: dict, | ||
embedding_labels: dict, | ||
embedding_paddings: list[str], | ||
categorical_groups: dict, | ||
# Parameters for BaseModel | ||
loss: Metric, | ||
learning_rate: float, | ||
dataset_parameters: dict, | ||
optimizer: str = "adam", | ||
output_transformer: Callable = None, | ||
# Parameters for DUET's architecture | ||
dec_in: int = 1, | ||
c_out: int = 1, | ||
e_layers: int = 2, | ||
d_layers: int = 1, | ||
d_model: int = 512, | ||
d_ff: int = 2048, | ||
hidden_size: int = 256, | ||
freq: str = "h", | ||
factor: int = 1, | ||
n_heads: int = 8, | ||
activation: str = "gelu", | ||
output_attention: int = 0, | ||
patch_len: int = 16, | ||
stride: int = 8, | ||
dropout: float = 0.2, | ||
fc_dropout: float = 0.2, | ||
moving_avg: int = 25, | ||
lradj: str = "type3", | ||
num_workers: int = 0, | ||
patience: int = 10, | ||
num_experts: int = 4, | ||
noisy_gating: bool = True, | ||
k: int = 1, | ||
CI: bool = False, | ||
**kwargs, | ||
): | ||
""" | ||
Initialize DUET model. | ||
|
||
Args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
static_categoricals: names of static categorical variables | ||
static_reals: names of static continuous variables | ||
time_varying_categoricals_encoder: names of categorical variables for encoder | ||
time_varying_categoricals_decoder: names of categorical variables for decoder | ||
time_varying_reals_encoder: names of continuous variables for encoder | ||
time_varying_reals_decoder: names of continuous variables for decoder | ||
x_reals: order of continuous variables in tensor passed to forward function | ||
x_categoricals: order of categorical variables in tensor passed to forward function | ||
embedding_sizes: dictionary mapping (string) indices to tuple of number of categorical classes and | ||
embedding size | ||
embedding_paddings: list of indices for embeddings which transform the zero's embedding to a zero vector | ||
embedding_labels: dictionary mapping (string) indices to list of categorical labels | ||
categorical_groups: dictionary where values | ||
are list of categorical variables that are forming together a new categorical | ||
variable which is the key in the dictionary | ||
loss: loss function to use. Can be any instance of torchmetrics.Metric | ||
learning_rate: learning rate to use | ||
dataset_parameters: dictionary containing parameters of the dataset | ||
optimizer: optimizer to use (default: "adam") | ||
output_transformer: function to transform the output of the network | ||
dec_in: number of input features to the decoder | ||
c_out: number of output features | ||
e_layers: number of encoder layers | ||
d_layers: number of decoder layers | ||
d_model: dimensionality of the model's hidden states. It is the size of the | ||
vectors used to represent the time series data throughout the model. | ||
d_ff: dimensionality of the feedforward network model | ||
hidden_size: hidden size for the distributional router's encoder, | ||
which is part of the Mixture of Experts mechanism. | ||
freq: frequency of the time series data. This is used for generating time | ||
based features. | ||
factor: factor for the attention mechanism | ||
n_heads: number of attention heads | ||
activation: activation function used in the model | ||
output_attention: whether to output attention weights | ||
patch_len: length of each patch when using patching mechanism | ||
stride: stride for the patching mechanism | ||
dropout: dropout rate applied within the encoder layers to prevent overfitting | ||
fc_dropout: dropout rate applied to the final fully connected layer | ||
moving_avg: window size for moving average, used in the decomposition of the time series | ||
batch_size: batch size used during training | ||
lradj: learning rate adjustment strategy | ||
num_workers: number of workers for data loading | ||
patience: number of epochs with no improvement after which training will be stopped | ||
num_experts: number of experts in the Mixture of Experts mechanism | ||
noisy_gating: whether to use noisy gating in the Mixture of Experts mechanism | ||
k: number of experts to be selected by the router for each input | ||
CI: whether to use channel independent configuration. If True, the model | ||
processes each channel (variate) of the time series independently before combining them. | ||
**kwargs: additional arguments | ||
""" # noqa: E501 | ||
self.save_hyperparameters() | ||
super().__init__( | ||
loss=loss, | ||
learning_rate=learning_rate, | ||
optimizer=optimizer, | ||
output_transformer=output_transformer, | ||
dataset_parameters=dataset_parameters, | ||
) | ||
|
||
self.cat_embeddings = MultiEmbedding( | ||
embedding_sizes=self.hparams.embedding_sizes, | ||
embedding_paddings=self.hparams.embedding_paddings, | ||
categorical_groups=self.hparams.categorical_groups, | ||
x_categoricals=self.hparams.x_categoricals, | ||
) | ||
|
||
self.cluster = Linear_extractor_cluster(self.hparams) | ||
self.CI = self.hparams.CI | ||
self.n_vars = self.hparams.enc_in | ||
self.mask_generator = Mahalanobis_mask(self.hparams.seq_len) | ||
self.Channel_transformer = Encoder( | ||
[ | ||
EncoderLayer( | ||
AttentionLayer( | ||
FullAttention( | ||
True, | ||
self.hparams.factor, | ||
attention_dropout=self.hparams.dropout, | ||
output_attention=self.hparams.output_attention, | ||
), | ||
self.hparams.d_model, | ||
self.hparams.n_heads, | ||
), | ||
self.hparams.d_model, | ||
self.hparams.d_ff, | ||
dropout=self.hparams.dropout, | ||
activation=self.hparams.activation, | ||
) | ||
for _ in range(self.hparams.e_layers) | ||
], | ||
norm_layer=torch.nn.LayerNorm(self.hparams.d_model), | ||
) | ||
|
||
self.linear_head = nn.Sequential( | ||
nn.Linear(self.hparams.d_model, self.hparams.pred_len), | ||
nn.Dropout(self.hparams.fc_dropout), | ||
) | ||
|
||
def forward(self, x: dict) -> dict: | ||
embedded_features = self.cat_embeddings(x["encoder_cat"]) | ||
cont_tensor = x["encoder_cont"] | ||
|
||
input_tensor = torch.cat([cont_tensor, embedded_features["cols"]], dim=-1) | ||
|
||
batch_size = input_tensor.shape[0] # noqa: F841 | ||
|
||
if self.hparams.CI: | ||
channel_independent_input = rearrange(input_tensor, "b l n -> (b n) l 1") | ||
|
||
reshaped_output, L_importance = self.cluster(channel_independent_input) | ||
|
||
temporal_feature = rearrange( | ||
reshaped_output, "(b n) l 1 -> b l n", b=input_tensor.shape[0] | ||
) # noqa: E501 | ||
else: | ||
temporal_feature, L_importance = self.cluster(input_tensor) | ||
|
||
temporal_feature = rearrange(temporal_feature, "b d n -> b n d") | ||
|
||
if self.n_vars > 1: | ||
# Multivariate case: apply channel transformer | ||
|
||
changed_input = rearrange(input_tensor, "b l n -> b n l") | ||
|
||
channel_mask = self.mask_generator(changed_input) | ||
|
||
channel_group_feature, _ = self.Channel_transformer( | ||
x=temporal_feature, attn_mask=channel_mask | ||
) | ||
else: | ||
# For univariate case, the group feature is just the temporal feature | ||
channel_group_feature = temporal_feature | ||
|
||
# Select the features for the target variable(s) | ||
target_features = torch.stack( | ||
[ | ||
channel_group_feature[i, self.target_positions] | ||
for i in range(channel_group_feature.size(0)) | ||
], | ||
dim=0, | ||
) | ||
|
||
# if torch.isnan(target_features).any(): | ||
# raise ValueError("Target features contain NaN values.") | ||
|
||
# Passing only the target features to the prediction head | ||
normalized_prediction = self.linear_head(target_features) | ||
|
||
# if torch.isnan(normalized_prediction).any(): | ||
# raise ValueError("Predictions contain NaN values.") | ||
|
||
normalized_prediction = rearrange(normalized_prediction, "b n d -> b d n") | ||
|
||
prediction = self.transform_output( | ||
prediction=normalized_prediction, target_scale=x["target_scale"] | ||
) | ||
|
||
return self.to_network_output(prediction=prediction, L_importance=L_importance) | ||
|
||
@classmethod | ||
def from_dataset( | ||
cls, | ||
dataset: TimeSeriesDataSet, | ||
allowed_encoder_known_variable_names: list[str] = None, | ||
**kwargs, | ||
): | ||
""" | ||
Create a DUET model from a TimeSeriesDataSet. | ||
""" | ||
new_kwargs = DEFAULT_DUET_HYPER_PARAMS.copy() | ||
new_kwargs.update(kwargs) | ||
|
||
# Adding parameters we infer from the dataset | ||
new_kwargs.update( | ||
{ | ||
"seq_len": dataset.max_encoder_length, | ||
"pred_len": dataset.max_prediction_length, | ||
"enc_in": len(dataset.reals) + len(dataset.categoricals), | ||
} | ||
) | ||
|
||
return super().from_dataset( | ||
dataset, | ||
allowed_encoder_known_variable_names=allowed_encoder_known_variable_names, | ||
**new_kwargs, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use
_safe_import
here, this is not a core dep ofptf
, so it should be imported using_safe_import
. See here