Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 157 additions & 124 deletions tests/output/test_get_save_output_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,133 +12,166 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import signal
import socket
import subprocess
import queue
import time
import traceback

import pytest

from fastdeploy import LLM, SamplingParams

FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
MAX_WAIT_SECONDS = 60

os.environ["LD_LIBRARY_PATH"] = "/usr/local/nccl/"
# enable get_save_output_v1
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"


def is_port_open(host: str, port: int, timeout=1.0):
"""
Check if a TCP port is open on the given host.
Returns True if connection succeeds, False otherwise.
"""
try:
with socket.create_connection((host, port), timeout):
return True
except Exception:
return False


@pytest.fixture(scope="module")
def model_path():
"""
Get model path from environment variable MODEL_PATH,
default to "./ERNIE-4.5-0.3B-Paddle" if not set.
"""
base_path = os.getenv("MODEL_PATH")
if base_path:
return os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
else:
return "./ERNIE-4.5-0.3B-Paddle"


@pytest.fixture(scope="module")
def llm(model_path):
"""
Fixture to initialize the LLM model with a given model path
"""
try:
output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip()
for pid in output.splitlines():
os.kill(int(pid), signal.SIGKILL)
print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}")
except subprocess.CalledProcessError:
import unittest
from threading import Thread
from unittest.mock import Mock

import paddle
import zmq

from fastdeploy import envs
from fastdeploy.inter_communicator import ZmqIpcClient
from fastdeploy.model_executor.pre_and_post_process import _build_stream_transfer_data
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.worker.gpu_model_runner import GPUModelRunner

paddle.set_device("cpu")


# Mock classes and constants needed for the test
class MockConfig:
class ParallelConfig:
local_data_parallel_id = 0
enable_expert_parallel = False
data_parallel_size = 1

class SpeculativeConfig:
method = None

class ModelConfig:
enable_logprob = False

class SchedulerConfig:
name = "default"

parallel_config = ParallelConfig()
speculative_config = SpeculativeConfig()
model_config = ModelConfig()
scheduler_config = SchedulerConfig()


class MockTask:
def __init__(self):
self.request_id = "test_request_1"
self.arrival_time = time.time()
self.inference_start_time = time.time()
self.schedule_start_time = time.time()
self.preprocess_end_time = time.time() - 0.1
self.preprocess_start_time = time.time() - 0.2
self.eos_token_ids = [2]
self.output_token_ids = []
self.messages = "Test prompt"
self.num_cached_tokens = 0
self.disaggregate_info = None
self.prefill_chunk_info = None
self.prefill_chunk_num = 0
self.pooling_params = None

def get(self, key: str, default_value=None):
if hasattr(self, key):
return getattr(self, key)
elif hasattr(self, "sampling_params") and hasattr(self.sampling_params, key):
return getattr(self.sampling_params, key)
else:
return default_value


class MockResourceManager:
def __init__(self):
self.stop_flags = [False]
self.tasks_list = [MockTask()]
self.to_be_rescheduled_request_id_set = set()

def info(self):
return "Mock resource manager info"

def reschedule_preempt_task(self, task_id):
pass

try:
start = time.time()
llm = LLM(
model=model_path,
tensor_parallel_size=2,
num_gpu_blocks_override=1024,
engine_worker_queue_port=FD_ENGINE_QUEUE_PORT,
cache_queue_port=FD_CACHE_QUEUE_PORT,
max_model_len=8192,
seed=1,
)

# Wait for the port to be open
wait_start = time.time()
while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT):
if time.time() - wait_start > MAX_WAIT_SECONDS:
pytest.fail(
f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}"
)
time.sleep(1)

print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.")
yield llm
except Exception:
print(f"Failed to load model from {model_path}.")
traceback.print_exc()
pytest.fail(f"Failed to initialize LLM model from {model_path}")


def test_generate_prompts(llm):
"""
Test basic prompt generation
"""

