Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic padding via collate_fn #761

Open
Jomonsugi opened this issue Jul 12, 2022 · 11 comments
Open

dynamic padding via collate_fn #761

Jomonsugi opened this issue Jul 12, 2022 · 11 comments

Comments

@Jomonsugi
Copy link

Jomonsugi commented Jul 12, 2022

I would like to dynamically pad my tensors by way of the collate_fn argument that can be passed to petastorm.pytorch.DataLoader, but I am seemingly thwarted by make_batch_reader here, thus it appears make_batch_reader prevents the user from shoring up tensor size through the dataloader.

Or is this possible and I'm just missing how to do so? collate_fn can take care of the variable length values on a batch by batch basis. Otherwise it seems like I'd need to pad all the data in my spark data frame which increases data size substantially, slows training and I assume i/o through petastorm in general.

What I would like to do looks something like below where the function passed to collate_fun would dynamically pad my variable length values.

reader = make_batch_reader(
        channel,
        workers_count=2,
        num_epochs=1,
        schema_fields=['input', 'labels']
    )

dl = DataLoader(reader,
                batch_size = 8,
                shuffling_queue_capacity = 100000,
                collate_fn=some_padding_function
               )
@Jomonsugi
Copy link
Author

Jomonsugi commented Jul 14, 2022

@KamWithK I see you explored a similar issue here #603 a while ago. Did you ever figure out a workaround? What I did find is that I can use the Hugging Face DatasetAPI to convert the pandas dataframe that TransformSpec uses to a Dataset in order to do whatever transformations needed and then back to a pandas dataframe. However, the error is still thrown if the tensors are not the same length when output by TransFormSpec so I am unable to dynamically pad batches as I'd like to do via collate_fn.

Here is an example that works because I am padding within TransFormSpec:

from datasets import Dataset
from transformers import AutoTokenizer

model_type = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_type)

def _tranformer(rows):
    
    dataset=Dataset.from_pandas(rows)
    
    def tokenize_function(example):
        return tokenizer(example['text_a'],
                         example['text_b'],
                         max_length=96,
                         padding='max_length',
                         truncation=True,
                         return_tensors='np'
                        )

    tokenized_dataset = dataset.map(tokenize_function, batched=True)

    return tokenized_dataset.to_pandas()

Of course without using Hugging Face any custom padding function would do, but with HF it is trivial. You'd only have to pass DataCollatorWithPadding to the DataLoader:

from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
...
dl = DataLoader(reader,
                batch_size = 8,
                shuffling_queue_capacity = 100000,
                collate_fn=data_collator
               )
...

@Jomonsugi
Copy link
Author

It appears my long term memory was not at work. I found I essentially had a very similar problem over a year ago here #650 in which control within the data loader would have also been useful (that time with BatchedDataLoader. @selitvin am I right to say this was never addressed?

@ianbenlolo
Copy link

I would also like to add a collate_fn but am creating a torch dataloader with SparkDatasetConverter.make_torch_dataloader. the work around i am using is doing my collating in the TransformSpec but it would be easier to work with torch.Tensors instead with a collate_fn.

Thanks

@selitvin
Copy link
Collaborator

selitvin commented Aug 29, 2022

Sorry for the delayed response... been super busy at my day work.

Would something like this user defined collate_lists_fn help? Is the API what you'd expect? We can also add some off-the-shelf collate implementations, like padding with a constant (like in this test example)

def test_read_with_collate(reader_factory, tmp_path):
    data = pd.DataFrame({"str": ["a", "bc"], "varlen_nums": [[1], [3, 4]]})
    path = tmp_path / 'data'
    url = f"file:///{path}"
    data.to_parquet(path)

    def collate_lists_fn(column_name: str, schema: Unischema, values):
        max_len = max(map(len, values))
        result = np.asarray([np.pad(v, (0, max_len - len(v)), 'constant', constant_values=0) for v in values])
        return result

    with make_batch_reader(url, collate_lists_fn=collate_lists_fn) as reader:
        actual = list(reader)
        
    assert len(actual) == 1
    np.testing.assert_equal(actual[0].varlen_nums, [[1, 0], [1, 2]])
    np.testing.assert_equal(actual[0].str, ["a", "bc"])

Here is a draft PR: #772

@ianbenlolo
Copy link

Yes this is! How would this differ to TransformSpec though? Would the input also be a dataframe or would it be a torch Tensor?
Thank you for getting back to me!

