Skip to content

Commit

Permalink
fix(Events): added FixedForwardWindowIndexerNoTruncation class, mak…
Browse files Browse the repository at this point in the history
…e sure `truncate_if_timeframe_is_smaller==True` outputs numbers instead of nans for the right tail of the time series (#526)

* fix(Events): added `FixedForwardWindowIndexerNoTruncation` class, make sure `truncate_if_timeframe_is_smaller==True` outputs numbers instead of nans for the right tail of the time series

* fix: types
  • Loading branch information
almostintuitive committed Oct 12, 2023
1 parent 61ca286 commit 5488ad8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/fold/events/labeling/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def label_events(
series=y,
period=self.time_horizon,
shift_by=self.shift_by,
truncate_end=self.truncate_if_timeframe_is_smaller,
)
if self.truncate_if_timeframe_is_smaller:
cutoff_point = y.index[-self.time_horizon]
Expand Down
63 changes: 60 additions & 3 deletions src/fold/utils/forward.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,56 @@
from typing import Callable, Optional
from typing import Callable, Optional, Tuple

import numpy as np
import pandas as pd
from pandas.core.indexers.objects import BaseIndexer


class FixedForwardWindowIndexerNoTruncation(BaseIndexer):
"""
Creates window boundaries for fixed-length windows that include the current row.
Examples
--------
>>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]})
>>> df
B
0 0.0
1 1.0
2 2.0
3 NaN
4 4.0
>>> indexer = pd.api.indexers.FixedForwardWindowIndexer(window_size=2)
>>> df.rolling(window=indexer, min_periods=1).sum()
B
0 1.0
1 3.0
2 2.0
3 4.0
4 4.0
"""

def get_window_bounds(
self,
num_values: int = 0,
min_periods: Optional[int] = None,
center: Optional[bool] = None,
closed: Optional[str] = None,
step: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray]:
if center:
raise ValueError("Forward-looking windows can't have center=True")
if closed is not None:
raise ValueError(
"Forward-looking windows don't support setting the closed argument"
)
if step is None:
step = 1

start = np.arange(0, num_values, step, dtype="int64")
end = start + self.window_size

return start, end


def create_forward_rolling(
Expand All @@ -9,10 +59,17 @@ def create_forward_rolling(
series: pd.Series,
period: int,
shift_by: Optional[int],
truncate_end: bool,
) -> pd.Series:
assert period > 0
indexer = pd.api.indexers.FixedForwardWindowIndexer(window_size=period)
indexer = (
pd.api.indexers.FixedForwardWindowIndexer(window_size=period)
if truncate_end
else FixedForwardWindowIndexerNoTruncation(window_size=period)
)
shift_by = shift_by if shift_by is not None else -1
assert shift_by < 0
transformation_func = transformation_func if transformation_func else lambda x: x
return agg_func(transformation_func(series).rolling(window=indexer)).shift(shift_by)
return agg_func(
transformation_func(series).rolling(window=indexer, min_periods=1)
).shift(shift_by)

0 comments on commit 5488ad8

Please sign in to comment.