diff --git a/onedal/datatypes/_data_conversion.py b/onedal/datatypes/_data_conversion.py index 0caac10884..4f9080ed6a 100644 --- a/onedal/datatypes/_data_conversion.py +++ b/onedal/datatypes/_data_conversion.py @@ -31,10 +31,14 @@ import dpctl.tensor as dpt -def _apply_and_pass(func, *args): +def _apply_and_pass(func, *args, **kwargs): if len(args) == 1: - return func(args[0]) - return tuple(map(func, args)) + return func(args[0], **kwargs) if len(kwargs) > 0 else func(args[0]) + return ( + tuple(func(arg, **kwargs) for arg in args) + if len(kwargs) > 0 + else tuple(func(arg) for arg in args) + ) def from_table(*args): @@ -58,7 +62,7 @@ def to_table(*args): if _is_dpc_backend: from ..common._policy import _HostInteropPolicy - def _convert_to_supported(policy, *data): + def _convert_to_supported(policy, *data, xp=np): def func(x): return x @@ -70,13 +74,13 @@ def func(x): device = policy._queue.sycl_device def convert_or_pass(x): - if (x is not None) and (x.dtype == np.float64): + if (x is not None) and (x.dtype == xp.float64): warnings.warn( "Data will be converted into float32 from " "float64 because device does not support it", RuntimeWarning, ) - return x.astype(np.float32) + return xp.astype(x, dtype=xp.float32) else: return x @@ -87,7 +91,7 @@ def convert_or_pass(x): else: - def _convert_to_supported(policy, *data): + def _convert_to_supported(policy, *data, xp=np): def func(x): return x diff --git a/onedal/decomposition/incremental_pca.py b/onedal/decomposition/incremental_pca.py index 7199c1e1c2..bc57d03779 100644 --- a/onedal/decomposition/incremental_pca.py +++ b/onedal/decomposition/incremental_pca.py @@ -23,6 +23,8 @@ from .pca import BasePCA +# TODO: +# update for BasePCA. class IncrementalPCA(BasePCA): """ Incremental estimator for PCA based on oneDAL implementation. diff --git a/onedal/decomposition/pca.py b/onedal/decomposition/pca.py index 7d9e38deab..03c45f77e5 100644 --- a/onedal/decomposition/pca.py +++ b/onedal/decomposition/pca.py @@ -23,6 +23,12 @@ from ..common._base import BaseEstimator from ..datatypes import _convert_to_supported, from_table, to_table +from ..utils._array_api import ( + _asarray, + _convert_to_numpy, + get_namespace, + sklearn_array_api_dispatch, +) class BasePCA(BaseEstimator, metaclass=ABCMeta): @@ -42,13 +48,13 @@ def __init__( self.is_deterministic = is_deterministic self.whiten = whiten - def _get_onedal_params(self, data, stage=None): + def _get_onedal_params(self, data, xp, stage=None): if stage is None: n_components = self._resolve_n_components_for_training(data.shape) elif stage == "predict": n_components = self.n_components_ return { - "fptype": "float" if data.dtype == np.float32 else "double", + "fptype": "float" if data.dtype == xp.float32 else "double", "method": self.method, "n_components": n_components, "is_deterministic": self.is_deterministic, @@ -95,6 +101,8 @@ def _resolve_n_components_for_result(self, shape_tuple): elif self.n_components == "mle": return _infer_dimension(self.explained_variance_, shape_tuple[0]) elif 0.0 < self.n_components < 1.0: + # TODO: + # check for Array API. ratio_cumsum = stable_cumsum(self.explained_variance_ratio_) return np.searchsorted(ratio_cumsum, self.n_components, side="right") + 1 elif isinstance(self.n_components, float) and self.n_components == 1.0: @@ -102,46 +110,65 @@ def _resolve_n_components_for_result(self, shape_tuple): else: return self.n_components - def _compute_noise_variance(self, n_components, n_sf_min): + def _compute_noise_variance(self, xp, n_components, n_sf_min): if n_components < n_sf_min: if len(self.explained_variance_) == n_sf_min: return self.explained_variance_[n_components:].mean() elif len(self.explained_variance_) < n_sf_min: # TODO Rename variances_ to var_ to align with sklearn/sklearnex IncrementalPCA + # TODO: + # check xp.sum for Array API. if hasattr(self, "variances_"): - resid_var = self.variances_.sum() + resid_var = xp.sum(self.variances_) elif hasattr(self, "var_"): - resid_var = self.var_.sum() + resid_var = xp.sum(self.var_) - resid_var -= self.explained_variance_.sum() + resid_var -= xp.sum(self.explained_variance_) return resid_var / (n_sf_min - n_components) else: return 0.0 - def _create_model(self): + def _create_model(self, xp): m = self._get_backend("decomposition", "dim_reduction", "model") - m.eigenvectors = to_table(self.components_) - m.means = to_table(self.mean_) + m.eigenvectors = to_table(_convert_to_numpy(self.components_, xp=xp)) + m.means = to_table(_convert_to_numpy(self.mean_, xp=xp)) if self.whiten: - m.eigenvalues = to_table(self.explained_variance_) + m.eigenvalues = to_table(_convert_to_numpy(self.explained_variance_, xp=xp)) self._onedal_model = m return m - def predict(self, X, queue=None): + def _predict(self, X, xp, queue=None): policy = self._get_policy(queue, X) - model = self._create_model() + model = self._create_model(xp) X = _convert_to_supported(policy, X) - params = self._get_onedal_params(X, stage="predict") + params = self._get_onedal_params(X, xp, stage="predict") result = self._get_backend( - "decomposition", "dim_reduction", "infer", policy, params, model, to_table(X) + "decomposition", + "dim_reduction", + "infer", + policy, + params, + model, + to_table(_convert_to_numpy(X, xp=xp)), ) - return from_table(result.transformed_data) + # Since `from_table` data management enabled only for numpy host, + # copy data from numpy host output to xp namespace array. + return _asarray( + from_table(result.transformed_data).reshape(-1), xp=xp, sycl_queue=queue + ) + + def _predict(self, X, xp, queue=None): + xp, is_array_api_compliant = get_namespace(X) + # update for queue getting. + queue = X.sycl_queue + return self._fit(X, xp, is_array_api_compliant, queue) class PCA(BasePCA): - def fit(self, X, y=None, queue=None): + @sklearn_array_api_dispatch() + def _fit(self, X, xp, is_array_api_compliant, y=None, queue=None): n_samples, n_features = X.shape n_sf_min = min(n_samples, n_features) self._validate_n_components(self.n_components, n_samples, n_features) @@ -149,23 +176,50 @@ def fit(self, X, y=None, queue=None): policy = self._get_policy(queue, X) # TODO: investigate why np.ndarray with OWNDATA=FALSE flag # fails to be converted to oneDAL table + # TODO: + # check if only numpy issues. if isinstance(X, np.ndarray) and not X.flags["OWNDATA"]: X = X.copy() - X = _convert_to_supported(policy, X) + X = _convert_to_supported(policy, X, xp=xp) params = self._get_onedal_params(X) result = self._get_backend( - "decomposition", "dim_reduction", "train", policy, params, to_table(X) + "decomposition", + "dim_reduction", + "train", + policy, + params, + to_table(_convert_to_numpy(X, xp=xp)), ) - self.mean_ = from_table(result.means).ravel() - self.variances_ = from_table(result.variances) - self.components_ = from_table(result.eigenvectors) - self.singular_values_ = from_table(result.singular_values).ravel() - self.explained_variance_ = np.maximum(from_table(result.eigenvalues).ravel(), 0) - self.explained_variance_ratio_ = from_table( - result.explained_variances_ratio - ).ravel() + # Since `from_table` data management enabled only for numpy host, + # copy data from numpy host output to xp namespace array. + self.mean_ = _asarray( + from_table(result.means).reshape(-1), xp=xp, sycl_queue=queue + ) + self.variances_ = _asarray( + from_table(result.variances).reshape(-1), xp=xp, sycl_queue=queue + ) + self.components_ = _asarray( + from_table(result.eigenvectors).reshape(-1), xp=xp, sycl_queue=queue + ) + self.singular_values_ = _asarray( + from_table(result.singular_values).reshape(-1), xp=xp, sycl_queue=queue + ) + # self.explained_variance_ = np.maximum(from_table(result.eigenvalues).ravel(), 0) + # TODO: + # check for Array API. + self.explained_variance_ = xp.max( + _asarray( + from_table(result.singular_values).reshape(-1), xp=xp, sycl_queue=queue + ), + 0, + ) + self.explained_variance_ratio_ = _asarray( + from_table(result.explained_variances_ratio).reshape(-1), + xp=xp, + sycl_queue=queue, + ) self.n_samples_ = n_samples self.n_features_ = n_features @@ -175,8 +229,10 @@ def fit(self, X, y=None, queue=None): n_components = self._resolve_n_components_for_result(X.shape) self.n_components_ = n_components - self.noise_variance_ = self._compute_noise_variance(n_components, n_sf_min) + self.noise_variance_ = self._compute_noise_variance(xp, n_components, n_sf_min) + # TODO: + # check ufunc work here. if n_components < params["n_components"]: self.explained_variance_ = self.explained_variance_[:n_components] self.components_ = self.components_[:n_components] @@ -184,3 +240,9 @@ def fit(self, X, y=None, queue=None): self.explained_variance_ratio_ = self.explained_variance_ratio_[:n_components] return self + + def fit(self, X, y=None, queue=None): + xp, is_array_api_compliant = get_namespace(X) + # update for queue getting. + queue = X.sycl_queue + return self._fit(X, xp, is_array_api_compliant, y, queue) diff --git a/onedal/decomposition/tests/test_pca.py b/onedal/decomposition/tests/test_pca.py new file mode 100644 index 0000000000..1a2fd915f0 --- /dev/null +++ b/onedal/decomposition/tests/test_pca.py @@ -0,0 +1,17 @@ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# TODO: +# TBD. diff --git a/onedal/utils/_array_api.py b/onedal/utils/_array_api.py index 8208406a28..a4e5b83d2c 100644 --- a/onedal/utils/_array_api.py +++ b/onedal/utils/_array_api.py @@ -17,12 +17,33 @@ """Tools to support array_api.""" from collections.abc import Iterable +from functools import wraps + +import numpy as np +from sklearn import config_context, get_config + +from daal4py.sklearn._utils import get_dtype +from daal4py.sklearn._utils import make2d as d4p_make2d from ._dpep_helpers import dpctl_available, dpnp_available if dpctl_available: + import dpctl.tensor as dpt from dpctl.tensor import usm_ndarray + +# TODO: +# move to Array API module. +# TODO +# def make2d(arg, xp=None, is_array_api_compliant=None): +def make2d(arg, xp=None): + if xp and not _is_numpy_namespace(xp) and arg.ndim == 1: + return xp.reshape(arg, (arg.size, 1)) if arg.ndim == 1 else arg + # TODO: + # reimpl via is_array_api_compliant usage. + return d4p_make2d(arg) + + if dpnp_available: import dpnp @@ -38,6 +59,20 @@ def _convert_to_dpnp(array): return array +def _convert_to_numpy(array, xp): + """Convert X into a NumPy ndarray on the CPU.""" + xp_name = xp.__name__ + + if dpctl_available and xp_name in { + "dpctl.tensor", + }: + return dpt.to_numpy(array) + elif dpnp_available and isinstance(array, dpnp.ndarray): + return dpnp.asnumpy(array) + else: + return _asarray(array, xp) + + def _asarray(data, xp, *args, **kwargs): """Converted input object to array format of xp namespace provided.""" if hasattr(data, "__array_namespace__"): @@ -63,19 +98,76 @@ def _get_sycl_namespace(*arrays): """Get namespace of sycl arrays.""" # sycl support designed to work regardless of array_api_dispatch sklearn global value - sycl_type = {type(x): x for x in arrays if hasattr(x, "__sycl_usm_array_interface__")} + sua_iface = {type(x): x for x in arrays if hasattr(x, "__sycl_usm_array_interface__")} - if len(sycl_type) > 1: - raise ValueError(f"Multiple SYCL types for array inputs: {sycl_type}") + if len(sua_iface) > 1: + raise ValueError(f"Multiple SYCL types for array inputs: {sua_iface}") - if sycl_type: - (X,) = sycl_type.values() + if sua_iface: + (X,) = sua_iface.values() if hasattr(X, "__array_namespace__"): - return sycl_type, X.__array_namespace__(), True + return sua_iface, X.__array_namespace__(), True elif dpnp_available and isinstance(X, dpnp.ndarray): - return sycl_type, dpnp, False + return sua_iface, dpnp, False else: - raise ValueError(f"SYCL type not recognized: {sycl_type}") + raise ValueError(f"SYCL type not recognized: {sua_iface}") + + return sua_iface, None, False + + +# TODO: +# +sklearn_array_api_version = True - return sycl_type, None, False + +def sklearn_array_api_dispatch(freefunc=False): + """ + TBD + """ + + def decorator(func): + def wrapper_impl(obj, *args, **kwargs): + # if sklearn_array_api_version and not get_config["array_api_dispatch"]: + if sklearn_array_api_version: + with config_context(array_api_dispatch=True): + return func(obj, *args, **kwargs) + return func(obj, *args, **kwargs) + + if freefunc: + + @wraps(func) + def wrapper_free(*args, **kwargs): + return wrapper_impl(None, *args, **kwargs) + + return wrapper_free + + @wraps(func) + def wrapper_with_self(self, *args, **kwargs): + return wrapper_impl(self, *args, **kwargs) + + return wrapper_with_self + + return decorator + + +def get_namespace(*arrays): + """Get namespace of arrays. + TBD. + Parameters + ---------- + *arrays : array objects + Array objects. + Returns + ------- + namespace : module + Namespace shared by array objects. + is_array_api : bool + True of the arrays are containers that implement the Array API spec. + """ + sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays) + + if sycl_type: + return xp, is_array_api_compliant + else: + return np, True diff --git a/sklearnex/_device_offload.py b/sklearnex/_device_offload.py index 8b52b3c395..bef91649ec 100644 --- a/sklearnex/_device_offload.py +++ b/sklearnex/_device_offload.py @@ -16,13 +16,14 @@ from functools import wraps +from daal4py.sklearn._utils import sklearn_check_version from onedal._device_offload import ( _copy_to_usm, _get_global_queue, _transfer_to_host, dpnp_available, ) -from onedal.utils._array_api import _asarray, _is_numpy_namespace +from onedal.utils._array_api import _asarray if dpnp_available: import dpnp @@ -76,7 +77,7 @@ def dispatch(obj, method_name, branches, *args, **kwargs): return branches[backend](obj, *hostargs, **hostkwargs, queue=q) if backend == "sklearn": if ( - "array_api_dispatch" in get_config() + sklearn_check_version("1.4") and get_config()["array_api_dispatch"] and "array_api_support" in obj._get_tags() and obj._get_tags()["array_api_support"] diff --git a/sklearnex/dispatcher.py b/sklearnex/dispatcher.py index a4a62556f6..9e3601ff14 100644 --- a/sklearnex/dispatcher.py +++ b/sklearnex/dispatcher.py @@ -128,6 +128,9 @@ def get_patch_map_core(preview=False): from ._config import get_config as get_config_sklearnex from ._config import set_config as set_config_sklearnex + if sklearn_check_version("1.4"): + import sklearn.utils._array_api as _array_api_module + if sklearn_check_version("1.2.1"): from .utils.parallel import _FuncWrapper as _FuncWrapper_sklearnex else: @@ -165,6 +168,10 @@ def get_patch_map_core(preview=False): from .svm import NuSVC as NuSVC_sklearnex from .svm import NuSVR as NuSVR_sklearnex + if sklearn_check_version("1.4"): + from .utils._array_api import _convert_to_numpy as _convert_to_numpy_sklearnex + from .utils._array_api import get_namespace as get_namespace_sklearnex + # DBSCAN mapping.pop("dbscan") mapping["dbscan"] = [[(cluster_module, "DBSCAN", DBSCAN_sklearnex), None]] @@ -440,6 +447,24 @@ def get_patch_map_core(preview=False): mapping["_funcwrapper"] = [ [(parallel_module, "_FuncWrapper", _FuncWrapper_sklearnex), None] ] + if sklearn_check_version("1.4"): + # Necessary for array_api support + mapping["get_namespace"] = [ + [ + ( + _array_api_module, + "get_namespace", + get_namespace_sklearnex, + ), + None, + ] + ] + mapping["_convert_to_numpy"] = [ + [ + (_array_api_module, "_convert_to_numpy", _convert_to_numpy_sklearnex), + None, + ] + ] return mapping diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index 9c383abaab..9e424fb159 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -45,6 +45,7 @@ CPU_SKIP_LIST = ( + "_convert_to_numpy", # additional memory allocation is expected proportional to the input data "TSNE", # too slow for using in testing on common data size "config_context", # does not malloc "get_config", # does not malloc @@ -59,6 +60,7 @@ ) GPU_SKIP_LIST = ( + "_convert_to_numpy", # additional memory allocation is expected proportional to the input data "TSNE", # too slow for using in testing on common data size "RandomForestRegressor", # too slow for using in testing on common data size "KMeans", # does not support GPU offloading diff --git a/sklearnex/tests/test_patching.py b/sklearnex/tests/test_patching.py index 897f19172d..c7ec3b1475 100755 --- a/sklearnex/tests/test_patching.py +++ b/sklearnex/tests/test_patching.py @@ -307,10 +307,13 @@ def list_all_attr(string): module_map = {i: i for i in sklearnex__all__.intersection(sklearn__all__)} - # _assert_all_finite patches an internal sklearn function which isn't - # exposed via __all__ in sklearn. It is a special case where this rule - # is not applied (e.g. it is grandfathered in). + # _assert_all_finite, _convert_to_numpy, get_namespace patch an internal + # sklearn functions which aren't exposed via __all__ in sklearn. It is a special + # case where this rule is not applied (e.g. it is grandfathered in). del patched["_assert_all_finite"] + if sklearn_check_version("1.4"): + del patched["_convert_to_numpy"] + del patched["get_namespace"] # remove all scikit-learn-intelex-only estimators for i in patched.copy(): diff --git a/sklearnex/utils/_array_api.py b/sklearnex/utils/_array_api.py index bc30be5021..901e242851 100644 --- a/sklearnex/utils/_array_api.py +++ b/sklearnex/utils/_array_api.py @@ -19,64 +19,128 @@ import numpy as np from daal4py.sklearn._utils import sklearn_check_version -from onedal.utils._array_api import _get_sycl_namespace +from onedal.utils._array_api import _asarray, _get_sycl_namespace -if sklearn_check_version("1.2"): +if sklearn_check_version("1.4"): from sklearn.utils._array_api import get_namespace as sklearn_get_namespace + from sklearn.utils._array_api import _convert_to_numpy as _sklearn_convert_to_numpy +from onedal._device_offload import dpctl_available, dpnp_available -def get_namespace(*arrays): - """Get namespace of arrays. +if dpctl_available: + import dpctl.tensor as dpt - Introspect `arrays` arguments and return their common Array API - compatible namespace object, if any. NumPy 1.22 and later can - construct such containers using the `numpy.array_api` namespace - for instance. +if dpnp_available: + import dpnp - This function will return the namespace of SYCL-related arrays - which define the __sycl_usm_array_interface__ attribute - regardless of array_api support, the configuration of - array_api_dispatch, or scikit-learn version. - See: https://numpy.org/neps/nep-0047-array-api-standard.html +def _convert_to_numpy(array, xp): + """Convert X into a NumPy ndarray on the CPU.""" + xp_name = xp.__name__ - If `arrays` are regular numpy arrays, an instance of the - `_NumPyApiWrapper` compatibility wrapper is returned instead. + if dpctl_available and xp_name in { + "dpctl.tensor", + }: + return dpt.to_numpy(array) + elif dpnp_available and isinstance(array, dpnp.ndarray): + return dpnp.asnumpy(array) + elif sklearn_check_version("1.4"): + return _sklearn_convert_to_numpy(array, xp) + else: + return _asarray(array, xp) - Namespace support is not enabled by default. To enabled it - call: - sklearn.set_config(array_api_dispatch=True) +if sklearn_check_version("1.5"): - or: + def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None): + """Get namespace of arrays. - with sklearn.config_context(array_api_dispatch=True): - # your code here + Extends stock scikit-learn's `get_namespace` primitive to support DPCTL usm_ndarrays + and DPNP ndarrays. + If no DPCTL usm_ndarray or DPNP ndarray inputs and backend scikit-learn version supports + Array API then :obj:`sklearn.utils._array_api.get_namespace` results are drawn. + Otherwise, numpy namespace will be returned. - Otherwise an instance of the `_NumPyApiWrapper` - compatibility wrapper is always returned irrespective of - the fact that arrays implement the `__array_namespace__` - protocol or not. + Designed to work for numpy.ndarray, DPCTL usm_ndarrays and DPNP ndarrays without + `array-api-compat` or backend scikit-learn Array API support. - Parameters - ---------- - *arrays : array objects - Array objects. + For full documentation refer to :obj:`sklearn.utils._array_api.get_namespace`. - Returns - ------- - namespace : module - Namespace shared by array objects. + Parameters + ---------- + *arrays : array objects + Array objects. - is_array_api : bool - True of the arrays are containers that implement the Array API spec. - """ + remove_none : bool, default=True + Whether to ignore None objects passed in arrays. - sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays) + remove_types : tuple or list, default=(str,) + Types to ignore in the arrays. - if sycl_type: - return xp, is_array_api_compliant - elif sklearn_check_version("1.2"): - return sklearn_get_namespace(*arrays) - else: - return np, False + xp : module, default=None + Precomputed array namespace module. When passed, typically from a caller + that has already performed inspection of its own inputs, skips array + namespace inspection. + + Returns + ------- + namespace : module + Namespace shared by array objects. + + is_array_api : bool + True of the arrays are containers that implement the Array API spec. + """ + + usm_iface, xp_sycl_namespace, is_array_api_compliant = _get_sycl_namespace( + *arrays + ) + + if usm_iface: + return xp_sycl_namespace, is_array_api_compliant + elif sklearn_check_version("1.4"): + return sklearn_get_namespace( + *arrays, remove_none=remove_none, remove_types=remove_types, xp=xp + ) + else: + return np, False + +else: + + def get_namespace(*arrays): + """Get namespace of arrays. + + Extends stock scikit-learn's `get_namespace` primitive to support DPCTL usm_ndarrays + and DPNP ndarrays. + If no DPCTL usm_ndarray or DPNP ndarray inputs and backend scikit-learn version supports + Array API then :obj:`sklearn.utils._array_api.get_namespace(*arrays)` results are drawn. + Otherwise, numpy namespace will be returned. + + Designed to work for numpy.ndarray, DPCTL usm_ndarrays and DPNP ndarrays without + `array-api-compat` or backend scikit-learn Array API support. + + For full documentation refer to :obj:`sklearn.utils._array_api.get_namespace`. + + Parameters + ---------- + *arrays : array objects + Array objects. + + Returns + ------- + namespace : module + Namespace shared by array objects. + + is_array_api : bool + True of the arrays are containers that implement the Array API spec. + """ + + usm_iface, xp_sycl_namespace, is_array_api_compliant = _get_sycl_namespace( + *arrays + ) + + if usm_iface: + return xp_sycl_namespace, is_array_api_compliant + elif sklearn_check_version("1.4"): + return sklearn_get_namespace(*arrays) + else: + return np, False diff --git a/sklearnex/utils/tests/test_array_api.py b/sklearnex/utils/tests/test_array_api.py new file mode 100644 index 0000000000..bc4756ba84 --- /dev/null +++ b/sklearnex/utils/tests/test_array_api.py @@ -0,0 +1,182 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) + +# TODO: +# add test suit for dpctl.tensor, dpnp.ndarray, numpy.ndarray without config_context(array_api_dispatch=True)). +# TODO: +# extend for DPNP inputs. + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="numpy,dpctl,array_api", device_filter_="cpu,gpu" + ), +) +def test_get_namespace_with_config_context(dataframe, queue): + """Test `get_namespace` with `array_api_dispatch` enabled.""" + from sklearnex import config_context + from sklearnex.utils._array_api import get_namespace + + array_api_compat = pytest.importorskip("array_api_compat") + + X_np = np.asarray([[1, 2, 3]]) + X = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + + with config_context(array_api_dispatch=True): + xp_out, is_array_api_compliant = get_namespace(X) + assert is_array_api_compliant + if not dataframe in "numpy,array_api": + # Rather than array_api_compat.get_namespace raw output + # `get_namespace` has specific wrapper classes for `numpy.ndarray` + # or `array-api-strict`. + assert xp_out == array_api_compat.get_namespace(X) + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="numpy,dpctl,array_api", device_filter_="cpu,gpu" + ), +) +def test_get_namespace_with_patching(dataframe, queue): + """Test `get_namespace` with `array_api_dispatch` and + `patch_sklearn` enabled. + """ + array_api_compat = pytest.importorskip("array_api_compat") + + from sklearnex import patch_sklearn + + patch_sklearn() + + from sklearn import config_context + from sklearn.utils._array_api import get_namespace + + X_np = np.asarray([[1, 2, 3]]) + X = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + + with config_context(array_api_dispatch=True): + xp_out, is_array_api_compliant = get_namespace(X) + assert is_array_api_compliant + if not dataframe in "numpy,array_api": + # Rather than array_api_compat.get_namespace raw output + # `get_namespace` has specific wrapper classes for `numpy.ndarray` + # or `array-api-strict`. + assert xp_out == array_api_compat.get_namespace(X) + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="dpctl,array_api", device_filter_="cpu,gpu" + ), +) +def test_convert_to_numpy_with_patching(dataframe, queue): + """Test `_convert_to_numpy` with `array_api_dispatch` and + `patch_sklearn` enabled. + """ + pytest.importorskip("array_api_compat") + + from sklearnex import patch_sklearn + + patch_sklearn() + + from sklearn import config_context + from sklearn.utils._array_api import _convert_to_numpy, get_namespace + + X_np = np.asarray([[1, 2, 3]]) + X = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + + with config_context(array_api_dispatch=True): + xp, _ = get_namespace(X) + x_np = _convert_to_numpy(X, xp) + assert type(X_np) == type(x_np) + assert_allclose(X_np, x_np) + + +@pytest.mark.skipif( + not sklearn_check_version("1.4"), + reason="Array API dispatch requires sklearn 1.4 version", +) +@pytest.mark.parametrize( + "dataframe,queue", + get_dataframes_and_queues( + dataframe_filter_="numpy,dpctl,array_api", device_filter_="cpu,gpu" + ), +) +@pytest.mark.parametrize( + "dtype", + [ + pytest.param(np.float32, id=np.dtype(np.float32).name), + pytest.param(np.float64, id=np.dtype(np.float64).name), + ], +) +def test_validate_data_with_patching(dataframe, queue, dtype): + """Test validate_data with `array_api_dispatch` and + `patch_sklearn` enabled. + """ + pytest.importorskip("array_api_compat") + + from sklearnex import patch_sklearn + + patch_sklearn() + + from sklearn import config_context + from sklearn.base import BaseEstimator + + if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data + else: + validate_data = BaseEstimator._validate_data + + from sklearn.utils._array_api import _convert_to_numpy, get_namespace + + X_np = np.asarray([[1, 2, 3], [4, 5, 6]], dtype=dtype) + X_df = _convert_to_dataframe(X_np, sycl_queue=queue, target_df=dataframe) + with config_context(array_api_dispatch=True): + est = BaseEstimator() + xp, _ = get_namespace(X_df) + X_df_res = validate_data( + est, X_df, accept_sparse="csr", dtype=[xp.float64, xp.float32] + ) + assert type(X_df) == type(X_df_res) + if dataframe != "numpy": + # _convert_to_numpy not designed for numpy.ndarray inputs. + assert_allclose(_convert_to_numpy(X_df, xp), _convert_to_numpy(X_df_res, xp)) + else: + assert_allclose(X_df, X_df_res)