Skip to content

Commit

Permalink
Merge pull request #48 from databricks-industry-solutions/log-foundat…
Browse files Browse the repository at this point in the history
…ion-models

model logging and registry for foundation models
  • Loading branch information
ryuta-yoshimatsu authored Jun 3, 2024
2 parents 1a35856 + d4f2178 commit d830c69
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 6 deletions.
8 changes: 5 additions & 3 deletions mmf_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,11 @@ def evaluate_foundation_model(self, model_conf):
with mlflow.start_run(experiment_id=self.experiment_id) as run:
model_name = model_conf["name"]
model = self.model_registry.get_model(model_name)
model.register(
registered_model_name=f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}"
)
# For now, only support registering chronos, moirai and moment models
if model_conf["framework"] in ["Chronos", "Moirai", "Moment"]:
model.register(
registered_model_name=f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}"
)
hist_df, removed = self.prepare_data_for_global_model("evaluating") # Reuse the same as global
train_df, val_df = self.split_df_train_val(hist_df)
model_uri = f"runs:/{run.info.run_id}/model"
Expand Down
2 changes: 1 addition & 1 deletion mmf_sa/models/models_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ models:
framework: Chronos
model_type: foundation
num_samples: 10
batch_size: 4
batch_size: 2

MoiraiBase:
module: mmf_sa.models.moiraiforecast.MoiraiPipeline
Expand Down
67 changes: 66 additions & 1 deletion mmf_sa/models/moiraiforecast/MoiraiPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pandas as pd
import numpy as np
import torch
import mlflow
from mlflow.types import Schema, TensorSpec
from mlflow.models.signature import ModelSignature
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error
from typing import Iterator
from pyspark.sql.functions import collect_list, pandas_udf
Expand All @@ -19,9 +22,34 @@ def __init__(self, params):
self.model = None
self.install("git+https://github.com/SalesforceAIResearch/uni2ts.git")

def install(self, package: str):
@staticmethod
def install(package: str):
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])

def register(self, registered_model_name: str):
pipeline = MoiraiModel(
self.repo,
self.params["prediction_length"],
self.params["patch_size"],
self.params["num_samples"],
)
input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))])
output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
input_example = np.random.rand(52)
mlflow.pyfunc.log_model(
"model",
python_model=pipeline,
registered_model_name=registered_model_name,
signature=signature,
input_example=input_example,
pip_requirements=[
"git+https://github.com/SalesforceAIResearch/uni2ts.git",
"git+https://github.com/databricks-industry-solutions/many-model-forecasting.git",
"pyspark==3.5.0",
],
)

def create_horizon_timestamps_udf(self):
@pandas_udf('array<timestamp>')
def horizon_timestamps_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
Expand Down Expand Up @@ -172,3 +200,40 @@ def __init__(self, params):
super().__init__(params)
self.params = params
self.repo = "Salesforce/moirai-1.0-R-base"


class MoiraiModel(mlflow.pyfunc.PythonModel):
def __init__(self, repository, prediction_length, patch_size, num_samples):
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
self.repository = repository
self.prediction_length = prediction_length
self.patch_size = patch_size
self.num_samples = num_samples
self.module = MoiraiModule.from_pretrained(self.repository)
self.pipeline = None

def predict(self, context, input_data, params=None):
from einops import rearrange
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
self.pipeline = MoiraiForecast(
module=self.module,
prediction_length=self.prediction_length,
context_length=len(input_data),
patch_size=self.patch_size,
num_samples=self.num_samples,
target_dim=1,
feat_dynamic_real_dim=0,
past_feat_dynamic_real_dim=0,
)
# Time series values. Shape: (batch, time, variate)
past_target = rearrange(torch.as_tensor(input_data, dtype=torch.float32), "t -> 1 t 1")
# 1s if the value is observed, 0s otherwise. Shape: (batch, time, variate)
past_observed_target = torch.ones_like(past_target, dtype=torch.bool)
# 1s if the value is padding, 0s otherwise. Shape: (batch, time)
past_is_pad = torch.zeros_like(past_target, dtype=torch.bool).squeeze(-1)
forecast = self.pipeline(
past_target=past_target,
past_observed_target=past_observed_target,
past_is_pad=past_is_pad,
)
return np.median(forecast[0], axis=0)
59 changes: 58 additions & 1 deletion mmf_sa/models/momentforecast/MomentPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import pandas as pd
import numpy as np
import torch
import mlflow
from mlflow.types import Schema, TensorSpec
from mlflow.models.signature import ModelSignature
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error
from typing import Iterator
from pyspark.sql.functions import collect_list, pandas_udf
Expand All @@ -19,9 +22,32 @@ def __init__(self, params):
self.model = None
self.install("git+https://github.com/moment-timeseries-foundation-model/moment.git")

def install(self, package: str):
@staticmethod
def install(package: str):
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])

def register(self, registered_model_name: str):
pipeline = MomentModel(
self.repo,
self.params["prediction_length"],
)
input_schema = Schema([TensorSpec(np.dtype(np.double), (-1,))])
output_schema = Schema([TensorSpec(np.dtype(np.uint8), (-1,))])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
input_example = np.random.rand(52)
mlflow.pyfunc.log_model(
"model",
python_model=pipeline,
registered_model_name=registered_model_name,
signature=signature,
input_example=input_example,
pip_requirements=[
"git+https://github.com/moment-timeseries-foundation-model/moment.git",
"git+https://github.com/databricks-industry-solutions/many-model-forecasting.git",
"pyspark==3.5.0",
],
)

def create_horizon_timestamps_udf(self):
@pandas_udf('array<timestamp>')
def horizon_timestamps_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
Expand Down Expand Up @@ -157,3 +183,34 @@ def __init__(self, params):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.repo = "AutonLab/MOMENT-1-large"


class MomentModel(mlflow.pyfunc.PythonModel):
def __init__(self, repository, prediction_length):
from momentfm import MOMENTPipeline
self.repository = repository
self.prediction_length = prediction_length
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.pipeline = MOMENTPipeline.from_pretrained(
self.repository,
device_map=self.device,
model_kwargs={
"task_name": "forecasting",
"forecast_horizon": self.prediction_length},
)
self.pipeline.init()
self.pipeline = self.pipeline.to(self.device)

def predict(self, context, input_data, params=None):
series = list(input_data)
if len(series) < 512:
input_mask = [1] * len(series) + [0] * (512 - len(series))
series = series + [0] * (512 - len(series))
else:
input_mask = [1] * 512
series = series[-512:]
input_mask = torch.reshape(torch.tensor(input_mask),(1, 512)).to(self.device)
series = torch.reshape(torch.tensor(series),(1, 1, 512)).to(dtype=torch.float32).to(self.device)
output = self.pipeline(series, input_mask=input_mask)
forecast = output.forecast.squeeze().tolist()
return forecast

0 comments on commit d830c69

Please sign in to comment.