Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Cannot Load LightGBM Model When Placed in a Spark Pipeline with Custom Transformers #2293

Open
3 of 19 tasks
dsmith111 opened this issue Oct 6, 2024 · 0 comments
Open
3 of 19 tasks

Comments

@dsmith111
Copy link

dsmith111 commented Oct 6, 2024

SynapseML version

1.0.5

System information

  • Language version: Python 3.11.0rc1 and Scala version 2.12.15
  • Spark Version: 3.5.0
  • Spark Platform: Databricks

Describe the problem

When attempting to create a Spark pipeline that contains both a custom transformer and a LightGBM model, the loading of the pipeline fails with "AttributeError: module 'com.microsoft.azure.synapse.ml.lightgbm' has no attribute". Note: The saving/loading works fine if either the custom transformer is not present or the lightgbm model is not present, it is the combination of both that leads to the issue.

Related to issues #614 and #1701.

Code to reproduce issue

Example Custom Transformers

from pyspark.ml import Transformer
from pyspark.ml.util import MLReadable, MLWritable, DefaultParamsReadable, DefaultParamsWritable
from pyspark.ml.param.shared import HasInputCol, HasOutputCol
from pyspark.ml.param import Param, Params
import pyspark.sql.functions as f 
from pyspark.ml.linalg import VectorUDT, Vectors
from pyspark.sql.types import *
import json

class ColumnSelector(Transformer, DefaultParamsReadable, DefaultParamsWritable):
    """
    Custom Transformer to select and rename columns from a DataFrame.
    Enhanced to be MLWritable.
    """
    
    def __init__(self, selectExpr=None):
        super(ColumnSelector, self).__init__()
        self.selectExpr = Param(self, "selectExpr", "The SQL expression used for selecting and renaming columns")
        self._setDefault(selectExpr=selectExpr)
        if selectExpr is not None:
            self.setSelectExpr(selectExpr)

    def setSelectExpr(self, value):
        """
        Sets the SQL expression for selecting and renaming columns.
        """
        return self._set(selectExpr=value)

    def getSelectExpr(self):
        """
        Gets the current SQL expression for selecting and renaming columns.
        """
        return self.getOrDefault(self.selectExpr)

    def _transform(self, dataset):
        """
        The method that defines the operations to produce the selected and renamed columns.
        """
        return dataset.selectExpr(*self.getSelectExpr())


class StringArrayToVectorTransformer(Transformer, DefaultParamsReadable, DefaultParamsWritable):
    """
    Custom Transformer that converts a string representation of an integer array to a VectorUDT.
    Enhanced to be MLWritable.
    """
    
    def __init__(self, inputCol=None, outputCol=None):
        super(StringArrayToVectorTransformer, self).__init__()
        self.inputCol = Param(self, "inputCol", "The input column which is a string representation of an array")
        self.outputCol = Param(self, "outputCol", "The output column which will be a Dense Vector")
        self._setDefault(inputCol=inputCol, outputCol=outputCol)
        if inputCol is not None:
            self.setInputCol(inputCol)
        if outputCol is not None:
            self.setOutputCol(outputCol)

    def setInputCol(self, value):
        """
        Sets the value of `inputCol`.
        """
        return self._set(inputCol=value)

    def setOutputCol(self, value):
        """
        Sets the value of `outputCol`.
        """
        return self._set(outputCol=value)

    def _transform(self, dataset):
        """
        The method that defines the operations to produce the `outputCol` from `inputCol`.
        Converts string array "[1,2,3]" to a DenseVector.
        """
        pass

    def getOutputCol(self):
        return self.getOrDefault(self.outputCol)

    def getInputCol(self):
        return self.getOrDefault(self.inputCol)

Repro Code

from pyspark.ml import Pipeline
from pyspark.ml.pipeline import PipelineModel
import synapse.ml.lightgbm as lgbm

import CustomTransformers # Or just include the classes directly

