Skip to content

Commit a3ac860

Browse files
Backports v0.13.9 (#3076)
* Fix Rotbaum serialization and deserialization (#3068) * Fix Rotbaum to handle short series (#3073) * fix after backport --------- Co-authored-by: Anurag Pant <[email protected]>
1 parent dd8449f commit a3ac860

File tree

5 files changed

+133
-10
lines changed

5 files changed

+133
-10
lines changed

src/gluonts/ext/rotbaum/_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import gc
2121
from collections import defaultdict
2222

23-
from gluonts.core.component import validated
23+
from gluonts.core.component import equals, validated
2424

2525

2626
class QRF:
@@ -125,6 +125,13 @@ def _create_xgboost_model(model_params: Optional[dict] = None):
125125
}
126126
return xgboost.sklearn.XGBModel(**model_params)
127127

128+
def __eq__(self, that):
129+
"""
130+
Two QRX instances are considered equal if they have the same
131+
constructor arguments.
132+
"""
133+
return equals(self, that)
134+
128135
def fit(
129136
self,
130137
x_train: Union[pd.DataFrame, List],

src/gluonts/ext/rotbaum/_predictor.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313

1414
import concurrent.futures
1515
import logging
16+
import pickle
1617
from itertools import chain
1718
from typing import Iterator, List, Optional
1819
from toolz import first
1920

2021
import numpy as np
2122
import pandas as pd
23+
from pathlib import Path
2224
from itertools import compress
2325

2426
from gluonts.core.component import validated
@@ -337,6 +339,31 @@ def predict(
337339
item_id=ts.get("item_id"),
338340
)
339341

342+
def serialize(self, path: Path) -> None:
343+
"""
344+
This function calls parent class serialize() in order to serialize
345+
the class name, version information and constuctor arguments. It
346+
persists the tree predictor by pickling the model list that is
347+
generated when pickling the TreePredictor.
348+
"""
349+
super().serialize(path)
350+
with (path / "predictor.pkl").open("wb") as f:
351+
pickle.dump(self.model_list, f)
352+
353+
@classmethod
354+
def deserialize(cls, path: Path, **kwargs) -> "TreePredictor":
355+
"""
356+
This function loads and returns the serialized model. It loads
357+
the predictor class with the serialized arguments. It then loads
358+
the trained model list by reading the pickle file.
359+
"""
360+
361+
predictor = super().deserialize(path)
362+
assert isinstance(predictor, cls)
363+
with (path / "predictor.pkl").open("rb") as f:
364+
predictor.model_list = pickle.load(f)
365+
return predictor
366+
340367
def explain(
341368
self, importance_type: str = "gain", percentage: bool = True
342369
) -> ExplanationResult:

src/gluonts/ext/rotbaum/_preprocess.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -448,9 +448,12 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
448448
end_index = starting_index + self.context_window_size
449449
if starting_index < 0:
450450
prefix = [None] * abs(starting_index)
451+
time_series_window = time_series["target"]
451452
else:
452453
prefix = []
453-
time_series_window = time_series["target"][starting_index:end_index]
454+
time_series_window = time_series["target"][
455+
starting_index:end_index
456+
]
454457
only_lag_features, transform_dict = self._pre_transform(
455458
time_series_window, self.subtract_mean, self.count_nans
456459
)
@@ -460,7 +463,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
460463
if self.use_feat_static_real
461464
else []
462465
)
463-
if self.cardinality:
466+
if (
467+
self.cardinality
468+
and time_series.get("feat_static_cat", None) is not None
469+
):
464470
feat_static_cat = (
465471
self.encode_one_hot_all(time_series["feat_static_cat"])
466472
if self.one_hot_encode
@@ -473,10 +479,10 @@ def make_features(self, time_series: Dict, starting_index: int) -> List:
473479
list(
474480
chain(
475481
*[
476-
list(ent[0]) + list(ent[1].values())
482+
prefix + list(ent[0]) + list(ent[1].values())
477483
for ent in [
478484
self._pre_transform(
479-
ts[starting_index:end_index],
485+
ts if prefix else ts[starting_index:end_index],
480486
self.subtract_mean,
481487
self.count_nans,
482488
)

test/ext/rotbaum/test_model.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

14-
14+
from pathlib import Path
1515
import pytest
16+
import tempfile
1617

17-
from gluonts.ext.rotbaum import TreeEstimator
18+
from gluonts.ext.rotbaum import TreeEstimator, TreePredictor
1819

1920

2021
@pytest.fixture()
@@ -33,5 +34,20 @@ def test_accuracy(accuracy_test, hyperparameters, quantiles):
3334
accuracy_test(TreeEstimator, hyperparameters, accuracy=0.20)
3435

3536

36-
def test_serialize(serialize_test, hyperparameters):
37-
serialize_test(TreeEstimator, hyperparameters)
37+
def test_serialize(serialize_test, hyperparameters, dsinfo):
38+
forecaster = TreeEstimator.from_hyperparameters(
39+
freq=dsinfo.freq,
40+
**{
41+
"prediction_length": dsinfo.prediction_length,
42+
"num_parallel_samples": dsinfo.num_parallel_samples,
43+
},
44+
**hyperparameters,
45+
)
46+
47+
predictor_act = forecaster.train(dsinfo.train_ds)
48+
49+
with tempfile.TemporaryDirectory() as temp_dir:
50+
predictor_act.serialize(Path(temp_dir))
51+
predictor_exp = TreePredictor.deserialize(Path(temp_dir))
52+
assert predictor_act == predictor_exp
53+
assert predictor_act.model_list == predictor_exp.model_list

test/ext/rotbaum/test_rotbaum_smoke.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# permissions and limitations under the License.
1313

1414
import pytest
15+
import numpy as np
1516

16-
from gluonts.ext.rotbaum import TreeEstimator
17+
from gluonts.ext.rotbaum import TreeEstimator, TreePredictor
1718

1819
from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features
20+
from gluonts.dataset.common import ListDataset
1921

2022
# TODO: Add support for categorical and dynamic features.
2123

@@ -59,3 +61,68 @@ def test_rotbaum_smoke(datasets):
5961
predictor = estimator.train(dataset_train)
6062
forecasts = list(predictor.predict(dataset_test))
6163
assert len(forecasts) == len(dataset_test)
64+
65+
66+
def test_short_history_item_pred():
67+
prediction_length = 7
68+
freq = "D"
69+
70+
dataset = ListDataset(
71+
data_iter=[
72+
{
73+
"start": "2017-10-11",
74+
"item_id": "item_1",
75+
"target": np.array(
76+
[
77+
1.0,
78+
9.0,
79+
2.0,
80+
0.0,
81+
0.0,
82+
1.0,
83+
5.0,
84+
3.0,
85+
4.0,
86+
2.0,
87+
0.0,
88+
0.0,
89+
1.0,
90+
6.0,
91+
]
92+
),
93+
"feat_static_cat": np.array([0.0, 0.0], dtype=float),
94+
"past_feat_dynamic_real": np.array(
95+
[
96+
[1.0222e06 for i in range(14)],
97+
[750.0 for i in range(14)],
98+
]
99+
),
100+
},
101+
{
102+
"start": "2017-10-11",
103+
"item_id": "item_2",
104+
"target": np.array([7.0, 0.0, 0.0, 23.0, 13.0]),
105+
"feat_static_cat": np.array([0.0, 1.0], dtype=float),
106+
"past_feat_dynamic_real": np.array(
107+
[[0 for i in range(5)], [750.0 for i in range(5)]]
108+
),
109+
},
110+
],
111+
freq=freq,
112+
)
113+
114+
predictor = TreePredictor(
115+
freq=freq,
116+
prediction_length=prediction_length,
117+
quantiles=[0.1, 0.5, 0.9],
118+
max_n_datapts=50000,
119+
method="QuantileRegression",
120+
use_past_feat_dynamic_real=True,
121+
use_feat_dynamic_real=False,
122+
use_feat_dynamic_cat=False,
123+
use_feat_static_real=False,
124+
cardinality="auto",
125+
)
126+
predictor = predictor.train(dataset)
127+
forecasts = list(predictor.predict(dataset))
128+
assert forecasts[1].quantile(0.5).shape[0] == prediction_length

0 commit comments

Comments
 (0)