Skip to content

Commit c0b0581

Browse files
fixed a bug on the chronos pipeline
1 parent 6d4b537 commit c0b0581

File tree

5 files changed

+13
-34
lines changed

5 files changed

+13
-34
lines changed

examples/foundation_daily.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def transform_group(df):
129129
"MoiraiLarge",
130130
"MoiraiMoESmall",
131131
"MoiraiMoEBase",
132-
"MoiraiMoELarge",
133132
"TimesFM_1_0_200m",
134133
"TimesFM_2_0_500m",
135134
]

examples/foundation_monthly.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def transform_group(df):
135135
"MoiraiLarge",
136136
"MoiraiMoESmall",
137137
"MoiraiMoEBase",
138-
"MoiraiMoELarge",
139138
"TimesFM_1_0_200m",
140139
"TimesFM_2_0_500m",
141140
]

examples/m5-examples/foundation_daily_m5.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
"MoiraiLarge",
3636
"MoiraiMoESmall",
3737
"MoiraiMoEBase",
38-
"MoiraiMoELarge",
3938
"TimesFM_1_0_200m",
4039
"TimesFM_2_0_500m",
4140
]

mmf_sa/Forecaster.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ def backtest_global_model(
383383
spark=self.spark,
384384
# backtest_retrain=self.conf["backtest_retrain"],
385385
))
386-
387386
group_id_dtype = IntegerType() \
388387
if train_df[self.conf["group_id"]].dtype == 'int' else StringType()
389388

@@ -399,7 +398,6 @@ def backtest_global_model(
399398
]
400399
)
401400
res_sdf = self.spark.createDataFrame(res_pdf, schema)
402-
403401
# Write evaluation results to a delta table
404402
if write:
405403
if self.conf.get("evaluation_output", None):
@@ -413,7 +411,6 @@ def backtest_global_model(
413411
.write.mode("append")
414412
.saveAsTable(self.conf.get("evaluation_output"))
415413
)
416-
417414
# Compute aggregated metrics
418415
res_df = (
419416
res_sdf.groupby(["metric_name"])

mmf_sa/models/chronosforecast/ChronosPipeline.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,13 @@ def predict(self,
9090
horizon_timestamps_udf(hist_df.ds).alias("ds"),
9191
forecast_udf(hist_df.y).alias("y"))
9292
).toPandas()
93-
9493
forecast_df = forecast_df.reset_index(drop=False).rename(
9594
columns={
9695
"unique_id": self.params.group_id,
9796
"ds": self.params.date_col,
9897
"y": self.params.target,
9998
}
10099
)
101-
102100
# Todo
103101
# forecast_df[self.params.target] = forecast_df[self.params.target].clip(0.01)
104102
return forecast_df, self.model
@@ -165,19 +163,13 @@ def predict_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
165163
import numpy as np
166164
import pandas as pd
167165
# Initialize the ChronosPipeline with a pretrained model from the specified repository
168-
from chronos import BaseChronosPipeline, ChronosBoltPipeline
169-
if "bolt" in self.repo:
170-
pipeline = ChronosBoltPipeline.from_pretrained(
171-
self.repo,
172-
device_map=self.device,
173-
torch_dtype=torch.bfloat16,
174-
)
175-
else:
176-
pipeline = BaseChronosPipeline.from_pretrained(
177-
self.repo,
178-
device_map=self.device,
179-
torch_dtype=torch.bfloat16,
180-
)
166+
from chronos import BaseChronosPipeline
167+
pipeline = BaseChronosPipeline.from_pretrained(
168+
self.repo,
169+
device_map='cuda',
170+
torch_dtype=torch.bfloat16,
171+
)
172+
181173
# inference
182174
for bulk in bulk_iterator:
183175
median = []
@@ -262,19 +254,12 @@ def __init__(self, repository, prediction_length):
262254
self.prediction_length = prediction_length
263255
self.device = "cuda" if torch.cuda.is_available() else "cpu"
264256
# Initialize the ChronosPipeline with a pretrained model from the specified repository
265-
from chronos import BaseChronosPipeline, ChronosBoltPipeline
266-
if "bolt" in self.repository:
267-
self.pipeline = ChronosBoltPipeline.from_pretrained(
268-
self.repository,
269-
device_map=self.device,
270-
torch_dtype=torch.bfloat16,
271-
)
272-
else:
273-
self.pipeline = BaseChronosPipeline.from_pretrained(
274-
self.repository,
275-
device_map=self.device,
276-
torch_dtype=torch.bfloat16,
277-
)
257+
from chronos import BaseChronosPipeline
258+
self.pipeline = BaseChronosPipeline.from_pretrained(
259+
self.repository,
260+
device_map='cuda',
261+
torch_dtype=torch.bfloat16,
262+
)
278263

279264
def predict(self, context, input_data, params=None):
280265
history = [torch.tensor(list(series)) for series in input_data]

0 commit comments

Comments
 (0)