Skip to content

Commit

Permalink
Add optional argument original_output to vLLM's generate function
Browse files Browse the repository at this point in the history
original_output is a boolean. If True then it returns the original
(full) output of the vLLM API call.
  • Loading branch information
LouSalaun committed Oct 19, 2024
1 parent 6cff654 commit c5f6aaf
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions outlines/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def generate(
*,
sampling_params: Optional["SamplingParams"] = None,
use_tqdm: bool = True,
original_output: bool = False,
):
"""Generate text using vLLM.
Expand All @@ -74,6 +75,9 @@ def generate(
vLLM documentation for more details: https://docs.vllm.ai/en/latest/dev/sampling_params.html.
use_tqdm
A boolean in order to display progress bar while inferencing
original_output
A boolean, if True then returns the original (full) output of the
vLLM model
Returns
-------
Expand All @@ -82,6 +86,8 @@ def generate(
this is a batch with several sequences but only one sample the list is
of shape `(n_batch)`. If there is only one sequence and one sample, a
string is returned.
If original_output is True, then this function returns the original
(full) output of the vLLM model.
"""
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -134,6 +140,10 @@ def generate(
lora_request=self.lora_request,
use_tqdm=use_tqdm,
)

if original_output:
return results

results = [[sample.text for sample in batch.outputs] for batch in results]

batch_size = len(results)
Expand Down

0 comments on commit c5f6aaf

Please sign in to comment.