# Only one prompt enabled for testing currently
prompts = [
"请介绍一下中国的四大发明。",
"太阳和地球之间的距离是多少?",
"写一首关于春天的古风诗。",
]

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
)

try:
outputs = llm.generate(prompts, sampling_params)

# Verify basic properties of the outputs
assert len(outputs) == len(prompts), "Number of outputs should match number of prompts"

for i, output in enumerate(outputs):
assert output.prompt == prompts[i], f"Prompt mismatch for case {i + 1}"
assert isinstance(output.outputs.text, str), f"Output text should be string for case {i + 1}"
assert len(output.outputs.text) > 0, f"Generated text should not be empty for case {i + 1}"
assert isinstance(output.finished, bool), f"'finished' should be boolean for case {i + 1}"
assert output.metrics.model_execute_time > 0, f"Execution time should be positive for case {i + 1}"

print(f"=== Prompt generation Case {i + 1} Passed ===")

except Exception:
print("Failed during prompt generation.")
traceback.print_exc()
pytest.fail("Prompt generation test failed")

class MockCachedGeneratedTokens:
def __init__(self):
self.cache = []

def put_results(self, results):
self.cache.extend(results)


class TestGetSaveOutputV1(unittest.TestCase):
def setup_model_runner(self):
"""Helper method to setup GPUModelRunner with different configurations"""
cfg = MockConfig()
cfg.speculative_config.method = None
cfg.model_config.enable_logprob = False

model_runner = GPUModelRunner.__new__(GPUModelRunner)

model_runner.zmq_client = None
model_runner.async_output_queue = None
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
model_runner.zmq_client = ZmqIpcClient(
name=f"get_save_output_rank{cfg.parallel_config.local_data_parallel_id}", mode=zmq.PUSH
)
model_runner.zmq_client.connect()
model_runner.zmq_client.socket.SNDTIMEO = 3000
model_runner.async_output_queue: queue.Queue = queue.Queue()
model_runner.async_output_copy_thread = Thread(
target=model_runner._async_output_busy_loop,
daemon=True,
name="WorkerAsyncOutputCopy",
)
model_runner.async_output_copy_thread.start()

return model_runner

def setup_token_processor(self):
"""Helper method to setup TokenProcessor with different configurations"""
cfg = MockConfig()
cfg.speculative_config.method = None
cfg.model_config.enable_logprob = False

processor = TokenProcessor.__new__(TokenProcessor)
processor.cfg = cfg
processor.cached_generated_tokens: MockCachedGeneratedTokens = MockCachedGeneratedTokens()
processor.executor = Mock()
processor.engine_worker_queue = Mock()
processor.split_connector = Mock()
processor.worker = None
processor.resource_manager = MockResourceManager()
task1 = MockTask()
task2 = MockTask()
processor.resource_manager.tasks_list = [task1, task2]
processor.resource_manager.stop_flags = [False, False]
processor.tokens_counter = {task1.request_id: 0, task2.request_id: 0}
processor.total_step = 0
processor.speculative_decoding = False
processor.use_logprobs = False

processor.number_of_output_tokens = 0
processor.prefill_result_status = {}

processor.run()
return processor

def test_normal(self):
"""Test normal senario(without speculative decoding and logprobs)"""
# init token_processor, model_runner and start zmq_client
envs.FD_USE_GET_SAVE_OUTPUT_V1 = 1
processor = self.setup_token_processor()
model_runner = self.setup_model_runner()

# put data into zmq client
data = paddle.to_tensor([[100]], dtype="int64")
output_tokens = _build_stream_transfer_data(data)
model_runner.async_output_queue.put(output_tokens)

# check result
cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens
for c in cached_generated_tokens.cache:
assert c.outputs.token_ids == [100]


if __name__ == "__main__":
"""
Main entry point for the test script.
"""
pytest.main(["-sv", __file__])
unittest.main(verbosity=2, buffer=False)
Loading