diff --git a/cluster_experiments/washover.py b/cluster_experiments/washover.py index fdd9ca18..523c9188 100644 --- a/cluster_experiments/washover.py +++ b/cluster_experiments/washover.py @@ -242,5 +242,124 @@ def from_config(cls, config) -> "Washover": return cls(washover_time_delta=washover_time_delta) +class TwoEventsWashover(Washover): + """ + For each record in our experiment, two events' timestamps, + login_timestamp and logout_timestamp. + We want to apply washover such that if there is a change in treatment + between login_timestamp and logout_timestamp. + Ex: + Calendar: + +-----------+-------+ + | treatment | time | + +-----------+-------+ + | A | 10:00 | + | B | 11:00 | + | B | 12:00 | + +-----------+-------+ + Record data: + + +----+------------+----------+ + | id | start_time | end_time | + +----+------------+----------+ + | 1 | 10:50 | 10:59 | + | 2 | 10:51 | 11:01 | + | 3 | 11:01 | 11:05 | + | 4 | 11:01 | 12:01 | + | 5 | 12:01 | 12:05 | + +----+------------+----------+ + + This washover will drop record number 2. + """ + + def __init__( + self, + calendar_df: pd.DataFrame, + time_column_calendar: str = "start_time", + treatment_column_calendar: str = "treatment", + ): + """ + Args: + calendar_df (pd.DataFrame): df with the calendar of treatments. + time_column_calendar (str): Name of the time column in the calendar df. + treatment_column_calendar (str): Name of the treatment column in the calendar df. + record_df (pd.DataFrame): Record dataframe. + start_time_column (str): Name of the start time column in the record dataframe. + end_time_column (str): Name of the end time column in the record dataframe. + + """ + self.calendar_df = calendar_df + self.time_column_calendar = time_column_calendar + self.treatment_column_calendar = treatment_column_calendar + + def _validate_columns( + self, record_df: pd.DataFrame, start_time_column: str, end_time_column: str + ): + if self.time_column_calendar not in self.calendar_df.columns: + raise ValueError( + f"{self.time_column_calendar = } is not in the calendar dataframe columns and/or not specified as an input." + ) + if self.treatment_column_calendar not in self.calendar_df.columns: + raise ValueError( + f"{self.treatment_column_calendar = } is not in the calendar dataframe columns and/or not specified as an input." + ) + if start_time_column not in record_df.columns: + raise ValueError( + f"{start_time_column = } is not in the record dataframe columns and/or not specified as an input." + ) + if end_time_column not in record_df.columns: + raise ValueError( + f"{end_time_column = } is not in the record dataframe columns and/or not specified as an input." + ) + + def washover( + self, + record_df: pd.DataFrame, + start_time_column: str = "start_time", + end_time_column: str = "end_time", + ): + """ + Return the Dataframe after applying the washover. + """ + self._validate_columns(record_df, start_time_column, end_time_column) + self.calendar_df[self.time_column_calendar] = pd.to_datetime( + self.calendar_df[self.time_column_calendar], + ).dt.time + record_df[start_time_column] = pd.to_datetime(record_df[start_time_column]) + record_df[end_time_column] = pd.to_datetime(record_df[end_time_column]) + record_df["start_time_floor"] = ( + record_df[start_time_column].dt.floor("h").dt.time + ) + record_df["end_time_floor"] = record_df[end_time_column].dt.floor("h").dt.time + record_df[end_time_column] = record_df[end_time_column].dt.time + record_df[start_time_column] = record_df[start_time_column].dt.time + record_df = ( + record_df.merge( + self.calendar_df, + how="left", + left_on="start_time_floor", + right_on=self.time_column_calendar, + suffixes=("", "_start"), + ) + .drop(columns=["start_time_start"]) + .merge( + self.calendar_df, + how="left", + left_on="end_time_floor", + right_on=self.time_column_calendar, + suffixes=("", "_end"), + ) + .drop( + columns=[ + self.time_column_calendar + "_end", + ] + ) + ) + record_df = record_df.query("treatment == treatment_end").drop( + columns=["treatment_end"] + ) + return record_df + + # This is kept in here because of circular imports, need to rethink this washover_mapping = {"": EmptyWashover, "constant_washover": ConstantWashover}