Replies: 2 comments
-
Indeed, our conversion can be improved. This is one of the reasons I plan to refactor the "features" logic sometime soon. However, we store data in Arrow to support larger-than-RAM datasets, so there is a slight overhead when accessing samples and converting them from Arrow to Python/PyTorch/TF/Jax. This means it's extremely hard to as fast or faster than in-memory datasets (from |
Beta Was this translation helpful? Give feedback.
-
@mariosasko Thanks a lot for your reply. It clarifies a lot. While looking at the logic, it might be worth having a look at multiple workers as well. I don't know the extent of its support with JAX. I'm eager to see the end result. Cheers! |
Beta Was this translation helpful? Give feedback.
-
The tutorial on how to use Datasets with JAX is very interesting: https://huggingface.co/docs/datasets/use_with_jax
However, there seems to be a performance issue, at least on the MNIST case. In that example, it seems the naive approach to simply load a batch and manually convert it to JAX is miles better. Here's a link to some benchmarks, where, depending on how
datasets
is installed, PyTorch's dataloaders can provide up to 5X speedup:https://colab.research.google.com/drive/1kWPxP4hhyBMrebK5kDoeV8HLB639_Bhh?usp=sharing
This might not be an issue, but any suggestions as to what might be the cause; on how to make this better ?
Thanks,
Beta Was this translation helpful? Give feedback.
All reactions