Skip to content

Commit 3f41adc

Browse files
committed
[SPARK-54186][PYTHON][TESTS] Fix doctests for PandasCogroupedOps.applyInPandas
### What changes were proposed in this pull request? Enable doctests for `PandasCogroupedOps.applyInPandas` ### Why are the changes needed? to improve test coverage and make sure the examples are correct ### Does this PR introduce _any_ user-facing change? yes, doc-only changes ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #52885 from zhengruifeng/enable_apply_in_pandas. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 8b1ee1c commit 3f41adc

File tree

1 file changed

+50
-39
lines changed

1 file changed

+50
-39
lines changed

python/pyspark/sql/pandas/group_ops.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -73,20 +73,20 @@ def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame":
7373
>>> df = spark.createDataFrame(
7474
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
7575
... ("id", "v"))
76-
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
76+
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
7777
... def normalize(pdf):
7878
... v = pdf.v
7979
... return pdf.assign(v=(v - v.mean()) / v.std())
8080
...
81-
>>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
81+
>>> df.groupby("id").apply(normalize).sort("id", "v").show()
8282
+---+-------------------+
8383
| id| v|
8484
+---+-------------------+
85-
| 1|-0.7071067811865475|
86-
| 1| 0.7071067811865475|
87-
| 2|-0.8320502943378437|
88-
| 2|-0.2773500981126146|
89-
| 2| 1.1094003924504583|
85+
| 1|-0.7071067811865...|
86+
| 1| 0.7071067811865...|
87+
| 2|-0.8320502943378...|
88+
| 2|-0.2773500981126...|
89+
| 2| 1.1094003924504...|
9090
+---+-------------------+
9191
9292
See Also
@@ -159,25 +159,26 @@ def applyInPandas(
159159
160160
Examples
161161
--------
162-
>>> import pandas as pd # doctest: +SKIP
163-
>>> from pyspark.sql.functions import ceil
162+
>>> import pandas as pd
163+
>>> from pyspark.sql import functions as sf
164164
>>> df = spark.createDataFrame(
165165
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
166-
... ("id", "v")) # doctest: +SKIP
166+
... ("id", "v"))
167167
>>> def normalize(pdf):
168168
... v = pdf.v
169169
... return pdf.assign(v=(v - v.mean()) / v.std())
170170
...
171171
>>> df.groupby("id").applyInPandas(
172-
... normalize, schema="id long, v double").show() # doctest: +SKIP
172+
... normalize, schema="id long, v double"
173+
... ).sort("id", "v").show()
173174
+---+-------------------+
174175
| id| v|
175176
+---+-------------------+
176-
| 1|-0.7071067811865475|
177-
| 1| 0.7071067811865475|
178-
| 2|-0.8320502943378437|
179-
| 2|-0.2773500981126146|
180-
| 2| 1.1094003924504583|
177+
| 1|-0.7071067811865...|
178+
| 1| 0.7071067811865...|
179+
| 2|-0.8320502943378...|
180+
| 2|-0.2773500981126...|
181+
| 2| 1.1094003924504...|
181182
+---+-------------------+
182183
183184
Alternatively, the user can pass a function that takes two arguments.
@@ -189,14 +190,15 @@ def applyInPandas(
189190
190191
>>> df = spark.createDataFrame(
191192
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
192-
... ("id", "v")) # doctest: +SKIP
193+
... ("id", "v"))
193194
>>> def mean_func(key, pdf):
194195
... # key is a tuple of one numpy.int64, which is the value
195196
... # of 'id' for the current group
196197
... return pd.DataFrame([key + (pdf.v.mean(),)])
197198
...
198-
>>> df.groupby('id').applyInPandas(
199-
... mean_func, schema="id long, v double").show() # doctest: +SKIP
199+
>>> df.groupby("id").applyInPandas(
200+
... mean_func, schema="id long, v double"
201+
... ).sort("id").show()
200202
+---+---+
201203
| id| v|
202204
+---+---+
@@ -209,34 +211,36 @@ def applyInPandas(
209211
... # of 'id' and 'ceil(df.v / 2)' for the current group
210212
... return pd.DataFrame([key + (pdf.v.sum(),)])
211213
...
212-
>>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas(
213-
... sum_func, schema="id long, `ceil(v / 2)` long, v double").show() # doctest: +SKIP
214+
>>> df.groupby(df.id, sf.ceil(df.v / 2)).applyInPandas(
215+
... sum_func, schema="id long, `ceil(v / 2)` long, v double"
216+
... ).sort("id", "v").show()
214217
+---+-----------+----+
215218
| id|ceil(v / 2)| v|
216219
+---+-----------+----+
217-
| 2| 5|10.0|
218220
| 1| 1| 3.0|
219-
| 2| 3| 5.0|
220221
| 2| 2| 3.0|
222+
| 2| 3| 5.0|
223+
| 2| 5|10.0|
221224
+---+-----------+----+
222225
223226
The function can also take and return an iterator of `pandas.DataFrame` using type
224227
hints.
225228
226-
>>> from typing import Iterator # doctest: +SKIP
229+
>>> from typing import Iterator
227230
>>> df = spark.createDataFrame(
228231
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
229-
... ("id", "v")) # doctest: +SKIP
232+
... ("id", "v"))
230233
>>> def filter_func(
231234
... batches: Iterator[pd.DataFrame]
232-
... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
235+
... ) -> Iterator[pd.DataFrame]:
233236
... for batch in batches:
234237
... # Process and yield each batch independently
235238
... filtered = batch[batch['v'] > 2.0]
236239
... if not filtered.empty:
237240
... yield filtered[['v']]
238241
>>> df.groupby("id").applyInPandas(
239-
... filter_func, schema="v double").show() # doctest: +SKIP
242+
... filter_func, schema="v double"
243+
... ).sort("v").show()
240244
+----+
241245
| v|
242246
+----+
@@ -250,25 +254,26 @@ def applyInPandas(
250254
be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
251255
data types. The data will still be passed in as an iterator of `pandas.DataFrame`.
252256
253-
>>> from typing import Iterator, Tuple, Any # doctest: +SKIP
257+
>>> from typing import Iterator, Tuple, Any
254258
>>> def transform_func(
255259
... key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
256-
... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
260+
... ) -> Iterator[pd.DataFrame]:
257261
... for batch in batches:
258262
... # Yield transformed results for each batch
259263
... result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
260264
... yield result[['id', 'v_doubled']]
261265
>>> df.groupby("id").applyInPandas(
262-
... transform_func, schema="id long, v_doubled double").show() # doctest: +SKIP
263-
+---+----------+
264-
| id|v_doubled |
265-
+---+----------+
266-
| 1| 2.0|
267-
| 1| 4.0|
268-
| 2| 6.0|
269-
| 2| 10.0|
270-
| 2| 20.0|
271-
+---+----------+
266+
... transform_func, schema="id long, v_doubled double"
267+
... ).sort("id", "v_doubled").show()
268+
+---+---------+
269+
| id|v_doubled|
270+
+---+---------+
271+
| 1| 2.0|
272+
| 1| 4.0|
273+
| 2| 6.0|
274+
| 2| 10.0|
275+
| 2| 20.0|
276+
+---+---------+
272277
273278
Notes
274279
-----
@@ -1187,8 +1192,14 @@ def _test() -> None:
11871192
import doctest
11881193
from pyspark.sql import SparkSession
11891194
import pyspark.sql.pandas.group_ops
1195+
from pyspark.testing.utils import have_pandas, have_pyarrow
11901196

11911197
globs = pyspark.sql.pandas.group_ops.__dict__.copy()
1198+
1199+
if not have_pandas or not have_pyarrow:
1200+
del pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.apply.__doc__
1201+
del pyspark.sql.pandas.group_ops.PandasGroupedOpsMixin.applyInPandas.__doc__
1202+
11921203
spark = SparkSession.builder.master("local[4]").appName("sql.pandas.group tests").getOrCreate()
11931204
globs["spark"] = spark
11941205
(failure_count, test_count) = doctest.testmod(

0 commit comments

Comments
 (0)