Skip to content

Commit 8eca106

Browse files
authored
Merge pull request #31 from BoFeng2477/patch-2
The reshape of input_id doesn't match HF OPT model's API
2 parents 9f2abd3 + 957ce15 commit 8eca106

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ def sync():
336336
for i in range(input_ids.numel()):
337337
tick = time.time()
338338
out = model(
339-
input_ids[:, i].reshape(-1),
339+
input_ids[:, i].reshape((1,-1)),
340340
past_key_values=cache['past'],
341341
attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1))
342342
)

0 commit comments

Comments
 (0)