Skip to content

Conversation

sshleifer
Copy link
Contributor

@sshleifer sshleifer commented Sep 23, 2019

Previously, each sequence was padded to the length of the longest sequence in the dataset.
In this PR, each batch is padded to the length of the longest sequence in the batch. This results in a 30% speedup with negligible impact on metrics.

Code Changes

  • ChatDataset yields example dicts like {'input_ids': [[hist + cand1], ..[hist +cand_n]],} for the PADDED_INPUTS and mc_token_ids and mc_labels in the same format as previously.
  • ChatDataset().collate_fn(examples: list) turns a list of example dicts into the list of 5 tensors by batching them and padding them
  • As a result, get_dataloaders does much less
  • There is a data format change to the part of the process where we make lists of examples to facilitate this.
  • convai_evaluation.py still calls the old pad_dataset

1 Epoch Sanity Check

Before Change: 85 minutes
Validation: {'accuracy': 0.7483655941545956,
'average_accuracy': 0.7483655941545956,
'average_nll': 2.6815188920676687,
'average_ppl': 14.607263311061963,
'nll': 2.6815188920676687}

After Change: 60 minutes
Validation: {'accuracy': 0.7466991411357519,
'average_accuracy': 0.7466991411357519,
'average_nll': 2.6821035040007972,
'average_ppl': 14.615805388160778,
'nll': 2.6821035040007972}

Command:

python train.py --model_checkpoint openai-gpt --dataset_cache dataset_cache --fp16 O1 --n_epochs 1 --train_batch_size 4

@sshleifer sshleifer changed the title (WIP) Pad each batch, not the whole dataset Pad each batch, not the whole dataset Sep 29, 2019
return train_loader, valid_loader, train_sampler, valid_sampler


def make_data_lists(args, personachat, tokenizer):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

for utterance in dialog["utterances"]:
history = utterance["history"][-(2*args.max_history+1):]
candidate_instances = defaultdict(list)
history = utterance["history"][-(2 * args.max_history + 1):]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could add assert len(utterance['candidates']) >= num_candidates

return instance, sequence # TODO: second arg is never used, delete it


def pad_and_tensorize(batch_dict, padding):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this and ChatDataset should be easy to unit test

valid_dataset = ChatDataset(datasets['valid'], pad_id)

logger.info("Build train and validation dataloaders")
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(maybe) put this in ChatDataset.to_loader(self, args, shuffle) -> sampler, loader

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at some point might also want to document which tensors are 3D

for input_name, input_array in instance.items():
datasets[dataset_name][input_name].append(input_array)
candidate_instances[input_name].append(input_array)
for k in candidate_instances.keys():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.items() will save some chars

train.py Outdated
for j, candidate in enumerate(utterance["candidates"][-num_candidates:]):
lm_labels = bool(j == num_candidates-1)
instance, _ = build_input_from_segments(persona, history, candidate, tokenizer, lm_labels)
lm_labels = bool(j == num_candidates - 1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better varname?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant