Skip to content

Commit 6676f56

Browse files
trivialfisgalipremsagarhcho3
authored
[bp] Handle cudf.pandas proxy objects properly (dmlc#11014) (dmlc#11018)
--------- Co-authored-by: GALI PREM SAGAR <[email protected]> Co-authored-by: Hyunsu Cho <[email protected]>
1 parent 7b675da commit 6676f56

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

python-package/xgboost/core.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,6 +2488,7 @@ def inplace_predict(
24882488
_arrow_transform,
24892489
_is_arrow,
24902490
_is_cudf_df,
2491+
_is_cudf_pandas,
24912492
_is_cupy_alike,
24922493
_is_list,
24932494
_is_np_array_like,
@@ -2497,6 +2498,9 @@ def inplace_predict(
24972498
_transform_pandas_df,
24982499
)
24992500

2501+
if _is_cudf_pandas(data):
2502+
data = data._fsproxy_fast # pylint: disable=protected-access
2503+
25002504
enable_categorical = True
25012505
if _is_arrow(data):
25022506
data = _arrow_transform(data)

python-package/xgboost/data.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,16 @@ def _is_cudf_df(data: DataType) -> bool:
835835
return lazy_isinstance(data, "cudf.core.dataframe", "DataFrame")
836836

837837

838+
def _is_cudf_pandas(data: DataType) -> bool:
839+
"""Must go before both pandas and cudf checks."""
840+
return (
841+
lazy_isinstance(data, "pandas.core.frame", "DataFrame")
842+
or lazy_isinstance(data, "pandas.core.series", "Series")
843+
) and lazy_isinstance(
844+
type(data), "cudf.pandas.fast_slow_proxy", "_FastSlowProxyMeta"
845+
)
846+
847+
838848
def _get_cudf_cat_predicate() -> Callable[[Any], bool]:
839849
try:
840850
from cudf import CategoricalDtype
@@ -1187,6 +1197,8 @@ def dispatch_data_backend(
11871197
)
11881198
if _is_arrow(data):
11891199
data = _arrow_transform(data)
1200+
if _is_cudf_pandas(data):
1201+
data = data._fsproxy_fast # pylint: disable=protected-access
11901202
if _is_pandas_series(data):
11911203
import pandas as pd
11921204

@@ -1327,6 +1339,8 @@ def dispatch_meta_backend(
13271339
return
13281340
if _is_arrow(data):
13291341
data = _arrow_transform(data)
1342+
if _is_cudf_pandas(data):
1343+
data = data._fsproxy_fast # pylint: disable=protected-access
13301344
if _is_pandas_df(data):
13311345
_meta_from_pandas_df(data, name, dtype=dtype, handle=handle)
13321346
return
@@ -1398,6 +1412,8 @@ def _proxy_transform(
13981412
feature_types: Optional[FeatureTypes],
13991413
enable_categorical: bool,
14001414
) -> TransformedData:
1415+
if _is_cudf_pandas(data):
1416+
data = data._fsproxy_fast # pylint: disable=protected-access
14011417
if _is_cudf_df(data) or _is_cudf_ser(data):
14021418
return _transform_cudf_df(
14031419
data, feature_names, feature_types, enable_categorical

0 commit comments

Comments
 (0)