Skip to content

Implement tensor.isin #2098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpctl/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ set(_reduction_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/reductions/sum.cpp
)
set(_sorting_sources
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/isin.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't seem relating to sorting routine

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it uses common utilities with searchsorted (i.e., from rich_comparisons.hpp) which is why it lives there

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the code from rich_comparisons gets factored out, I can go ahead and move it elsewhere, I guess to _tensor_impl for now

${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_sort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/merge_argsort.cpp
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/sorting/radix_sort.cpp
Expand Down
2 changes: 2 additions & 0 deletions dpctl/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@
)
from ._searchsorted import searchsorted
from ._set_functions import (
isin,
unique_all,
unique_counts,
unique_inverse,
Expand Down Expand Up @@ -394,4 +395,5 @@
"top_k",
"dldevice_to_sycl_device",
"sycl_device_to_dldevice",
"isin",
]
10 changes: 5 additions & 5 deletions dpctl/tensor/_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
_empty_like_pair_orderK,
_empty_like_triple_orderK,
)
from dpctl.tensor._elementwise_common import (
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._type_utils import _can_cast
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._type_utils import _can_cast
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._type_utils import (
_resolve_one_strong_one_weak_types,
_resolve_one_strong_two_weak_types,
Expand Down
89 changes: 6 additions & 83 deletions dpctl/tensor/_elementwise_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers

import numpy as np

import dpctl
import dpctl.memory as dpm
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_pair_orderK
from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_acceptance_fn_default_binary,
_acceptance_fn_default_unary,
_all_data_types,
_find_buf_dtype,
_find_buf_dtype2,
_find_buf_dtype_in_place_op,
_resolve_weak_types,
_to_device_supported_dtype,
)


Expand Down Expand Up @@ -289,78 +284,6 @@ def __call__(self, x, /, *, out=None, order="K"):
return out


def _get_queue_usm_type(o):
"""Return SYCL device where object `o` allocated memory, or None."""
if isinstance(o, dpt.usm_ndarray):
return o.sycl_queue, o.usm_type
elif hasattr(o, "__sycl_usm_array_interface__"):
try:
m = dpm.as_usm_memory(o)
return m.sycl_queue, m.get_usm_type()
except Exception:
return None, None
return None, None


def _get_dtype(o, dev):
if isinstance(o, dpt.usm_ndarray):
return o.dtype
if hasattr(o, "__sycl_usm_array_interface__"):
return dpt.asarray(o).dtype
if _is_buffer(o):
host_dt = np.array(o).dtype
dev_dt = _to_device_supported_dtype(host_dt, dev)
return dev_dt
if hasattr(o, "dtype"):
dev_dt = _to_device_supported_dtype(o.dtype, dev)
return dev_dt
if isinstance(o, bool):
return WeakBooleanType(o)
if isinstance(o, int):
return WeakIntegralType(o)
if isinstance(o, float):
return WeakFloatingType(o)
if isinstance(o, complex):
return WeakComplexType(o)
return np.object_


def _validate_dtype(dt) -> bool:
return isinstance(
dt,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
) or (
isinstance(dt, dpt.dtype)
and dt
in [
dpt.bool,
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
dpt.float16,
dpt.float32,
dpt.float64,
dpt.complex64,
dpt.complex128,
]
)


def _get_shape(o):
if isinstance(o, dpt.usm_ndarray):
return o.shape
if _is_buffer(o):
return memoryview(o).shape
if isinstance(o, numbers.Number):
return tuple()
return getattr(o, "shape", tuple())


class BinaryElementwiseFunc:
"""
Class that implements binary element-wise functions.
Expand Down
111 changes: 111 additions & 0 deletions dpctl/tensor/_scalar_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Data Parallel Control (dpctl)
#
# Copyright 2020-2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numbers

import numpy as np

import dpctl.memory as dpm
import dpctl.tensor as dpt
from dpctl.tensor._usmarray import _is_object_with_buffer_protocol as _is_buffer

from ._type_utils import (
WeakBooleanType,
WeakComplexType,
WeakFloatingType,
WeakIntegralType,
_to_device_supported_dtype,
)


def _get_queue_usm_type(o):
"""Return SYCL device where object `o` allocated memory, or None."""
if isinstance(o, dpt.usm_ndarray):
return o.sycl_queue, o.usm_type
elif hasattr(o, "__sycl_usm_array_interface__"):
try:
m = dpm.as_usm_memory(o)
return m.sycl_queue, m.get_usm_type()
except Exception:
return None, None
return None, None


def _get_dtype(o, dev):
if isinstance(o, dpt.usm_ndarray):
return o.dtype
if hasattr(o, "__sycl_usm_array_interface__"):
return dpt.asarray(o).dtype
if _is_buffer(o):
host_dt = np.array(o).dtype
dev_dt = _to_device_supported_dtype(host_dt, dev)
return dev_dt
if hasattr(o, "dtype"):
dev_dt = _to_device_supported_dtype(o.dtype, dev)
return dev_dt
if isinstance(o, bool):
return WeakBooleanType(o)
if isinstance(o, int):
return WeakIntegralType(o)
if isinstance(o, float):
return WeakFloatingType(o)
if isinstance(o, complex):
return WeakComplexType(o)
return np.object_


def _validate_dtype(dt) -> bool:
return isinstance(
dt,
(WeakBooleanType, WeakIntegralType, WeakFloatingType, WeakComplexType),
) or (
isinstance(dt, dpt.dtype)
and dt
in [
dpt.bool,
dpt.int8,
dpt.uint8,
dpt.int16,
dpt.uint16,
dpt.int32,
dpt.uint32,
dpt.int64,
dpt.uint64,
dpt.float16,
dpt.float32,
dpt.float64,
dpt.complex64,
dpt.complex128,
]
)


def _get_shape(o):
if isinstance(o, dpt.usm_ndarray):
return o.shape
if _is_buffer(o):
return memoryview(o).shape
if isinstance(o, numbers.Number):
return tuple()
return getattr(o, "shape", tuple())


__all__ = [
"_get_dtype",
"_get_queue_usm_type",
"_get_shape",
"_validate_dtype",
]
10 changes: 5 additions & 5 deletions dpctl/tensor/_search_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._tensor_impl as ti
from dpctl.tensor._elementwise_common import (
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
from ._scalar_utils import (
_get_dtype,
_get_queue_usm_type,
_get_shape,
_validate_dtype,
)
from dpctl.tensor._manipulation_functions import _broadcast_shape_impl
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager

from ._copy_utils import _empty_like_orderK, _empty_like_triple_orderK
from ._type_utils import (
WeakBooleanType,
WeakComplexType,
Expand Down
Loading
Loading