diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index 4634bc839..e240c11cd 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -4,6 +4,7 @@ import pytest import torch from pydantic import BaseModel, constr +from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams import outlines.generate as generate @@ -42,15 +43,31 @@ def test_vllm_generation_api(model, generator_type, params): res = generator("test", stop_at=[".", "ab"]) assert isinstance(res, str) + res = generator("test", original_output=True) + assert isinstance(res, list) + assert len(res) == 1 + assert isinstance(res[0], RequestOutput) + res1 = generator("test", seed=1) res2 = generator("test", seed=1) assert isinstance(res1, str) assert isinstance(res2, str) assert res1 == res2 + res1 = generator("test", seed=1, original_output=True) + res2 = generator("test", seed=1) + assert isinstance(res1[0], RequestOutput) + assert isinstance(res2, str) + text1 = [sample.text for sample in res1[0].outputs] + assert len(text1) == 1 + assert text1[0] == res2 + res = generator(["test", "test1"]) assert len(res) == 2 + res = generator(["test", "test1"], original_output=True) + assert len(res) == 2 + def test_vllm_sampling_params(model): generator = generate.text(model)