Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] LightGBMRanker fit causes org.apache.spark.SparkException: Job aborted due to stage failure #2235

Open
3 of 19 tasks
shubhanshumishra-doordash opened this issue Jun 12, 2024 · 0 comments

Comments

@shubhanshumishra-doordash

SynapseML version

1.0.4

System information

  • Language version (e.g. python 3.8, scala 2.12): python 3.9.5, 12.2 LTS ML (Scala 2.12)
  • Spark Version (e.g. 3.2.3): Apache Spark 3.3.2
  • Spark Platform (e.g. Synapse, Databricks): Databricks
  • SynapseML com.microsoft.azure:synapseml_2.12:1.0.4

Describe the problem

Databricks Cluster Details:

8-14 Workers
6,144-10,752 GB Memory
768-1,344 Cores
1 Driver
384 GB Memory, 48 Cores
Runtime
12.2.x-cpu-ml-scala2.12
r5d.24xlarge
r5d.12xlarge
183–314 DBU/h

When training LGBRanker I am seeing the following error:

org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(172, 516) finished unsuccessfully.

Code to reproduce issue

from synapse.ml.lightgbm import LightGBMRanker
from pyspark.ml.feature import VectorAssembler

spark_feature_col = "spark_features"
def process_spark_data(store_ranker_train_dataset, store_ranker_eval_dataset):
  featurizer = VectorAssembler(inputCols=feature_columns, outputCol=spark_feature_col)
  train_data = featurizer.transform(store_ranker_train_dataset)[spark_feature_col, label_cols[0], session_group_id_cols[0]]
  eval_data = featurizer.transform(store_ranker_eval_dataset)[spark_feature_col, label_cols[0], session_group_id_cols[0]]
  return train_data, eval_data


eval_dataset = (
  spark.read.format("delta").load(data_dir.replace("/dbfs", ""))
  .where(f"active_date >= {val_start_date!r} AND active_date <= {val_end_date!r}")
  .select(final_columns))

train_dataset = (
  spark.read.format("delta").load(data_dir.replace("/dbfs", ""))
  .where(f"active_date >= {train_start_date!r} AND active_date <= {train_end_date!r}")
  .select(final_columns))

train_data, eval_data = process_spark_data(train_dataset, eval_dataset)
lgbm_ranker = LightGBMRanker(
    labelCol=label_cols[0],
    featuresCol=spark_feature_col,
    groupCol=session_group_id_cols[0],
    predictionCol="preds",
    leafPredictionCol="leafPreds",
    featuresShapCol="importances",
    repartitionByGroupingColumn=True,
    objective="lambdarank",
    metric="ndcg",
    dataTransferMode="streaming",
    numIterations=500,  # number of trees
    labelGain=[0, 1, 3],
    passThroughArgs="importance_type=gain",
    seed=1234,
    numThreads=48,  # tune threads
    maxDepth=10,
    # maxDepth=-1,
    minDataInLeaf=20,
    minSumHessianInLeaf=0.01,
    minGainToSplit=0.0,
    numLeaves=254,
    # featurePreFilter=False,
    lambdaL1=0.5,
    featureFraction=0.72,
    evalAt=[1, 2, 3, 4, 5, 10, 20, 30],
    earlyStoppingRound=200,  # decrease
    useBarrierExecutionMode=True
)

lgbm_ranker.fit(train_data)

Other info / logs

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
File <command-3527623459345888>:1
----> 1 lgbm_ranker.fit(train_data)

File /databricks/python_shell/dbruntime/MLWorkloadsInstrumentation/_pyspark.py:30, in _create_patch_function.<locals>.patched_method(self, *args, **kwargs)
     28 call_succeeded = False
     29 try:
---> 30     result = original_method(self, *args, **kwargs)
     31     call_succeeded = True
     32     return result

File /databricks/python/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py:555, in safe_patch.<locals>.safe_patch_function(*args, **kwargs)
    553     patch_function.call(call_original, *args, **kwargs)
    554 else:
--> 555     patch_function(call_original, *args, **kwargs)
    557 session.state = "succeeded"
    559 try_log_autologging_event(
    560     AutologgingEventLogger.get_logger().log_patch_function_success,
    561     session,
   (...)
    565     kwargs,
    566 )

