Skip to content

Commit

Permalink
small fix: removed num_samples from chronos pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed Jan 15, 2025
1 parent b5f933e commit 5609ee7
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions mmf_sa/models/chronosforecast/ChronosPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def register(self, registered_model_name: str):
pipeline = ChronosModel(
self.repo,
self.params["prediction_length"],
self.params["num_samples"],
self.device,
)
input_schema = Schema([TensorSpec(np.dtype(np.double), (-1, -1))])
Expand Down Expand Up @@ -189,7 +188,6 @@ def predict_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
forecasts = pipeline.predict(
context=contexts,
prediction_length=self.params["prediction_length"],
#num_samples=self.params["num_samples"],
)
median.extend([np.median(forecast, axis=0) for forecast in forecasts])
yield pd.Series(median)
Expand Down Expand Up @@ -259,11 +257,10 @@ def __init__(self, params):


class ChronosModel(mlflow.pyfunc.PythonModel):
def __init__(self, repository, prediction_length, num_samples, device):
def __init__(self, repository, prediction_length):
import torch
self.repository = repository
self.prediction_length = prediction_length
self.num_samples = num_samples
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the ChronosPipeline with a pretrained model from the specified repository
from chronos import BaseChronosPipeline, ChronosBoltPipeline
Expand All @@ -285,7 +282,6 @@ def predict(self, context, input_data, params=None):
forecast = self.pipeline.predict(
context=history,
prediction_length=self.prediction_length,
#num_samples=self.num_samples,
)
return forecast.numpy()

0 comments on commit 5609ee7

Please sign in to comment.