@selitvin
Copy link
Collaborator

Hmmm. I think I take it back. It's not conceptually different than TransformSpec, so we probably should not introduce another tool that does the same. More over, the PR would do the collation in the main process/thread and not in the workers which is not necessary in this case.

So, I am not sure then. How do you think we could improve the API to suite your needs better? Is there anything that needs to be done?

@Jomonsugi
Copy link
Author

I personally ended up having to inherit the petastorm DataLoader class into a subclass and override methods to accomplish what I wanted to do. So that I could deal with each of my examples having different lengths, I built upon the already implemented strategy that petastorm uses to fill the buffer with tuples, which allowed me to "dynamically pad" my arrays to the max length array within each batch after retrieved from the buffer. retrieve_as_batch can be ignored for this discussion. I implemented this elsewhere so that I could experiment with retrieving batches in randomized chunks where order in the buffer mattered.

Below is what I did.

self.tokenizer and self.data_collator are both assigned to objects imported from the Hugging Face library. text_column_a is one string of text while text_column_b is another expected by the particular model that I am using. As you can see above I simply removed lines of code that are not of any use to my needs and added a few others.

class TransformersDataLoader(DataLoader):
    """
    A data loader adaptor for `torch.utils.data.DataLoader`.

    This class iterates and returns items from the Reader in batches.

    This loader can be used as an iterator and will terminate when the reader
    used in the construction of the class runs out of samples.
    """

    def __init__(
        self,
        reader: Reader,
        retrieve_as_batch: bool = False,
        batch_size: int = None,
        shuffling_queue_capacity: int = None,
        extra_capacity: int = None,
        collate_fn: Callable = None,
        model_name: str = None,
        max_length: int = None,
        tokenizer_input_keys: List[str] = None,
        text_column: str = None,
        text_column_b: str = None,
        label_column: str = None,
        **kwargs
    ):
        """
        Initializes a parquet data loader object.

        Number of epochs is defined by the configuration of the reader argument.

        An optional shuffling queue is created if shuffling_queue_capacity is
        greater than 0. No samples will be returned to a user by the
        `DataLoader` until the queue is full. After that, batches of
        `batch_size` will be created by uniformly sampling the shuffling queue.
        Once no more samples are available from the data reader, the shuffling
        queue is allowed to be consumed till no further samples are available.

        NOTE: if you are using ``make_batch_reader``, this shuffling queue will
        be randomizing the order of the entire batches and not changing the
        order of elements within a batch. This is likely not what you intend to
        do.

        Arguments:
            reader (Reader): petastorm Reader instance
            retrieve_as_batch (bool): Whether to retreive batches from queue by group
                or randomized across entire queue. Retreiving batches by group is only
                desired if the order of the queue matters as this strategy is slightly  
                slower.
            batch_size (int): the number of items to return per batch
            shuffling_queue_capacity (int): Queue capacity is passed to the
                underlying :class:`tf.RandomShuffleQueue`. If set to 0, no
                shuffling will be done.
            extra_capacity (int): extra capacity is passed to the
                underlying :class:`tf.RandomShuffleQueue`. 
            collate_fn (Callable): callable to merge a list of samples from the queue to form a
                batch input to the model. If not provided the default _collate_and_pad 
                will be called. See method for structure of input/output.   
            model_name (str): Name of HuggingFace Transformer model. Defaults
                to `bert-base-uncased`.
            max_length (int): Length of sequences for model inputs. Defaults to
                16.
            tokenizer_input_keys (list): Tokenizer keys to use in
                model traning.
            text_column (str): Name of text input column. Defaults to 'text'.
            text_column_b (str): Name of second text input (for text similarity
                problems). Defaults to None.
            label_column (str): Name of label column. Defaults to 'label'.
            cache_inputs (bool): Whether or not to cache model inputs.
        """
        super().__init__(
            reader,
            batch_size=batch_size or 1,
            collate_fn=collate_fn or self._collate_and_pad,
            shuffling_queue_capacity=shuffling_queue_capacity or 0,
        )
        
        # If not retrieving as batch, not so important. Otherwise, the lower, the faster.
        self.extra_capacity = extra_capacity or 1000000
        self.retrieve_as_batch = retrieve_as_batch
        self.model_name = model_name or 'bert-base-uncased'
        self.max_length = max_length or 16
        self.tokenizer_input_keys = tokenizer_input_keys
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.data_collator = DataCollatorWithPadding(tokenizer=self.tokenizer)
        self.text_column = text_column or "text"
        self.text_column_b = text_column_b
        self.label_column = label_column or "label"
        self.keys = None
        
    def _iter_impl(self):
        """
        The Data Loader iterator stops the for-loop when reader runs out of samples.
        """

        min_after_dequeue = self.shuffling_queue_capacity - 1
        self._shuffling_buffer = BatchShufflingBuffer(
            self.batch_size,
            self.shuffling_queue_capacity,
            self.extra_capacity,
            min_after_retrieve=min_after_dequeue
            )

        for row in self.reader:
            row_as_dict = row._asdict()
           
            row_as_dict.update({"labels": torch.tensor(row_as_dict[self.label_column])})
            del row_as_dict[self.label_column]

            row_as_dict.update(self._tokenize_text(row_as_dict))

            self.keys = row_as_dict.keys()
            
            # Add rows to shuffling buffer.

            #   row_as_dict:        {'a': [1,2,3], 'b':[4,5,6]}
            #   row_group_as_tuple: [(1, 4), (2, 5), (3, 6)]
            row_group_as_tuple = list(zip(*(row_as_dict[k] for k in self.keys)))
            # Adding data as 'row-by-row' into a shuffling buffer. Opportunity to optimize.
            self._shuffling_buffer.add_many(row_group_as_tuple)

            # Yield batches from the shuffling buffer.
            for batch in self._yield_batches():
                yield batch

        # Yield remaining rows in shuffling buffer.
        self._shuffling_buffer.finish()
        print('shuffling buffer finished')
        for batch in self._yield_batches():
            yield batch
        # Yield the last and partial batch.
        if self._batch_acc:
            yield self.collate_fn(self._batch_acc) 
                
    def _yield_batches(self):
        
        while self._shuffling_buffer.can_retrieve():
            
            if self.retrieve_as_batch:
                self._batch_acc = self._shuffling_buffer.retrieve_batch()
            else:
                post_shuffled_row = self._shuffling_buffer.retrieve()
                self._batch_acc.append(post_shuffled_row)
                
            # Batch is ready? Collate and emmit
            if len(self._batch_acc) == self.batch_size:
                yield self.collate_fn(self._batch_acc)
                self._batch_acc = []
    
    def _tokenize_text(self, row_as_dict):
        
        text_array = row_as_dict[self.text_column]
        del row_as_dict[self.text_column]
        
        if self.text_column_b:
            text_array_b = row_as_dict[self.text_column_b]
            text_inputs = list(zip(text_array.tolist(), text_array_b.tolist()))
            del row_as_dict[self.text_column_b]
        else:
            text_inputs = text_array.tolist()
            
        tokens = self.tokenizer(
            text_inputs,
            max_length=self.max_length,
            truncation=True,
        )
        
        return {
            k: tokens[k] for k in self.tokenizer_input_keys
        }
    
    def _collate_and_pad(self, batch):
        
        #   batch:      [(1,2,3), (4,5,6)]
        #   keys:       ['a', 'b', 'c']
        #   batch_dict: {'a': [1,4], 'b': [2,5], 'c': [3,6]}
        batch_dict = dict(zip(self.keys, list(map(list, zip(*batch)))))
        
        return self.data_collator(batch_dict)

At a high level, this same idea can be used to store really anything in a tuple in the buffer, as long as it is properly transformed at the batch level for input during training. Hopefully this idea can be used by others where the pandas transform is falling short and/or by those contributing to the petastorm library as a base for added functionality.

@ianbenlolo
Copy link

Interesting. I wanted to use the converter.make_torch_dataloader so what i ended up with was just this TransformSpec func:

def transform_row(pd_batch):
    """
    The input and output of this function must be pandas dataframes.
    """
    max_len = pd_batch.deltats.map(len).max()
    pd_batch['deltats'] = pd_batch.deltats.apply(lambda v: np.pad(v, (0, max_len - len(v)), 'constant', constant_values=-1) )
    
    max_len = pd_batch.embed_vectors.map(len).max()
    pd_batch['embed_vectors'] = pd_batch.embed_vectors.apply(lambda v: np.pad(v, (0, max_len - len(v)), 'constant', constant_values=0) )
    return pd_batch

To answer @selitvin, im not sure. I was hoping to be able to pass a collate_fn to pytorch the same way you do when instantiating a dataloader but i suppose this is equivalent (except it does not operate on tensors).

@Jomonsugi
Copy link
Author

@ianbenlolo I think your idea is certainly preferred as it works within the petastorm libraries current functionality. In many cases dynamically padding tensors as I did isn't needed (but can increase training speed). I could have worked out what I did with pandas I assume, but I just found there to be too much overhead in the expectations of data input and output personally. You'll notice in my code I removed _sanitize_pytorch_types(row_as_dict).

@ianbenlolo
Copy link

Hm so it turns out what i wanted to do (pass a pad_collate function) is in fact possible with the data_loader_fn argument like so:

from petastorm.pytorch import DataLoader
data_loader_fn= partial(DataLoader, collate_fn=pad_collate)

This now confuses me though as to the purpose of transform_spec because I was padding in it and yet i got the following stacktrace:

RuntimeError                              Traceback (most recent call last)
<command-2054687199870552> in <cell line: 1>()
      5     print(train_dataloader)
      6     print('*'*50)
----> 7     for a in train_dataloader:
      9         pass

/databricks/python/lib/python3.9/site-packages/petastorm/pytorch.py in __iter__(self)
    119 
    120         try:
--> 121             for batch in self._iter_impl():
    122                 yield batch
    123         except Exception as e:

/databricks/python/lib/python3.9/site-packages/petastorm/pytorch.py in _iter_impl(self)
    218             # _yield_batches will emit as much batches as are allowed by the shuffling_buffer (RandomShufflingBuffer
    219             # will avoid underflowing below a certain number of samples to guarantee some samples decorrelation)
--> 220             for batch in self._yield_batches(keys):
    221                 yield batch
    222 

/databricks/python/lib/python3.9/site-packages/petastorm/pytorch.py in _yield_batches(self, keys)
    245             # Batch is ready? Collate and emmit
    246             if len(self._batch_acc) == self.batch_size:
--> 247                 yield self.collate_fn(self._batch_acc)
    248                 self._batch_acc = []
    249 

/databricks/python/lib/python3.9/site-packages/petastorm/pytorch.py in decimal_friendly_collate(batch)
     86         return batch
     87     elif isinstance(batch[0], collections.Mapping):
---> 88         return {key: decimal_friendly_collate([d[key] for d in batch]) for key in batch[0]}
     89     elif isinstance(batch[0], _string_classes):
     90         return batch

/databricks/python/lib/python3.9/site-packages/petastorm/pytorch.py in <dictcomp>(.0)
     86         return batch
     87     elif isinstance(batch[0], collections.Mapping):
---> 88         return {key: decimal_friendly_collate([d[key] for d in batch]) for key in batch[0]}
     89     elif isinstance(batch[0], _string_classes):
     90         return batch

/databricks/python/lib/python3.9/site-packages/petastorm/pytorch.py in decimal_friendly_collate(batch)
     93         return [decimal_friendly_collate(samples) for samples in transposed]
     94     else:
---> 95         return default_collate(batch)
     96 
     97 

/databricks/python/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
    144                 raise TypeError(default_collate_err_msg_format.format(elem.dtype))
    145 
--> 146             return default_collate([torch.as_tensor(b) for b in batch])
    147         elif elem.shape == ():  # scalars
    148             return torch.as_tensor(batch)

/databricks/python/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py in default_collate(batch)
    136             storage = elem.storage()._new_shared(numel)
    137             out = elem.new(storage).resize_(len(batch), *list(elem.size()))
--> 138         return torch.stack(batch, 0, out=out)
    139     elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
    140             and elem_type.__name__ != 'string_':

RuntimeError: stack expects each tensor to be equal size, but got [120] at entry 0 and [252] at entry 80

So; I am confused why the padding in TransformSpec is not sufficient?
Thanks.

@selitvin
Copy link
Collaborator

selitvin commented Sep 8, 2022

Agree. It's confusing. Indeed there are two different ways of doing this: either preparing all data in TranssformSpec so that it can be automatically collated; or doing the transformation during the collation. To give some context on the design choices that led to the current implementation:

  1. The constraints on the uniformity of tensor types is indeed redundant when working with pytorch (and I suspect modern TF). The original design worked well with TF 1.0 where all tensors had to have uniform types/shapes to be fed into processing graph.
  2. Preparing for collation in TransformSpec is probably preferable since it runs on worker threads/processes and does not take from the main thread/process CPU/GIL.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants