Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model validation for FS models #165

Merged
merged 12 commits into from
Jul 25, 2024
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pytest==7.4.4
pytest-black
mlflow==2.0.1
mlflow==2.11.3
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ assert model_name != "", "model_name notebook parameter must be specified"
{{- if (eq .input_include_models_in_unity_catalog "no") }}
stage = get_deployed_model_stage_for_env(env)
model_uri = f"models:/{model_name}/{stage}"{{else}}
alias = "Champion"
alias = "champion"
model_uri = f"models:/{model_name}@{alias}"{{end}}

# COMMAND ----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,20 @@ def deploy(model_uri, env):
_, model_name, version = model_uri.split("/")
client = MlflowClient(registry_uri="databricks-uc")
mv = client.get_model_version(model_name, version)
target_alias = "Champion"
target_alias = "champion"
if target_alias not in mv.aliases:
client.set_registered_model_alias(
name=model_name,
alias=target_alias,
version=version)
print(f"Assigned alias '{target_alias}' to model version {model_uri}.")

# remove "Challenger" alias if assigning "Champion" alias
if target_alias == "Champion" and "Challenger" in mv.aliases:
print(f"Removing 'Challenger' alias from model version {model_uri}.")
# remove "challenger" alias if assigning "champion" alias
if target_alias == "champion" and "challenger" in mv.aliases:
print(f"Removing 'challenger' alias from model version {model_uri}.")
client.delete_registered_model_alias(
name=model_name,
alias="Challenger")
alias="challenger")
{{end}}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mlflow==2.7.1
mlflow==2.11.3
numpy>=1.23.0
pandas==1.5.3
scikit-learn>=1.1.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ To get started, open `{{template `project_name_alphanumeric_underscore` .}}/reso
new_cluster: &new_cluster
new_cluster:
num_workers: 3
spark_version: 13.3.x-cpu-ml-scala2.12
spark_version: 15.3.x-cpu-ml-scala2.12
node_type_id: {{template `cloud_specific_node_type_id` .}}
custom_tags:
clusterSource: mlops-stacks_{{template `stacks_version` .}}
Expand Down Expand Up @@ -243,7 +243,7 @@ targets:
new_cluster: &new_cluster
new_cluster:
num_workers: 3
spark_version: 13.3.x-cpu-ml-scala2.12
spark_version: 15.3.x-cpu-ml-scala2.12
node_type_id: {{template `cloud_specific_node_type_id` .}}
custom_tags:
clusterSource: mlops-stacks_{{template `stacks_version` .}}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
new_cluster: &new_cluster
new_cluster:
num_workers: 3
spark_version: 13.3.x-cpu-ml-scala2.12
spark_version: 15.3.x-cpu-ml-scala2.12
node_type_id: {{template `cloud_specific_node_type_id` .}}
custom_tags:
clusterSource: mlops-stacks_{{template `stacks_version` .}}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
new_cluster: &new_cluster
new_cluster:
num_workers: 3
spark_version: 13.3.x-cpu-ml-scala2.12
spark_version: 15.3.x-cpu-ml-scala2.12
node_type_id: {{template `cloud_specific_node_type_id` .}}
custom_tags:
clusterSource: mlops-stacks_{{template `stacks_version` .}}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
new_cluster: &new_cluster
new_cluster:
num_workers: 3
spark_version: 13.3.x-cpu-ml-scala2.12
spark_version: 15.3.x-cpu-ml-scala2.12
node_type_id: {{template `cloud_specific_node_type_id` .}}
custom_tags:
clusterSource: mlops-stacks_{{template `stacks_version` .}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
new_cluster: &new_cluster
new_cluster:
num_workers: 3
spark_version: 13.3.x-cpu-ml-scala2.12
spark_version: 15.3.x-cpu-ml-scala2.12
node_type_id: {{template `cloud_specific_node_type_id` .}}
custom_tags:
clusterSource: mlops-stacks_{{template `stacks_version` .}}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Model Validation Notebook
##
# This notebook uses mlflow model validation API to run mode validation after training and registering a model
# in model registry, before deploying it to the {{- if (eq .input_include_models_in_unity_catalog "no") }}"Production" stage{{else}} "Champion" alias{{end -}}.
# in model registry, before deploying it to the {{- if (eq .input_include_models_in_unity_catalog "no") }}"Production" stage{{else}} "champion" alias{{end -}}.
#
# It runs as part of CD and by an automated model training job -> validation -> deployment job defined under ``{{template `project_name_alphanumeric_underscore` .}}/resources/model-workflow-resource.yml``
#
Expand All @@ -14,13 +14,13 @@
# * `run_mode` - The `run_mode` defines whether model validation is enabled or not. It can be one of the three values:
# * `disabled` : Do not run the model validation notebook.
# * `dry_run` : Run the model validation notebook. Ignore failed model validation rules and proceed to move
# model to the {{- if (eq .input_include_models_in_unity_catalog "no") }}"Production" stage{{else}} "Champion" alias{{end -}}.
# * `enabled` : Run the model validation notebook. Move model to the {{- if (eq .input_include_models_in_unity_catalog "no") }} "Production" stage {{else}} "Champion" alias {{end -}} only if all model validation
# model to the {{- if (eq .input_include_models_in_unity_catalog "no") }}"Production" stage{{else}} "champion" alias{{end -}}.
# * `enabled` : Run the model validation notebook. Move model to the {{- if (eq .input_include_models_in_unity_catalog "no") }} "Production" stage {{else}} "champion" alias {{end -}} only if all model validation
# rules are passing.
{{- if (eq .input_include_models_in_unity_catalog "no") }}
# * enable_baseline_comparison - Whether to load the current registered "Production" stage model as baseline.
{{else}}
# * enable_baseline_comparison - Whether to load the current registered "Champion" model as baseline.
# * enable_baseline_comparison - Whether to load the current registered "champion" model as baseline.
{{end -}}
# Baseline model is a requirement for relative change and absolute change validation thresholds.
# * validation_input - Validation input. Please refer to data parameter in mlflow.evaluate documentation https://mlflow.org/docs/latest/python_api/mlflow.html#mlflow.evaluate
Expand Down Expand Up @@ -89,12 +89,6 @@ dbutils.widgets.text("model_name", "dev.{{ .input_schema_name }}.{{template `mod
dbutils.widgets.text("model_version", "", "Candidate Model Version")

