Skip to content

Commit d1d7441

Browse files
andersonicJun Ru Anderson
andauthored
[test] specify chunks for pipe/transformer benchmark (#52)
* specify chunks for pipe/transformer benchmark Set chunks to be equal to len(balance) for pipe/transformer benchmark. Will update words per second and memory usage checks in next commit (must test on CircleCI to find appropriate values) * change benchmark words per second and memory usage Did six runs for words-per-second, with results: 9144.40, 9163.91, 9993.01, 9082.82, 9155.09, 9000.67 Peak allocated bytes per device (which does not change between runs) were 193206272, 645632, 562688, 92688384 for devices 0, 1, 2 and 3, respectively * increase batch size batch size was small enough that the GPU's computing power was not the bottleneck, slowing training and specifically making more chunks slower. Increasing batch size has therefore increased training speed * update benchmark numbers ran six times, with wps 36917.44, 36797.65, 37006.03, 36872.84, 37129.31, 37003.31 and peak allocated bytes 4061909504, 4050944, 10427392, 2031824896 for devices 0,1,2 and 3 respectively. Co-authored-by: Jun Ru Anderson <[email protected]>
1 parent ab32cb7 commit d1d7441

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

benchmarks/transformer.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def get_data(device):
9898
TEXT.build_vocab(train_txt)
9999
ntokens = len(TEXT.vocab.stoi)
100100

101-
batch_size = 20
102-
eval_batch_size = 10
101+
batch_size = 500
102+
eval_batch_size = 200
103103
train_data = batchify(train_txt, batch_size, TEXT, device)
104104
val_data = batchify(val_txt, eval_batch_size, TEXT, device)
105105
test_data = batchify(test_txt, eval_batch_size, TEXT, device)
@@ -131,7 +131,7 @@ def make_model(device, ntokens):
131131

132132
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange).half().to(device)
133133
balance = generate_balance(min(num_devices, 4), len(model))
134-
p = Pipe(model, balance)
134+
p = Pipe(model, balance, chunks=len(balance))
135135

136136
criterion = nn.CrossEntropyLoss()
137137
lr = 0.0005 # learning rate
@@ -161,7 +161,7 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens):
161161
optimizer.step()
162162

163163
total_loss += loss.item()
164-
log_interval = 200
164+
log_interval = 50
165165
if batch % log_interval == 0 and batch > 0:
166166
cur_loss = total_loss / log_interval
167167
elapsed = time.time() - start_time
@@ -227,7 +227,7 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
227227
if can_benchmark and len(model.balance) == 4:
228228
# Assert that words per second is within 3 standard deviations of the average
229229
# of six golden runs
230-
assert wps > 27799.2 - (3 * 522.145)
230+
assert wps > 36954.4 - (3 * 116.825)
231231

232232
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(0)["allocated_bytes.all.peak"]))
233233
print("Peak allocated bytes on cuda:1: {:1d}".format(torch.cuda.memory_stats(1)["allocated_bytes.all.peak"]))
@@ -236,10 +236,10 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
236236

237237
# Assert that memory usage on each GPU is within 10% of golden run
238238
# Right-hand-side is golden run bytes * 110%
239-
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 193206272 * 1.1
240-
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 640512 * 1.1
241-
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 1412608 * 1.1
242-
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 95364608 * 1.1
239+
assert torch.cuda.memory_stats(0)["allocated_bytes.all.peak"] < 4061909504 * 1.1
240+
assert torch.cuda.memory_stats(1)["allocated_bytes.all.peak"] < 4050944 * 1.1
241+
assert torch.cuda.memory_stats(2)["allocated_bytes.all.peak"] < 10427392 * 1.1
242+
assert torch.cuda.memory_stats(3)["allocated_bytes.all.peak"] < 2031824896 * 1.1
243243
print("No regression detected")
244244

245245

0 commit comments

Comments
 (0)