@@ -266,9 +266,10 @@ def test_mismatching_default_values():
266
266
# Check that round-tripping works when different versions of scikit-learn
267
267
# have different default values for the same hyper-parameter.
268
268
class SklearnEstimatorV1 (BaseEstimator ):
269
- def __init__ (self , foo = 42 ):
269
+ def __init__ (self , n_init = 42 , solver = "lbfgs" ):
270
270
super ().__init__ ()
271
- self .foo = foo
271
+ self .n_init = n_init
272
+ self .solver = solver
272
273
273
274
def fit (self , X , y ):
274
275
self .bar_ = 42
@@ -277,9 +278,10 @@ def transform(self, X):
277
278
return X
278
279
279
280
class SklearnEstimatorV2 (BaseEstimator ):
280
- def __init__ (self , foo = "auto" ):
281
+ def __init__ (self , n_init = "auto" , solver = "auto" ):
281
282
super ().__init__ ()
282
- self .foo = foo
283
+ self .n_init = n_init
284
+ self .solver = solver
283
285
284
286
def fit (self , X , y ):
285
287
self .bar_ = 42
@@ -288,17 +290,25 @@ def transform(self, X):
288
290
return X
289
291
290
292
class CuMLEstimator (UniversalBase ):
293
+ # Setting these by hand as `import_cpu_model` is a bit tricky to use with
294
+ # classes defined in a test function.
291
295
_cpu_model_class = SklearnEstimatorV1
292
- _cpu_hyperparams = ["foo " ]
296
+ _cpu_hyperparams = ["n_init" , "solver " ]
293
297
294
298
@device_interop_preparation
295
299
def __init__ (
296
- self , foo = 42 , handle = None , verbose = False , output_type = None
300
+ self ,
301
+ n_init = 42 ,
302
+ solver = "qn" ,
303
+ handle = None ,
304
+ verbose = False ,
305
+ output_type = None ,
297
306
):
298
307
super ().__init__ (
299
308
handle = handle , verbose = verbose , output_type = output_type
300
309
)
301
- self .foo = foo
310
+ self .n_init = n_init
311
+ self .solver = solver
302
312
303
313
@enable_device_interop
304
314
def fit (self , X , y , convert_dtype = True , sample_weight = None ):
@@ -309,7 +319,7 @@ def transform(self, X):
309
319
310
320
@classmethod
311
321
def _get_param_names (cls ):
312
- return super ()._get_param_names () + ["foo " ]
322
+ return super ()._get_param_names () + ["n_init" , "solver " ]
313
323
314
324
def get_attr_names (self ):
315
325
return ["bar_" ]
@@ -321,15 +331,15 @@ def get_attr_names(self):
321
331
cml = CuMLEstimator ()
322
332
assert_estimator_roundtrip (cml , SklearnEstimatorV1 , X , y = y , transform = True )
323
333
324
- cml = CuMLEstimator (foo = 12 )
334
+ cml = CuMLEstimator (n_init = 12 )
325
335
assert_estimator_roundtrip (cml , SklearnEstimatorV1 , X , y = y , transform = True )
326
336
327
337
# Check against v2 of the scikit-learn estimator
328
338
CuMLEstimator ._cpu_model_class = SklearnEstimatorV2
329
339
330
- # With explicit value for `foo`
331
- cml = CuMLEstimator (foo = 12 )
340
+ # XXX In this case should we pass "auto" or 42 to the scikit-learn constructor?
341
+ cml = CuMLEstimator ()
332
342
assert_estimator_roundtrip (cml , SklearnEstimatorV2 , X , y = y , transform = True )
333
343
334
- cml = CuMLEstimator ()
344
+ cml = CuMLEstimator (n_init = 12 )
335
345
assert_estimator_roundtrip (cml , SklearnEstimatorV2 , X , y = y , transform = True )
0 commit comments