Skip to content

Commit 0eb06be

Browse files
authored
Merge branch 'develop' into adapt/ernie-vl
2 parents 985ccf7 + b61a272 commit 0eb06be

File tree

1 file changed

+157
-124
lines changed

1 file changed

+157
-124
lines changed

tests/output/test_get_save_output_v1.py

Lines changed: 157 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -12,133 +12,166 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
16-
import signal
17-
import socket
18-
import subprocess
15+
import queue
1916
import time
20-
import traceback
21-
22-
import pytest
23-
24-
from fastdeploy import LLM, SamplingParams
25-
26-
FD_ENGINE_QUEUE_PORT = int(os.getenv("FD_ENGINE_QUEUE_PORT", 8313))
27-
FD_CACHE_QUEUE_PORT = int(os.getenv("FD_CACHE_QUEUE_PORT", 8333))
28-
MAX_WAIT_SECONDS = 60
29-
30-
os.environ["LD_LIBRARY_PATH"] = "/usr/local/nccl/"
31-
# enable get_save_output_v1
32-
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = "1"
33-
34-
35-
def is_port_open(host: str, port: int, timeout=1.0):
36-
"""
37-
Check if a TCP port is open on the given host.
38-
Returns True if connection succeeds, False otherwise.
39-
"""
40-
try:
41-
with socket.create_connection((host, port), timeout):
42-
return True
43-
except Exception:
44-
return False
45-
46-
47-
@pytest.fixture(scope="module")
48-
def model_path():
49-
"""
50-
Get model path from environment variable MODEL_PATH,
51-
default to "./ERNIE-4.5-0.3B-Paddle" if not set.
52-
"""
53-
base_path = os.getenv("MODEL_PATH")
54-
if base_path:
55-
return os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle")
56-
else:
57-
return "./ERNIE-4.5-0.3B-Paddle"
58-
59-
60-
@pytest.fixture(scope="module")
61-
def llm(model_path):
62-
"""
63-
Fixture to initialize the LLM model with a given model path
64-
"""
65-
try:
66-
output = subprocess.check_output(f"lsof -i:{FD_ENGINE_QUEUE_PORT} -t", shell=True).decode().strip()
67-
for pid in output.splitlines():
68-
os.kill(int(pid), signal.SIGKILL)
69-
print(f"Killed process on port {FD_ENGINE_QUEUE_PORT}, pid={pid}")
70-
except subprocess.CalledProcessError:
17+
import unittest
18+
from threading import Thread
19+
from unittest.mock import Mock
20+
21+
import paddle
22+
import zmq
23+
24+
from fastdeploy import envs
25+
from fastdeploy.inter_communicator import ZmqIpcClient
26+
from fastdeploy.model_executor.pre_and_post_process import _build_stream_transfer_data
27+
from fastdeploy.output.token_processor import TokenProcessor
28+
from fastdeploy.worker.gpu_model_runner import GPUModelRunner
29+
30+
paddle.set_device("cpu")
31+
32+
33+
# Mock classes and constants needed for the test
34+
class MockConfig:
35+
class ParallelConfig:
36+
local_data_parallel_id = 0
37+
enable_expert_parallel = False
38+
data_parallel_size = 1
39+
40+
class SpeculativeConfig:
41+
method = None
42+
43+
class ModelConfig:
44+
enable_logprob = False
45+
46+
class SchedulerConfig:
47+
name = "default"
48+
49+
parallel_config = ParallelConfig()
50+
speculative_config = SpeculativeConfig()
51+
model_config = ModelConfig()
52+
scheduler_config = SchedulerConfig()
53+
54+
55+
class MockTask:
56+
def __init__(self):
57+
self.request_id = "test_request_1"
58+
self.arrival_time = time.time()
59+
self.inference_start_time = time.time()
60+
self.schedule_start_time = time.time()
61+
self.preprocess_end_time = time.time() - 0.1
62+
self.preprocess_start_time = time.time() - 0.2
63+
self.eos_token_ids = [2]
64+
self.output_token_ids = []
65+
self.messages = "Test prompt"
66+
self.num_cached_tokens = 0
67+
self.disaggregate_info = None
68+
self.prefill_chunk_info = None
69+
self.prefill_chunk_num = 0
70+
self.pooling_params = None
71+
72+
def get(self, key: str, default_value=None):
73+
if hasattr(self, key):
74+
return getattr(self, key)
75+
elif hasattr(self, "sampling_params") and hasattr(self.sampling_params, key):
76+
return getattr(self.sampling_params, key)
77+
else:
78+
return default_value
79+
80+
81+
class MockResourceManager:
82+
def __init__(self):
83+
self.stop_flags = [False]
84+
self.tasks_list = [MockTask()]
85+
self.to_be_rescheduled_request_id_set = set()
86+
87+
def info(self):
88+
return "Mock resource manager info"
89+
90+
def reschedule_preempt_task(self, task_id):
7191
pass
7292

73-
try:
74-
start = time.time()
75-
llm = LLM(
76-
model=model_path,
77-
tensor_parallel_size=2,
78-
num_gpu_blocks_override=1024,
79-
engine_worker_queue_port=FD_ENGINE_QUEUE_PORT,
80-
cache_queue_port=FD_CACHE_QUEUE_PORT,
81-
max_model_len=8192,
82-
seed=1,
83-
)
84-
85-
# Wait for the port to be open
86-
wait_start = time.time()
87-
while not is_port_open("127.0.0.1", FD_ENGINE_QUEUE_PORT):
88-
if time.time() - wait_start > MAX_WAIT_SECONDS:
89-
pytest.fail(
90-
f"Model engine did not start within {MAX_WAIT_SECONDS} seconds on port {FD_ENGINE_QUEUE_PORT}"
91-
)
92-
time.sleep(1)
93-
94-
print(f"Model loaded successfully from {model_path} in {time.time() - start:.2f}s.")
95-
yield llm
96-
except Exception:
97-
print(f"Failed to load model from {model_path}.")
98-
traceback.print_exc()
99-
pytest.fail(f"Failed to initialize LLM model from {model_path}")
100-
101-
102-
def test_generate_prompts(llm):
103-
"""
104-
Test basic prompt generation
105-
"""
106-
107-
# Only one prompt enabled for testing currently
108-
prompts = [
109-
"请介绍一下中国的四大发明。",
110-
"太阳和地球之间的距离是多少?",
111-
"写一首关于春天的古风诗。",
112-
]
113-
114-
sampling_params = SamplingParams(
115-
temperature=0.8,
116-
top_p=0.95,
117-
)
118-
119-
try:
120-
outputs = llm.generate(prompts, sampling_params)
121-
122-
# Verify basic properties of the outputs
123-
assert len(outputs) == len(prompts), "Number of outputs should match number of prompts"
124-
125-
for i, output in enumerate(outputs):
126-
assert output.prompt == prompts[i], f"Prompt mismatch for case {i + 1}"
127-
assert isinstance(output.outputs.text, str), f"Output text should be string for case {i + 1}"
128-
assert len(output.outputs.text) > 0, f"Generated text should not be empty for case {i + 1}"
129-
assert isinstance(output.finished, bool), f"'finished' should be boolean for case {i + 1}"
130-
assert output.metrics.model_execute_time > 0, f"Execution time should be positive for case {i + 1}"
131-
132-
print(f"=== Prompt generation Case {i + 1} Passed ===")
133-
134-
except Exception:
135-
print("Failed during prompt generation.")
136-
traceback.print_exc()
137-
pytest.fail("Prompt generation test failed")
93+
94+
class MockCachedGeneratedTokens:
95+
def __init__(self):
96+
self.cache = []
97+
98+
def put_results(self, results):
99+
self.cache.extend(results)
100+
101+
102+
class TestGetSaveOutputV1(unittest.TestCase):
103+
def setup_model_runner(self):
104+
"""Helper method to setup GPUModelRunner with different configurations"""
105+
cfg = MockConfig()
106+
cfg.speculative_config.method = None
107+
cfg.model_config.enable_logprob = False
108+
109+
model_runner = GPUModelRunner.__new__(GPUModelRunner)
110+
111+
model_runner.zmq_client = None
112+
model_runner.async_output_queue = None
113+
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
114+
model_runner.zmq_client = ZmqIpcClient(
115+
name=f"get_save_output_rank{cfg.parallel_config.local_data_parallel_id}", mode=zmq.PUSH
116+
)
117+
model_runner.zmq_client.connect()
118+
model_runner.zmq_client.socket.SNDTIMEO = 3000
119+
model_runner.async_output_queue: queue.Queue = queue.Queue()
120+
model_runner.async_output_copy_thread = Thread(
121+
target=model_runner._async_output_busy_loop,
122+
daemon=True,
123+
name="WorkerAsyncOutputCopy",
124+
)
125+
model_runner.async_output_copy_thread.start()
126+
127+
return model_runner
128+
129+
def setup_token_processor(self):
130+
"""Helper method to setup TokenProcessor with different configurations"""
131+
cfg = MockConfig()
132+
cfg.speculative_config.method = None
133+
cfg.model_config.enable_logprob = False
134+
135+
processor = TokenProcessor.__new__(TokenProcessor)
136+
processor.cfg = cfg
137+
processor.cached_generated_tokens: MockCachedGeneratedTokens = MockCachedGeneratedTokens()
138+
processor.executor = Mock()
139+
processor.engine_worker_queue = Mock()
140+
processor.split_connector = Mock()
141+
processor.worker = None
142+
processor.resource_manager = MockResourceManager()
143+
task1 = MockTask()
144+
task2 = MockTask()
145+
processor.resource_manager.tasks_list = [task1, task2]
146+
processor.resource_manager.stop_flags = [False, False]
147+
processor.tokens_counter = {task1.request_id: 0, task2.request_id: 0}
148+
processor.total_step = 0
149+
processor.speculative_decoding = False
150+
processor.use_logprobs = False
151+
152+
processor.number_of_output_tokens = 0
153+
processor.prefill_result_status = {}
154+
155+
processor.run()
156+
return processor
157+
158+
def test_normal(self):
159+
"""Test normal senario(without speculative decoding and logprobs)"""
160+
# init token_processor, model_runner and start zmq_client
161+
envs.FD_USE_GET_SAVE_OUTPUT_V1 = 1
162+
processor = self.setup_token_processor()
163+
model_runner = self.setup_model_runner()
164+
165+
# put data into zmq client
166+
data = paddle.to_tensor([[100]], dtype="int64")
167+
output_tokens = _build_stream_transfer_data(data)
168+
model_runner.async_output_queue.put(output_tokens)
169+
170+
# check result
171+
cached_generated_tokens: MockCachedGeneratedTokens = processor.cached_generated_tokens
172+
for c in cached_generated_tokens.cache:
173+
assert c.outputs.token_ids == [100]
138174

139175

140176
if __name__ == "__main__":
141-
"""
142-
Main entry point for the test script.
143-
"""
144-
pytest.main(["-sv", __file__])
177+
unittest.main(verbosity=2, buffer=False)

0 commit comments

Comments
 (0)