@@ -10,24 +10,22 @@ from libcpp cimport bool
10
10
from libcpp.memory cimport unique_ptr
11
11
from libcpp.utility cimport move
12
12
13
- import pylibcudf
13
+ import pylibcudf as plc
14
14
15
15
import cudf
16
- from cudf._lib.types import LIBCUDF_TO_SUPPORTED_NUMPY_TYPES
17
16
from cudf.core.dtypes import ListDtype, StructDtype
17
+ from cudf._lib.types import PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES
18
+ from cudf._lib.types cimport dtype_from_column_view, underlying_type_t_type_id
18
19
from cudf.core.missing import NA, NaT
19
20
20
- cimport pylibcudf.libcudf.types as libcudf_types
21
21
# We currently need this cimport because some of the implementations here
22
22
# access the c_obj of the scalar, and because we need to be able to call
23
23
# pylibcudf.Scalar.from_libcudf. Both of those are temporarily acceptable until
24
24
# DeviceScalar is phased out entirely from cuDF Cython (at which point
25
25
# cudf.Scalar will be directly backed by pylibcudf.Scalar).
26
- from pylibcudf cimport Scalar as plc_Scalar
26
+ from pylibcudf cimport Scalar as plc_Scalar, type_id as plc_TypeID
27
27
from pylibcudf.libcudf.scalar.scalar cimport list_scalar, scalar, struct_scalar
28
28
29
- from cudf._lib.types cimport dtype_from_column_view, underlying_type_t_type_id
30
-
31
29
32
30
def _replace_nested (obj , check , replacement ):
33
31
if isinstance (obj, list ):
@@ -62,12 +60,12 @@ def gather_metadata(dtypes):
62
60
"""
63
61
out = []
64
62
for name, dtype in dtypes.items():
65
- v = pylibcudf .interop.ColumnMetadata(name)
63
+ v = plc .interop.ColumnMetadata(name)
66
64
if isinstance (dtype, cudf.StructDtype):
67
65
v.children_meta = gather_metadata(dtype.fields)
68
66
elif isinstance (dtype, cudf.ListDtype):
69
67
# Offsets column is unnamed and has no children
70
- v.children_meta.append(pylibcudf .interop.ColumnMetadata(" " ))
68
+ v.children_meta.append(plc .interop.ColumnMetadata(" " ))
71
69
v.children_meta.extend(
72
70
gather_metadata({" " : dtype.element_type})
73
71
)
@@ -81,7 +79,7 @@ cdef class DeviceScalar:
81
79
# that from_unique_ptr is implemented is probably dereferencing this in an
82
80
# invalid state. See what the best way to fix that is.
83
81
def __cinit__ (self , *args , **kwargs ):
84
- self .c_value = pylibcudf .Scalar.__new__ (pylibcudf .Scalar)
82
+ self .c_value = plc .Scalar.__new__ (plc .Scalar)
85
83
86
84
def __init__ (self , value , dtype ):
87
85
"""
@@ -127,20 +125,20 @@ cdef class DeviceScalar:
127
125
pa_array = pa.array([pa.scalar(value, type = pa_type)])
128
126
129
127
pa_table = pa.Table.from_arrays([pa_array], names = [" " ])
130
- table = pylibcudf .interop.from_arrow(pa_table)
128
+ table = plc .interop.from_arrow(pa_table)
131
129
132
130
column = table.columns()[0 ]
133
131
if isinstance (dtype, cudf.core.dtypes.DecimalDtype):
134
132
if isinstance (dtype, cudf.core.dtypes.Decimal32Dtype):
135
- column = pylibcudf .unary.cast(
136
- column, pylibcudf .DataType(pylibcudf .TypeId.DECIMAL32, - dtype.scale)
133
+ column = plc .unary.cast(
134
+ column, plc .DataType(plc .TypeId.DECIMAL32, - dtype.scale)
137
135
)
138
136
elif isinstance (dtype, cudf.core.dtypes.Decimal64Dtype):
139
- column = pylibcudf .unary.cast(
140
- column, pylibcudf .DataType(pylibcudf .TypeId.DECIMAL64, - dtype.scale)
137
+ column = plc .unary.cast(
138
+ column, plc .DataType(plc .TypeId.DECIMAL64, - dtype.scale)
141
139
)
142
140
143
- self .c_value = pylibcudf .copying.get_element(column, 0 )
141
+ self .c_value = plc .copying.get_element(column, 0 )
144
142
self ._dtype = dtype
145
143
146
144
def _to_host_scalar (self ):
@@ -150,7 +148,7 @@ cdef class DeviceScalar:
150
148
null_type = NaT if is_datetime or is_timedelta else NA
151
149
152
150
metadata = gather_metadata({" " : self .dtype})[0 ]
153
- ps = pylibcudf .interop.to_arrow(self .c_value, metadata)
151
+ ps = plc .interop.to_arrow(self .c_value, metadata)
154
152
if not ps.is_valid:
155
153
return null_type
156
154
@@ -225,43 +223,42 @@ cdef class DeviceScalar:
225
223
return s
226
224
227
225
cdef void _set_dtype(self , dtype = None ):
228
- cdef libcudf_types.data_type cdtype = self .get_raw_ptr()[0 ].type()
229
-
226
+ cdef plc_TypeID cdtype_id = self .c_value.type().id()
230
227
if dtype is not None :
231
228
self ._dtype = dtype
232
- elif cdtype.id() in {
233
- libcudf_types.type_id .DECIMAL32,
234
- libcudf_types.type_id .DECIMAL64,
235
- libcudf_types.type_id .DECIMAL128,
229
+ elif cdtype_id in {
230
+ plc_TypeID .DECIMAL32,
231
+ plc_TypeID .DECIMAL64,
232
+ plc_TypeID .DECIMAL128,
236
233
}:
237
234
raise TypeError (
238
235
" Must pass a dtype when constructing from a fixed-point scalar"
239
236
)
240
- elif cdtype.id() == libcudf_types.type_id .STRUCT:
237
+ elif cdtype_id == plc_TypeID .STRUCT:
241
238
struct_table_view = (< struct_scalar* > self .get_raw_ptr())[0 ].view()
242
239
self ._dtype = StructDtype({
243
240
str (i): dtype_from_column_view(struct_table_view.column(i))
244
241
for i in range (struct_table_view.num_columns())
245
242
})
246
- elif cdtype.id() == libcudf_types.type_id .LIST:
243
+ elif cdtype_id == plc_TypeID .LIST:
247
244
if (
248
245
< list_scalar* > self .get_raw_ptr()
249
- )[0 ].view().type().id() == libcudf_types.type_id .LIST:
246
+ )[0 ].view().type().id() == plc_TypeID .LIST:
250
247
self ._dtype = dtype_from_column_view(
251
248
(< list_scalar* > self .get_raw_ptr())[0 ].view()
252
249
)
253
250
else :
254
251
self ._dtype = ListDtype(
255
- LIBCUDF_TO_SUPPORTED_NUMPY_TYPES [
252
+ PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES [
256
253
< underlying_type_t_type_id> (
257
254
(< list_scalar* > self .get_raw_ptr())[0 ]
258
255
.view().type().id()
259
256
)
260
257
]
261
258
)
262
259
else :
263
- self ._dtype = LIBCUDF_TO_SUPPORTED_NUMPY_TYPES [
264
- < underlying_type_t_type_id> (cdtype.id() )
260
+ self ._dtype = PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES [
261
+ < underlying_type_t_type_id> (cdtype_id )
265
262
]
266
263
267
264
0 commit comments