string_array_to_vector = CustomTransformers.StringArrayToVectorTransformer(inputCol="embedded_object_keys_string", outputCol="features")
select_columns = CustomTransformers.ColumnSelector(selectExpr=["objectKey", "PreciseTimeStamp", "prediction"])

# load any training_df
lgbm_model = lgbm.LightGBMClassifier(featuresCol="features", labelCol="label").fit(string_array_to_vector.transform(training_df))

pipeline = Pipeline(
    stages=[
        string_array_to_vector,
        lgbm_model,
        select_columns
        ]
    )

pipeline.write().overwrite().save("/tmp/pipeline")
reloaded_pipe = PipelineModel.load("/tmp/pipeline") # Fails with the attribute error here

Other info / logs

More Logs for Attribute Error

return PipelineModel.load(model_uri)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/util.py", line 465, in load
    return cls.read().load(path)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/pipeline.py", line 288, in load
    uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/pipeline.py", line 442, in load
    stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, sc)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/util.py", line 749, in loadParamsInstance
    py_type: Type[RL] = DefaultParamsReader.__get_class(pythonClassName)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/util.py", line 655, in __get_class
    return getattr(m, parts[-1])
           ^^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'com.microsoft.azure.synapse.ml.lightgbm' has no attribute 'LightGBMClassificationModel'

Additionally, in issue #1701:

In this case the PipelineModel.write method returned a non java writer. The classes synapse.ml.lightgbm.LightGBMClassifier and synapse.ml.lightgbm.LightGBMRegressor inherit correct java reader (pyspark.ml.util.JavaMLReadable) and writer (pyspark.ml.util.JavaMLWritable). The problem is with the superclass synapse.ml.core.schema.Utils.ComplexParamsMixin that inherits only from the pyspark.ml.util.MLReadable.

I could bypass the problem by wrapping the estimator with the pyspark.ml.Pipeline. In this situation the write method of the last stage will return the JavaMLWriter not the PipelineModelWriter.

It seems to be related to some incorrect writer reference. However, my custom transformers inherit from the relevant Pyspark ML classes and should handle this.

Workarounds
Now, from those issues, I've been able to create some workarounds, which end up not being enough in certain contexts. The first workaround is simply by fitting, and then nesting the LightGBM model within another pipeline as a PipelineModel. This allows it to be saved and loaded in the interactive Spark driver just fine.

pipeline = Pipeline(
    stages=[
        custom_transformer,
        PipelineModel(stages=[lgbm_model]),
        custom_transformer
        ]
    )

The second workaround is by intercepting the class call like so

from pyspark.ml.util import DefaultParamsReader
try:
    from unittest import mock
except ImportError:
    # For Python 2 you might have to pip install mock
    import mock

class MmlShim(object):
    mangled_name = '_DefaultParamsReader__get_class'
    prev_get_clazz = getattr(DefaultParamsReader, mangled_name)

    @classmethod
    def __get_class(cls, clazz):
        try:
            return cls.prev_get_clazz(clazz)
        except AttributeError as outer:
            try:
                alt_clazz = clazz.replace('com.microsoft.azure.synapse', 'synapse')
                return cls.prev_get_clazz(alt_clazz)
            except AttributeError:
                raise outer

    def __enter__(self):
        self.mock = mock.patch.object(DefaultParamsReader, self.mangled_name, self.__get_class)
        self.mock.__enter__()
        return self

    def __exit__(self, *exc_info):
        self.mock.__exit__(*exc_info)

with MmlShim():
    reloaded_pipe = PipelineModel.load("/tmp/pipeline")

However, this workaround seems to break apart when I attempt to use it in Databricks specific commands. Such as "score_batch". When this command is run, the attribute error returns. I assume this is just due to the other tasks having new python instances without these fixes applied, so I setup an initialization script to make the change to the library on launch to see if it would resolve the problem:

synapse_init_patch.sh

#!/bin/bash

