Skip to content

Commit

Permalink
Add VILA API server which is compatible with OpenAI SDK (NVlabs#133)
Browse files Browse the repository at this point in the history
Co-authored-by: Ligeng Zhu <[email protected]>
  • Loading branch information
Edmund Wang and Lyken17 authored Jul 29, 2024
1 parent 3710e28 commit 6a781ea
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,5 @@ ckpts*

playground
*/visualization/*
.env
server.log
18 changes: 18 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM nvcr.io/nvidia/pytorch:24.06-py3

WORKDIR /app

RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \
&& sh ~/miniconda.sh -b -p /opt/conda \
&& rm ~/miniconda.sh

ENV PATH /opt/conda/bin:$PATH
COPY pyproject.toml pyproject.toml
COPY llava llava

COPY environment_setup.sh environment_setup.sh
RUN bash environment_setup.sh vila


COPY server.py server.py
CMD ["conda", "run", "-n", "vila", "--no-capture-output", "python", "-u", "-W", "ignore", "server.py"]
73 changes: 67 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ VILA is a visual language model (VLM) pretrained with interleaved image-text dat

| $~~~~~~$ | Precision | A100 | 4090 | Orin |
| ---------------------- | --------- | ----- | ----- | ---- |
| VILA1.5-3B | fp16 | 104.6 | 137.6 | 25.4 |
| VILA1.5-3B-AWQ | int4 | 182.8 | 215.5 | 42.5 |
| VILA1.5-3B-S2 | fp16 | 104.3 | 137.2 | 24.6 |
| VILA1.5-3B-S2-AWQ | int4 | 180.2 | 219.3 | 40.1 |
| VILA1.5-3B | fp16 | 104.6 | 137.6 | 25.4 |
| VILA1.5-3B-AWQ | int4 | 182.8 | 215.5 | 42.5 |
| VILA1.5-3B-S2 | fp16 | 104.3 | 137.2 | 24.6 |
| VILA1.5-3B-S2-AWQ | int4 | 180.2 | 219.3 | 40.1 |
| Llama-3-VILA1.5-8B | fp16 | 74.9 | 57.4 | 10.2 |
| Llama-3-VILA1.5-8B-AWQ | int4 | 168.9 | 150.2 | 28.7 |
| VILA1.5-13B | fp16 | 50.9 | OOM | 6.1 |
| VILA1.5-13B-AWQ | int4 | 115.9 | 105.7 | 20.6 |
| VILA1.5-40B | fp16 | OOM | OOM | -- |
| VILA1.5-40B-AWQ | int4 | 57.0 | OOM | -- |
| VILA1.5-40B | fp16 | OOM | OOM | -- |
| VILA1.5-40B-AWQ | int4 | 57.0 | OOM | -- |

<sup>NOTE: Measured using the [TinyChat](https://github.com/mit-han-lab/llm-awq/tinychat) backend at batch size = 1.</sup>

Expand Down Expand Up @@ -232,6 +232,67 @@ We support AWQ-quantized 4bit VILA on GPU platforms via [TinyChat](https://githu

We further support our AWQ-quantized 4bit VILA models on various CPU platforms with both x86 and ARM architectures with our [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine). We also provide a detailed [tutorial](https://github.com/mit-han-lab/TinyChatEngine/tree/main?tab=readme-ov-file#deploy-vision-language-model-vlm-chatbot-with-tinychatengine) to help the users deploy VILA on different CPUs.

### Running VILA API server

A simple API server has been provided to serve VILA models. The server is built on top of [FastAPI](https://fastapi.tiangolo.com/) and [Huggingface Transformers](https://huggingface.co/transformers/). The server can be run with the following command:

#### With CLI

```bash
python -W ignore server.py \
--port 8000 \
--model-path Efficient-Large-Model/VILA1.5-3B \
--conv-mode vicuna_v1
```

#### With Docker

```bash
docker build -t vila-server:latest .
docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
-v ./hub:/root/.cache/huggingface/hub \
-it --rm -p 8000:8000 \
-e VILA_MODEL_PATH=Efficient-Large-Model/VILA1.5-3B \
-e VILA_CONV_MODE=vicuna_v1 \
vila-server:latest
```

Then you can call the endpoint with the OpenAI SDK as follows:

```python
from openai import OpenAI

client = OpenAI(
base_url="http://localhost:8000",
api_key="fake-key",
)
response = client.chat.completions.create(
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "What’s in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://blog.logomyway.com/wp-content/uploads/2022/01/NVIDIA-logo.jpg",
# Or you can pass in a base64 encoded image
# "url": "data:image/png;base64,<base64_encoded_image>",
},
},
],
}
],
max_tokens=300,
model="VILA1.5-3B",
# You can pass in extra parameters as follows
extra_body={"num_beams": 1, "use_cache": False},
)
print(response.choices[0].message.content)
```

<sup>NOTE: This API server is intended for evaluation purposes only and has not been optimized for production use. It has only been tested on A100 and H100 GPUs.</sup>

## Checkpoints

We release [VILA1.5-3B](https://hf.co/Efficient-Large-Model/VILA1.5-3b), [VILA1.5-3B-S2](https://hf.co/Efficient-Large-Model/VILA1.5-3b-s2), [Llama-3-VILA1.5-8B](https://hf.co/Efficient-Large-Model/Llama-3-VILA1.5-8b), [VILA1.5-13B](https://hf.co/Efficient-Large-Model/VILA1.5-13b), [VILA1.5-40B](https://hf.co/Efficient-Large-Model/VILA1.5-40b) and the 4-bit [AWQ](https://arxiv.org/abs/2306.00978)-quantized models [VILA1.5-3B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-3b-AWQ), [VILA1.5-3B-S2-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-3b-s2-AWQ), [Llama-3-VILA1.5-8B-AWQ](https://hf.co/Efficient-Large-Model/Llama-3-VILA1.5-8b-AWQ), [VILA1.5-13B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-13b-AWQ), [VILA1.5-40B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-40b-AWQ).
Expand Down
261 changes: 261 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import argparse
import base64
import os
import re
import time
import uuid
from contextlib import asynccontextmanager
from io import BytesIO
from typing import List, Literal, Optional, Union, get_args

import requests
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from PIL import Image as PILImage
from PIL.Image import Image
from pydantic import BaseModel

from llava.constants import (
DEFAULT_IM_END_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_TOKEN,
IMAGE_PLACEHOLDER,
IMAGE_TOKEN_INDEX,
)
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (
KeywordsStoppingCriteria,
get_model_name_from_path,
process_images,
tokenizer_image_token,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init


class TextContent(BaseModel):
type: Literal["text"]
text: str


class ImageURL(BaseModel):
url: str


class ImageContent(BaseModel):
type: Literal["image_url"]
image_url: ImageURL


IMAGE_CONTENT_BASE64_REGEX = re.compile(r"^data:image/(png|jpe?g);base64,(.*)$")


class ChatMessage(BaseModel):
role: Literal["user", "assistant"]
content: Union[str, List[Union[TextContent, ImageContent]]]


class ChatCompletionRequest(BaseModel):
model: Literal[
"VILA1.5-3B",
"VILA1.5-3B-AWQ",
"VILA1.5-3B-S2",
"VILA1.5-3B-S2-AWQ",
"Llama-3-VILA1.5-8B",
"Llama-3-VILA1.5-8B-AWQ",
"VILA1.5-13B",
"VILA1.5-13B-AWQ",
"VILA1.5-40B",
"VILA1.5-40B-AWQ",
]
messages: List[ChatMessage]
max_tokens: Optional[int] = 512
top_p: Optional[float] = 0.9
temperature: Optional[float] = 0.2
stream: Optional[bool] = False
use_cache: Optional[bool] = True
num_beams: Optional[int] = 1

model = None
model_name = None
tokenizer = None
image_processor = None
context_len = None


def load_image(image_url: str) -> Image:
if image_url.startswith("http") or image_url.startswith("https"):
response = requests.get(image_url)
image = PILImage.open(BytesIO(response.content)).convert("RGB")
else:
match_results = IMAGE_CONTENT_BASE64_REGEX.match(image_url)
if match_results is None:
raise ValueError(f"Invalid image url: {image_url}")
image_base64 = match_results.groups()[1]
image = PILImage.open(BytesIO(base64.b64decode(image_base64))).convert("RGB")
return image


def get_literal_values(cls, field_name: str):
field_type = cls.__annotations__.get(field_name)
if field_type is None:
raise ValueError(f"{field_name} is not a valid field name")
if hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
return get_args(field_type)
raise ValueError(f"{field_name} is not a Literal type")


VILA_MODELS = get_literal_values(ChatCompletionRequest, "model")


def normalize_image_tags(qs: str) -> str:
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
if model.config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)

if DEFAULT_IMAGE_TOKEN not in qs:
raise ValueError("No image was found in input messages.")
return qs


@asynccontextmanager
async def lifespan(app: FastAPI):
global model, model_name, tokenizer, image_processor, context_len
disable_torch_init()
model_path = app.args.model_path
model_name = get_model_name_from_path(model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, model_name, None
)
print(f"Model {model_name} loaded successfully. Context length: {context_len}")
yield


app = FastAPI(lifespan=lifespan)


# Load model upon startup
@app.post("/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
try:
global model, tokenizer, image_processor, context_len

if request.model != model_name:
raise ValueError(
f"The endpoint is configured to use the model {model_name}, "
f"but the request model is {request.model}"
)
max_tokens = request.max_tokens
temperature = request.temperature
top_p = request.top_p
use_cache = request.use_cache
num_beams = request.num_beams

messages = request.messages
conv_mode = app.args.conv_mode

images = []

conv = conv_templates[conv_mode].copy()
user_role = conv.roles[0]
assistant_role = conv.roles[1]

for message in messages:
if message.role == "user":
prompt = ""

if isinstance(message.content, str):
prompt += message.content
if isinstance(message.content, list):
for content in message.content:
if content.type == "text":
prompt += content.text
if content.type == "image_url":
image = load_image(content.image_url.url)
images.append(image)
prompt += IMAGE_PLACEHOLDER
normalized_prompt = normalize_image_tags(prompt)
conv.append_message(user_role, normalized_prompt)
if message.role == "assistant":
prompt = message.content
conv.append_message(assistant_role, prompt)

prompt_text = conv.get_prompt()
print("Prompt input: ", prompt_text)

images_tensor = process_images(images, image_processor, model.config).to(
model.device, dtype=torch.float16
)
input_ids = (
tokenizer_image_token(
prompt_text, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
)
.unsqueeze(0)
.to(model.device)
)

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=[
images_tensor,
],
do_sample=True if temperature > 0 else False,
temperature=temperature,
top_p=top_p,
num_beams=num_beams,
max_new_tokens=max_tokens,
use_cache=use_cache,
stopping_criteria=[stopping_criteria],
)

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
outputs = outputs[: -len(stop_str)]
outputs = outputs.strip()
print("\nAssistant: ", outputs)

resp_content = [TextContent(type="text", text=outputs)]
return {
"id": uuid.uuid4().hex,
"object": "chat.completion",
"created": time.time(),
"model": request.model,
"choices": [
{"message": ChatMessage(role="assistant", content=resp_content)}
],
}
except Exception as e:
return JSONResponse(
status_code=500,
content={"error": str(e)},
)


if __name__ == "__main__":

host = os.getenv("VILA_HOST", "0.0.0.0")
port = os.getenv("VILA_PORT", 8000)
model_path = os.getenv("VILA_MODEL_PATH", "Efficient-Large-Model/VILA1.5-3B")
conv_mode = os.getenv("VILA_CONV_MODE", "vicuna_v1")
workers = os.getenv("VILA_WORKERS", 1)

parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=host)
parser.add_argument("--port", type=int, default=port)
parser.add_argument("--model-path", type=str, default=model_path)
parser.add_argument("--conv-mode", type=str, default=conv_mode)
parser.add_argument("--workers", type=int, default=workers)
app.args = parser.parse_args()

uvicorn.run(app, host=host, port=port, workers=workers)
Loading

0 comments on commit 6a781ea

Please sign in to comment.