File /databricks/python/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py:254, in with_managed_run.<locals>.patch_with_managed_run(original, *args, **kwargs)
    251     managed_run = create_managed_run()
    253 try:
--> 254     result = patch_function(original, *args, **kwargs)
    255 except (Exception, KeyboardInterrupt):
    256     # In addition to standard Python exceptions, handle keyboard interrupts to ensure
    257     # that runs are terminated if a user prematurely interrupts training execution
    258     # (e.g. via sigint / ctrl-c)
    259     if managed_run:

File /databricks/python/lib/python3.9/site-packages/mlflow/pyspark/ml/__init__.py:1109, in autolog.<locals>.patched_fit(original, self, *args, **kwargs)
   1107 if t.should_log():
   1108     with _AUTOLOGGING_METRICS_MANAGER.disable_log_post_training_metrics():
-> 1109         fit_result = fit_mlflow(original, self, *args, **kwargs)
   1110     # In some cases the `fit_result` may be an iterator of spark models.
   1111     if should_log_post_training_metrics and isinstance(fit_result, Model):

File /databricks/python/lib/python3.9/site-packages/mlflow/pyspark/ml/__init__.py:1095, in autolog.<locals>.fit_mlflow(original, self, *args, **kwargs)
   1093 _log_pretraining_metadata(estimator, params)
   1094 input_training_df = args[0].persist(StorageLevel.MEMORY_AND_DISK)
-> 1095 spark_model = original(self, *args, **kwargs)
   1096 _log_posttraining_metadata(estimator, spark_model, params, input_training_df)
   1097 input_training_df.unpersist()

File /databricks/python/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py:536, in safe_patch.<locals>.safe_patch_function.<locals>.call_original(*og_args, **og_kwargs)
    533         original_result = original(*_og_args, **_og_kwargs)
    534         return original_result
--> 536 return call_original_fn_with_event_logging(_original_fn, og_args, og_kwargs)

File /databricks/python/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py:471, in safe_patch.<locals>.safe_patch_function.<locals>.call_original_fn_with_event_logging(original_fn, og_args, og_kwargs)
    462 try:
    463     try_log_autologging_event(
    464         AutologgingEventLogger.get_logger().log_original_function_start,
    465         session,
   (...)
    469         og_kwargs,
    470     )
--> 471     original_fn_result = original_fn(*og_args, **og_kwargs)
    473     try_log_autologging_event(
    474         AutologgingEventLogger.get_logger().log_original_function_success,
    475         session,
   (...)
    479         og_kwargs,
    480     )
    481     return original_fn_result

File /databricks/python/lib/python3.9/site-packages/mlflow/utils/autologging_utils/safety.py:533, in safe_patch.<locals>.safe_patch_function.<locals>.call_original.<locals>._original_fn(*_og_args, **_og_kwargs)
    525 # Show all non-MLflow warnings as normal (i.e. not as event logs)
    526 # during original function execution, even if silent mode is enabled
    527 # (`silent=True`), since these warnings originate from the ML framework
    528 # or one of its dependencies and are likely relevant to the caller
    529 with set_non_mlflow_warnings_behavior_for_current_thread(
    530     disable_warnings=False,
    531     reroute_warnings=False,
    532 ):
--> 533     original_result = original(*_og_args, **_og_kwargs)
    534     return original_result

File /databricks/spark/python/pyspark/ml/base.py:205, in Estimator.fit(self, dataset, params)
    203         return self.copy(params)._fit(dataset)
    204     else:
--> 205         return self._fit(dataset)
    206 else:
    207     raise TypeError(
    208         "Params must be either a param map or a list/tuple of param maps, "
    209         "but got %s." % type(params)
    210     )

File /local_disk0/spark-69d54278-6df5-43cd-91cb-def9861fb93a/userFiles-8ff55ffb-4948-40f0-acf4-441ba8fd8534/addedFile6359067566179330225com_microsoft_azure_synapseml_lightgbm_2_12_1_0_4-f4a0f.jar/synapse/ml/lightgbm/LightGBMRanker.py:2148, in LightGBMRanker._fit(self, dataset)
   2147 def _fit(self, dataset):
-> 2148     java_model = self._fit_java(dataset)
   2149     return self._create_model(java_model)

