Skip to content

Commit de2731b

Browse files
committed
hap: output one line per sentence when batch size is >1
1 parent 4f635c1 commit de2731b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

ha/score.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ class Tok:
7171

7272
with torch.amp.autocast(device_type='cuda', dtype=dtype):
7373
logits = model.forward_all(input_ids=input_ids, target_ids=completions, reduction='none')
74-
if len(logits.shape) < 2:
75-
logits = logits[None, :]
76-
for sentence_logits, loss, tokens in zip(logits, logits.view(-1, input_ids.shape[-1]).sum(-1), completion_tokens):
74+
logits = logits.view(-1, input_ids.shape[-1])
75+
for sentence_logits, tokens in zip(logits, completion_tokens):
76+
loss = sentence_logits.sum(-1)
7777
num_tokens = min(model.config.block_size, len(tokens))
7878
loss_per_token = loss.item() / num_tokens
7979
if args.verbose:

0 commit comments

Comments
 (0)