Open
Description
Dask-ML ParallelPostFit prediction fails on empty partitions
Minimal Complete Verifiable Example:
from sklearn.linear_model import LogisticRegression
import dask.dataframe as dd
from dask_ml.wrappers import ParallelPostFit
import pandas as pd
df = pd.DataFrame({"x": [1, 2, 3, 4, 5, 6, 7, 8], "y": [True, False] * 4})
ddf = dd.from_pandas(df, npartitions=4)
clf = ParallelPostFit(LogisticRegression())
clf = clf.fit(df[["x"]], df["y"])
ddf_with_empty_part = ddf[ddf.x < 5][["x"]]
result = clf.predict(ddf_with_empty_part).compute()
expected = clf.estimator.predict(ddf_with_empty_part.compute())
assert_eq_ar(result, expected)
TRACE
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [3], in <cell line: 1>()
----> 1 result.compute()
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/base.py:290, in DaskMethodsMixin.compute(self, **kwargs)
266 def compute(self, **kwargs):
267 """Compute this dask collection
268
269 This turns a lazy Dask collection into its in-memory equivalent.
(...)
288 dask.base.compute
289 """
--> 290 (result,) = compute(self, traverse=False, **kwargs)
291 return result
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/base.py:573, in compute(traverse, optimize_graph, scheduler, get, *args, **kwargs)
570 keys.append(x.__dask_keys__())
571 postcomputes.append(x.__dask_postcompute__())
--> 573 results = schedule(dsk, keys, **kwargs)
574 return repack([f(r, *a) for r, (f, a) in zip(results, postcomputes)])
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/threaded.py:81, in get(dsk, result, cache, num_workers, pool, **kwargs)
78 elif isinstance(pool, multiprocessing.pool.Pool):
79 pool = MultiprocessingPoolExecutor(pool)
---> 81 results = get_async(
82 pool.submit,
83 pool._max_workers,
84 dsk,
85 result,
86 cache=cache,
87 get_id=_thread_get_id,
88 pack_exception=pack_exception,
89 **kwargs,
90 )
92 # Cleanup pools associated to dead threads
93 with pools_lock:
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/local.py:506, in get_async(submit, num_workers, dsk, result, cache, get_id, rerun_exceptions_locally, pack_exception, raise_exception, callbacks, dumps, loads, chunksize, **kwargs)
504 _execute_task(task, data) # Re-execute locally
505 else:
--> 506 raise_exception(exc, tb)
507 res, worker_id = loads(res_info)
508 state["cache"][key] = res
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/local.py:314, in reraise(exc, tb)
312 if exc.__traceback__ is not tb:
313 raise exc.with_traceback(tb)
--> 314 raise exc
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/local.py:219, in execute_task(key, task_info, dumps, loads, get_id, pack_exception)
217 try:
218 task, data = loads(task_info)
--> 219 result = _execute_task(task, data)
220 id = get_id()
221 result = dumps((result, id))
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/optimization.py:969, in SubgraphCallable.__call__(self, *args)
967 if not len(args) == len(self.inkeys):
968 raise ValueError("Expected %d args, got %d" % (len(self.inkeys), len(args)))
--> 969 return core.get(self.dsk, self.outkey, dict(zip(self.inkeys, args)))
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/core.py:149, in get(dsk, out, cache)
147 for key in toposort(dsk):
148 task = dsk[key]
--> 149 result = _execute_task(task, cache)
150 cache[key] = result
151 result = _execute_task(out, cache)
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/core.py:119, in _execute_task(arg, cache, dsk)
115 func, args = arg[0], arg[1:]
116 # Note: Don't assign the subtask results to a variable. numpy detects
117 # temporaries by their reference count and can execute certain
118 # operations in-place.
--> 119 return func(*(_execute_task(a, cache) for a in args))
120 elif not ishashable(arg):
121 return arg
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/utils.py:39, in apply(func, args, kwargs)
37 def apply(func, args, kwargs=None):
38 if kwargs:
---> 39 return func(*args, **kwargs)
40 else:
41 return func(*args)
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/dask/dataframe/core.py:6259, in apply_and_enforce(*args, **kwargs)
6257 func = kwargs.pop("_func")
6258 meta = kwargs.pop("_meta")
-> 6259 df = func(*args, **kwargs)
6260 if is_dataframe_like(df) or is_series_like(df) or is_index_like(df):
6261 if not len(df):
File ~/dask_ml_dev/dask-ml/dask_ml/wrappers.py:630, in _predict(part, estimator)
629 def _predict(part, estimator):
--> 630 return estimator.predict(part)
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/sklearn/linear_model/_base.py:425, in LinearClassifierMixin.predict(self, X)
411 def predict(self, X):
412 """
413 Predict class labels for samples in X.
414
(...)
423 Vector containing the class labels for each sample.
424 """
--> 425 scores = self.decision_function(X)
426 if len(scores.shape) == 1:
427 indices = (scores > 0).astype(int)
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/sklearn/linear_model/_base.py:407, in LinearClassifierMixin.decision_function(self, X)
387 """
388 Predict confidence scores for samples.
389
(...)
403 this class would be predicted.
404 """
405 check_is_fitted(self)
--> 407 X = self._validate_data(X, accept_sparse="csr", reset=False)
408 scores = safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_
409 return scores.ravel() if scores.shape[1] == 1 else scores
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/sklearn/base.py:566, in BaseEstimator._validate_data(self, X, y, reset, validate_separately, **check_params)
564 raise ValueError("Validation should be done on X, y or both.")
565 elif not no_val_X and no_val_y:
--> 566 X = check_array(X, **check_params)
567 out = X
568 elif no_val_X and not no_val_y:
File /datasets/vjawa/miniconda3/envs/dask-ml-dev/lib/python3.8/site-packages/sklearn/utils/validation.py:805, in check_array(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)
803 n_samples = _num_samples(array)
804 if n_samples < ensure_min_samples:
--> 805 raise ValueError(
806 "Found array with %d sample(s) (shape=%s) while a"
807 " minimum of %d is required%s."
808 % (n_samples, array.shape, ensure_min_samples, context)
809 )
811 if ensure_min_features > 0 and array.ndim == 2:
812 n_features = array.shape[1]
ValueError: Found array with 0 sample(s) (shape=(0, 1)) while a minimum of 1 is required.
Anything else we need to know?:
Related Issue: dask-contrib/dask-sql#414
Metadata
Metadata
Assignees
Labels
No labels