Skip to content

Commit 493661a

Browse files
committed
Make things work? Extend tests
1 parent 6195dab commit 493661a

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

python/cuml/cuml/internals/base.pyx

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,14 +641,23 @@ class UniversalBase(Base):
641641
if kwargs:
642642
filtered_kwargs = kwargs
643643
else:
644+
# XXX Which way around should this be, get_params updated with kwargs
645+
# XXX or kwargs updated with get_params?
646+
all_kwargs = self.get_params()
647+
all_kwargs.update(self._full_kwargs)
648+
644649
filtered_kwargs = {}
645-
for keyword, arg in self._full_kwargs.items():
650+
for keyword, arg in all_kwargs.items():
651+
# These are cuml specific arguments that should not be passed
652+
# to scikit-learn
653+
if keyword in ("output_type", "handle"):
654+
continue
646655
if keyword in self._cpu_hyperparams:
647656
filtered_kwargs[keyword] = arg
648657
else:
649-
logger.info("Unused keyword parameter: {} "
650-
"during CPU estimator "
651-
"initialization".format(keyword))
658+
logger.debug("Unused keyword parameter: {} "
659+
"during CPU estimator "
660+
"initialization".format(keyword))
652661

653662
# initialize model
654663
self._cpu_model = self._cpu_model_class(**filtered_kwargs)

python/cuml/cuml/tests/test_sklearn_import_export.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,10 @@ def test_mismatching_default_values():
266266
# Check that round-tripping works when different versions of scikit-learn
267267
# have different default values for the same hyper-parameter.
268268
class SklearnEstimatorV1(BaseEstimator):
269-
def __init__(self, foo=42):
269+
def __init__(self, n_init=42, solver="lbfgs"):
270270
super().__init__()
271-
self.foo = foo
271+
self.n_init = n_init
272+
self.solver = solver
272273

273274
def fit(self, X, y):
274275
self.bar_ = 42
@@ -277,9 +278,10 @@ def transform(self, X):
277278
return X
278279

279280
class SklearnEstimatorV2(BaseEstimator):
280-
def __init__(self, foo="auto"):
281+
def __init__(self, n_init="auto", solver="auto"):
281282
super().__init__()
282-
self.foo = foo
283+
self.n_init = n_init
284+
self.solver = solver
283285

284286
def fit(self, X, y):
285287
self.bar_ = 42
@@ -288,17 +290,25 @@ def transform(self, X):
288290
return X
289291

290292
class CuMLEstimator(UniversalBase):
293+
# Setting these by hand as `import_cpu_model` is a bit tricky to use with
294+
# classes defined in a test function.
291295
_cpu_model_class = SklearnEstimatorV1
292-
_cpu_hyperparams = ["foo"]
296+
_cpu_hyperparams = ["n_init", "solver"]
293297

294298
@device_interop_preparation
295299
def __init__(
296-
self, foo=42, handle=None, verbose=False, output_type=None
300+
self,
301+
n_init=42,
302+
solver="qn",
303+
handle=None,
304+
verbose=False,
305+
output_type=None,
297306
):
298307
super().__init__(
299308
handle=handle, verbose=verbose, output_type=output_type
300309
)
301-
self.foo = foo
310+
self.n_init = n_init
311+
self.solver = solver
302312

303313
@enable_device_interop
304314
def fit(self, X, y, convert_dtype=True, sample_weight=None):
@@ -309,7 +319,7 @@ def transform(self, X):
309319

310320
@classmethod
311321
def _get_param_names(cls):
312-
return super()._get_param_names() + ["foo"]
322+
return super()._get_param_names() + ["n_init", "solver"]
313323

314324
def get_attr_names(self):
315325
return ["bar_"]
@@ -321,15 +331,15 @@ def get_attr_names(self):
321331
cml = CuMLEstimator()
322332
assert_estimator_roundtrip(cml, SklearnEstimatorV1, X, y=y, transform=True)
323333

324-
cml = CuMLEstimator(foo=12)
334+
cml = CuMLEstimator(n_init=12)
325335
assert_estimator_roundtrip(cml, SklearnEstimatorV1, X, y=y, transform=True)
326336

327337
# Check against v2 of the scikit-learn estimator
328338
CuMLEstimator._cpu_model_class = SklearnEstimatorV2
329339

330-
# With explicit value for `foo`
331-
cml = CuMLEstimator(foo=12)
340+
# XXX In this case should we pass "auto" or 42 to the scikit-learn constructor?
341+
cml = CuMLEstimator()
332342
assert_estimator_roundtrip(cml, SklearnEstimatorV2, X, y=y, transform=True)
333343

334-
cml = CuMLEstimator()
344+
cml = CuMLEstimator(n_init=12)
335345
assert_estimator_roundtrip(cml, SklearnEstimatorV2, X, y=y, transform=True)

0 commit comments

Comments
 (0)