From 1eed0635f2619a61f3a42159a71abfcb55d688cb Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Fri, 19 Jul 2024 13:15:56 +0800 Subject: [PATCH] Add lightweight serving and support tgi parameter (#11600) * init tgi request * update openai api * update for pp * update and add readme * add to docker * add start bash * update * update * update --- docker/llm/inference/xpu/docker/Dockerfile | 3 + docker/llm/serving/xpu/docker/Dockerfile | 1 + docker/llm/serving/xpu/docker/README.md | 8 + .../start-lightweight_serving-service.sh | 4 + .../example/GPU/Lightweight-Serving/README.md | 196 ++++++++++++++++++ .../lightweight_serving.py} | 0 .../GPU/Pipeline-Parallel-Serving/README.md | 23 +- .../wrk_script_1024.lua | 2 +- python/llm/example/GPU/README.md | 3 +- .../ipex_llm/serving/fastapi/api_server.py | 52 +++-- .../ipex_llm/serving/fastapi/model_worker.py | 11 +- .../ipex_llm/serving/fastapi/tgi_protocol.py | 63 ++++++ .../transformers/pipeline_parallel.py | 4 +- 13 files changed, 326 insertions(+), 44 deletions(-) create mode 100644 docker/llm/serving/xpu/docker/start-lightweight_serving-service.sh create mode 100644 python/llm/example/GPU/Lightweight-Serving/README.md rename python/llm/example/GPU/{Pipeline-Parallel-Serving/serving.py => Lightweight-Serving/lightweight_serving.py} (100%) create mode 100644 python/llm/src/ipex_llm/serving/fastapi/tgi_protocol.py diff --git a/docker/llm/inference/xpu/docker/Dockerfile b/docker/llm/inference/xpu/docker/Dockerfile index ce629ad2b69..9c717556da5 100644 --- a/docker/llm/inference/xpu/docker/Dockerfile +++ b/docker/llm/inference/xpu/docker/Dockerfile @@ -62,6 +62,9 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO # Download pp_serving mkdir -p /llm/pp_serving && \ cp ./ipex-llm/python/llm/example/GPU/Pipeline-Parallel-Serving/*.py /llm/pp_serving/ && \ + # Download lightweight_serving + mkdir -p /llm/lightweight_serving && \ + cp ./ipex-llm/python/llm/example/GPU/Lightweight-Serving/*.py /llm/lightweight_serving/ && \ # Install related library of benchmarking pip install pandas omegaconf && \ chmod +x /llm/benchmark.sh && \ diff --git a/docker/llm/serving/xpu/docker/Dockerfile b/docker/llm/serving/xpu/docker/Dockerfile index 6bc43042273..7a5d962e32c 100644 --- a/docker/llm/serving/xpu/docker/Dockerfile +++ b/docker/llm/serving/xpu/docker/Dockerfile @@ -33,6 +33,7 @@ COPY ./start-vllm-service.sh /llm/ COPY ./benchmark_vllm_throughput.py /llm/ COPY ./start-fastchat-service.sh /llm/ COPY ./start-pp_serving-service.sh /llm/ +COPY ./start-lightweight_serving-service.sh /llm/ WORKDIR /llm/ diff --git a/docker/llm/serving/xpu/docker/README.md b/docker/llm/serving/xpu/docker/README.md index 4a4b07ab953..5b3f00cda9f 100644 --- a/docker/llm/serving/xpu/docker/README.md +++ b/docker/llm/serving/xpu/docker/README.md @@ -44,6 +44,14 @@ After the container is booted, you could get into the container through `docker Currently, we provide two different serving engines in the image, which are FastChat serving engine and vLLM serving engine. + +#### Lightweight serving engine + +To run Lightweight serving on one intel gpu using `IPEX-LLM` as backend, you can refer to this [readme](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Lightweight-Serving). + +For convenience, we have included a file `/llm/start-lightweight_serving-service` in the image. + + #### Pipeline parallel serving engine To run Pipeline parallel serving using `IPEX-LLM` as backend, you can refer to this [readme](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-FastAPI). diff --git a/docker/llm/serving/xpu/docker/start-lightweight_serving-service.sh b/docker/llm/serving/xpu/docker/start-lightweight_serving-service.sh new file mode 100644 index 00000000000..b51e4fc3e13 --- /dev/null +++ b/docker/llm/serving/xpu/docker/start-lightweight_serving-service.sh @@ -0,0 +1,4 @@ +cd /llm/lightweight_serving +model_path="/llm/models/Llama-2-7b-chat-hf" +low_bit="sym_int4" +python lightweight_serving.py --repo-id-or-model-path $model_path --low-bit $low_bit \ No newline at end of file diff --git a/python/llm/example/GPU/Lightweight-Serving/README.md b/python/llm/example/GPU/Lightweight-Serving/README.md new file mode 100644 index 00000000000..791bf1666cb --- /dev/null +++ b/python/llm/example/GPU/Lightweight-Serving/README.md @@ -0,0 +1,196 @@ +# Running Lightweight Serving using IPEX-LLM on one Intel GPU + +## Requirements + +To run this example with IPEX-LLM on one Intel GPU, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. + +## Example + +### 1. Install + +#### 1.1 Installation on Linux +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.11 +conda activate llm +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +pip install fastapi uvicorn openai +pip install gradio # for gradio web UI +conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc +``` + +#### 1.2 Installation on Windows +We suggest using conda to manage environment: +```bash +conda create -n llm python=3.11 libuv +conda activate llm + +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +pip install fastapi uvicorn openai +pip install gradio # for gradio web UI +conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc +``` + +### 2. Configures OneAPI environment variables for Linux + +> [!NOTE] +> Skip this step if you are running on Windows. + +This is a required step on Linux for APT or offline installed oneAPI. Skip this step for PIP-installed oneAPI. + +```bash +source /opt/intel/oneapi/setvars.sh +``` + +### 3. Runtime Configurations +For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device. +#### 3.1 Configurations for Linux +
+ +For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series + +```bash +export USE_XETLA=OFF +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +export SYCL_CACHE_PERSISTENT=1 +``` + +
+ +
+ +For Intel Data Center GPU Max Series + +```bash +export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +export SYCL_CACHE_PERSISTENT=1 +export ENABLE_SDP_FUSION=1 +``` +> Note: Please note that `libtcmalloc.so` can be installed by `conda install -c conda-forge -y gperftools=2.10`. +
+ +
+ +For Intel iGPU + +```bash +export SYCL_CACHE_PERSISTENT=1 +export BIGDL_LLM_XMX_DISABLED=1 +``` + +
+ +#### 3.2 Configurations for Windows +
+ +For Intel iGPU + +```cmd +set SYCL_CACHE_PERSISTENT=1 +set BIGDL_LLM_XMX_DISABLED=1 +``` + +
+ +
+ +For Intel Arc™ A-Series Graphics + +```cmd +set SYCL_CACHE_PERSISTENT=1 +``` + +
+ +> [!NOTE] +> For the first time that each model runs on Intel iGPU/Intel Arc™ A300-Series or Pro A60, it may take several minutes to compile. + +### 4. Running example + +``` +python ./lightweight_serving.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --low-bit LOW_BIT --port PORT +``` + +Arguments info: +- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for the model (e.g. `meta-llama/Llama-2-7b-chat-hf` and `meta-llama/Llama-2-13b-chat-hf`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'meta-llama/Llama-2-7b-chat-hf'`. +- `--low-bit LOW_BIT`: Sets the low bit optimizations (such as 'sym_int4', 'fp16', 'fp8' and 'fp6') for the model. It is default to be `sym_int4`. +- `--port PORT`: The serving access port. It is default to be `8000`. + + +### 5. Sample Input and Output + +We can use `curl` to test serving api. And need to set no_proxy to ensure that requests are not forwarded by a proxy. `export no_proxy=localhost,127.0.0.1` + +#### /generate + +```bash +curl -X POST -H "Content-Type: application/json" -d '{ + "inputs": "What is AI?", + "parameters": { + "max_new_tokens": 32, + "min_new_tokens": 32, + "repetition_penalty": 1.0, + "temperature": 1.0, + "do_sample": false, + "top_k": 5, + "tok_p": 1.0 + }, + "stream": false +}' http://localhost:8000/generate +``` + +#### /generate_stream + +```bash +curl -X POST -H "Content-Type: application/json" -d '{ + "inputs": "What is AI?", + "parameters": { + "max_new_tokens": 32, + "min_new_tokens": 32, + "repetition_penalty": 1.0, + "temperature": 1.0, + "do_sample": false, + "top_k": 5, + "tok_p": 1.0 + }, + "stream": false +}' http://localhost:8000/generate_stream +``` + +#### /v1/chat/completions + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Llama-2-7b-chat-hf", + "messages": [{"role": "user", "content": "Hello! What is your name?"}] + }' +``` + +#### /v1/completions + +```bash +curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Llama-2-7b-chat-hf", + "prompt": "Once upon a time", + "max_tokens": 32 + }' +``` + +### 6. Benchmark with wrk + +Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#4-benchmark-with-wrk) for more details + +## 7. Using the `benchmark.py` Script + +Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#5-using-the-benchmarkpy-script) for more details + +## 8. Gradio Web UI + +Please refer to [here](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Pipeline-Parallel-Serving#6-gradio-web-ui) for more details \ No newline at end of file diff --git a/python/llm/example/GPU/Pipeline-Parallel-Serving/serving.py b/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py similarity index 100% rename from python/llm/example/GPU/Pipeline-Parallel-Serving/serving.py rename to python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py diff --git a/python/llm/example/GPU/Pipeline-Parallel-Serving/README.md b/python/llm/example/GPU/Pipeline-Parallel-Serving/README.md index 73607b2e042..32745ceaf92 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Serving/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-Serving/README.md @@ -52,6 +52,7 @@ pip install trl==0.8.1 ### 2-1. Run ipex-llm serving on one GPU card +Refer to [Lightweight-Serving](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Lightweight-Serving), get the [lightweight_serving.py](https://github.com/intel-analytics/ipex-llm/tree/main/python/llm/example/GPU/Lightweight-Serving/lightweight_serving.py) and then to run. ```bash # Need to set NUM_GPUS=1 and MODEL_PATH in run.sh first bash run.sh @@ -78,18 +79,16 @@ We can use `curl` to test serving api #### generate() ```bash -# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy. -export http_proxy= -export https_proxy= - -curl -X 'POST' \ - 'http://127.0.0.1:8000/generate' \ - -H 'accept: application/json' \ - -H 'Content-Type: application/json' \ - -d '{ - "prompt": "What is AI?", - "n_predict": 32 -}' +# Set no_proxy to ensure that requests are not forwarded by a proxy. +export no_proxy=localhost,127.0.0.1 + +curl -X POST -H "Content-Type: application/json" -d '{ + "inputs": "What is AI?", + "parameters": { + "max_new_tokens": 32 + }, + "stream": false +}' http://localhost:8000/generate ``` diff --git a/python/llm/example/GPU/Pipeline-Parallel-Serving/wrk_script_1024.lua b/python/llm/example/GPU/Pipeline-Parallel-Serving/wrk_script_1024.lua index 8118dec088e..f01415a04e2 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-Serving/wrk_script_1024.lua +++ b/python/llm/example/GPU/Pipeline-Parallel-Serving/wrk_script_1024.lua @@ -1,7 +1,7 @@ wrk.method = "POST" wrk.headers["accept"] = "application/json" wrk.headers["Content-Type"] = "application/json" -wrk.body = '{"prompt": "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. However, her parents were always telling her to stay close to home, to be careful, and to avoid any danger. But the little girl was stubborn, and she wanted to see what was on the other side of the mountain. So she sneaked out of the house one night, leaving a note for her parents, and set off on her journey. As she climbed the mountain, the little girl felt a sense of excitement and wonder. She had never been this far away from home before, and she couldnt wait to see what she would find on the other side. She climbed higher and higher, her lungs burning from the thin air, until she finally reached the top of the mountain. And there, she found a beautiful meadow filled with wildflowers and a sparkling stream. The little girl danced and played in the meadow, feeling free and alive. She knew she had to return home eventually, but for now, she was content to enjoy her adventure. As the sun began to set, the little girl reluctantly made her way back down the mountain, but she knew that she would never forget her adventure and the joy of discovering something new and exciting. And whenever she felt scared or unsure, she would remember the thrill of climbing the mountain and the beauty of the meadow on the other side, and she would know that she could face any challenge that came her way, with courage and determination. She carried the memories of her journey in her heart, a constant reminder of the strength she possessed. The little girl returned home to her worried parents, who had discovered her note and anxiously awaited her arrival. They scolded her for disobeying their instructions and venturing into the unknown. But as they looked into her sparkling eyes and saw the glow on her face, their anger softened. They realized that their little girl had grown, that she had experienced something extraordinary. The little girl shared her tales of the mountain and the meadow with her parents, painting vivid pictures with her words. She spoke of the breathtaking view from the mountaintop, where the world seemed to stretch endlessly before her. She described the delicate petals of the wildflowers, vibrant hues that danced in the gentle breeze. And she recounted the soothing melody of the sparkling stream, its waters reflecting the golden rays of the setting sun. Her parents listened intently, captivated by her story. They realized that their daughter had discovered a part of herself on that journey—a spirit of curiosity and a thirst for exploration. They saw that she had learned valuable lessons about independence, resilience, and the beauty that lies beyond ones comfort zone. From that day forward, the little girls parents encouraged her to pursue her dreams and embrace new experiences. They understood that while there were risks in the world, there were also rewards waiting to be discovered. They supported her as she continued to embark on adventures, always reminding her to stay safe but never stifling her spirit. As the years passed, the little girl grew into a remarkable woman, fearlessly exploring the world and making a difference wherever she went. The lessons she had learned on that fateful journey stayed with her, guiding her through challenges and inspiring her to live life to the fullest. And so, the once timid little girl became a symbol of courage and resilience, a reminder to all who knew her that the greatest joys in life often lie just beyond the mountains we fear to climb. Her story spread far and wide, inspiring others to embrace their own journeys and discover the wonders that awaited them. In the end, the little girls adventure became a timeless tale, passed down through generations, reminding us all that sometimes, the greatest rewards come to those who dare to step into the unknown and follow their hearts. With each passing day, the little girls story continued to inspire countless individuals, igniting a spark within their souls and encouraging them to embark on their own extraordinary adventures. The tale of her bravery and determination resonated deeply with people from all walks of life, reminding them of the limitless possibilities that awaited them beyond the boundaries of their comfort zones. People marveled at the little girls unwavering spirit and her unwavering belief in the power of dreams. They saw themselves reflected in her journey, finding solace in the knowledge that they too could overcome their fears and pursue their passions. The little girl\'s story became a beacon of hope, a testament to the human spirit", "n_predict":128}' +wrk.body = '{"inputs": "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. However, her parents were always telling her to stay close to home, to be careful, and to avoid any danger. But the little girl was stubborn, and she wanted to see what was on the other side of the mountain. So she sneaked out of the house one night, leaving a note for her parents, and set off on her journey. As she climbed the mountain, the little girl felt a sense of excitement and wonder. She had never been this far away from home before, and she couldnt wait to see what she would find on the other side. She climbed higher and higher, her lungs burning from the thin air, until she finally reached the top of the mountain. And there, she found a beautiful meadow filled with wildflowers and a sparkling stream. The little girl danced and played in the meadow, feeling free and alive. She knew she had to return home eventually, but for now, she was content to enjoy her adventure. As the sun began to set, the little girl reluctantly made her way back down the mountain, but she knew that she would never forget her adventure and the joy of discovering something new and exciting. And whenever she felt scared or unsure, she would remember the thrill of climbing the mountain and the beauty of the meadow on the other side, and she would know that she could face any challenge that came her way, with courage and determination. She carried the memories of her journey in her heart, a constant reminder of the strength she possessed. The little girl returned home to her worried parents, who had discovered her note and anxiously awaited her arrival. They scolded her for disobeying their instructions and venturing into the unknown. But as they looked into her sparkling eyes and saw the glow on her face, their anger softened. They realized that their little girl had grown, that she had experienced something extraordinary. The little girl shared her tales of the mountain and the meadow with her parents, painting vivid pictures with her words. She spoke of the breathtaking view from the mountaintop, where the world seemed to stretch endlessly before her. She described the delicate petals of the wildflowers, vibrant hues that danced in the gentle breeze. And she recounted the soothing melody of the sparkling stream, its waters reflecting the golden rays of the setting sun. Her parents listened intently, captivated by her story. They realized that their daughter had discovered a part of herself on that journey—a spirit of curiosity and a thirst for exploration. They saw that she had learned valuable lessons about independence, resilience, and the beauty that lies beyond ones comfort zone. From that day forward, the little girls parents encouraged her to pursue her dreams and embrace new experiences. They understood that while there were risks in the world, there were also rewards waiting to be discovered. They supported her as she continued to embark on adventures, always reminding her to stay safe but never stifling her spirit. As the years passed, the little girl grew into a remarkable woman, fearlessly exploring the world and making a difference wherever she went. The lessons she had learned on that fateful journey stayed with her, guiding her through challenges and inspiring her to live life to the fullest. And so, the once timid little girl became a symbol of courage and resilience, a reminder to all who knew her that the greatest joys in life often lie just beyond the mountains we fear to climb. Her story spread far and wide, inspiring others to embrace their own journeys and discover the wonders that awaited them. In the end, the little girls adventure became a timeless tale, passed down through generations, reminding us all that sometimes, the greatest rewards come to those who dare to step into the unknown and follow their hearts. With each passing day, the little girls story continued to inspire countless individuals, igniting a spark within their souls and encouraging them to embark on their own extraordinary adventures. The tale of her bravery and determination resonated deeply with people from all walks of life, reminding them of the limitless possibilities that awaited them beyond the boundaries of their comfort zones. People marveled at the little girls unwavering spirit and her unwavering belief in the power of dreams. They saw themselves reflected in her journey, finding solace in the knowledge that they too could overcome their fears and pursue their passions. The little girl\'s story became a beacon of hope, a testament to the human spirit", "parameters": {"max_new_tokens": 128, "min_new_tokens": 128}}' logfile = io.open("wrk.log", "w"); diff --git a/python/llm/example/GPU/README.md b/python/llm/example/GPU/README.md index dc7600c6bf0..41a86c7678b 100644 --- a/python/llm/example/GPU/README.md +++ b/python/llm/example/GPU/README.md @@ -9,7 +9,8 @@ This folder contains examples of running IPEX-LLM on Intel GPU: - [Deepspeed-AutoTP](Deepspeed-AutoTP): running distributed inference using ***DeepSpeed AutoTP*** (with IPEX-LLM low-bit optimized models) on Intel GPUs - [Deepspeed-AutoTP-FastAPI](Deepspeed-AutoTP-FastAPI): running distributed inference using ***DeepSpeed AutoTP*** and start serving with ***FastAPI***(with IPEX-LLM low-bit optimized models) on Intel GPUs - [Pipeline-Parallel-Inference](Pipeline-Parallel-Inference): running IPEX-LLM optimized low-bit model vertically partitioned on multiple Intel GPUs -- [Pipeline-Parallel-FastAPI](Pipeline-Parallel-FastAPI): running IPEX-LLM serving with **FastAPI** on multiple Intel GPUs in pipeline parallel fasion +- [Pipeline-Parallel-Serving](Pipeline-Parallel-Serving): running IPEX-LLM serving with **FastAPI** on multiple Intel GPUs in pipeline parallel fasion +- [Lightweight-Serving](Lightweight-Serving): running IPEX-LLM serving with **FastAPI** on one Intel GPU In a lightweight way - [LangChain](LangChain): running ***LangChain*** applications on IPEX-LLM - [PyTorch-Models](PyTorch-Models): running any PyTorch model on IPEX-LLM (with "one-line code change") - [Speculative-Decoding](Speculative-Decoding): running any ***Hugging Face Transformers*** model with ***self-speculative decoding*** on Intel GPUs diff --git a/python/llm/src/ipex_llm/serving/fastapi/api_server.py b/python/llm/src/ipex_llm/serving/fastapi/api_server.py index b02cbeaf43b..75bd49d0611 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/api_server.py +++ b/python/llm/src/ipex_llm/serving/fastapi/api_server.py @@ -25,15 +25,17 @@ import asyncio import uuid from typing import List, Optional, Union, Dict +from .tgi_protocol import Parameters result_dict: Dict[str, str] = {} logger = logging.get_logger(__name__) -class PromptRequest(BaseModel): - prompt: str - n_predict: Optional[int] = 256 +class InputsRequest(BaseModel): + inputs: str + parameters: Optional[Parameters] = None + stream: Optional[bool] = False req_type: str = 'completion' @@ -194,9 +196,12 @@ async def generator(local_model, delta_text_queue, request_id): @app.post("/generate") -async def generate(prompt_request: PromptRequest): +async def generate(inputs_request: InputsRequest): + if inputs_request.stream: + result = await generate_stream_api(inputs_request) + return result request_id = str(uuid.uuid4()) - await local_model.waiting_requests.put((request_id, prompt_request)) + await local_model.waiting_requests.put((request_id, inputs_request)) while True: await asyncio.sleep(0) cur_streamer = local_model.streamer.get(request_id, None) @@ -208,25 +213,24 @@ async def generate(prompt_request: PromptRequest): @app.post("/generate_stream") -async def generate_stream_api(prompt_request: PromptRequest): - request_id, result = await generate_stream(prompt_request) +async def generate_stream_api(inputs_request: InputsRequest): + request_id, result = await generate_stream(inputs_request) return result -async def generate_stream(prompt_request: PromptRequest): +async def generate_stream(inputs_request: InputsRequest): request_id = str(uuid.uuid4()) + "stream" - await local_model.waiting_requests.put((request_id, prompt_request)) + await local_model.waiting_requests.put((request_id, inputs_request)) while True: await asyncio.sleep(0) cur_streamer = local_model.streamer.get(request_id, None) if cur_streamer is not None: - if prompt_request.req_type == 'completion': + if inputs_request.req_type == 'completion': cur_generator = completion_stream_generator(local_model, cur_streamer, request_id) - elif prompt_request.req_type == 'chat': + elif inputs_request.req_type == 'chat': cur_generator = chat_stream_generator(local_model, cur_streamer, request_id) else: invalidInputError(False, "Invalid Request Type.") - return request_id, StreamingResponse( content=cur_generator, media_type="text/event-stream" ) @@ -255,15 +259,16 @@ async def create_chat_completion(request: ChatCompletionRequest): n_predict = 256 else: n_predict = request.max_tokens - prompt_request = PromptRequest( - prompt=get_prompt(request.messages), - n_predict=n_predict, + inputs_request = InputsRequest( + inputs=get_prompt(request.messages), + parameters=Parameters(max_new_tokens=n_predict), + stream=request.stream, req_type="chat" ) if request.stream: - request_id, result = await generate_stream(prompt_request) + request_id, result = await generate_stream(inputs_request) else: - request_id, result = await generate(prompt_request) + request_id, result = await generate(inputs_request) choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role="assistant", content=result), @@ -280,18 +285,19 @@ async def create_chat_completion(request: ChatCompletionRequest): async def create_completion(request: CompletionRequest): model_name = local_model.model_name if request.max_tokens is None: - n_predict = 256 + n_predict = 32 else: n_predict = request.max_tokens - prompt_request = PromptRequest( - prompt=request.prompt, - n_predict=n_predict, + inputs_request = InputsRequest( + inputs=request.prompt, + parameters=Parameters(max_new_tokens=n_predict), + stream=request.stream, req_type="completion" ) if request.stream: - request_id, result = await generate_stream(prompt_request) + request_id, result = await generate_stream(inputs_request) else: - request_id, result = await generate(prompt_request) + request_id, result = await generate(inputs_request) choice_data = CompletionResponseChoice( index=0, text=result, diff --git a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py index 90f195a6cb4..c88e1434ab7 100644 --- a/python/llm/src/ipex_llm/serving/fastapi/model_worker.py +++ b/python/llm/src/ipex_llm/serving/fastapi/model_worker.py @@ -59,21 +59,22 @@ async def add_request(self, tokenizer): return tmp_result = await self.waiting_requests.get() request_id, prompt_request = tmp_result - plain_texts = prompt_request.prompt + plain_texts = prompt_request.inputs inputs = tokenizer(plain_texts, return_tensors="pt", padding=True) input_ids = inputs.input_ids.to('xpu') - max_tokens = prompt_request.n_predict - return input_ids, max_tokens, request_id + parameters = prompt_request.parameters + return input_ids, parameters, request_id @torch.no_grad() async def process_step(self, tokenizer, result_dict): if not self.waiting_requests.empty(): - input_ids, max_tokens, request_id = await self.add_request(tokenizer) + input_ids, parameters, request_id = await self.add_request(tokenizer) self.streamer[request_id] = TextIteratorStreamer(tokenizer, skip_prompt=True) def model_generate(): + generate_kwargs = {k: v for k, v in parameters.dict().items() if v is not None} self.model.generate(input_ids, - streamer=self.streamer[request_id], max_new_tokens=max_tokens) + streamer=self.streamer[request_id], **generate_kwargs) torch.xpu.empty_cache() torch.xpu.synchronize() diff --git a/python/llm/src/ipex_llm/serving/fastapi/tgi_protocol.py b/python/llm/src/ipex_llm/serving/fastapi/tgi_protocol.py new file mode 100644 index 00000000000..63f1ff91f55 --- /dev/null +++ b/python/llm/src/ipex_llm/serving/fastapi/tgi_protocol.py @@ -0,0 +1,63 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Adapted from +# https://github.com/huggingface/text-generation-inference/blob/main/clients/python/text_generation/types.py + + +from pydantic import BaseModel, field_validator +from typing import List, Optional +from ipex_llm.utils.common import invalidInputError + + +class Parameters(BaseModel): + max_new_tokens: int = 32 + do_sample: Optional[bool] = None + min_new_tokens: Optional[int] = None + repetition_penalty: Optional[float] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + typical_p: Optional[float] = None + + @field_validator("repetition_penalty") + def valid_repetition_penalty(cls, v): + if v is not None and v <= 0: + invalidInputError(False, "`repetition_penalty` must be strictly positive") + return v + + @field_validator("temperature") + def valid_temp(cls, v): + if v is not None and v <= 0: + invalidInputError(False, "`temperature` must be strictly positive") + return v + + @field_validator("top_k") + def valid_top_k(cls, v): + if v is not None and v <= 0: + invalidInputError(False, "`top_k` must be strictly positive") + return v + + @field_validator("top_p") + def valid_top_p(cls, v): + if v is not None and (v <= 0 or v >= 1.0): + invalidInputError(False, "`top_p` must be > 0.0 and < 1.0") + return v + + @field_validator("typical_p") + def valid_typical_p(cls, v): + if v is not None and (v <= 0 or v >= 1.0): + invalidInputError(False, "`typical_p` must be > 0.0 and < 1.0") + return v diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index e8dff53d54f..bced1ec9925 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -699,14 +699,14 @@ async def add_request(self, tokenizer): request_ids.append(request_id) prompt_requests.append(prompt_request) - plain_texts = [req.prompt for req in prompt_requests] + plain_texts = [req.inputs for req in prompt_requests] inputs = tokenizer(plain_texts, return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(f'xpu:{self.rank}') attention_mask = inputs.attention_mask.to(f'xpu:{self.rank}') new_batch = BatchTask( batch_id="batch_" + str(uuid.uuid4()), request_ids=request_ids, - max_tokens=max([req.n_predict for req in prompt_requests]), + max_tokens=max([req.parameters.max_new_tokens for req in prompt_requests]), batch_size=input_ids.size(0), input_len=input_ids.size(1), prompt_lengths=[sum(attention_mask[i, :]) for i in range(input_ids.size(0))],