From 5609ee7faf76f875ae6845aabb84d67a0e483a48 Mon Sep 17 00:00:00 2001 From: Ryuta Yoshimatsu Date: Wed, 15 Jan 2025 07:34:34 +0100 Subject: [PATCH] small fix: removed num_samples from chronos pipeline --- mmf_sa/models/chronosforecast/ChronosPipeline.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/mmf_sa/models/chronosforecast/ChronosPipeline.py b/mmf_sa/models/chronosforecast/ChronosPipeline.py index 9b9708a..17fb7de 100644 --- a/mmf_sa/models/chronosforecast/ChronosPipeline.py +++ b/mmf_sa/models/chronosforecast/ChronosPipeline.py @@ -26,7 +26,6 @@ def register(self, registered_model_name: str): pipeline = ChronosModel( self.repo, self.params["prediction_length"], - self.params["num_samples"], self.device, ) input_schema = Schema([TensorSpec(np.dtype(np.double), (-1, -1))]) @@ -189,7 +188,6 @@ def predict_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: forecasts = pipeline.predict( context=contexts, prediction_length=self.params["prediction_length"], - #num_samples=self.params["num_samples"], ) median.extend([np.median(forecast, axis=0) for forecast in forecasts]) yield pd.Series(median) @@ -259,11 +257,10 @@ def __init__(self, params): class ChronosModel(mlflow.pyfunc.PythonModel): - def __init__(self, repository, prediction_length, num_samples, device): + def __init__(self, repository, prediction_length): import torch self.repository = repository self.prediction_length = prediction_length - self.num_samples = num_samples self.device = "cuda" if torch.cuda.is_available() else "cpu" # Initialize the ChronosPipeline with a pretrained model from the specified repository from chronos import BaseChronosPipeline, ChronosBoltPipeline @@ -285,7 +282,6 @@ def predict(self, context, input_data, params=None): forecast = self.pipeline.predict( context=history, prediction_length=self.prediction_length, - #num_samples=self.num_samples, ) return forecast.numpy()