Skip to content

Commit 25b3930

Browse files
committed
fix: fix shape
1 parent 8cd6f2b commit 25b3930

File tree

5 files changed

+67
-0
lines changed

5 files changed

+67
-0
lines changed

onedal/neighbors/neighbors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ def _get_onedal_params(self, X, y=None, n_neighbors=None):
7070

7171
fptype = np.float64
7272

73+
# _fit_method should be set by sklearnex level before calling oneDAL
74+
if not hasattr(self, "_fit_method") or self._fit_method is None:
75+
raise ValueError(
76+
"_fit_method must be set by sklearnex level before calling oneDAL. "
77+
"This indicates improper usage - oneDAL neighbors should not be called directly."
78+
)
79+
7380
return {
7481
"fptype": fptype,
7582
"vote_weights": "uniform" if weights == "uniform" else "distance",

sklearnex/neighbors/common.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,41 @@ def _validate_kneighbors_bounds(self, n_neighbors, query_is_train, X):
166166
f"n_samples = {X.shape[0]}" # include n_samples for common tests
167167
)
168168

169+
def _process_classification_targets(self, y):
170+
"""Process classification targets and set class-related attributes."""
171+
import numpy as np
172+
173+
# Handle shape processing
174+
shape = getattr(y, "shape", None)
175+
self._shape = shape if shape is not None else y.shape
176+
177+
if y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1:
178+
self.outputs_2d_ = False
179+
y = y.reshape((-1, 1))
180+
else:
181+
self.outputs_2d_ = True
182+
183+
# Process classes
184+
self.classes_ = []
185+
self._y = np.empty(y.shape, dtype=int)
186+
for k in range(self._y.shape[1]):
187+
classes, self._y[:, k] = np.unique(y[:, k], return_inverse=True)
188+
self.classes_.append(classes)
189+
190+
if not self.outputs_2d_:
191+
self.classes_ = self.classes_[0]
192+
self._y = self._y.ravel()
193+
194+
return y
195+
196+
def _process_regression_targets(self, y):
197+
"""Process regression targets and set shape-related attributes."""
198+
# Handle shape processing for regression
199+
shape = getattr(y, "shape", None)
200+
self._shape = shape if shape is not None else y.shape
201+
self._y = y
202+
return y
203+
169204
def _fit_validation(self, X, y=None):
170205
if sklearn_check_version("1.2"):
171206
self._validate_params()

sklearnex/neighbors/knn_classification.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def _onedal_fit(self, X, y, queue=None):
191191

192192
_check_classification_targets(y)
193193

194+
# Handle shape and class processing at sklearnex level
195+
y = self._process_classification_targets(y)
196+
194197
onedal_params = {
195198
"n_neighbors": self.n_neighbors,
196199
"weights": self.weights,
@@ -204,6 +207,13 @@ def _onedal_fit(self, X, y, queue=None):
204207
self._onedal_estimator.effective_metric_ = self.effective_metric_
205208
self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
206209
self._onedal_estimator._fit_method = self._fit_method
210+
211+
# Set shape and class attributes on the onedal estimator
212+
self._onedal_estimator._shape = self._shape
213+
self._onedal_estimator.classes_ = self.classes_
214+
self._onedal_estimator._y = self._y
215+
self._onedal_estimator.outputs_2d_ = self.outputs_2d_
216+
207217
self._onedal_estimator.fit(X, y, queue=queue)
208218

209219
self._save_attributes()

sklearnex/neighbors/knn_regression.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def _onedal_fit(self, X, y, queue=None):
158158
# Parse auto method
159159
self._fit_method = self._parse_auto_method(self.algorithm, X.shape[0], X.shape[1])
160160

161+
# Handle shape processing at sklearnex level
162+
y = self._process_regression_targets(y)
163+
161164
onedal_params = {
162165
"n_neighbors": self.n_neighbors,
163166
"weights": self.weights,
@@ -171,6 +174,11 @@ def _onedal_fit(self, X, y, queue=None):
171174
self._onedal_estimator.effective_metric_ = self.effective_metric_
172175
self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
173176
self._onedal_estimator._fit_method = self._fit_method
177+
178+
# Set shape attributes on the onedal estimator
179+
self._onedal_estimator._shape = self._shape
180+
self._onedal_estimator._y = self._y
181+
174182
self._onedal_estimator.fit(X, y, queue=queue)
175183

176184
self._save_attributes()

sklearnex/neighbors/knn_unsupervised.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def _onedal_fit(self, X, y=None, queue=None):
149149
# Parse auto method
150150
self._fit_method = self._parse_auto_method(self.algorithm, X.shape[0], X.shape[1])
151151

152+
# Set basic attributes for unsupervised
153+
self.classes_ = None
154+
152155
onedal_params = {
153156
"n_neighbors": self.n_neighbors,
154157
"algorithm": self.algorithm,
@@ -161,6 +164,10 @@ def _onedal_fit(self, X, y=None, queue=None):
161164
self._onedal_estimator.effective_metric_ = self.effective_metric_
162165
self._onedal_estimator.effective_metric_params_ = self.effective_metric_params_
163166
self._onedal_estimator._fit_method = self._fit_method
167+
168+
# Set attributes on the onedal estimator
169+
self._onedal_estimator.classes_ = self.classes_
170+
164171
self._onedal_estimator.fit(X, y, queue=queue)
165172

166173
self._save_attributes()

0 commit comments

Comments
 (0)