Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Count full dataset in DatasetShard.__len__ #3414

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

jeffkinnison
Copy link
Contributor

@jeffkinnison jeffkinnison commented May 21, 2023

Running the new test with the previous code, DatasetShard.__len__ returned 256 rather than the full dataset size defined in the test. Summing the size of the batches returned by DatasetIterator.iter_batches returns the full dataset size.

@jeffkinnison jeffkinnison marked this pull request as ready for review May 21, 2023 21:39
@github-actions
Copy link

github-actions bot commented May 21, 2023

Unit Test Results

  6 files  ±0    6 suites  ±0   1h 19m 29s ⏱️ + 8m 4s
33 tests ±0  29 ✔️ ±0    4 💤 ±0  0 ±0 
99 runs  ±0  87 ✔️ ±0  12 💤 ±0  0 ±0 

Results for commit 8ff5f4d. ± Comparison against base commit cb37535.

♻️ This comment has been updated with latest results.

Comment on lines +283 to +284
# Sum over all batches when using a DatasetIterator
count = sum(map(lambda b: b.count(), self.epoch_iter.iter_batches()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of quick questions:

  1. When would we hit the TypeError? Is it in the case that self.epoch_iter is not a DatasetIterator object?
  2. Does this end up doing a pass over the entire data and call count on each batch, then sum that total?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes, that's for backwards compatibility. I'm leaving it in for now just in case, but I think 2.3+ should follow the DatasetIterator path.
  2. Yes. iter_batch returns an iterator over the batches of one epoch. It batches by block if batch_size isn't provided, but we pass batch_size to the actual batcher in a separate iter_batches call.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Makes sense!
  2. How expensive is that pass over the data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The cost is relative to dataset size and object store memory, but running time for a number of different dataset sizes seems to be roughly the same as the other methods of getting dataset size from Ray Data.

Comment on lines +292 to +293
# Sum over all batches when using a DatasetIterator
count = sum(map(lambda b: b.count(), self.epoch_iter.iter_batches()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point, what is the type of epoch_iter?

Copy link
Contributor Author

@jeffkinnison jeffkinnison May 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we get to that point, we're using a DatasetIterator with no underlying pipeline. AFAIK this should not happen, but if it does this catchall should prevent a crash. We typically see PipelinedDataIterator, which has the _base_dataset_pipeline attribute. That lets us call count directly on a pipeline object, similar to how counting worked with Ray<2.3.

# expand_paths returns two lists, so get the first element of each
read_path = read_path[0]
file_size = file_size[0]
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than do a try except, can you just make this conditional based on the ray version? We have a lot of examples of this in the code already:

https://github.com/ludwig-ai/ludwig/blob/master/ludwig/backend/ray.py#L365

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll convert it.

}
import ray

if version.parse(ray.__version__) >= version.parse("2.3.0"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can set this to a variable at the top, as doing this parsing is somewhat costly, so always better avoid doing it more than once.

import ray

if version.parse(ray.__version__) >= version.parse("2.3.0"):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is 2.3.0 the right min version? It was working before with 2.3, right? So maybe we just want to check against 2.4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do. The new mlflow integration also works with 2.3, but we can introduce it with the 2.4 bump.

count = next(self.epoch_iter._base_dataset_pipeline).count()
else:
count = next(self.epoch_iter).count()
except TypeError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit hard to follow. Can we make this conditional based on a version of Ray instead of catching a type error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can break this out into more fine-grained conditions.

pipeline = next(self.dataset_epoch_iterator)
try:
pipeline = next(self.dataset_epoch_iterator)
except TypeError:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, the TypeError isn't clearly why or when it happens. I would prefer checking Ray version, or at least leaving a comment explaining what can cause the error.

pipeline = pipeline.map_batches(augment_batch, batch_size=batch_size, batch_format="pandas")

for batch in pipeline.iter_batches(prefetch_blocks=0, batch_size=batch_size, batch_format="pandas"):
if _ray_230:
batch = augment_batch(batch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oof, not good. This does augmentation in the worker process. We definitely don't want that, as it could slow down training by sucking up CPU cycles. Why does the map_batches call above no longer work? Can we fix it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So DatasetIterator objects don't have a map_batches method, and calling map_batches on the underlying _base_dataset_pipeline in the iterator leads to downstream problems reusing the pipeline that aren't resolved by calling repeat. I'll play around with this and see if we can get it into a pipeline.

# Explicitly raise a RuntimeError if an error is encountered during a Ray trial.
# NOTE: Cascading the exception with "raise _ from e" still results in hanging.
raise RuntimeError(f"Encountered Ray Tune error: {e}")
raise RuntimeError(f"Encountered Ray Tune error: {traceback.format_exc()}")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can change this to:

raise RuntimeError(...) from e

That way you get the whole traceback without needing to turn it into a stirng.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants