diff --git a/docker/llm/inference/xpu/docker/Dockerfile b/docker/llm/inference/xpu/docker/Dockerfile
index 7a812482db7..ce629ad2b69 100644
--- a/docker/llm/inference/xpu/docker/Dockerfile
+++ b/docker/llm/inference/xpu/docker/Dockerfile
@@ -61,7 +61,7 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO
cp -r ./ipex-llm/python/llm/example/GPU/vLLM-Serving/ ./vLLM-Serving && \
# Download pp_serving
mkdir -p /llm/pp_serving && \
- cp ./ipex-llm/python/llm/example/GPU/Pipeline-Parallel-FastAPI/*.py /llm/pp_serving/ && \
+ cp ./ipex-llm/python/llm/example/GPU/Pipeline-Parallel-Serving/*.py /llm/pp_serving/ && \
# Install related library of benchmarking
pip install pandas omegaconf && \
chmod +x /llm/benchmark.sh && \
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py
deleted file mode 100644
index bbcb392f70a..00000000000
--- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py
+++ /dev/null
@@ -1,346 +0,0 @@
-#
-# Copyright 2016 The BigDL Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-import torch.nn.parallel
-import torch.distributed as dist
-import os
-
-from ipex_llm.utils.common import invalidInputError
-from ipex_llm.transformers import init_pipeline_parallel, ModelRunner
-import oneccl_bindings_for_pytorch
-import json
-
-from transformers.utils import logging
-logger = logging.get_logger(__name__)
-
-init_pipeline_parallel()
-
-my_rank = dist.get_rank()
-my_size = dist.get_world_size()
-device = f"xpu:{my_rank}"
-logger.info(f"rank: {my_rank}, size: {my_size}")
-
-import time
-from transformers import AutoTokenizer
-from fastapi import FastAPI, HTTPException, Request
-from fastapi.responses import StreamingResponse
-from pydantic import BaseModel
-import uvicorn
-import asyncio, uuid
-from typing import Dict, List, Optional, Any, Callable, Union
-import argparse
-
-
-class PromptRequest(BaseModel):
- prompt: str
- n_predict: Optional[int] = 256
- req_type: str = 'completion'
-
-from openai.types.chat import ChatCompletionMessageParam
-class ChatCompletionRequest(BaseModel):
- messages: List[ChatCompletionMessageParam]
- model: str
- max_tokens: Optional[int] = None
- stream: Optional[bool] = False
-
-
-class CompletionRequest(BaseModel):
- model: str
- prompt: Union[List[int], List[List[int]], str, List[str]]
- max_tokens: Optional[int] = None
- stream: Optional[bool] = False
-
-empty_req = PromptRequest(prompt="", n_predict=0)
-
-app = FastAPI()
-global tokenizer
-global local_model
-
-request_queue: asyncio.Queue = asyncio.Queue()
-result_dict: Dict[str, str] = {}
-streamer_dict = {}
-local_rank = my_rank
-
-
-from openai_protocol import (
- ChatCompletionResponseStreamChoice,
- ChatCompletionStreamResponse,
- ChatCompletionResponseChoice,
- ChatCompletionResponse,
- ChatMessage,
- DeltaMessage,
- CompletionResponseChoice,
- CompletionResponse,
- CompletionResponseStreamChoice,
- CompletionStreamResponse,
-)
-
-
-async def chat_stream_generator(local_model, delta_text_queue, request_id):
- model_name = local_model.model_name
- index = 0
- while True:
- if not delta_text_queue.empty():
- with local_model.dict_lock:
- remain, delta_text = await delta_text_queue.get()
- # print(remain)
- choice_data = ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(role="assistant", content=delta_text),
- logprobs=None,
- finish_reason=None)
- chunk = ChatCompletionStreamResponse(
- id=request_id,
- choices=[choice_data],
- model=model_name)
- data = chunk.model_dump_json(exclude_unset=True)
- yield f"data: {data}\n\n"
- index = index + 1
- if remain == 0:
- choice_data = ChatCompletionResponseStreamChoice(
- index=index,
- delta=DeltaMessage(role="assistant", content=None),
- logprobs=None,
- finish_reason="length")
- chunk = ChatCompletionStreamResponse(
- id=request_id,
- choices=[choice_data],
- model=model_name)
- data = chunk.model_dump_json(exclude_unset=True)
- yield f"data: {data}\n\n"
- break
- else:
- await asyncio.sleep(0)
- local_model.streamer.pop(request_id, None)
-
-
-async def completion_stream_generator(local_model, delta_text_queue, request_id):
- model_name = local_model.model_name
- index = 0
- while True:
- if not delta_text_queue.empty():
- with local_model.dict_lock:
- remain, delta_text = await delta_text_queue.get()
- # print(remain)
- choice_data = CompletionResponseStreamChoice(
- index=index,
- text=delta_text,
- logprobs=None,
- finish_reason=None)
- chunk = CompletionStreamResponse(
- id=request_id,
- choices=[choice_data],
- model=model_name)
- data = chunk.model_dump_json(exclude_unset=True)
- yield f"data: {data}\n\n"
- index = index + 1
- if remain == 0:
- choice_data = CompletionResponseStreamChoice(
- index=index,
- text="",
- logprobs=None,
- finish_reason="length")
- chunk = CompletionStreamResponse(
- id=request_id,
- choices=[choice_data],
- model=model_name)
- data = chunk.model_dump_json(exclude_unset=True)
- yield f"data: {data}\n\n"
- break
- else:
- await asyncio.sleep(0)
- local_model.streamer.pop(request_id, None)
-
-
-async def generator(local_model, delta_text_queue, request_id):
- while True:
- if not delta_text_queue.empty():
- with local_model.dict_lock:
- remain, delta_text = await delta_text_queue.get()
- yield delta_text
- if remain == 0:
- break
- else:
- await asyncio.sleep(0)
- local_model.streamer.pop(request_id, None)
-
-
-@app.post("/generate/")
-async def generate(prompt_request: PromptRequest):
- request_id = str(uuid.uuid4())
- await local_model.waiting_requests.put((request_id, prompt_request))
- while True:
- await asyncio.sleep(0)
- cur_streamer = local_model.streamer.get(request_id, None)
- if cur_streamer is not None:
- output_str = []
- async for item in generator(local_model, cur_streamer, request_id):
- output_str.append(item)
- return request_id, "".join(output_str)
-
-
-async def generate_stream(prompt_request: PromptRequest):
- request_id = str(uuid.uuid4()) + "stream"
- await local_model.waiting_requests.put((request_id, prompt_request))
- while True:
- await asyncio.sleep(0)
- cur_streamer = local_model.streamer.get(request_id, None)
- if cur_streamer is not None:
- if prompt_request.req_type == 'completion':
- cur_generator = completion_stream_generator(local_model, cur_streamer, request_id)
- elif prompt_request.req_type == 'chat':
- cur_generator = chat_stream_generator(local_model, cur_streamer, request_id)
- else:
- invalidInputError(False, "Invalid Request Type.")
-
- return request_id, StreamingResponse(
- content=cur_generator, media_type="text/event-stream"
- )
-
-@app.post("/generate_stream/")
-async def generate_stream_api(prompt_request: PromptRequest):
- request_id, result = await generate_stream(prompt_request)
- return result
-
-
-DEFAULT_SYSTEM_PROMPT = """\
-"""
-
-def get_prompt(messages) -> str:
- prompt = ""
- for msg in messages:
- role = msg["role"]
- content = msg["content"]
- if role == "system":
- prompt += f"<>\n{content}\n<>\n\n"
- elif role == "user":
- prompt += f"[INST] {content} [/INST] "
- elif role == "assistant":
- prompt += f"{content} "
- else:
- raise ValueError(f"Unknown role: {role}")
- return prompt.strip()
-
-@app.post("/v1/chat/completions")
-async def create_chat_completion(request: ChatCompletionRequest):
- model_name = local_model.model_name
- if request.max_tokens is None:
- n_predict = 256
- else:
- n_predict = request.max_tokens
- prompt_request = PromptRequest(
- prompt=get_prompt(request.messages),
- n_predict=n_predict,
- req_type="chat"
- )
- if request.stream:
- request_id, result = await generate_stream(prompt_request)
- else:
- request_id, result = await generate(prompt_request)
- choice_data = ChatCompletionResponseChoice(
- index=0,
- message=ChatMessage(role="assistant", content=result),
- logprobs=None,
- finish_reason="length")
- result = ChatCompletionResponse(
- id=request_id,
- choices=[choice_data],
- model=model_name)
- return result
-
-@app.post("/v1/completions")
-async def create_completion(request: CompletionRequest):
- model_name = local_model.model_name
- if request.max_tokens is None:
- n_predict = 256
- else:
- n_predict = request.max_tokens
- prompt_request = PromptRequest(
- prompt=request.prompt,
- n_predict=n_predict,
- req_type="completion"
- )
- if request.stream:
- request_id, result = await generate_stream(prompt_request)
- else:
- request_id, result = await generate(prompt_request)
- choice_data = CompletionResponseChoice(
- index=0,
- text=result,
- logprobs=None,
- finish_reason="length")
- result = CompletionResponse(
- id=request_id,
- choices=[choice_data],
- model=model_name)
- return result
-
-
-async def process_requests(local_model, result_dict):
- while True:
- await asyncio.sleep(0)
- await local_model.process_step(tokenizer, result_dict)
-
-
-@app.on_event("startup")
-async def startup_event():
- asyncio.create_task(process_requests(local_model, result_dict))
-
-async def main():
- parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
- parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
- help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
- ', or the path to the huggingface checkpoint folder')
- parser.add_argument('--low-bit', type=str, default='sym_int4',
- help='The quantization type the model will convert to.')
- parser.add_argument('--port', type=int, default=8000,
- help='The port number on which the server will run.')
- parser.add_argument('--max-num-seqs', type=int, default=8,
- help='Max num sequences in a batch.')
- parser.add_argument('--max-prefilled-seqs', type=int, default=0,
- help='Max num sequences in a batch during prefilling.')
-
- args = parser.parse_args()
- model_path = args.repo_id_or_model_path
- low_bit = args.low_bit
- max_num_seqs = args.max_num_seqs
- max_prefilled_seqs = args.max_prefilled_seqs
-
- # serialize model initialization so that we do not run out of CPU memory
- for i in range(my_size):
- if my_rank == i:
- logger.info("start model initialization")
- global local_model
- local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs)
- logger.info("model initialized")
- dist.barrier()
- # Load tokenizer
- global tokenizer
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
- if tokenizer.pad_token is None:
- tokenizer.pad_token = tokenizer.eos_token
-
- if local_rank == 0:
- config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
- server = uvicorn.Server(config)
- await server.serve()
- else:
- while True:
- await asyncio.sleep(0)
- await local_model.process_step(tokenizer, result_dict)
-
-if __name__ == "__main__":
- asyncio.run(main())
\ No newline at end of file
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md b/python/llm/example/GPU/Pipeline-Parallel-Serving/README.md
similarity index 94%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/README.md
index 278c8c549b6..73607b2e042 100644
--- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md
+++ b/python/llm/example/GPU/Pipeline-Parallel-Serving/README.md
@@ -50,7 +50,14 @@ pip install transformers==4.40.0
pip install trl==0.8.1
```
-### 2. Run pipeline parallel serving on multiple GPUs
+### 2-1. Run ipex-llm serving on one GPU card
+
+```bash
+# Need to set NUM_GPUS=1 and MODEL_PATH in run.sh first
+bash run.sh
+```
+
+### 2-2. Run pipeline parallel serving on multiple GPUs
```bash
# Need to set MODEL_PATH in run.sh first
@@ -76,7 +83,7 @@ export http_proxy=
export https_proxy=
curl -X 'POST' \
- 'http://127.0.0.1:8000/generate/' \
+ 'http://127.0.0.1:8000/generate' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
@@ -99,7 +106,7 @@ Please change the test url accordingly.
```bash
# set t/c to the number of concurrencies to test full throughput.
-wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate/ --timeout 1m
+wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate --timeout 1m
```
## 5. Using the `benchmark.py` Script
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py b/python/llm/example/GPU/Pipeline-Parallel-Serving/benchmark.py
similarity index 100%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/benchmark.py
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/gradio_webui.py b/python/llm/example/GPU/Pipeline-Parallel-Serving/gradio_webui.py
similarity index 100%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/gradio_webui.py
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/gradio_webui.py
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Serving/pipeline_serving.py b/python/llm/example/GPU/Pipeline-Parallel-Serving/pipeline_serving.py
new file mode 100644
index 00000000000..e3062714ca0
--- /dev/null
+++ b/python/llm/example/GPU/Pipeline-Parallel-Serving/pipeline_serving.py
@@ -0,0 +1,78 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch.distributed as dist
+from ipex_llm.transformers import init_pipeline_parallel, PPModelWorker
+from ipex_llm.serving.fastapi import FastApp
+from transformers.utils import logging
+from transformers import AutoTokenizer
+import uvicorn
+import asyncio
+from typing import Dict
+import argparse
+logger = logging.get_logger(__name__)
+
+init_pipeline_parallel()
+my_rank = dist.get_rank()
+my_size = dist.get_world_size()
+device = f"xpu:{my_rank}"
+logger.info(f"rank: {my_rank}, size: {my_size}")
+result_dict: Dict[str, str] = {}
+local_rank = my_rank
+
+async def main():
+ parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging Pipeline-Parallel')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
+ help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--low-bit', type=str, default='sym_int4',
+ help='The quantization type the model will convert to.')
+ parser.add_argument('--port', type=int, default=8000,
+ help='The port number on which the server will run.')
+ parser.add_argument('--max-num-seqs', type=int, default=8,
+ help='Max num sequences in a batch.')
+ parser.add_argument('--max-prefilled-seqs', type=int, default=0,
+ help='Max num sequences in a batch during prefilling.')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+ low_bit = args.low_bit
+ max_num_seqs = args.max_num_seqs
+ max_prefilled_seqs = args.max_prefilled_seqs
+
+ # serialize model initialization so that we do not run out of CPU memory
+ for i in range(my_size):
+ if my_rank == i:
+ logger.info("start model initialization")
+ local_model = PPModelWorker(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs)
+ logger.info("model initialized")
+ dist.barrier()
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ myapp = FastApp(local_model, tokenizer)
+ if local_rank == 0:
+ config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
+ server = uvicorn.Server(config)
+ await server.serve()
+ else:
+ while True:
+ await asyncio.sleep(0)
+ await local_model.process_step(tokenizer, result_dict)
+
+if __name__ == "__main__":
+ asyncio.run(main())
\ No newline at end of file
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/1024.txt b/python/llm/example/GPU/Pipeline-Parallel-Serving/prompt/1024.txt
similarity index 100%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/1024.txt
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/prompt/1024.txt
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/128.txt b/python/llm/example/GPU/Pipeline-Parallel-Serving/prompt/128.txt
similarity index 100%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/128.txt
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/prompt/128.txt
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/2048.txt b/python/llm/example/GPU/Pipeline-Parallel-Serving/prompt/2048.txt
similarity index 100%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/2048.txt
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/prompt/2048.txt
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/32.txt b/python/llm/example/GPU/Pipeline-Parallel-Serving/prompt/32.txt
similarity index 100%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/32.txt
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/prompt/32.txt
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh b/python/llm/example/GPU/Pipeline-Parallel-Serving/run.sh
similarity index 77%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/run.sh
index 14d5a3e6e33..fc3c489ac8e 100644
--- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh
+++ b/python/llm/example/GPU/Pipeline-Parallel-Serving/run.sh
@@ -40,4 +40,9 @@ export LOW_BIT="fp8"
export MAX_NUM_SEQS="4"
export MAX_PREFILLED_SEQS=0
-CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT --max-num-seqs $MAX_NUM_SEQS --max-prefilled-seqs $MAX_PREFILLED_SEQS
+if [[ $NUM_GPUS -eq 1 ]]; then
+ export ZE_AFFINITY_MASK=0
+ python serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT
+else
+ CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit $LOW_BIT --max-num-seqs $MAX_NUM_SEQS --max-prefilled-seqs $MAX_PREFILLED_SEQS
+fi
diff --git a/python/llm/example/GPU/Pipeline-Parallel-Serving/serving.py b/python/llm/example/GPU/Pipeline-Parallel-Serving/serving.py
new file mode 100644
index 00000000000..003307a198f
--- /dev/null
+++ b/python/llm/example/GPU/Pipeline-Parallel-Serving/serving.py
@@ -0,0 +1,53 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+from transformers.utils import logging
+import time
+from transformers import AutoTokenizer
+import uvicorn
+import asyncio
+import argparse
+from ipex_llm.serving.fastapi import FastApp
+from ipex_llm.serving.fastapi import ModelWorker
+logger = logging.get_logger(__name__)
+
+async def main():
+ parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging ipex-llm')
+ parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
+ help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
+ ', or the path to the huggingface checkpoint folder')
+ parser.add_argument('--low-bit', type=str, default='sym_int4',
+ help='The quantization type the model will convert to.')
+ parser.add_argument('--port', type=int, default=8000,
+ help='The port number on which the server will run.')
+
+ args = parser.parse_args()
+ model_path = args.repo_id_or_model_path
+ low_bit = args.low_bit
+
+ local_model = ModelWorker(model_path, low_bit)
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ myapp = FastApp(local_model, tokenizer)
+ config = uvicorn.Config(app=myapp.app, host="0.0.0.0", port=args.port)
+ server = uvicorn.Server(config)
+ await server.serve()
+
+if __name__ == "__main__":
+ asyncio.run(main())
\ No newline at end of file
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/wrk_script_1024.lua b/python/llm/example/GPU/Pipeline-Parallel-Serving/wrk_script_1024.lua
similarity index 100%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/wrk_script_1024.lua
rename to python/llm/example/GPU/Pipeline-Parallel-Serving/wrk_script_1024.lua
diff --git a/python/llm/src/ipex_llm/serving/fastapi/__init__.py b/python/llm/src/ipex_llm/serving/fastapi/__init__.py
new file mode 100644
index 00000000000..79ddfd9dcba
--- /dev/null
+++ b/python/llm/src/ipex_llm/serving/fastapi/__init__.py
@@ -0,0 +1,18 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from .api_server import FastApp
+from .model_worker import ModelWorker
\ No newline at end of file
diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
new file mode 100644
index 00000000000..b02cbeaf43b
--- /dev/null
+++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py
@@ -0,0 +1,315 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from ipex_llm.utils.common import invalidInputError
+from transformers.utils import logging
+from fastapi import FastAPI
+from fastapi.responses import StreamingResponse
+from openai.types.chat import ChatCompletionMessageParam
+from pydantic import BaseModel
+from ipex_llm.utils.common import invalidInputError
+import asyncio
+import uuid
+from typing import List, Optional, Union, Dict
+
+
+result_dict: Dict[str, str] = {}
+logger = logging.get_logger(__name__)
+
+
+class PromptRequest(BaseModel):
+ prompt: str
+ n_predict: Optional[int] = 256
+ req_type: str = 'completion'
+
+
+class ChatCompletionRequest(BaseModel):
+ messages: List[ChatCompletionMessageParam]
+ model: str
+ max_tokens: Optional[int] = None
+ stream: Optional[bool] = False
+
+
+class CompletionRequest(BaseModel):
+ model: str
+ prompt: Union[List[int], List[List[int]], str, List[str]]
+ max_tokens: Optional[int] = None
+ stream: Optional[bool] = False
+
+
+app = FastAPI()
+global tokenizer
+global local_model
+
+
+class FastApp():
+ def __init__(self, model, mytokenizer):
+ global tokenizer
+ global local_model
+ local_model = model
+ tokenizer = mytokenizer
+ self.app = app
+
+
+from .openai_protocol import (
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponse,
+ ChatMessage,
+ DeltaMessage,
+ CompletionResponseChoice,
+ CompletionResponse,
+ CompletionResponseStreamChoice,
+ CompletionStreamResponse,
+)
+
+
+def get_queue_next_token(delta_text_queue):
+ timeout = int(os.getenv("IPEX_LLM_FASTAPI_TIMEOUT", 60))
+ delta_text = delta_text_queue.text_queue.get(timeout=timeout)
+ if delta_text is None:
+ remain = 0
+ else:
+ remain = 1
+ return delta_text, remain
+
+async def chat_stream_generator(local_model, delta_text_queue, request_id):
+ model_name = local_model.model_name
+ index = 0
+ while True:
+ if not hasattr(delta_text_queue, 'empty'):
+ delta_text, remain = get_queue_next_token(delta_text_queue)
+ else:
+ if not delta_text_queue.empty():
+ with local_model.dict_lock:
+ remain, delta_text = await delta_text_queue.get()
+ else:
+ await asyncio.sleep(0)
+ continue
+ if remain == 0 and delta_text is not None or remain != 0:
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=DeltaMessage(role="assistant", content=delta_text),
+ logprobs=None,
+ finish_reason=None)
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ choices=[choice_data],
+ model=model_name)
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+ index = index + 1
+ if remain == 0:
+ choice_data = ChatCompletionResponseStreamChoice(
+ index=index,
+ delta=DeltaMessage(role="assistant", content=None),
+ logprobs=None,
+ finish_reason="length")
+ chunk = ChatCompletionStreamResponse(
+ id=request_id,
+ choices=[choice_data],
+ model=model_name)
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+ break
+ local_model.streamer.pop(request_id, None)
+
+
+async def completion_stream_generator(local_model, delta_text_queue, request_id):
+ model_name = local_model.model_name
+ index = 0
+ while True:
+ if not hasattr(delta_text_queue, 'empty'):
+ delta_text, remain = get_queue_next_token(delta_text_queue)
+ else:
+ if not delta_text_queue.empty():
+ with local_model.dict_lock:
+ remain, delta_text = await delta_text_queue.get()
+ else:
+ await asyncio.sleep(0)
+ continue
+ if remain == 0 and delta_text is not None or remain != 0:
+ choice_data = CompletionResponseStreamChoice(
+ index=index,
+ text=delta_text,
+ logprobs=None,
+ finish_reason=None)
+ chunk = CompletionStreamResponse(
+ id=request_id,
+ choices=[choice_data],
+ model=model_name)
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+ index = index + 1
+ if remain == 0:
+ choice_data = CompletionResponseStreamChoice(
+ index=index,
+ text="",
+ logprobs=None,
+ finish_reason="length")
+ chunk = CompletionStreamResponse(
+ id=request_id,
+ choices=[choice_data],
+ model=model_name)
+ data = chunk.model_dump_json(exclude_unset=True)
+ yield f"data: {data}\n\n"
+ break
+ local_model.streamer.pop(request_id, None)
+
+
+async def generator(local_model, delta_text_queue, request_id):
+ while True:
+ if not hasattr(delta_text_queue, 'empty'):
+ delta_text, remain = get_queue_next_token(delta_text_queue)
+ if delta_text is None:
+ break
+ else:
+ yield delta_text
+ else:
+ if not delta_text_queue.empty():
+ with local_model.dict_lock:
+ remain, delta_text = await delta_text_queue.get()
+ yield delta_text
+ if remain == 0:
+ break
+ else:
+ await asyncio.sleep(0)
+ continue
+ local_model.streamer.pop(request_id, None)
+
+
+@app.post("/generate")
+async def generate(prompt_request: PromptRequest):
+ request_id = str(uuid.uuid4())
+ await local_model.waiting_requests.put((request_id, prompt_request))
+ while True:
+ await asyncio.sleep(0)
+ cur_streamer = local_model.streamer.get(request_id, None)
+ if cur_streamer is not None:
+ output_str = []
+ async for item in generator(local_model, cur_streamer, request_id):
+ output_str.append(item)
+ return request_id, "".join(output_str)
+
+
+@app.post("/generate_stream")
+async def generate_stream_api(prompt_request: PromptRequest):
+ request_id, result = await generate_stream(prompt_request)
+ return result
+
+
+async def generate_stream(prompt_request: PromptRequest):
+ request_id = str(uuid.uuid4()) + "stream"
+ await local_model.waiting_requests.put((request_id, prompt_request))
+ while True:
+ await asyncio.sleep(0)
+ cur_streamer = local_model.streamer.get(request_id, None)
+ if cur_streamer is not None:
+ if prompt_request.req_type == 'completion':
+ cur_generator = completion_stream_generator(local_model, cur_streamer, request_id)
+ elif prompt_request.req_type == 'chat':
+ cur_generator = chat_stream_generator(local_model, cur_streamer, request_id)
+ else:
+ invalidInputError(False, "Invalid Request Type.")
+
+ return request_id, StreamingResponse(
+ content=cur_generator, media_type="text/event-stream"
+ )
+
+
+def get_prompt(messages) -> str:
+ prompt = ""
+ for msg in messages:
+ role = msg["role"]
+ content = msg["content"]
+ if role == "system":
+ prompt += f"<>\n{content}\n<>\n\n"
+ elif role == "user":
+ prompt += f"[INST] {content} [/INST] "
+ elif role == "assistant":
+ prompt += f"{content} "
+ else:
+ invalidInputError(False, f"Unknown role: {role}")
+ return prompt.strip()
+
+
+@app.post("/v1/chat/completions")
+async def create_chat_completion(request: ChatCompletionRequest):
+ model_name = local_model.model_name
+ if request.max_tokens is None:
+ n_predict = 256
+ else:
+ n_predict = request.max_tokens
+ prompt_request = PromptRequest(
+ prompt=get_prompt(request.messages),
+ n_predict=n_predict,
+ req_type="chat"
+ )
+ if request.stream:
+ request_id, result = await generate_stream(prompt_request)
+ else:
+ request_id, result = await generate(prompt_request)
+ choice_data = ChatCompletionResponseChoice(
+ index=0,
+ message=ChatMessage(role="assistant", content=result),
+ logprobs=None,
+ finish_reason="length")
+ result = ChatCompletionResponse(
+ id=request_id,
+ choices=[choice_data],
+ model=model_name)
+ return result
+
+
+@app.post("/v1/completions")
+async def create_completion(request: CompletionRequest):
+ model_name = local_model.model_name
+ if request.max_tokens is None:
+ n_predict = 256
+ else:
+ n_predict = request.max_tokens
+ prompt_request = PromptRequest(
+ prompt=request.prompt,
+ n_predict=n_predict,
+ req_type="completion"
+ )
+ if request.stream:
+ request_id, result = await generate_stream(prompt_request)
+ else:
+ request_id, result = await generate(prompt_request)
+ choice_data = CompletionResponseChoice(
+ index=0,
+ text=result,
+ logprobs=None,
+ finish_reason="length")
+ result = CompletionResponse(
+ id=request_id,
+ choices=[choice_data],
+ model=model_name)
+ return result
+
+
+@app.on_event("startup")
+async def startup_event():
+ asyncio.create_task(process_requests(local_model, result_dict))
+
+
+async def process_requests(local_model, result_dict):
+ while True:
+ await asyncio.sleep(0)
+ await local_model.process_step(tokenizer, result_dict)
diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
new file mode 100644
index 00000000000..90f195a6cb4
--- /dev/null
+++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py
@@ -0,0 +1,82 @@
+#
+# Copyright 2016 The BigDL Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+from transformers.utils import logging
+import time
+import asyncio
+from transformers import TextIteratorStreamer
+logger = logging.get_logger(__name__)
+
+
+class ModelWorker:
+ def __init__(self, checkpoint, low_bit, torch_dtype=torch.float16):
+ self.dtype = torch_dtype
+ start = time.perf_counter()
+ model = self.load_model(checkpoint, low_bit)
+ from ipex_llm.utils.benchmark_util import BenchmarkWrapper
+ self.model = BenchmarkWrapper(model, do_print=True)
+ end = time.perf_counter()
+ logger.info(f"Time to load weights: {end - start:.2f}s")
+ self.waiting_requests = asyncio.Queue()
+ self.streamer = {}
+ self.model_name = checkpoint
+
+ def load_model(self, model_path, low_bit='sym_int4'):
+ from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
+ try:
+ model = AutoModelForCausalLM.from_pretrained(model_path,
+ load_in_low_bit=low_bit,
+ torch_dtype=self.dtype,
+ optimize_model=True,
+ trust_remote_code=True,
+ use_cache=True,)
+ except:
+ model = AutoModel.from_pretrained(model_path,
+ load_in_low_bit=low_bit,
+ torch_dtype=self.dtype,
+ optimize_model=True,
+ trust_remote_code=True,
+ use_cache=True,)
+ model = model.eval().to("xpu")
+ return model
+
+ async def add_request(self, tokenizer):
+ if self.waiting_requests.empty():
+ return
+ tmp_result = await self.waiting_requests.get()
+ request_id, prompt_request = tmp_result
+ plain_texts = prompt_request.prompt
+ inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
+ input_ids = inputs.input_ids.to('xpu')
+ max_tokens = prompt_request.n_predict
+ return input_ids, max_tokens, request_id
+
+ @torch.no_grad()
+ async def process_step(self, tokenizer, result_dict):
+ if not self.waiting_requests.empty():
+ input_ids, max_tokens, request_id = await self.add_request(tokenizer)
+ self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True)
+
+ def model_generate():
+ self.model.generate(input_ids,
+ streamer=self.streamer[request_id], max_new_tokens=max_tokens)
+ torch.xpu.empty_cache()
+ torch.xpu.synchronize()
+
+ from threading import Thread
+ t1 = Thread(target=model_generate)
+ t1.start()
diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/openai_protocol.py b/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py
similarity index 93%
rename from python/llm/example/GPU/Pipeline-Parallel-FastAPI/openai_protocol.py
rename to python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py
index a04fb03657a..1bc8f1e3a69 100644
--- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/openai_protocol.py
+++ b/python/llm/src/ipex_llm/serving/fastapi/openai_protocol.py
@@ -15,6 +15,7 @@
#
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
+
import time
from typing import Dict, List, Literal, Optional, Union
@@ -22,11 +23,14 @@
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated
+from ipex_llm.utils.common import invalidInputError
+
# from vllm.sampling_params import SamplingParams
def random_uuid() -> str:
return str(uuid.uuid4().hex)
+
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
@@ -127,10 +131,10 @@ class ChatCompletionRequest(OpenAIBaseModel):
)
add_generation_prompt: Optional[bool] = Field(
default=True,
- description=
- ("If true, the generation prompt will be added to the chat template. "
- "This is a parameter used by chat template in tokenizer config of the "
- "model."),
+ description=(
+ "If true, the generation prompt will be added to the chat template. "
+ "This is a parameter used by chat template in tokenizer config of the "
+ "model."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
@@ -179,9 +183,9 @@ def check_guided_decoding_count(cls, data):
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
- raise ValueError(
- "You can only use one kind of guided decoding "
- "('guided_json', 'guided_regex' or 'guided_choice').")
+ invalidInputError(False,
+ "You can only use one kind of guided decoding "
+ "('guided_json', 'guided_regex' or 'guided_choice').")
return data
@@ -232,10 +236,10 @@ class CompletionRequest(OpenAIBaseModel):
)
response_format: Optional[ResponseFormat] = Field(
default=None,
- description=
- ("Similar to chat completion, this parameter specifies the format of "
- "output. Only {'type': 'json_object'} or {'type': 'text' } is "
- "supported."),
+ description=(
+ "Similar to chat completion, this parameter specifies the format of "
+ "output. Only {'type': 'json_object'} or {'type': 'text' } is "
+ "supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
@@ -279,9 +283,9 @@ def check_guided_decoding_count(cls, data):
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
- raise ValueError(
- "You can only use one kind of guided decoding "
- "('guided_json', 'guided_regex' or 'guided_choice').")
+ invalidInputError(False,
+ "You can only use one kind of guided decoding "
+ "('guided_json', 'guided_regex' or 'guided_choice').")
return data
diff --git a/python/llm/src/ipex_llm/transformers/__init__.py b/python/llm/src/ipex_llm/transformers/__init__.py
index 07146fc04cb..6904e897fbe 100644
--- a/python/llm/src/ipex_llm/transformers/__init__.py
+++ b/python/llm/src/ipex_llm/transformers/__init__.py
@@ -22,4 +22,4 @@
AutoModelForNextSentencePrediction, AutoModelForMultipleChoice, \
AutoModelForTokenClassification
from .modelling_bigdl import *
-from .pipeline_parallel import init_pipeline_parallel, ModelRunner
+from .pipeline_parallel import init_pipeline_parallel, PPModelWorker
diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
index 14c4b8a1c74..0812a12aa87 100644
--- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
+++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py
@@ -468,7 +468,7 @@ def make_attention_mask(prompt_lengths):
return attention_mask
-class ModelRunner:
+class PPModelWorker:
"""Implementation for pipeline parallel multi-stage serving."""
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_prefilled_seqs,
torch_dtype=torch.float16):