|
8 | 8 |
|
9 | 9 | from flytekit.core import context_manager
|
10 | 10 | from flytekitplugins.spark import Spark
|
11 |
| -from flytekitplugins.spark.task import Databricks, new_spark_session |
| 11 | +from flytekitplugins.spark.task import Databricks, DatabricksV2, new_spark_session |
12 | 12 | from pyspark.sql import SparkSession
|
13 | 13 |
|
14 | 14 | import flytekit
|
@@ -105,6 +105,45 @@ def my_databricks(a: int) -> int:
|
105 | 105 | assert my_databricks(a=3) == 3
|
106 | 106 |
|
107 | 107 |
|
| 108 | +@pytest.mark.parametrize("spark_conf", [None, {"spark": "2"}]) |
| 109 | +def test_databricks_v2(reset_spark_session, spark_conf): |
| 110 | + databricks_conf = { |
| 111 | + "name": "flytekit databricks plugin example", |
| 112 | + "new_cluster": { |
| 113 | + "spark_version": "11.0.x-scala2.12", |
| 114 | + "node_type_id": "r3.xlarge", |
| 115 | + "aws_attributes": {"availability": "ON_DEMAND"}, |
| 116 | + "num_workers": 4, |
| 117 | + "docker_image": {"url": "pingsutw/databricks:latest"}, |
| 118 | + }, |
| 119 | + "timeout_seconds": 3600, |
| 120 | + "max_retries": 1, |
| 121 | + "spark_python_task": { |
| 122 | + "python_file": "dbfs:///FileStore/tables/entrypoint-1.py", |
| 123 | + "parameters": "ls", |
| 124 | + }, |
| 125 | + } |
| 126 | + |
| 127 | + databricks_instance = "account.cloud.databricks.com" |
| 128 | + |
| 129 | + @task( |
| 130 | + task_config=DatabricksV2( |
| 131 | + databricks_conf=databricks_conf, |
| 132 | + databricks_instance=databricks_instance, |
| 133 | + spark_conf=spark_conf, |
| 134 | + ) |
| 135 | + ) |
| 136 | + def my_databricks(a: int) -> int: |
| 137 | + session = flytekit.current_context().spark_session |
| 138 | + assert session.sparkContext.appName == "FlyteSpark: ex:local:local:local" |
| 139 | + return a |
| 140 | + |
| 141 | + assert my_databricks.task_config is not None |
| 142 | + assert my_databricks.task_config.databricks_conf == databricks_conf |
| 143 | + assert my_databricks.task_config.databricks_instance == databricks_instance |
| 144 | + assert my_databricks.task_config.spark_conf == (spark_conf or {}) |
| 145 | + assert my_databricks(a=3) == 3 |
| 146 | + |
108 | 147 | def test_new_spark_session():
|
109 | 148 | name = "SessionName"
|
110 | 149 | spark_conf = {"spark1": "1", "spark2": "2"}
|
|
0 commit comments