File /databricks/spark/python/pyspark/ml/wrapper.py:380, in JavaEstimator._fit_java(self, dataset)
    377 assert self._java_obj is not None
    379 self._transfer_params_to_java()
--> 380 return self._java_obj.fit(dataset._jdf)

File /databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/java_gateway.py:1321, in JavaMember.__call__(self, *args)
   1315 command = proto.CALL_COMMAND_NAME +\
   1316     self.command_header +\
   1317     args_command +\
   1318     proto.END_COMMAND_PART
   1320 answer = self.gateway_client.send_command(command)
-> 1321 return_value = get_return_value(
   1322     answer, self.gateway_client, self.target_id, self.name)
   1324 for temp_arg in temp_args:
   1325     temp_arg._detach()

File /databricks/spark/python/pyspark/errors/exceptions.py:228, in capture_sql_exception.<locals>.deco(*a, **kw)
    226 def deco(*a: Any, **kw: Any) -> Any:
    227     try:
--> 228         return f(*a, **kw)
    229     except Py4JJavaError as e:
    230         converted = convert_exception(e.java_exception)

File /databricks/spark/python/lib/py4j-0.10.9.5-src.zip/py4j/protocol.py:326, in get_return_value(answer, gateway_client, target_id, name)
    324 value = OUTPUT_CONVERTER[type](answer[2:], gateway_client)
    325 if answer[1] == REFERENCE_TYPE:
--> 326     raise Py4JJavaError(
    327         "An error occurred while calling {0}{1}{2}.\n".
    328         format(target_id, ".", name), value)
    329 else:
    330     raise Py4JError(
    331         "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n".
    332         format(target_id, ".", name, value))

Py4JJavaError: An error occurred while calling o1923.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Could not recover from a failed barrier ResultStage. Most recent failure reason: Stage failed because barrier task ResultTask(172, 516) finished unsuccessfully.
ExecutorLostFailure (executor 23 exited caused by one of the running tasks) Reason: Command exited with code 134
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:3424)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:3346)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:3335)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:3335)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskCompletion(DAGScheduler.scala:2816)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3629)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3573)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:3561)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:51)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$runJob$1(DAGScheduler.scala:1193)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at com.databricks.spark.util.FrameProfiler$.record(FrameProfiler.scala:80)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:1181)
	at org.apache.spark.SparkContext.runJobInternal(SparkContext.scala:2758)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1070)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:165)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:125)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:445)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1068)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.executePartitionTasks(LightGBMBase.scala:621)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.executeTraining(LightGBMBase.scala:598)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.trainOneDataBatch(LightGBMBase.scala:446)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.$anonfun$train$2(LightGBMBase.scala:62)
	at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logVerb(SynapseMLLogging.scala:163)
	at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logVerb$(SynapseMLLogging.scala:160)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMRanker.logVerb(LightGBMRanker.scala:26)
	at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logFit(SynapseMLLogging.scala:153)
	at com.microsoft.azure.synapse.ml.logging.SynapseMLLogging.logFit$(SynapseMLLogging.scala:152)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMRanker.logFit(LightGBMRanker.scala:26)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.train(LightGBMBase.scala:64)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMBase.train$(LightGBMBase.scala:36)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMRanker.train(LightGBMRanker.scala:26)
	at com.microsoft.azure.synapse.ml.lightgbm.LightGBMRanker.train(LightGBMRanker.scala:26)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:380)
	at py4j.Gateway.invoke(Gateway.java:306)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:195)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:115)
	at java.lang.Thread.run(Thread.java:750)

What component(s) does this bug affect?

  • area/cognitive: Cognitive project
  • area/core: Core project
  • area/deep-learning: DeepLearning project
  • area/lightgbm: Lightgbm project
  • area/opencv: Opencv project
  • area/vw: VW project
  • area/website: Website
  • area/build: Project build system
  • area/notebooks: Samples under notebooks folder
  • area/docker: Docker usage
  • area/models: models related issue

What language(s) does this bug affect?

  • language/scala: Scala source code
  • language/python: Pyspark APIs
  • language/r: R APIs
  • language/csharp: .NET APIs
  • language/new: Proposals for new client languages

What integration(s) does this bug affect?

  • integrations/synapse: Azure Synapse integrations
  • integrations/azureml: Azure ML integrations
  • integrations/databricks: Databricks integrations
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant