Skip to content

Commit

Permalink
ADLR/megatron-lm!2430 - Fix log probs output for inference
Browse files Browse the repository at this point in the history
Co-authored-by: William Dykas <[email protected]>
Co-authored-by: Mcore Bot <[email protected]>
Co-authored-by: William Dykas <[email protected]>
Co-authored-by: root <[email protected]>
  • Loading branch information
5 people authored and jaredcasper committed Jan 7, 2025
1 parent 24e0126 commit 5ff34d0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,10 @@ def generate_all_output_tokens_static_batch(
request.generated_log_probs = (
None
if output_log_probs is None
else output_log_probs[idx, input_prompt_length:required_sequence_length]
else output_log_probs[
idx,
input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1),
]
)
request.status = Status.COMPLETED
request.generated_text = self.detokenize_generations(required_result_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,41 @@ def test_generate_all_output_tokens_static_batch(self):
), f"Status should be completed but its {request.status}"
assert request.generated_length > 0, f"Generated length should be greater than zero"
assert request.generated_text is not None, "Generated text should not be None"

def test_output_log_probs(self):
self.mock_tokenizer.vocab_size = self.vocab_size
self.mock_tokenizer.bos = 0
self.mock_tokenizer.eod = self.vocab_size - 1
self.mock_tokenizer.detokenize.return_value = ''.join(
random.choices(string.ascii_letters, k=random.randint(4, 10))
)

prompt = ""
active_requests: Dict[int, InferenceRequest] = OrderedDict()
for i in range(self.batch_size):
self.mock_tokenizer.tokenize.return_value = torch.randn(
self.batch_size, self.vocab_size
).cuda()
inference_request = InferenceRequest(
request_id=i,
prompt=prompt,
inference_parameters=SamplingParams(
num_tokens_to_generate=1, return_log_probs=True
),
arrival_time=time.time(),
prompt_tokens=[self.mock_tokenizer.bos],
status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS,
)
active_requests[i] = inference_request

requests = self.text_generation_controller.generate_all_output_tokens_static_batch(
active_requests
)

for request_id, request in requests.items():
assert (
request.status == Status.COMPLETED
), f"Status should be completed but its {request.status}"
assert request.generated_length > 0, f"Generated length should be greater than zero"
assert request.generated_text is not None, "Generated text should not be None"
assert len(request.generated_log_probs) == request.generated_length

0 comments on commit 5ff34d0

Please sign in to comment.