diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index f15c819c43..fa1406a2bf 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -381,7 +381,10 @@ def generate_all_output_tokens_static_batch( request.generated_log_probs = ( None if output_log_probs is None - else output_log_probs[idx, input_prompt_length:required_sequence_length] + else output_log_probs[ + idx, + input_prompt_length - 1 : (input_prompt_length + required_sequence_length - 1), + ] ) request.status = Status.COMPLETED request.generated_text = self.detokenize_generations(required_result_tokens) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py index 1db360f232..e9ab941ab3 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_simple_text_generation_controller.py @@ -173,3 +173,41 @@ def test_generate_all_output_tokens_static_batch(self): ), f"Status should be completed but its {request.status}" assert request.generated_length > 0, f"Generated length should be greater than zero" assert request.generated_text is not None, "Generated text should not be None" + + def test_output_log_probs(self): + self.mock_tokenizer.vocab_size = self.vocab_size + self.mock_tokenizer.bos = 0 + self.mock_tokenizer.eod = self.vocab_size - 1 + self.mock_tokenizer.detokenize.return_value = ''.join( + random.choices(string.ascii_letters, k=random.randint(4, 10)) + ) + + prompt = "" + active_requests: Dict[int, InferenceRequest] = OrderedDict() + for i in range(self.batch_size): + self.mock_tokenizer.tokenize.return_value = torch.randn( + self.batch_size, self.vocab_size + ).cuda() + inference_request = InferenceRequest( + request_id=i, + prompt=prompt, + inference_parameters=SamplingParams( + num_tokens_to_generate=1, return_log_probs=True + ), + arrival_time=time.time(), + prompt_tokens=[self.mock_tokenizer.bos], + status=Status.ACTIVE_BUT_NOT_GENERATING_TOKENS, + ) + active_requests[i] = inference_request + + requests = self.text_generation_controller.generate_all_output_tokens_static_batch( + active_requests + ) + + for request_id, request in requests.items(): + assert ( + request.status == Status.COMPLETED + ), f"Status should be completed but its {request.status}" + assert request.generated_length > 0, f"Generated length should be greater than zero" + assert request.generated_text is not None, "Generated text should not be None" + assert len(request.generated_log_probs) == request.generated_length