@@ -166,6 +166,41 @@ def _validate_kneighbors_bounds(self, n_neighbors, query_is_train, X):
166166 f"n_samples = { X .shape [0 ]} " # include n_samples for common tests
167167 )
168168
169+ def _process_classification_targets (self , y ):
170+ """Process classification targets and set class-related attributes."""
171+ import numpy as np
172+
173+ # Handle shape processing
174+ shape = getattr (y , "shape" , None )
175+ self ._shape = shape if shape is not None else y .shape
176+
177+ if y .ndim == 1 or y .ndim == 2 and y .shape [1 ] == 1 :
178+ self .outputs_2d_ = False
179+ y = y .reshape ((- 1 , 1 ))
180+ else :
181+ self .outputs_2d_ = True
182+
183+ # Process classes
184+ self .classes_ = []
185+ self ._y = np .empty (y .shape , dtype = int )
186+ for k in range (self ._y .shape [1 ]):
187+ classes , self ._y [:, k ] = np .unique (y [:, k ], return_inverse = True )
188+ self .classes_ .append (classes )
189+
190+ if not self .outputs_2d_ :
191+ self .classes_ = self .classes_ [0 ]
192+ self ._y = self ._y .ravel ()
193+
194+ return y
195+
196+ def _process_regression_targets (self , y ):
197+ """Process regression targets and set shape-related attributes."""
198+ # Handle shape processing for regression
199+ shape = getattr (y , "shape" , None )
200+ self ._shape = shape if shape is not None else y .shape
201+ self ._y = y
202+ return y
203+
169204 def _fit_validation (self , X , y = None ):
170205 if sklearn_check_version ("1.2" ):
171206 self ._validate_params ()
0 commit comments