@@ -73,20 +73,20 @@ def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame":
7373 >>> df = spark.createDataFrame(
7474 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
7575 ... ("id", "v"))
76- >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
76+ >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
7777 ... def normalize(pdf):
7878 ... v = pdf.v
7979 ... return pdf.assign(v=(v - v.mean()) / v.std())
8080 ...
81- >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
81+ >>> df.groupby("id").apply(normalize).sort("id", "v"). show()
8282 +---+-------------------+
8383 | id| v|
8484 +---+-------------------+
85- | 1|-0.7071067811865475 |
86- | 1| 0.7071067811865475 |
87- | 2|-0.8320502943378437 |
88- | 2|-0.2773500981126146 |
89- | 2| 1.1094003924504583 |
85+ | 1|-0.7071067811865... |
86+ | 1| 0.7071067811865... |
87+ | 2|-0.8320502943378... |
88+ | 2|-0.2773500981126... |
89+ | 2| 1.1094003924504... |
9090 +---+-------------------+
9191
9292 See Also
@@ -159,25 +159,26 @@ def applyInPandas(
159159
160160 Examples
161161 --------
162- >>> import pandas as pd # doctest: +SKIP
163- >>> from pyspark.sql.functions import ceil
162+ >>> import pandas as pd
163+ >>> from pyspark.sql import functions as sf
164164 >>> df = spark.createDataFrame(
165165 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
166- ... ("id", "v")) # doctest: +SKIP
166+ ... ("id", "v"))
167167 >>> def normalize(pdf):
168168 ... v = pdf.v
169169 ... return pdf.assign(v=(v - v.mean()) / v.std())
170170 ...
171171 >>> df.groupby("id").applyInPandas(
172- ... normalize, schema="id long, v double").show() # doctest: +SKIP
172+ ... normalize, schema="id long, v double"
173+ ... ).sort("id", "v").show()
173174 +---+-------------------+
174175 | id| v|
175176 +---+-------------------+
176- | 1|-0.7071067811865475 |
177- | 1| 0.7071067811865475 |
178- | 2|-0.8320502943378437 |
179- | 2|-0.2773500981126146 |
180- | 2| 1.1094003924504583 |
177+ | 1|-0.7071067811865... |
178+ | 1| 0.7071067811865... |
179+ | 2|-0.8320502943378... |
180+ | 2|-0.2773500981126... |
181+ | 2| 1.1094003924504... |
181182 +---+-------------------+
182183
183184 Alternatively, the user can pass a function that takes two arguments.
@@ -189,14 +190,15 @@ def applyInPandas(
189190
190191 >>> df = spark.createDataFrame(
191192 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
192- ... ("id", "v")) # doctest: +SKIP
193+ ... ("id", "v"))
193194 >>> def mean_func(key, pdf):
194195 ... # key is a tuple of one numpy.int64, which is the value
195196 ... # of 'id' for the current group
196197 ... return pd.DataFrame([key + (pdf.v.mean(),)])
197198 ...
198- >>> df.groupby('id').applyInPandas(
199- ... mean_func, schema="id long, v double").show() # doctest: +SKIP
199+ >>> df.groupby("id").applyInPandas(
200+ ... mean_func, schema="id long, v double"
201+ ... ).sort("id").show()
200202 +---+---+
201203 | id| v|
202204 +---+---+
@@ -209,34 +211,36 @@ def applyInPandas(
209211 ... # of 'id' and 'ceil(df.v / 2)' for the current group
210212 ... return pd.DataFrame([key + (pdf.v.sum(),)])
211213 ...
212- >>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas(
213- ... sum_func, schema="id long, `ceil(v / 2)` long, v double").show() # doctest: +SKIP
214+ >>> df.groupby(df.id, sf.ceil(df.v / 2)).applyInPandas(
215+ ... sum_func, schema="id long, `ceil(v / 2)` long, v double"
216+ ... ).sort("id", "v").show()
214217 +---+-----------+----+
215218 | id|ceil(v / 2)| v|
216219 +---+-----------+----+
217- | 2| 5|10.0|
218220 | 1| 1| 3.0|
219- | 2| 3| 5.0|
220221 | 2| 2| 3.0|
222+ | 2| 3| 5.0|
223+ | 2| 5|10.0|
221224 +---+-----------+----+
222225
223226 The function can also take and return an iterator of `pandas.DataFrame` using type
224227 hints.
225228
226- >>> from typing import Iterator # doctest: +SKIP
229+ >>> from typing import Iterator
227230 >>> df = spark.createDataFrame(
228231 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
229- ... ("id", "v")) # doctest: +SKIP
232+ ... ("id", "v"))
230233 >>> def filter_func(
231234 ... batches: Iterator[pd.DataFrame]
232- ... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
235+ ... ) -> Iterator[pd.DataFrame]:
233236 ... for batch in batches:
234237 ... # Process and yield each batch independently
235238 ... filtered = batch[batch['v'] > 2.0]
236239 ... if not filtered.empty:
237240 ... yield filtered[['v']]
238241 >>> df.groupby("id").applyInPandas(
239- ... filter_func, schema="v double").show() # doctest: +SKIP
242+ ... filter_func, schema="v double"
243+ ... ).sort("v").show()
240244 +----+
241245 | v|
242246 +----+
@@ -250,25 +254,26 @@ def applyInPandas(
250254 be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
251255 data types. The data will still be passed in as an iterator of `pandas.DataFrame`.
252256
253- >>> from typing import Iterator, Tuple, Any # doctest: +SKIP
257+ >>> from typing import Iterator, Tuple, Any
254258 >>> def transform_func(
255259 ... key: Tuple[Any, ...], batches: Iterator[pd.DataFrame]
256- ... ) -> Iterator[pd.DataFrame]: # doctest: +SKIP
260+ ... ) -> Iterator[pd.DataFrame]:
257261 ... for batch in batches:
258262 ... # Yield transformed results for each batch
259263 ... result = batch.assign(id=key[0], v_doubled=batch['v'] * 2)
260264 ... yield result[['id', 'v_doubled']]
261265 >>> df.groupby("id").applyInPandas(
262- ... transform_func, schema="id long, v_doubled double").show() # doctest: +SKIP
263- +---+----------+
264- | id|v_doubled |
265- +---+----------+
266- | 1| 2.0|
267- | 1| 4.0|
268- | 2| 6.0|
269- | 2| 10.0|
270- | 2| 20.0|
271- +---+----------+
266+ ... transform_func, schema="id long, v_doubled double"
267+ ... ).sort("id", "v_doubled").show()
268+ +---+---------+
269+ | id|v_doubled|
270+ +---+---------+
271+ | 1| 2.0|
272+ | 1| 4.0|
273+ | 2| 6.0|
274+ | 2| 10.0|
275+ | 2| 20.0|
276+ +---+---------+
272277
273278 Notes
274279 -----
@@ -1187,8 +1192,14 @@ def _test() -> None:
11871192 import doctest
11881193 from pyspark .sql import SparkSession
11891194 import pyspark .sql .pandas .group_ops
1195+ from pyspark .testing .utils import have_pandas , have_pyarrow
11901196
11911197 globs = pyspark .sql .pandas .group_ops .__dict__ .copy ()
1198+
1199+ if not have_pandas or not have_pyarrow :
1200+ del pyspark .sql .pandas .group_ops .PandasGroupedOpsMixin .apply .__doc__
1201+ del pyspark .sql .pandas .group_ops .PandasGroupedOpsMixin .applyInPandas .__doc__
1202+
11921203 spark = SparkSession .builder .master ("local[4]" ).appName ("sql.pandas.group tests" ).getOrCreate ()
11931204 globs ["spark" ] = spark
11941205 (failure_count , test_count ) = doctest .testmod (
0 commit comments