From 55acb9f82c5a3ccf1de0a06564d823afe4fb80fa Mon Sep 17 00:00:00 2001 From: Yassine Date: Wed, 17 Apr 2024 20:06:52 -0700 Subject: [PATCH] this is a lie --- generate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/generate.py b/generate.py index 8446d115..5b63939d 100644 --- a/generate.py +++ b/generate.py @@ -57,12 +57,12 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): return idx_next, probs def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: - # input_pos: [B, S] + # input_pos: [S] logits = model(x, input_pos) return sample(logits, **sampling_kwargs)[0] def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] + # input_pos: [1] assert input_pos.shape[-1] == 1 logits = model(x, input_pos) return sample(logits, **sampling_kwargs)