Skip to content

Commit

Permalink
Add fix for model validation with Feature Store models
Browse files Browse the repository at this point in the history
  • Loading branch information
aliazzzdat committed Jul 17, 2024
1 parent cf07d13 commit c4a52e4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,6 @@ dbutils.widgets.text("model_name", "dev.{{ .input_schema_name }}.{{template `mod
dbutils.widgets.text("model_version", "", "Candidate Model Version")

# COMMAND ----------
{{ if (eq .input_include_feature_store `yes`) }}
print(
"Currently model validation is not supported for models registered with feature store. Please refer to "
"issue https://github.com/databricks/mlops-stacks/issues/70 for more details."
)
dbutils.notebook.exit(0){{ end }}
run_mode = dbutils.widgets.get("run_mode").lower()
assert run_mode == "disabled" or run_mode == "dry_run" or run_mode == "enabled"

Expand Down Expand Up @@ -183,6 +177,14 @@ assert model_version != "", "model_version notebook parameter must be specified"

# take input
enable_baseline_comparison = dbutils.widgets.get("enable_baseline_comparison")

{{ if (eq .input_include_feature_store `yes`) }}
enable_baseline_comparison = "false"
print(
"Currently baseline model comparison is not supported for models registered with feature store. Please refer to "
"issue https://github.com/databricks/mlops-stacks/issues/70 for more details."
)
{{ end }}
assert enable_baseline_comparison == "true" or enable_baseline_comparison == "false"
enable_baseline_comparison = enable_baseline_comparison == "true"

Expand Down Expand Up @@ -257,6 +259,26 @@ def log_to_model_description(run, success):

# COMMAND ----------

{{ if (eq .input_include_feature_store `yes`) }}
###########################################################################################
# Temporary fix as FS model can't predict as pyfunc model #
# MLflow evaluate can take a lambda function instead of model uri for model #
# but not for baseline model it requires model_uri (baseline comparison set to false) #
###########################################################################################

from databricks.feature_store import FeatureStoreClient

def get_fs_model(df):
fs_client = FeatureStoreClient()
return fs_client.score_batch(
model_uri,
spark.createDataFrame(df)
).select('prediction').toPandas()

###########################################################################################

{{ end }}

training_run = get_training_run(model_name, model_version)

# run evaluate
Expand All @@ -277,7 +299,11 @@ with mlflow.start_run(

try:
eval_result = mlflow.evaluate(
{{ if (eq .input_include_feature_store `yes`) }}
get_fs_model,
{{ else }}
model=model_uri,
{{ end }}
data=data,
targets=targets,
model_type=model_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def validation_thresholds():
threshold=500, higher_is_better=False # max_error should be <= 500
),
"mean_squared_error": MetricThreshold(
threshold=20, # mean_squared_error should be <= 20
threshold=200, # mean_squared_error should be <= 200
# min_absolute_change=0.01, # mean_squared_error should be at least 0.01 greater than baseline model accuracy
# min_relative_change=0.01, # mean_squared_error should be at least 1 percent greater than baseline model accuracy
higher_is_better=False,
Expand Down

0 comments on commit c4a52e4

Please sign in to comment.