Skip to content

Commit 79da1ab

Browse files
committed
[SPARK-41725][PYTHON][TESTS][FOLLOW-UP] Remove collect for SQL command execution in tests
### What changes were proposed in this pull request? This PR removes `sql("command").collect()` workaround in PySpark tests codes. ### Why are the changes needed? They were added previously to work around within Spark Connect. This is fixed now, so we don't need to call `collect` anymore. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? CI in this PR should test it out. Closes apache#40251 from HyukjinKwon/SPARK-41725. Authored-by: Hyukjin Kwon <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 03187e2 commit 79da1ab

File tree

6 files changed

+83
-93
lines changed

6 files changed

+83
-93
lines changed

python/pyspark/sql/catalog.py

+41-43
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,6 @@ def getDatabase(self, dbName: str) -> Database:
246246
locationUri=jdb.locationUri(),
247247
)
248248

249-
# TODO(SPARK-41725): we don't have to `collect` for every `sql` but
250-
# Spark Connect requires it. We should remove them out.
251249
def databaseExists(self, dbName: str) -> bool:
252250
"""Check if the database with the specified name exists.
253251
@@ -275,15 +273,15 @@ def databaseExists(self, dbName: str) -> bool:
275273
276274
>>> spark.catalog.databaseExists("test_new_database")
277275
False
278-
>>> _ = spark.sql("CREATE DATABASE test_new_database").collect()
276+
>>> _ = spark.sql("CREATE DATABASE test_new_database")
279277
>>> spark.catalog.databaseExists("test_new_database")
280278
True
281279
282280
Using the fully qualified name with the catalog name.
283281
284282
>>> spark.catalog.databaseExists("spark_catalog.test_new_database")
285283
True
286-
>>> _ = spark.sql("DROP DATABASE test_new_database").collect()
284+
>>> _ = spark.sql("DROP DATABASE test_new_database")
287285
"""
288286
return self._jcatalog.databaseExists(dbName)
289287

