diff --git a/README.md b/README.md index d64174c..9d2b525 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ run_forecast( - ```group_id``` is a column storing the unique id that groups your dataset to each time series. - ```date_col``` is your time column name. - ```target``` is your target column name. -- ```freq``` is your prediction frequency. Currently, "D" for daily and "M" for monthly are supported. Note that ```freq``` supported is as per the model basis, hence check the model documentation carefully. +- ```freq``` is your prediction frequency. Currently, "D" for daily and "M" for monthly are supported. Note that ```freq``` supported is as per the model basis, hence check the model documentation carefully. Monthly forecasting expects the timestamp column in ```train_data``` and ```scoring_output``` to be the last day of the month. - ```prediction_length``` is your forecasting horizon in the number of steps. - ```backtest_months``` specifies how many previous months you use for backtesting. - ```stride``` is the number of steps in which you update your backtesting trial start date when going from one trial to the next. diff --git a/examples/foundation_monthly.py b/examples/foundation_monthly.py index b0f404b..dd40016 100644 --- a/examples/foundation_monthly.py +++ b/examples/foundation_monthly.py @@ -115,6 +115,11 @@ def transform_group(df): # COMMAND ---------- +# MAGIC %md +# MAGIC Note that monthly forecasting requires the timestamp column to represent the last day of each month. + +# COMMAND ---------- + # MAGIC %md ### Models # MAGIC Let's configure a list of models we are going to apply to our time series for evaluation and forecasting. A comprehensive list of all supported models is available in [mmf_sa/models/models_conf.yaml](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/models_conf.yaml). Look for the models where `model_type: foundation`; these are the foundation models we import from [chronos](https://github.com/amazon-science/chronos-forecasting/tree/main), [uni2ts](https://github.com/SalesforceAIResearch/uni2ts) and [moment](https://github.com/moment-timeseries-foundation-model/moment). Check their documentation for the detailed description of each model. diff --git a/examples/global_monthly.py b/examples/global_monthly.py index 6f11c8e..236933c 100644 --- a/examples/global_monthly.py +++ b/examples/global_monthly.py @@ -115,6 +115,11 @@ def transform_group(df): # COMMAND ---------- +# MAGIC %md +# MAGIC Note that monthly forecasting requires the timestamp column to represent the last day of each month. + +# COMMAND ---------- + # MAGIC %md ### Models # MAGIC Let's configure a list of models we are going to apply to our time series for evaluation and forecasting. A comprehensive list of all supported models is available in [mmf_sa/models/models_conf.yaml](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/models_conf.yaml). Look for the models where `model_type: global`; these are the global models we import from [neuralforecast](https://github.com/Nixtla/neuralforecast). Check their documentation for the detailed description of each model. diff --git a/examples/local_univariate_monthly.py b/examples/local_univariate_monthly.py index 2bcf9ac..5a7e7bb 100644 --- a/examples/local_univariate_monthly.py +++ b/examples/local_univariate_monthly.py @@ -130,6 +130,11 @@ def transform_group(df): # COMMAND ---------- +# MAGIC %md +# MAGIC Note that monthly forecasting requires the timestamp column to represent the last day of each month. + +# COMMAND ---------- + # MAGIC %md ### Models # MAGIC Let's configure a list of models we are going to apply to our time series for evaluation and forecasting. A comprehensive list of all supported models is available in [mmf_sa/models/models_conf.yaml](https://github.com/databricks-industry-solutions/many-model-forecasting/blob/main/mmf_sa/models/models_conf.yaml). Look for the models where `model_type: local`; these are the local models we import from [statsforecast](https://github.com/Nixtla/statsforecast), [r fable](https://cran.r-project.org/web/packages/fable/vignettes/fable.html) and [sktime](https://github.com/sktime/sktime). Check their documentations for the description of each model. diff --git a/mmf_sa/Forecaster.py b/mmf_sa/Forecaster.py index efc1671..85cc4e0 100644 --- a/mmf_sa/Forecaster.py +++ b/mmf_sa/Forecaster.py @@ -75,6 +75,9 @@ def set_mlflow_experiment(self): Parameters: self (Forecaster): A Forecaster object. Returns: experiment_id (str): A string specifying the experiment id. """ + parent_dir = os.path.dirname(f'/Workspace{self.conf["experiment_path"]}') + if not os.path.exists(parent_dir): + os.makedirs(parent_dir) mlflow.set_experiment(self.conf["experiment_path"]) experiment_id = ( MlflowClient() diff --git a/mmf_sa/models/chronosforecast/ChronosPipeline.py b/mmf_sa/models/chronosforecast/ChronosPipeline.py index f966280..17367fb 100644 --- a/mmf_sa/models/chronosforecast/ChronosPipeline.py +++ b/mmf_sa/models/chronosforecast/ChronosPipeline.py @@ -42,7 +42,7 @@ def register(self, registered_model_name: str): "torchvision==0.18.1", "transformers==4.41.2", "cloudpickle==2.2.1", - "chronos-forecasting", + "chronos-forecasting==1.4.1", "git+https://github.com/databricks-industry-solutions/many-model-forecasting.git", "pyspark==3.5.0", ], diff --git a/mmf_sa/models/moiraiforecast/MoiraiPipeline.py b/mmf_sa/models/moiraiforecast/MoiraiPipeline.py index 3f0daf6..636c23e 100644 --- a/mmf_sa/models/moiraiforecast/MoiraiPipeline.py +++ b/mmf_sa/models/moiraiforecast/MoiraiPipeline.py @@ -41,7 +41,7 @@ def register(self, registered_model_name: str): signature=signature, input_example=input_example, pip_requirements=[ - "uni2ts", + "uni2ts==1.2.0", "git+https://github.com/databricks-industry-solutions/many-model-forecasting.git", "pyspark==3.5.0", ], diff --git a/mmf_sa/models/timesfmforecast/TimesFMPipeline.py b/mmf_sa/models/timesfmforecast/TimesFMPipeline.py index c48a692..5ec3a76 100644 --- a/mmf_sa/models/timesfmforecast/TimesFMPipeline.py +++ b/mmf_sa/models/timesfmforecast/TimesFMPipeline.py @@ -32,7 +32,7 @@ def register(self, registered_model_name: str): signature=signature, #input_example=input_example, pip_requirements=[ - "timesfm[torch]", + "timesfm[torch]==1.2.7", "git+https://github.com/databricks-industry-solutions/many-model-forecasting.git", "pyspark==3.5.0", ], @@ -183,15 +183,30 @@ def __init__(self, params, repo): self.params = params self.repo = repo #self.backend = "gpu" if torch.cuda.is_available() else "cpu" - self.model = timesfm.TimesFm( - hparams=timesfm.TimesFmHparams( - backend="gpu", - per_core_batch_size=32, - horizon_len=self.params.prediction_length, - ), - checkpoint=timesfm.TimesFmCheckpoint( - huggingface_repo_id=self.repo, - ), + if self.repo == "google/timesfm-1.0-200m-pytorch": + self.model = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="gpu", + per_core_batch_size=32, + horizon_len=self.params.prediction_length, + ), + checkpoint=timesfm.TimesFmCheckpoint( + huggingface_repo_id=self.repo, + ), + ) + else: + self.model = timesfm.TimesFm( + hparams=timesfm.TimesFmHparams( + backend="gpu", + per_core_batch_size=32, + horizon_len=self.params.prediction_length, + num_layers=50, + use_positional_embedding=False, + context_len=2048, + ), + checkpoint=timesfm.TimesFmCheckpoint( + huggingface_repo_id=self.repo + ), ) def predict(self, context, input_df, params=None): diff --git a/requirements-global.txt b/requirements-global.txt index ddfcbae..a809a60 100644 --- a/requirements-global.txt +++ b/requirements-global.txt @@ -1,5 +1,5 @@ -r requirements.txt neuralforecast==2.0.0 -timesfm[torch] -chronos-forecasting -uni2ts \ No newline at end of file +timesfm[torch]==1.2.7 +chronos-forecasting==1.4.1 +uni2ts==1.2.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ede6897..91d91a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,10 +2,11 @@ rpy2==3.5.16 kaleido==0.2.1 Jinja2 omegaconf==2.3.0 +numba==0.60.0 statsforecast==1.7.4 missingno==0.5.2 tbats==1.1.3 sktime==0.29.0 lightgbm==4.3.0 datasetsforecast==0.0.8 -fugue==0.9.0 +fugue==0.9.0 \ No newline at end of file