diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index d1f97bde2..55fdade6c 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -50,6 +50,7 @@ def generate( *, sampling_params: Optional["SamplingParams"] = None, use_tqdm: bool = True, + original_output: bool = False, ): """Generate text using vLLM. @@ -74,6 +75,9 @@ def generate( vLLM documentation for more details: https://docs.vllm.ai/en/latest/dev/sampling_params.html. use_tqdm A boolean in order to display progress bar while inferencing + original_output + A boolean, if True then returns the original (full) output of the + vLLM model Returns ------- @@ -82,6 +86,8 @@ def generate( this is a batch with several sequences but only one sample the list is of shape `(n_batch)`. If there is only one sequence and one sample, a string is returned. + If original_output is True, then this function returns the original + (full) output of the vLLM model. """ from vllm.sampling_params import SamplingParams @@ -134,6 +140,10 @@ def generate( lora_request=self.lora_request, use_tqdm=use_tqdm, ) + + if original_output: + return results + results = [[sample.text for sample in batch.outputs] for batch in results] batch_size = len(results)