# Locate the pyspark ml util.py file
UTIL_PY=$(find /databricks -name util.py | grep pyspark/ml/util.py)

# Backup the original file
cp $UTIL_PY ${UTIL_PY}.bak

# Modify the file using sed or awk to insert the patch
sed -i "/def __get_class(clazz: str) -> Type\[RL\]:/a \\
        try:\\
            parts = clazz.split('.')\\
            module = '.'.join(parts[:-1])\\
            m = __import__(module, fromlist=[parts[-1]])\\
            return getattr(m, parts[-1])\\
        except AttributeError:\\
            if 'com.microsoft.azure.synapse' in clazz:\\
                clazz = clazz.replace('com.microsoft.azure.synapse', 'synapse')\\
            parts = clazz.split('.')\\
            module = '.'.join(parts[:-1])\\
            m = __import__(module, fromlist=[parts[-1]])\\
            return getattr(m, parts[-1])\\
            # Ignore the rest" $UTIL_PY


exit 0

This allows it to get past the attribute error, however, it results in a new error:

                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/python/lib/python3.11/site-packages/mlflow/spark/__init__.py", line 836, in _load_model
    return PipelineModel.load(model_uri)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/util.py", line 465, in load
    return cls.read().load(path)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/pipeline.py", line 288, in load
    uid, stages = PipelineSharedReadWrite.load(metadata, self.sc, path)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/pipeline.py", line 442, in load
    stage: "PipelineStage" = DefaultParamsReader.loadParamsInstance(stagePath, sc)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/util.py", line 763, in loadParamsInstance
    instance = py_type.load(path)
               ^^^^^^^^^^^^^^^^^^
  File "/databricks/spark/python/pyspark/ml/util.py", line 465, in load
    return cls.read().load(path)
           ^^^^^^^^^^
  File "/local_disk0/spark-ff9371a9-07c1-49fb-99e9-490767d2edf5/userFiles-8d89c9e1-eef5-4d7d-a724-0a863d9d2d54/com_microsoft_azure_synapseml_lightgbm_2_12_1_0_5.jar/synapse/ml/lightgbm/_LightGBMClassificationModel.py", line 142, in read
    return JavaMMLReader(cls)
           ^^^^^^^^^^^^^^^^^^
  File "/local_disk0/spark-ff9371a9-07c1-49fb-99e9-490767d2edf5/userFiles-8d89c9e1-eef5-4d7d-a724-0a863d9d2d54/com_microsoft_azure_synapseml_core_2_12_1_0_5.jar/synapse/ml/core/schema/Utils.py", line 149, in __init__
    super(JavaMMLReader, self).__init__(clazz)
  File "/databricks/spark/python/pyspark/ml/util.py", line 408, in __init__
    self._jread = self._load_java_obj(clazz).read()
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'JavaPackage' object is not callable

This issue is a hard blocker for my project and even if a patch isn't feasible in the short-term, I'm at least looking to have a monkey patch init script working to allow the project to progress until this issue is resolved upstream.

What component(s) does this bug affect?

  • area/cognitive: Cognitive project
  • area/core: Core project
  • area/deep-learning: DeepLearning project
  • area/lightgbm: Lightgbm project
  • area/opencv: Opencv project
  • area/vw: VW project
  • area/website: Website
  • area/build: Project build system
  • area/notebooks: Samples under notebooks folder
  • area/docker: Docker usage
  • area/models: models related issue

What language(s) does this bug affect?

  • language/scala: Scala source code
  • language/python: Pyspark APIs
  • language/r: R APIs
  • language/csharp: .NET APIs
  • language/new: Proposals for new client languages

What integration(s) does this bug affect?

  • integrations/synapse: Azure Synapse integrations
  • integrations/azureml: Azure ML integrations
  • integrations/databricks: Databricks integrations
    It
@dsmith111 dsmith111 added the bug label Oct 6, 2024
@github-actions github-actions bot added the triage label Oct 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant