Skip to content

Commit

Permalink
general data collator
Browse files Browse the repository at this point in the history
  • Loading branch information
Natooz committed Mar 27, 2023
1 parent 21df1bf commit 479c9cc
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions torchtoolkit/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,36 @@ def __call__(self, batch: List[Dict[str, Any]], return_tensors=None) -> Dict[str
return {"input_ids": x, "labels": torch.arange(x.size(0)).long(), "attention_mask": attention_mask} # rank


class DataCollatorGen:
def __init__(self, pad_token: int, pad_on_left: bool = False, shift_labels: bool = False):
"""Collator that simply pad the input sequences.
Input_ids will be padded with the pad token given, while labels will be
padded with -100.
:param pad_token: pas token
class DataCollatorBasic:
def __init__(
self,
pad_token: int,
bos_token: int = None,
eos_token: int = None,
pad_on_left: bool = False,
shift_labels: bool = False
):
"""Multifunction data collator, that can pad the sequences (right or left), add BOS and EOS tokens.
Input_ids will be padded with the pad token given, while labels will be padded with -100.
:param pad_token: PAD token
:param bos_token: BOS token
:param eos_token: EOS token
:param pad_on_left: will pad sequence on the left (default: False).
:param shift_labels: will shift inputs and labels for autoregressive training / teacher forcing.
"""
self.pad_token = pad_token
self.bos_token = bos_token
self.eos_token = eos_token
self.pad_on_left = pad_on_left
self.shift_labels = shift_labels

def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, torch.LongTensor]:
_add_bos_eos_tokens_to_batch(batch, bos_tok=self.bos_token, eos_tok=self.eos_token)
pad_on_left = batch[0]["pad_on_left"] if "pad_on_left" in batch[0] else self.pad_on_left
x, y = _pad_batch(batch, self.pad_token, pad_on_left), _pad_batch(batch, -100, pad_on_left)
# causal attention mask handled in model
if self.shift_labels: # otherwise it's handled in model such as GPT2LMHead
# attention mask handled in model
if self.shift_labels: # otherwise it's handled in models such as GPT2LMHead
x = x[:-1]
y = y[1:]

Expand Down

0 comments on commit 479c9cc

Please sign in to comment.