-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Count full dataset in DatasetShard.__len__ #3414
base: master
Are you sure you want to change the base?
Conversation
# Sum over all batches when using a DatasetIterator | ||
count = sum(map(lambda b: b.count(), self.epoch_iter.iter_batches())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couple of quick questions:
- When would we hit the TypeError? Is it in the case that self.epoch_iter is not a DatasetIterator object?
- Does this end up doing a pass over the entire data and call count on each batch, then sum that total?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Yes, that's for backwards compatibility. I'm leaving it in for now just in case, but I think 2.3+ should follow the DatasetIterator path.
- Yes.
iter_batch
returns an iterator over the batches of one epoch. It batches by block if batch_size isn't provided, but we pass batch_size to the actual batcher in a separate iter_batches call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Makes sense!
- How expensive is that pass over the data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- The cost is relative to dataset size and object store memory, but running time for a number of different dataset sizes seems to be roughly the same as the other methods of getting dataset size from Ray Data.
…nto ray-nightly-count-fix
# Sum over all batches when using a DatasetIterator | ||
count = sum(map(lambda b: b.count(), self.epoch_iter.iter_batches())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At this point, what is the type of epoch_iter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we get to that point, we're using a DatasetIterator
with no underlying pipeline. AFAIK this should not happen, but if it does this catchall should prevent a crash. We typically see PipelinedDataIterator
, which has the _base_dataset_pipeline
attribute. That lets us call count directly on a pipeline object, similar to how counting worked with Ray<2.3.
# expand_paths returns two lists, so get the first element of each | ||
read_path = read_path[0] | ||
file_size = file_size[0] | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than do a try except, can you just make this conditional based on the ray version? We have a lot of examples of this in the code already:
https://github.com/ludwig-ai/ludwig/blob/master/ludwig/backend/ray.py#L365
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll convert it.
} | ||
import ray | ||
|
||
if version.parse(ray.__version__) >= version.parse("2.3.0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can set this to a variable at the top, as doing this parsing is somewhat costly, so always better avoid doing it more than once.
import ray | ||
|
||
if version.parse(ray.__version__) >= version.parse("2.3.0"): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is 2.3.0
the right min version? It was working before with 2.3, right? So maybe we just want to check against 2.4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do. The new mlflow integration also works with 2.3, but we can introduce it with the 2.4 bump.
count = next(self.epoch_iter._base_dataset_pipeline).count() | ||
else: | ||
count = next(self.epoch_iter).count() | ||
except TypeError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit hard to follow. Can we make this conditional based on a version of Ray instead of catching a type error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can break this out into more fine-grained conditions.
pipeline = next(self.dataset_epoch_iterator) | ||
try: | ||
pipeline = next(self.dataset_epoch_iterator) | ||
except TypeError: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to above, the TypeError isn't clearly why or when it happens. I would prefer checking Ray version, or at least leaving a comment explaining what can cause the error.
pipeline = pipeline.map_batches(augment_batch, batch_size=batch_size, batch_format="pandas") | ||
|
||
for batch in pipeline.iter_batches(prefetch_blocks=0, batch_size=batch_size, batch_format="pandas"): | ||
if _ray_230: | ||
batch = augment_batch(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oof, not good. This does augmentation in the worker process. We definitely don't want that, as it could slow down training by sucking up CPU cycles. Why does the map_batches call above no longer work? Can we fix it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So DatasetIterator
objects don't have a map_batches
method, and calling map_batches
on the underlying _base_dataset_pipeline
in the iterator leads to downstream problems reusing the pipeline that aren't resolved by calling repeat
. I'll play around with this and see if we can get it into a pipeline.
# Explicitly raise a RuntimeError if an error is encountered during a Ray trial. | ||
# NOTE: Cascading the exception with "raise _ from e" still results in hanging. | ||
raise RuntimeError(f"Encountered Ray Tune error: {e}") | ||
raise RuntimeError(f"Encountered Ray Tune error: {traceback.format_exc()}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can change this to:
raise RuntimeError(...) from e
That way you get the whole traceback without needing to turn it into a stirng.
Running the new test with the previous code,
DatasetShard.__len__
returned256
rather than the full dataset size defined in the test. Summing the size of the batches returned byDatasetIterator.iter_batches
returns the full dataset size.