Skip to content

Commit c9a87d3

Browse files
authored
[cherry-pick] KeyError: 'sparkConf' occurs when running a Databricks task without spark_conf (#3282)
Signed-off-by: Kevin Su <[email protected]>
1 parent 9188d1b commit c9a87d3

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

plugins/flytekit-spark/flytekitplugins/spark/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def _get_databricks_job_spec(task_template: TaskTemplate) -> dict:
3838
if not new_cluster.get("docker_image"):
3939
new_cluster["docker_image"] = {"url": container.image}
4040
if not new_cluster.get("spark_conf"):
41-
new_cluster["spark_conf"] = custom["sparkConf"]
41+
new_cluster["spark_conf"] = custom.get("sparkConf", {})
4242
if not new_cluster.get("spark_env_vars"):
4343
new_cluster["spark_env_vars"] = {k: v for k, v in envs.items()}
4444
else:

plugins/flytekit-spark/tests/test_spark_task.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from flytekit.core import context_manager
1010
from flytekitplugins.spark import Spark
11-
from flytekitplugins.spark.task import Databricks, new_spark_session
11+
from flytekitplugins.spark.task import Databricks, DatabricksV2, new_spark_session
1212
from pyspark.sql import SparkSession
1313

1414
import flytekit
@@ -105,6 +105,45 @@ def my_databricks(a: int) -> int:
105105
assert my_databricks(a=3) == 3
106106

107107

108+
@pytest.mark.parametrize("spark_conf", [None, {"spark": "2"}])
109+
def test_databricks_v2(reset_spark_session, spark_conf):
110+
databricks_conf = {
111+
"name": "flytekit databricks plugin example",
112+
"new_cluster": {
113+
"spark_version": "11.0.x-scala2.12",
114+
"node_type_id": "r3.xlarge",
115+
"aws_attributes": {"availability": "ON_DEMAND"},
116+
"num_workers": 4,
117+
"docker_image": {"url": "pingsutw/databricks:latest"},
118+
},
119+
"timeout_seconds": 3600,
120+
"max_retries": 1,
121+
"spark_python_task": {
122+
"python_file": "dbfs:///FileStore/tables/entrypoint-1.py",
123+
"parameters": "ls",
124+
},
125+
}
126+
127+
databricks_instance = "account.cloud.databricks.com"
128+
129+
@task(
130+
task_config=DatabricksV2(
131+
databricks_conf=databricks_conf,
132+
databricks_instance=databricks_instance,
133+
spark_conf=spark_conf,
134+
)
135+
)
136+
def my_databricks(a: int) -> int:
137+
session = flytekit.current_context().spark_session
138+
assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local"
139+
return a
140+
141+
assert my_databricks.task_config is not None
142+
assert my_databricks.task_config.databricks_conf == databricks_conf
143+
assert my_databricks.task_config.databricks_instance == databricks_instance
144+
assert my_databricks.task_config.spark_conf == (spark_conf or {})
145+
assert my_databricks(a=3) == 3
146+
108147
def test_new_spark_session():
109148
name = "SessionName"
110149
spark_conf = {"spark1": "1", "spark2": "2"}

tests/flytekit/unit/experimental/test_eager_workflows.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,9 +287,15 @@ async def eager_wf_flyte_directory() -> str:
287287
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data")
288288
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data")
289289
@mock.patch("flytekit.core.utils.write_proto_to_file")
290-
def test_eager_workflow_dispatch(mock_write_to_file, mock_put_data, mock_get_data, mock_load_proto, event_loop):
290+
def test_eager_workflow_dispatch(mock_write_to_file, mock_put_data, mock_get_data, mock_load_proto):
291291
"""Test that event loop is preserved after executing eager workflow via dispatch."""
292292

293+
try:
294+
event_loop = asyncio.get_running_loop()
295+
except RuntimeError:
296+
event_loop = asyncio.new_event_loop()
297+
asyncio.set_event_loop(event_loop)
298+
293299
@eager
294300
async def eager_wf():
295301
await asyncio.sleep(0.1)

0 commit comments

Comments
 (0)