# COMMAND ----------
{{ if (eq .input_include_feature_store `yes`) }}
print(
"Currently model validation is not supported for models registered with feature store. Please refer to "
"issue https://github.com/databricks/mlops-stacks/issues/70 for more details."
)
dbutils.notebook.exit(0){{ end }}
run_mode = dbutils.widgets.get("run_mode").lower()
assert run_mode == "disabled" or run_mode == "dry_run" or run_mode == "enabled"

Expand Down Expand Up @@ -172,7 +166,7 @@ if model_uri == "":
{{ if (eq .input_include_models_in_unity_catalog "no") }}
baseline_model_uri = "models:/" + model_name + "/Production"
{{else}}
baseline_model_uri = "models:/" + model_name + "@Champion"
baseline_model_uri = "models:/" + model_name + "@champion"
{{ end }}
evaluators = "default"
assert model_uri != "", "model_uri notebook parameter must be specified"
Expand All @@ -183,6 +177,14 @@ assert model_version != "", "model_version notebook parameter must be specified"

# take input
enable_baseline_comparison = dbutils.widgets.get("enable_baseline_comparison")

{{ if (eq .input_include_feature_store `yes`) }}
enable_baseline_comparison = "false"
print(
"Currently baseline model comparison is not supported for models registered with feature store. Please refer to "
"issue https://github.com/databricks/mlops-stacks/issues/70 for more details."
)
{{ end }}
assert enable_baseline_comparison == "true" or enable_baseline_comparison == "false"
enable_baseline_comparison = enable_baseline_comparison == "true"

Expand Down Expand Up @@ -255,8 +257,76 @@ def log_to_model_description(run, success):
name=model_name, version=model_version, description=description
)