@@ -372,8 +370,8 @@ def getTable(self, tableName: str) -> Table:
372370
373371
Examples
374372
--------
375-
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
376-
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
373+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
374+
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
377375
>>> spark.catalog.getTable("tbl1")
378376
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
379377
@@ -383,7 +381,7 @@ def getTable(self, tableName: str) -> Table:
383381
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
384382
>>> spark.catalog.getTable("spark_catalog.default.tbl1")
385383
Table(name='tbl1', catalog='spark_catalog', namespace=['default'], ...
386-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
384+
>>> _ = spark.sql("DROP TABLE tbl1")
387385
388386
Throw an analysis exception when the table does not exist.
389387
@@ -535,7 +533,7 @@ def getFunction(self, functionName: str) -> Function:
535533
Examples
536534
--------
537535
>>> _ = spark.sql(
538-
... "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'").collect()
536+
... "CREATE FUNCTION my_func1 AS 'test.org.apache.spark.sql.MyDoubleAvg'")
539537
>>> spark.catalog.getFunction("my_func1")
540538
Function(name='my_func1', catalog='spark_catalog', namespace=['default'], ...
541539
@@ -602,11 +600,11 @@ def listColumns(self, tableName: str, dbName: Optional[str] = None) -> List[Colu
602600
603601
Examples
604602
--------
605-
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
606-
>>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet").collect()
603+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
604+
>>> _ = spark.sql("CREATE TABLE tblA (name STRING, age INT) USING parquet")
607605
>>> spark.catalog.listColumns("tblA")
608606
[Column(name='name', description=None, dataType='string', nullable=True, ...
609-
>>> _ = spark.sql("DROP TABLE tblA").collect()
607+
>>> _ = spark.sql("DROP TABLE tblA")
610608
"""
611609
if dbName is None:
612610
iter = self._jcatalog.listColumns(tableName).toLocalIterator()
@@ -667,8 +665,8 @@ def tableExists(self, tableName: str, dbName: Optional[str] = None) -> bool:
667665
668666
>>> spark.catalog.tableExists("unexisting_table")
669667
False
670-
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
671-
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
668+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
669+
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
672670
>>> spark.catalog.tableExists("tbl1")
673671
True
674672
@@ -680,13 +678,13 @@ def tableExists(self, tableName: str, dbName: Optional[str] = None) -> bool:
680678
True
681679
>>> spark.catalog.tableExists("tbl1", "default")
682680
True
683-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
681+
>>> _ = spark.sql("DROP TABLE tbl1")
684682
685683
Check if views exist:
686684
687685
>>> spark.catalog.tableExists("view1")
688686
False
689-
>>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1").collect()
687+
>>> _ = spark.sql("CREATE VIEW view1 AS SELECT 1")
690688
>>> spark.catalog.tableExists("view1")
691689
True
692690
@@ -698,14 +696,14 @@ def tableExists(self, tableName: str, dbName: Optional[str] = None) -> bool:
698696
True
699697
>>> spark.catalog.tableExists("view1", "default")
700698
True
701-
>>> _ = spark.sql("DROP VIEW view1").collect()
699+
>>> _ = spark.sql("DROP VIEW view1")
702700
703701
Check if temporary views exist:
704702
705-
>>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1").collect()
703+
>>> _ = spark.sql("CREATE TEMPORARY VIEW view1 AS SELECT 1")
706704
>>> spark.catalog.tableExists("view1")
707705
True
708-
>>> df = spark.sql("DROP VIEW view1").collect()
706+
>>> df = spark.sql("DROP VIEW view1")
709707
>>> spark.catalog.tableExists("view1")
710708
False
711709
"""
@@ -806,15 +804,15 @@ def createTable(
806804
Creating a managed table.
807805
808806
>>> _ = spark.catalog.createTable("tbl1", schema=spark.range(1).schema, source='parquet')
809-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
807+
>>> _ = spark.sql("DROP TABLE tbl1")
810808
811809
Creating an external table
812810
813811
>>> import tempfile
814812
>>> with tempfile.TemporaryDirectory() as d:
815813
... _ = spark.catalog.createTable(
816814
... "tbl2", schema=spark.range(1).schema, path=d, source='parquet')
817-
>>> _ = spark.sql("DROP TABLE tbl2").collect()
815+
>>> _ = spark.sql("DROP TABLE tbl2")
818816
"""
819817
if path is not None:
820818
options["path"] = path
@@ -954,8 +952,8 @@ def isCached(self, tableName: str) -> bool:
954952
955953
Examples
956954
--------
957-
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
958-
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
955+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
956+
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
959957
>>> spark.catalog.cacheTable("tbl1")
960958
>>> spark.catalog.isCached("tbl1")
961959
True
@@ -972,7 +970,7 @@ def isCached(self, tableName: str) -> bool:
972970
>>> spark.catalog.isCached("spark_catalog.default.tbl1")
973971
True
974972
>>> spark.catalog.uncacheTable("tbl1")
975-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
973+
>>> _ = spark.sql("DROP TABLE tbl1")
976974
"""
977975
return self._jcatalog.isCached(tableName)
978976

@@ -994,8 +992,8 @@ def cacheTable(self, tableName: str) -> None:
994992
995993
Examples
996994
--------
997-
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
998-
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
995+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
996+
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
999997
>>> spark.catalog.cacheTable("tbl1")
1000998
1001999
Throw an analysis exception when the table does not exist.
@@ -1009,7 +1007,7 @@ def cacheTable(self, tableName: str) -> None:
10091007
10101008
>>> spark.catalog.cacheTable("spark_catalog.default.tbl1")
10111009
>>> spark.catalog.uncacheTable("tbl1")
1012-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
1010+
>>> _ = spark.sql("DROP TABLE tbl1")
10131011
"""
10141012
self._jcatalog.cacheTable(tableName)
10151013

@@ -1031,8 +1029,8 @@ def uncacheTable(self, tableName: str) -> None:
10311029
10321030
Examples
10331031
--------
1034-
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
1035-
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
1032+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
1033+
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
10361034
>>> spark.catalog.cacheTable("tbl1")
10371035
>>> spark.catalog.uncacheTable("tbl1")
10381036
>>> spark.catalog.isCached("tbl1")
@@ -1050,7 +1048,7 @@ def uncacheTable(self, tableName: str) -> None:
10501048
>>> spark.catalog.uncacheTable("spark_catalog.default.tbl1")
10511049
>>> spark.catalog.isCached("tbl1")
10521050
False
1053-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
1051+
>>> _ = spark.sql("DROP TABLE tbl1")
10541052
"""
10551053
self._jcatalog.uncacheTable(tableName)
10561054

@@ -1064,12 +1062,12 @@ def clearCache(self) -> None:
10641062
10651063
Examples
10661064
--------
1067-
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
1068-
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet").collect()
1065+
>>> _ = spark.sql("DROP TABLE IF EXISTS tbl1")
1066+
>>> _ = spark.sql("CREATE TABLE tbl1 (name STRING, age INT) USING parquet")
10691067
>>> spark.catalog.clearCache()
10701068
>>> spark.catalog.isCached("tbl1")
10711069
False
1072-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
1070+
>>> _ = spark.sql("DROP TABLE tbl1")
10731071
"""
10741072
self._jcatalog.clearCache()
10751073

@@ -1095,10 +1093,10 @@ def refreshTable(self, tableName: str) -> None:
10951093
10961094
>>> import tempfile
10971095
>>> with tempfile.TemporaryDirectory() as d:
1098-
... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
1096+
... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
10991097
... _ = spark.sql(
1100-
... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect()
1101-
... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect()
1098+
... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d))
1099+
... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'")
11021100
... spark.catalog.cacheTable("tbl1")
11031101
... spark.table("tbl1").show()
11041102
+---+
@@ -1121,7 +1119,7 @@ def refreshTable(self, tableName: str) -> None:
11211119
Using the fully qualified name for the table.
11221120
11231121
>>> spark.catalog.refreshTable("spark_catalog.default.tbl1")
1124-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
1122+
>>> _ = spark.sql("DROP TABLE tbl1")
11251123
"""
11261124
self._jcatalog.refreshTable(tableName)
11271125

@@ -1149,12 +1147,12 @@ def recoverPartitions(self, tableName: str) -> None:
11491147
11501148
>>> import tempfile
11511149
>>> with tempfile.TemporaryDirectory() as d:
1152-
... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
1150+
... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
11531151
... spark.range(1).selectExpr(
11541152
... "id as key", "id as value").write.partitionBy("key").mode("overwrite").save(d)
11551153
... _ = spark.sql(
11561154
... "CREATE TABLE tbl1 (key LONG, value LONG)"
1157-
... "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d)).collect()
1155+
... "USING parquet OPTIONS (path '{}') PARTITIONED BY (key)".format(d))
11581156
... spark.table("tbl1").show()
11591157
... spark.catalog.recoverPartitions("tbl1")
11601158
... spark.table("tbl1").show()
@@ -1167,7 +1165,7 @@ def recoverPartitions(self, tableName: str) -> None:
11671165
+-----+---+
11681166
| 0| 0|
11691167
+-----+---+
1170-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
1168+
>>> _ = spark.sql("DROP TABLE tbl1")
11711169
"""
11721170
self._jcatalog.recoverPartitions(tableName)
11731171

@@ -1191,10 +1189,10 @@ def refreshByPath(self, path: str) -> None:
11911189
11921190
>>> import tempfile
11931191
>>> with tempfile.TemporaryDirectory() as d:
1194-
... _ = spark.sql("DROP TABLE IF EXISTS tbl1").collect()
1192+
... _ = spark.sql("DROP TABLE IF EXISTS tbl1")
11951193
... _ = spark.sql(
1196-
... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d)).collect()
1197-
... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'").collect()
1194+
... "CREATE TABLE tbl1 (col STRING) USING TEXT LOCATION '{}'".format(d))
1195+
... _ = spark.sql("INSERT INTO tbl1 SELECT 'abc'")
11981196
... spark.catalog.cacheTable("tbl1")
11991197
... spark.table("tbl1").show()
12001198
+---+
@@ -1214,7 +1212,7 @@ def refreshByPath(self, path: str) -> None:
12141212
>>> spark.table("tbl1").count()
12151213
0
12161214
1217-
>>> _ = spark.sql("DROP TABLE tbl1").collect()
1215+
>>> _ = spark.sql("DROP TABLE tbl1")
12181216
"""
12191217
self._jcatalog.refreshByPath(path)
12201218

python/pyspark/sql/readwriter.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,7 @@ def table(self, tableName: str) -> "DataFrame":
466466
| 8|
467467
| 9|
468468
+---+
469-
>>> _ = spark.sql("DROP TABLE tblA").collect()
469+
>>> _ = spark.sql("DROP TABLE tblA")
470470
"""
471471
return self._df(self._jreader.table(tableName))
472472

@@ -1232,7 +1232,7 @@ def bucketBy(
12321232
12331233
>>> from pyspark.sql.functions import input_file_name
12341234
>>> # Write a DataFrame into a Parquet file in a bucketed manner.
1235-
... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table").collect()
1235+
... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table")
12361236
>>> spark.createDataFrame([
12371237
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
12381238
... schema=["age", "name"]
@@ -1246,7 +1246,7 @@ def bucketBy(
12461246
|120|Hyukjin Kwon|
12471247
|140| Haejoon Lee|
12481248
+---+------------+
1249-
>>> _ = spark.sql("DROP TABLE bucketed_table").collect()
1249+
>>> _ = spark.sql("DROP TABLE bucketed_table")
12501250
"""
12511251
if not isinstance(numBuckets, int):
12521252
raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets)))
@@ -1296,7 +1296,7 @@ def sortBy(
12961296
12971297
>>> from pyspark.sql.functions import input_file_name
12981298
>>> # Write a DataFrame into a Parquet file in a sorted-bucketed manner.
1299-
... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table").collect()
1299+
... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table")
13001300
>>> spark.createDataFrame([
13011301
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
13021302
... schema=["age", "name"]
@@ -1311,7 +1311,7 @@ def sortBy(
13111311
|120|Hyukjin Kwon|
13121312
|140| Haejoon Lee|
13131313
+---+------------+
1314-
>>> _ = spark.sql("DROP TABLE sorted_bucketed_table").collect()
1314+
>>> _ = spark.sql("DROP TABLE sorted_bucketed_table")
13151315
"""
13161316
if isinstance(col, (list, tuple)):
13171317
if cols:
@@ -1417,7 +1417,7 @@ def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> None:
14171417
14181418
Examples
14191419
--------
1420-
>>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
1420+
>>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
14211421
>>> df = spark.createDataFrame([
14221422
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
14231423
... schema=["age", "name"]
@@ -1438,7 +1438,7 @@ def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> None:
14381438
|140| Haejoon Lee|
14391439
|140| Haejoon Lee|
14401440
+---+------------+
1441-
>>> _ = spark.sql("DROP TABLE tblA").collect()
1441+
>>> _ = spark.sql("DROP TABLE tblA")
14421442
"""
14431443
if overwrite is not None:
14441444
self.mode("overwrite" if overwrite else "append")
@@ -1495,7 +1495,7 @@ def saveAsTable(
14951495
--------
14961496
Creates a table from a DataFrame, and read it back.
14971497
1498-
>>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
1498+
>>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
14991499
>>> spark.createDataFrame([
15001500
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")],
15011501
... schema=["age", "name"]
@@ -1508,7 +1508,7 @@ def saveAsTable(
15081508
|120|Hyukjin Kwon|
15091509
|140| Haejoon Lee|
15101510
+---+------------+
1511-
>>> _ = spark.sql("DROP TABLE tblA").collect()
1511+
>>> _ = spark.sql("DROP TABLE tblA")
15121512
"""
15131513
self.mode(mode).options(**options)
15141514
if partitionBy is not None:

0 commit comments

Comments
 (0)