{{ if (eq .input_include_feature_store `yes`) }}

from datetime import timedelta, timezone
import math
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType


def rounded_unix_timestamp(dt, num_minutes=15):
"""
Ceilings datetime dt to interval num_minutes, then returns the unix timestamp.
"""
nsecs = dt.minute * 60 + dt.second + dt.microsecond * 1e-6
delta = math.ceil(nsecs / (60 * num_minutes)) * (60 * num_minutes) - nsecs
return int((dt + timedelta(seconds=delta)).replace(tzinfo=timezone.utc).timestamp())


rounded_unix_timestamp_udf = F.udf(rounded_unix_timestamp, IntegerType())


def rounded_taxi_data(taxi_data_df):
# Round the taxi data timestamp to 15 and 30 minute intervals so we can join with the pickup and dropoff features
# respectively.
taxi_data_df = (
taxi_data_df.withColumn(
"rounded_pickup_datetime",
F.to_timestamp(
rounded_unix_timestamp_udf(
taxi_data_df["tpep_pickup_datetime"], F.lit(15)
)
),
)
.withColumn(
"rounded_dropoff_datetime",
F.to_timestamp(
rounded_unix_timestamp_udf(
taxi_data_df["tpep_dropoff_datetime"], F.lit(30)
)
),
)
.drop("tpep_pickup_datetime")
.drop("tpep_dropoff_datetime")
)
taxi_data_df.createOrReplaceTempView("taxi_data")
return taxi_data_df


data = rounded_taxi_data(data)


{{ end }}

# COMMAND ----------

{{ if (eq .input_include_feature_store `yes`) }}
# Temporary fix as FS model can't predict as a pyfunc model
# MLflow evaluate can take a lambda function instead of a model uri for a model
# but id does not work for the baseline model as it requires a model_uri (baseline comparison is set to false)

from databricks.feature_store import FeatureStoreClient

def get_fs_model(df):
fs_client = FeatureStoreClient()
return (
fs_client.score_batch(model_uri, spark.createDataFrame(df))
.select("prediction")
.toPandas()
)
{{ end }}

training_run = get_training_run(model_name, model_version)

# run evaluate
Expand All @@ -277,7 +347,11 @@ with mlflow.start_run(

try:
eval_result = mlflow.evaluate(
{{ if (eq .input_include_feature_store `yes`) }}
model=get_fs_model,
{{ else }}
model=model_uri,
{{ end }}
data=data,
targets=targets,
model_type=model_type,
Expand Down Expand Up @@ -312,9 +386,9 @@ with mlflow.start_run(
mlflow.log_artifact(metrics_file)
log_to_model_description(run, True)
{{ if (eq .input_include_models_in_unity_catalog "yes") }}
# Assign "Challenger" alias to indicate model version has passed validation checks
print("Validation checks passed. Assigning 'Challenger' alias to model version.")
client.set_registered_model_alias(model_name, "Challenger", model_version)
# Assign "challenger" alias to indicate model version has passed validation checks
print("Validation checks passed. Assigning 'challenger' alias to model version.")
client.set_registered_model_alias(model_name, "challenger", model_version)
{{ end }}
except Exception as err:
log_to_model_description(run, False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def validation_thresholds():
threshold=500, higher_is_better=False # max_error should be <= 500
),
"mean_squared_error": MetricThreshold(
threshold=20, # mean_squared_error should be <= 20
threshold=500, # mean_squared_error should be <= 500
# min_absolute_change=0.01, # mean_squared_error should be at least 0.01 greater than baseline model accuracy
# min_relative_change=0.01, # mean_squared_error should be at least 1 percent greater than baseline model accuracy
higher_is_better=False,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_create_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,9 @@ def prepareContext(
if include_mlflow_recipes != "":
context["input_include_mlflow_recipes"] = include_mlflow_recipes
if include_models_in_unity_catalog != "":
context[
"input_include_models_in_unity_catalog"
] = include_models_in_unity_catalog
context["input_include_models_in_unity_catalog"] = (
include_models_in_unity_catalog
)
return context


Expand Down
Loading